Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/ray/test_mock_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ def _run_mock_test(self, mock_controller_cls, error_name: str):
status = ray.get(self.test_dataflow.get_replaybuffer_status.remote())
print(f"[{error_name}] Completed rollouts: {completed_rollouts}, Status: {status}")
self.assertEqual(len(completed_rollouts[0]), 0, f"[{error_name}] Expected no rollouts to complete successfully.")
self.assertEqual(status["rollout_finished_count"], 0, f"[{error_name}] Completed count in buffer should be 0.")
self.assertEqual(status["rollout_paused_count"], 0, f"[{error_name}] Expected no rollouts to be interrupted.")
self.assertEqual(status["remain_completed_samples_count"], 0, f"[{error_name}] Completed count in buffer should be 0.")
self.assertEqual(status["remain_aborted_samples_count"], 0, f"[{error_name}] Expected no rollouts to be interrupted.")
ray.get(self.test_env.shutdown.remote())

@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
Expand Down
24 changes: 12 additions & 12 deletions tests/ray/test_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,9 @@ def test_lmdeploy_async_dataflow(self):
finished_samples_count = sum(1 for data in responses[0] for item in data if item.env.rollout.finish_reason == "stop" or item.env.rollout.finish_reason == "length")
self.assertEqual(finished_samples_count // self.dataflow_cfg.prompt_repeat_k, self.dataflow_cfg.global_batch_size)
status = ray.get(self.test_flow.get_replaybuffer_status.remote())
finished_count = status["rollout_finished_count"] # 已经去掉了data_flow返回的数量
paused_count = status["rollout_paused_count"]
sample_count = status["action_count"]
finished_count = status["remain_completed_samples_count"] # 已经去掉了data_flow返回的数量
paused_count = status["remain_aborted_samples_count"]
sample_count = status["sample_from_dataset_count"] + status["sample_from_aborted_count"] + status["sample_from_expired_count"]
self.assertEqual(len(responses) + finished_count + paused_count, sample_count)
self.assertEqual(len(responses), self.dataflow_cfg.global_batch_size)

Expand All @@ -167,9 +167,9 @@ def test_lmdeploy_async_dataflow(self):
finished_resume_samples_count = sum(1 for data in response_resume[0] for item in data if item.env.rollout.finish_reason == "stop" or item.env.rollout.finish_reason == "length")
self.assertEqual(finished_resume_samples_count // self.dataflow_cfg.prompt_repeat_k, self.dataflow_cfg.global_batch_size)
status = ray.get(self.test_flow.get_replaybuffer_status.remote())
finished_count = status["rollout_finished_count"]
paused_count = status["rollout_paused_count"]
sample_count = status["action_count"]
finished_count = status["remain_completed_samples_count"]
paused_count = status["remain_aborted_samples_count"]
sample_count = status["sample_from_dataset_count"] + status["sample_from_aborted_count"] + status["sample_from_expired_count"]
self.assertEqual(len(response_resume) + finished_count + paused_count, sample_count)
self.assertEqual(len(response_resume), self.dataflow_cfg.global_batch_size)
ray.get(self.test_env.shutdown.remote())
Expand Down Expand Up @@ -236,9 +236,9 @@ def test_sglang_async_dataflow(self):
finished_samples_count = sum(1 for data in responses[0] for item in data if item.env.rollout.finish_reason == "stop" or item.env.rollout.finish_reason == "length")
self.assertEqual(finished_samples_count // self.dataflow_cfg.prompt_repeat_k, self.dataflow_cfg.global_batch_size)
status = ray.get(self.test_flow.get_replaybuffer_status.remote())
finished_count = status["rollout_finished_count"] # 已经去掉了data_flow返回的数量
paused_count = status["rollout_paused_count"]
sample_count = status["action_count"]
finished_count = status["remain_completed_samples_count"] # 已经去掉了data_flow返回的数量
paused_count = status["remain_aborted_samples_count"]
sample_count = status["sample_from_dataset_count"] + status["sample_from_aborted_count"] + status["sample_from_expired_count"]
self.assertEqual(len(responses) + finished_count + paused_count, sample_count)
self.assertEqual(len(responses), self.dataflow_cfg.global_batch_size)

Expand All @@ -247,9 +247,9 @@ def test_sglang_async_dataflow(self):
finished_resume_samples_count = sum(1 for data in response_resume[0] for item in data if item.env.rollout.finish_reason == "stop" or item.env.rollout.finish_reason == "length")
self.assertEqual(finished_resume_samples_count // self.dataflow_cfg.prompt_repeat_k, self.dataflow_cfg.global_batch_size)
status = ray.get(self.test_flow.get_replaybuffer_status.remote())
finished_count = status["rollout_finished_count"]
paused_count = status["rollout_paused_count"]
sample_count = status["action_count"]
finished_count = status["remain_completed_samples_count"]
paused_count = status["remain_aborted_samples_count"]
sample_count = status["sample_from_dataset_count"] + status["sample_from_aborted_count"] + status["sample_from_expired_count"]
self.assertEqual(len(response_resume) + finished_count + paused_count, sample_count)
self.assertEqual(len(response_resume), self.dataflow_cfg.global_batch_size)
ray.get(self.test_env.shutdown.remote())
Expand Down
70 changes: 54 additions & 16 deletions xtuner/v1/ray/dataflow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,37 @@ class DataFlowConfig(BaseModel):
int,
Parameter(help="Target number of samples to collect before stopping."),
] = 8
enable_partial_rollout: Annotated[
int, Parameter(help="Whether to enable async rollout_controller. 1 for enabled, 0 for disabled")
] = 0
sample_params: Annotated[SampleParams, Parameter(help="Parameters for sampling from the model.")] = SampleParams()
extra_params: Annotated[Dict, Parameter(help="Extra parameters for rollout.")] = {}
# async params
staleness_threshold: Annotated[
float,
Parameter(
help="The maximum allowed threshold of stale (expired) samples in a training batch. Must be between 0.0 and 1.0."
),
] = 0.0
enable_partial_rollout: Annotated[
bool,
Parameter(help="Whether to enable partial rollout for asynchronous data generation."),
] = False
tail_batch_candidate_steps: Annotated[
int,
Parameter(
help="Number of rollout steps after which a sample becomes a candidate for the tail batch. Set to 0 to disable."
),
] = 0
tail_batch_trigger_size: Annotated[
Optional[int],
Parameter(
help="Number of candidate samples needed in the queue to trigger a tail batch operation. Set to 0 to disable."
),
] = None
worker_log_dir: Annotated[Path, Parameter(help="Directory to save worker logs.")] = Path.cwd() / "work_dir"

def model_post_init(self, __context: Any) -> None:
self.worker_log_dir.mkdir(parents=True, exist_ok=True)
if self.tail_batch_trigger_size is None:
self.tail_batch_trigger_size = self.global_batch_size


@ray.remote
Expand Down Expand Up @@ -106,17 +128,20 @@ def __init__(
postprocessor (Optional[Callable]): An optional function to
post-process the generated samples.
"""
self.logger = get_logger(log_dir=dataflow_cfg.worker_log_dir, tag="DataFlow")
self.env = env
self.config = dataflow_cfg
replay_buffer_cfg.worker_log_dir = self.config.worker_log_dir
replay_buffer_cfg.enable_partial_rollout = self.config.enable_partial_rollout
replay_buffer_cfg.tail_batch_candidate_steps = self.config.tail_batch_candidate_steps
replay_buffer_cfg.tail_batch_trigger_size = self.config.tail_batch_trigger_size
self.replay_buffer = ReplayBuffer.remote(replay_buffer_cfg) # type: ignore[attr-defined]
self.env_controller = environment
self.finished_samples_count = 0
self.skipped_sample_count = 0
self.failed_sample_count = 0
self.logger = get_logger(log_dir=self.config.worker_log_dir, tag="DataFlow")
self.filtered_samples_count = 0
self.target_batch_size = self.config.global_batch_size
self.logger.info(f"DataFlowConfig:\n{self.config.model_dump_json(indent=2)}")
rollout_info = ray.get(self.env_controller.get_rollout_info.remote()) # type: ignore[attr-defined]
self.worker_url_list = list(rollout_info["server_url_dict"].values())
self.logger.info(f"DataFlow connected to active rollout workers url: {self.worker_url_list}")
Expand All @@ -140,28 +165,33 @@ def __init__(
f"Dataflow max_concurrent is set to {self.config.max_concurrent}, we proposed to set max_concurrent to {max_concurrent} based on rollout_max_batch_size_per_instance."
)
self.enable_partial_rollout = self.config.enable_partial_rollout
self.logger.info(f"DataFlowConfig:\n{self.config.model_dump_json(indent=2)}")

def _reset_internal_states(
def _prepare(
self,
global_batch_size: Optional[int] = None,
sample_params: Optional[SampleParams] = None,
extra_params: Optional[Dict] = None,
):
"""Resets all internal state variables of DataFlow."""
self.finished_samples_count = 0
self.skipped_sample_count = 0
self.failed_sample_count = 0
if global_batch_size and global_batch_size > 0:
self.target_batch_size = global_batch_size
else:
self.target_batch_size = self.config.global_batch_size

self.sample_from_expired_storage, self.finished_samples_count = ray.get(
self.replay_buffer.get_prerun_state.remote(self.target_batch_size)
)
self.skipped_sample_count = 0
self.failed_sample_count = 0
self.filtered_samples_count = 0

self.sample_params = sample_params if sample_params else self.config.sample_params
self.extra_params = extra_params if extra_params else self.config.extra_params
logger_msg = (
f"DataFlow internal states reset for new run: target_batch_size={self.target_batch_size}, "
f"DataFlow states for new generations: target_batch_size={self.target_batch_size}, "
f"sample_params: {self.sample_params}, extra_params: {self.extra_params}, "
f"enable_partial_rollout={self.enable_partial_rollout}."
f"sample_from_expired_storage={self.sample_from_expired_storage}, finished_samples_count={self.finished_samples_count}, "
)
self.logger.info(logger_msg)

Expand Down Expand Up @@ -191,7 +221,7 @@ async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowIte
# TODO(@duanyanhui): More fine-grained control over group data generation:
# Pass n to the inference engine to ensure that the same data is processed by the same server, improving efficiency.
group_data_items = await self.replay_buffer.sample.remote( # type: ignore[attr-defined]
self.env, self.enable_partial_rollout, self.config.prompt_repeat_k
self.env, self.config.prompt_repeat_k
)
assert len(group_data_items) > 0, "Sampled empty group data items from replay buffer."
action_id = group_data_items[0].uid.action_id
Expand All @@ -207,6 +237,8 @@ async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowIte
group_data_items = await self.replay_buffer.post_processor.remote(group_data_items) # type: ignore[attr-defined]
if len(group_data_items) > 0:
await self.replay_buffer.add.remote(group_data_items) # type: ignore[attr-defined]
else:
self.filtered_samples_count += 1
self.logger.debug(f"Worker task completed successfully for {action_id}.")
elif group_state == RolloutState.ABORTED:
await self.replay_buffer.add.remote(group_data_items) # type: ignore[attr-defined]
Expand Down Expand Up @@ -257,15 +289,15 @@ async def concurrent_task_runner(self):
while len(waiting_tasks) < self.config.max_concurrent:
# In async mode, we keep spawning. In sync mode, we stop if we have enough tasks in flight.
if (
not self.config.enable_partial_rollout
not self.enable_partial_rollout
and self.finished_samples_count + len(waiting_tasks) >= self.target_batch_size
):
break
task = create_task(self.worker_task())
waiting_tasks.add(task)

_, pending_tasks = await asyncio.wait(waiting_tasks, timeout=0.1, return_when=asyncio.FIRST_COMPLETED)
self.finished_samples_count = ray.get(self.replay_buffer.get_finished_samples.remote())
self.finished_samples_count = ray.get(self.replay_buffer.get_completed_samples_count.remote())
waiting_tasks = pending_tasks

pbar.n = self.finished_samples_count
Expand Down Expand Up @@ -340,7 +372,7 @@ async def run(
Returns:
List[RLDataFlowItem]: A list of collected training samples.
"""
self._reset_internal_states(global_batch_size=num, sample_params=sample_params, extra_params=extra_params)
self._prepare(global_batch_size=num, sample_params=sample_params, extra_params=extra_params)

if resume:
assert resume_path, "Resuming is enabled but no resume path is provided."
Expand All @@ -357,7 +389,13 @@ async def run(
return await self.replay_buffer.get_samples.remote(self.target_batch_size) # type: ignore[attr-defined]

def logging_replaybuffer_state(self):
ray.get(self.replay_buffer.print.remote())
status = self.get_replaybuffer_status()
logging_msg = f"ReplayBuffer Status: {status}"
logging_msg += f", finished_samples_count: {self.finished_samples_count}, "
logging_msg += f"skipped_samples_count: {self.skipped_sample_count}, "
logging_msg += f"failed_samples_count: {self.failed_sample_count}, "
logging_msg += f"filtered_samples_count: {self.filtered_samples_count}, "
self.logger.info(logging_msg)

def get_replaybuffer_status(self):
return ray.get(self.replay_buffer.status.remote())
Expand Down
Loading