diff --git a/tests/ray/test_mock_rollout.py b/tests/ray/test_mock_rollout.py index ad678069d..d5d1e5dff 100644 --- a/tests/ray/test_mock_rollout.py +++ b/tests/ray/test_mock_rollout.py @@ -118,8 +118,8 @@ def _run_mock_test(self, mock_controller_cls, error_name: str): status = ray.get(self.test_dataflow.get_replaybuffer_status.remote()) print(f"[{error_name}] Completed rollouts: {completed_rollouts}, Status: {status}") self.assertEqual(len(completed_rollouts[0]), 0, f"[{error_name}] Expected no rollouts to complete successfully.") - self.assertEqual(status["rollout_finished_count"], 0, f"[{error_name}] Completed count in buffer should be 0.") - self.assertEqual(status["rollout_paused_count"], 0, f"[{error_name}] Expected no rollouts to be interrupted.") + self.assertEqual(status["remain_completed_samples_count"], 0, f"[{error_name}] Completed count in buffer should be 0.") + self.assertEqual(status["remain_aborted_samples_count"], 0, f"[{error_name}] Expected no rollouts to be interrupted.") ray.get(self.test_env.shutdown.remote()) @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") diff --git a/tests/ray/test_rollout.py b/tests/ray/test_rollout.py index e3d9f5ac6..772b23cb2 100644 --- a/tests/ray/test_rollout.py +++ b/tests/ray/test_rollout.py @@ -156,9 +156,9 @@ def test_lmdeploy_async_dataflow(self): finished_samples_count = sum(1 for data in responses[0] for item in data if item.env.rollout.finish_reason == "stop" or item.env.rollout.finish_reason == "length") self.assertEqual(finished_samples_count // self.dataflow_cfg.prompt_repeat_k, self.dataflow_cfg.global_batch_size) status = ray.get(self.test_flow.get_replaybuffer_status.remote()) - finished_count = status["rollout_finished_count"] # 已经去掉了data_flow返回的数量 - paused_count = status["rollout_paused_count"] - sample_count = status["action_count"] + finished_count = status["remain_completed_samples_count"] # 已经去掉了data_flow返回的数量 + paused_count = status["remain_aborted_samples_count"] + sample_count = status["sample_from_dataset_count"] + status["sample_from_aborted_count"] + status["sample_from_expired_count"] self.assertEqual(len(responses) + finished_count + paused_count, sample_count) self.assertEqual(len(responses), self.dataflow_cfg.global_batch_size) @@ -167,9 +167,9 @@ def test_lmdeploy_async_dataflow(self): finished_resume_samples_count = sum(1 for data in response_resume[0] for item in data if item.env.rollout.finish_reason == "stop" or item.env.rollout.finish_reason == "length") self.assertEqual(finished_resume_samples_count // self.dataflow_cfg.prompt_repeat_k, self.dataflow_cfg.global_batch_size) status = ray.get(self.test_flow.get_replaybuffer_status.remote()) - finished_count = status["rollout_finished_count"] - paused_count = status["rollout_paused_count"] - sample_count = status["action_count"] + finished_count = status["remain_completed_samples_count"] + paused_count = status["remain_aborted_samples_count"] + sample_count = status["sample_from_dataset_count"] + status["sample_from_aborted_count"] + status["sample_from_expired_count"] self.assertEqual(len(response_resume) + finished_count + paused_count, sample_count) self.assertEqual(len(response_resume), self.dataflow_cfg.global_batch_size) ray.get(self.test_env.shutdown.remote()) @@ -236,9 +236,9 @@ def test_sglang_async_dataflow(self): finished_samples_count = sum(1 for data in responses[0] for item in data if item.env.rollout.finish_reason == "stop" or item.env.rollout.finish_reason == "length") self.assertEqual(finished_samples_count // self.dataflow_cfg.prompt_repeat_k, self.dataflow_cfg.global_batch_size) status = ray.get(self.test_flow.get_replaybuffer_status.remote()) - finished_count = status["rollout_finished_count"] # 已经去掉了data_flow返回的数量 - paused_count = status["rollout_paused_count"] - sample_count = status["action_count"] + finished_count = status["remain_completed_samples_count"] # 已经去掉了data_flow返回的数量 + paused_count = status["remain_aborted_samples_count"] + sample_count = status["sample_from_dataset_count"] + status["sample_from_aborted_count"] + status["sample_from_expired_count"] self.assertEqual(len(responses) + finished_count + paused_count, sample_count) self.assertEqual(len(responses), self.dataflow_cfg.global_batch_size) @@ -247,9 +247,9 @@ def test_sglang_async_dataflow(self): finished_resume_samples_count = sum(1 for data in response_resume[0] for item in data if item.env.rollout.finish_reason == "stop" or item.env.rollout.finish_reason == "length") self.assertEqual(finished_resume_samples_count // self.dataflow_cfg.prompt_repeat_k, self.dataflow_cfg.global_batch_size) status = ray.get(self.test_flow.get_replaybuffer_status.remote()) - finished_count = status["rollout_finished_count"] - paused_count = status["rollout_paused_count"] - sample_count = status["action_count"] + finished_count = status["remain_completed_samples_count"] + paused_count = status["remain_aborted_samples_count"] + sample_count = status["sample_from_dataset_count"] + status["sample_from_aborted_count"] + status["sample_from_expired_count"] self.assertEqual(len(response_resume) + finished_count + paused_count, sample_count) self.assertEqual(len(response_resume), self.dataflow_cfg.global_batch_size) ray.get(self.test_env.shutdown.remote()) diff --git a/xtuner/v1/ray/dataflow/flow.py b/xtuner/v1/ray/dataflow/flow.py index aff0b0ecc..26c2c50cf 100644 --- a/xtuner/v1/ray/dataflow/flow.py +++ b/xtuner/v1/ray/dataflow/flow.py @@ -68,15 +68,37 @@ class DataFlowConfig(BaseModel): int, Parameter(help="Target number of samples to collect before stopping."), ] = 8 - enable_partial_rollout: Annotated[ - int, Parameter(help="Whether to enable async rollout_controller. 1 for enabled, 0 for disabled") - ] = 0 sample_params: Annotated[SampleParams, Parameter(help="Parameters for sampling from the model.")] = SampleParams() extra_params: Annotated[Dict, Parameter(help="Extra parameters for rollout.")] = {} + # async params + staleness_threshold: Annotated[ + float, + Parameter( + help="The maximum allowed threshold of stale (expired) samples in a training batch. Must be between 0.0 and 1.0." + ), + ] = 0.0 + enable_partial_rollout: Annotated[ + bool, + Parameter(help="Whether to enable partial rollout for asynchronous data generation."), + ] = False + tail_batch_candidate_steps: Annotated[ + int, + Parameter( + help="Number of rollout steps after which a sample becomes a candidate for the tail batch. Set to 0 to disable." + ), + ] = 0 + tail_batch_trigger_size: Annotated[ + Optional[int], + Parameter( + help="Number of candidate samples needed in the queue to trigger a tail batch operation. Set to 0 to disable." + ), + ] = None worker_log_dir: Annotated[Path, Parameter(help="Directory to save worker logs.")] = Path.cwd() / "work_dir" def model_post_init(self, __context: Any) -> None: self.worker_log_dir.mkdir(parents=True, exist_ok=True) + if self.tail_batch_trigger_size is None: + self.tail_batch_trigger_size = self.global_batch_size @ray.remote @@ -106,17 +128,20 @@ def __init__( postprocessor (Optional[Callable]): An optional function to post-process the generated samples. """ + self.logger = get_logger(log_dir=dataflow_cfg.worker_log_dir, tag="DataFlow") self.env = env self.config = dataflow_cfg replay_buffer_cfg.worker_log_dir = self.config.worker_log_dir + replay_buffer_cfg.enable_partial_rollout = self.config.enable_partial_rollout + replay_buffer_cfg.tail_batch_candidate_steps = self.config.tail_batch_candidate_steps + replay_buffer_cfg.tail_batch_trigger_size = self.config.tail_batch_trigger_size self.replay_buffer = ReplayBuffer.remote(replay_buffer_cfg) # type: ignore[attr-defined] self.env_controller = environment self.finished_samples_count = 0 self.skipped_sample_count = 0 self.failed_sample_count = 0 - self.logger = get_logger(log_dir=self.config.worker_log_dir, tag="DataFlow") + self.filtered_samples_count = 0 self.target_batch_size = self.config.global_batch_size - self.logger.info(f"DataFlowConfig:\n{self.config.model_dump_json(indent=2)}") rollout_info = ray.get(self.env_controller.get_rollout_info.remote()) # type: ignore[attr-defined] self.worker_url_list = list(rollout_info["server_url_dict"].values()) self.logger.info(f"DataFlow connected to active rollout workers url: {self.worker_url_list}") @@ -140,28 +165,33 @@ def __init__( f"Dataflow max_concurrent is set to {self.config.max_concurrent}, we proposed to set max_concurrent to {max_concurrent} based on rollout_max_batch_size_per_instance." ) self.enable_partial_rollout = self.config.enable_partial_rollout + self.logger.info(f"DataFlowConfig:\n{self.config.model_dump_json(indent=2)}") - def _reset_internal_states( + def _prepare( self, global_batch_size: Optional[int] = None, sample_params: Optional[SampleParams] = None, extra_params: Optional[Dict] = None, ): """Resets all internal state variables of DataFlow.""" - self.finished_samples_count = 0 - self.skipped_sample_count = 0 - self.failed_sample_count = 0 if global_batch_size and global_batch_size > 0: self.target_batch_size = global_batch_size else: self.target_batch_size = self.config.global_batch_size + self.sample_from_expired_storage, self.finished_samples_count = ray.get( + self.replay_buffer.get_prerun_state.remote(self.target_batch_size) + ) + self.skipped_sample_count = 0 + self.failed_sample_count = 0 + self.filtered_samples_count = 0 + self.sample_params = sample_params if sample_params else self.config.sample_params self.extra_params = extra_params if extra_params else self.config.extra_params logger_msg = ( - f"DataFlow internal states reset for new run: target_batch_size={self.target_batch_size}, " + f"DataFlow states for new generations: target_batch_size={self.target_batch_size}, " f"sample_params: {self.sample_params}, extra_params: {self.extra_params}, " - f"enable_partial_rollout={self.enable_partial_rollout}." + f"sample_from_expired_storage={self.sample_from_expired_storage}, finished_samples_count={self.finished_samples_count}, " ) self.logger.info(logger_msg) @@ -191,7 +221,7 @@ async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowIte # TODO(@duanyanhui): More fine-grained control over group data generation: # Pass n to the inference engine to ensure that the same data is processed by the same server, improving efficiency. group_data_items = await self.replay_buffer.sample.remote( # type: ignore[attr-defined] - self.env, self.enable_partial_rollout, self.config.prompt_repeat_k + self.env, self.config.prompt_repeat_k ) assert len(group_data_items) > 0, "Sampled empty group data items from replay buffer." action_id = group_data_items[0].uid.action_id @@ -207,6 +237,8 @@ async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowIte group_data_items = await self.replay_buffer.post_processor.remote(group_data_items) # type: ignore[attr-defined] if len(group_data_items) > 0: await self.replay_buffer.add.remote(group_data_items) # type: ignore[attr-defined] + else: + self.filtered_samples_count += 1 self.logger.debug(f"Worker task completed successfully for {action_id}.") elif group_state == RolloutState.ABORTED: await self.replay_buffer.add.remote(group_data_items) # type: ignore[attr-defined] @@ -257,7 +289,7 @@ async def concurrent_task_runner(self): while len(waiting_tasks) < self.config.max_concurrent: # In async mode, we keep spawning. In sync mode, we stop if we have enough tasks in flight. if ( - not self.config.enable_partial_rollout + not self.enable_partial_rollout and self.finished_samples_count + len(waiting_tasks) >= self.target_batch_size ): break @@ -265,7 +297,7 @@ async def concurrent_task_runner(self): waiting_tasks.add(task) _, pending_tasks = await asyncio.wait(waiting_tasks, timeout=0.1, return_when=asyncio.FIRST_COMPLETED) - self.finished_samples_count = ray.get(self.replay_buffer.get_finished_samples.remote()) + self.finished_samples_count = ray.get(self.replay_buffer.get_completed_samples_count.remote()) waiting_tasks = pending_tasks pbar.n = self.finished_samples_count @@ -340,7 +372,7 @@ async def run( Returns: List[RLDataFlowItem]: A list of collected training samples. """ - self._reset_internal_states(global_batch_size=num, sample_params=sample_params, extra_params=extra_params) + self._prepare(global_batch_size=num, sample_params=sample_params, extra_params=extra_params) if resume: assert resume_path, "Resuming is enabled but no resume path is provided." @@ -357,7 +389,13 @@ async def run( return await self.replay_buffer.get_samples.remote(self.target_batch_size) # type: ignore[attr-defined] def logging_replaybuffer_state(self): - ray.get(self.replay_buffer.print.remote()) + status = self.get_replaybuffer_status() + logging_msg = f"ReplayBuffer Status: {status}" + logging_msg += f", finished_samples_count: {self.finished_samples_count}, " + logging_msg += f"skipped_samples_count: {self.skipped_sample_count}, " + logging_msg += f"failed_samples_count: {self.failed_sample_count}, " + logging_msg += f"filtered_samples_count: {self.filtered_samples_count}, " + self.logger.info(logging_msg) def get_replaybuffer_status(self): return ray.get(self.replay_buffer.status.remote()) diff --git a/xtuner/v1/ray/dataflow/replay_buffer.py b/xtuner/v1/ray/dataflow/replay_buffer.py index ee67bf520..0b5877f85 100644 --- a/xtuner/v1/ray/dataflow/replay_buffer.py +++ b/xtuner/v1/ray/dataflow/replay_buffer.py @@ -15,6 +15,7 @@ from xtuner.v1.data_proto.rl_data import ( RLDataFlowItem, RLDatasetItem, + RLEnvDataItem, RLExtraDataItem, RLUIDItem, RolloutState, @@ -49,10 +50,11 @@ class ReplayMeta: root_id: int = 0 action_id: int = 0 # same prompt share the same action_id action_ref: ObjectRef = None - observation_ids: List[int] = field(default_factory=list) # observation IDs for different versions + observation_ids: List[int] = field(default_factory=list) observation_refs: List[ObjectRef] = field(default_factory=list) - observation_versions: List[int] = field(default_factory=list) # reserved for async rollout + observation_versions: List[int] = field(default_factory=list) # 目前发数据为按组下发,暂时用不到 state: RolloutState = RolloutState.INIT + version: int = 0 # version for partial rollout extra_info: Dict[str, Any] = field(default_factory=dict) @@ -82,20 +84,20 @@ def mapping_dataitem_to_replaymeta(grouped_dataitem: List[RLDataFlowItem]) -> Re root_id = grouped_dataitem[0].uid.root_id action_id = grouped_dataitem[0].uid.action_id data = grouped_dataitem[0].data + # 现在是按组发送,那么一组里的dataitem的version是一样的,如果一组中的数据在某次rollout step中没有生成的数据,version也还是会+1 + group_version = grouped_dataitem[0].uid.version observation_ids = [] observation_refs = [] - observation_versions = [] - group_states = [] for item in grouped_dataitem: - version = item.uid.version observation_ids.append(item.uid.observation_id) observation_refs.append(ray.put(item.env)) - observation_versions.append(version) - group_states.append(item.env.rollout.finish_reason) group_state = determine_group_state(grouped_dataitem) - logger.debug(f"determined group_state: {group_state}, replay_state: {group_state}") + logger.debug( + f"Mapping data items to ReplayMeta {action_id} with group_state: {group_state}, group_version: {group_version}" + ) + replay_meta = ReplayMeta( env=env_str, root_id=root_id, @@ -103,8 +105,8 @@ def mapping_dataitem_to_replaymeta(grouped_dataitem: List[RLDataFlowItem]) -> Re action_ref=ray.put(data), observation_ids=observation_ids, observation_refs=observation_refs, - observation_versions=observation_versions, state=group_state, + version=group_version, extra_info={}, ) return replay_meta @@ -116,14 +118,14 @@ def mapping_replaymeta_to_dataitem(replay_meta: ReplayMeta) -> List[RLDataFlowIt action_id = replay_meta.action_id data_ref = ray.get(replay_meta.action_ref) group_data_item = [] - for obs_id, obs_ref, version in zip( - replay_meta.observation_ids, replay_meta.observation_refs, replay_meta.observation_versions - ): + for obs_id, obs_ref in zip(replay_meta.observation_ids, replay_meta.observation_refs): env_data = ray.get(obs_ref) ray._private.internal_api.free(obs_ref) item = RLDataFlowItem( - uid=RLUIDItem(env=env_str, root_id=root_id, action_id=action_id, observation_id=obs_id, version=version), + uid=RLUIDItem( + env=env_str, root_id=root_id, action_id=action_id, observation_id=obs_id, version=replay_meta.version + ), data=data_ref, env=env_data, extra_info=RLExtraDataItem(), @@ -193,51 +195,89 @@ class ReplayBufferConfig(BaseModel): dict, Parameter(help="Weights for different states in the replay buffer."), ] = {} + # async rollout related configs, assigned from dataflow cfg + enable_partial_rollout: Annotated[ + bool, + Parameter(help="Whether to enable partial rollout for asynchronous data generation."), + ] = False + tail_batch_candidate_steps: Annotated[ + int, + Parameter( + help="Number of rollout steps after which a sample becomes a candidate for the tail batch. Set to 0 to disable." + ), + ] = 0 + tail_batch_trigger_size: Annotated[ + Optional[int], + Parameter( + help="Number of candidate samples needed in the queue to trigger a tail batch operation. Set to 0 to disable." + ), + ] = None worker_log_dir: Annotated[Path, Parameter(help="Directory to save worker logs.")] = Path.cwd() / "work_dir" -class Sampler: - """Sampler for drawing prompts from datasets or the replay buffer.""" +class DatasetSampler: + """Sampler for drawing new prompts from the configured dataset. - def __init__(self, dataset, dataloader, tokenizer, storage): - """Initializes the Sampler. + This class is responsible for building a dataloader from the provided dataset configurations and sampling fresh + data prompts upon request. + """ + + def __init__(self, dataset_cfg, dataloader_cfg, tokenizer): + """Initializes the DatasetSampler. Args: - dataset: The dataset to sample from. - dataloader: The dataloader for the dataset. - tokenizer: The tokenizer for processing text. - storage: The ReplayBufferStorage instance. + dataset_cfg (List): Configuration for the datasets to sample from. + dataloader_cfg (Optional[DataloaderConfig]): Configuration for the + PyTorch DataLoader. + tokenizer (Union[PreTrainedTokenizer, PreTrainedTokenizerFast, str]): + The tokenizer for processing text data. Can be a path or an object. """ - self.train_dataset = dataset - self.train_dataloader = dataloader - self.train_dataloader_iter = iter(self.train_dataloader) - self.tokenizer = ( - tokenizer - if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)) - else AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True) + self.tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] + if isinstance(tokenizer, str): + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True) + else: + self.tokenizer = tokenizer + self.datasets = build_datasets(dataset_cfg, self.tokenizer) + if dataloader_cfg is not None: + self.dataloader_cfg = dataloader_cfg + else: + self.dataloader_cfg = DataloaderConfig( + collator="fake_collator", + pack_level="none", + ) + self.dataloader = build_dataloader( + dataloader_config=self.dataloader_cfg, + datasets=self.datasets, + global_batch_size=1, + micro_batch_size=1, + seed=1, ) - self.storage = storage - self.sample_count = 0 + self.dataloader_iter = iter(self.dataloader) self.logger = get_logger() - def sample_from_datasets(self, env: str, repeat_prompt_k: int) -> List[RLDataFlowItem]: - """Samples a new group of prompts from the original dataset. + def sample(self, env: str, prompt_repeat_k: int) -> List[RLDataFlowItem]: + """Samples a new prompt from the dataset and prepares it as a group. + + This method fetches the next item from the dataloader, assigns new + unique IDs (root_id, action_id, observation_id), and formats it into + a list of RLDataFlowItem objects, repeated `prompt_repeat_k` times. Args: - env (str): The environment name. - repeat_prompt_k (int): The number of times to repeat the prompt. + env (str): The environment name to be associated with the new samples. + prompt_repeat_k (int): The number of times to repeat the sampled + prompt in the returned group. Returns: - List[RLDataFlowItem]: A list of data items for the data group contains repeat_prompt_k samples from same data. + List[RLDataFlowItem]: A list of newly created data items for a rollout. """ root_id = uuid4().int action_id = uuid4().int - group_data_item: List[RLDataFlowItem] = [RLDataFlowItem() for _ in range(repeat_prompt_k)] + group_data_item: List[RLDataFlowItem] = [RLDataFlowItem() for _ in range(prompt_repeat_k)] try: - data = next(self.train_dataloader_iter)[0] + data = next(self.dataloader_iter)[0] except StopIteration: - self.train_dataloader_iter = iter(self.train_dataloader) - data = next(self.train_dataloader_iter)[0] + self.dataloader_iter = iter(self.dataloader) + data = next(self.dataloader_iter)[0] multimodal_train_info = data.pop("multimodal_train_info", {}) if "pixel_values" in multimodal_train_info: @@ -253,65 +293,35 @@ def sample_from_datasets(self, env: str, repeat_prompt_k: int) -> List[RLDataFlo ) data_item.data = RLDatasetItem(**data) data_item.extra_info = RLExtraDataItem(retry_times=0) - + self.logger.debug(f"Sampling new prompt with action_id: {action_id} in env: {env}") return group_data_item - def sample_from_unfinished_buffer(self) -> List[RLDataFlowItem]: - """Samples a prompt from a partially completed (unfinished) rollout.""" - action_id = self.storage._paused.pop(0) - self.logger.debug(f"Sampling unfinished action_id: {action_id} from replay buffer") - replay_meta = self.storage._actions[action_id] - group_samples = mapping_replaymeta_to_dataitem(replay_meta) - self.sample_count += 1 - if len(self.storage._paused) == 0: - self.logger.info(f"Sampled {self.sample_count} unfinished samples from replay buffer") - return group_samples - - def sample(self, env: str, enable_partial_rollout: int, prompt_repeat_k: int) -> List[RLDataFlowItem]: - """Selects a sampling strategy and returns a group of samples. - - It decides whether to sample from the unfinished buffer (for partial - rollouts greater than 0) or from the original dataset. - - Args: - env (str): The environment name. - enable_partial_rollout (int): Flag to enable partial rollout. - prompt_repeat_k (int): Number of times to repeat the prompt. - - Returns: - List[RLDataFlowItem]: A list of sampled data items. - """ - # TODO(@duanyanhui): 考虑sampler结构的独立性,不要传入replay buffer storage, - # sample_from_unfinished_buffer可以作为replay buffer的一个方法 - if enable_partial_rollout > 0 and len(self.storage._paused) > 0: - return self.sample_from_unfinished_buffer() - else: - # note: Sample grouped sample at once. They share the same action_id - return self.sample_from_datasets(env, prompt_repeat_k) - def resume(self, num: int) -> None: - self.train_dataloader_iter = itertools.islice(self.train_dataloader, num, None) + self.dataloader_iter = itertools.islice(self.dataloader, num, None) class ReplayBufferStorage: """Handles the storage of experiences for the replay buffer.""" - def __init__(self, worker_log_dir): + def __init__(self, replay_buffer_cfg): """Initializes the data structures for storing replay data.""" - self._paused: List[int] = [] # List of paused action_id, - self._returned: List[int] = [] # List of returned action_id, - self._actions: Dict[int, ReplayMeta] = {} # action_id: ReplayMeta - self._root2actions: Dict[int, List[int]] = defaultdict( - list - ) # root_id: [action_id, action_id, ...], designed for grpo - self._observations: Dict[int, ReplayMeta] = {} # observation_id: ReplayMeta - self._observations2states: Dict[int, str] = {} # observation_id: state_str - self._states: Dict[str, List[int]] = defaultdict(list) # str: [observation_id, observation_id, ...] - self._action2observations: Dict[int, List[int]] = defaultdict( - list - ) # action_id: [observation_id, observation_id, ...] - self.logger = get_logger(log_dir=worker_log_dir, tag="ReplayBuffer") + self.enable_partial_rollout = replay_buffer_cfg.enable_partial_rollout + self.tail_batch_candidate_steps = replay_buffer_cfg.tail_batch_candidate_steps + self.tail_batch_trigger_size = replay_buffer_cfg.tail_batch_trigger_size + + self._completed_actions: Dict[int, List[int]] = defaultdict(list) + self._aborted_actions: Dict[int, List[int]] = defaultdict(list) + self._expired_actions: List[int] = [] + self._actions: Dict[int, ReplayMeta] = {} + self._root2actions: Dict[int, List[int]] = {} + self._observations: Dict[int, ReplayMeta] = {} + self._observations2states: Dict[int, str] = {} + self._states: Dict[str, List[int]] = defaultdict(list) + self._action2observations: Dict[int, List[int]] = defaultdict(list) self._multimodal_train_infos: Dict[int, Dict[str, Any]] = {} + self.logger = get_logger(log_dir=replay_buffer_cfg.worker_log_dir, tag="ReplayBuffer") + self.sample_from_aborted_count = 0 + self.sample_from_expired_count = 0 def add(self, grouped_dataitem: List[RLDataFlowItem]): """Adds a group of data items to the storage. @@ -330,39 +340,29 @@ def add(self, grouped_dataitem: List[RLDataFlowItem]): replay_meta = mapping_dataitem_to_replaymeta(grouped_dataitem) root_id = replay_meta.root_id action_id = replay_meta.action_id - state = replay_meta.state - if state == RolloutState.ABORTED: - self._paused.append(action_id) - elif state == RolloutState.COMPLETED: - self._returned.append(action_id) - self.logger.debug( - f"Adding action_id: {action_id} with state: {state} to ReplayBufferStorage. Paused count: {len(self._paused)}, Returned count: {len(self._returned)}" - ) - self._root2actions[root_id].append(action_id) + # 1. 更新版本 + if root_id in self._root2actions: + # TODO: 考虑到非共卡的情况,version是否更新需要根据是否update_weights来判断 + replay_meta.version += 1 + self._root2actions[root_id].append(action_id) + self.logger.info( + f"Existing root_id: {root_id} with action_id {action_id} found. Incrementing version to {replay_meta.version}." + ) + else: + self._root2actions[root_id] = [action_id] self._actions[action_id] = replay_meta - # observation + # 2. 根据rollout_state更新completed/aborted/expired相关映射 + self._check_rollout_state_and_insert(replay_meta) + + # 3. 更新observations相关映射 for observation_id in replay_meta.observation_ids: self._action2observations[action_id].append(observation_id) self._observations[observation_id] = replay_meta self._observations2states[observation_id] = replay_meta.state self._states[replay_meta.state].append(observation_id) - def clear(self): - attrs_to_clear = [ - "_paused", - "_returned", - "_actions", - "_root2actions", - "_observations", - "_observations2states", - "_states", - "_action2observations", - ] - for attr in attrs_to_clear: - getattr(self, attr).clear() - def get(self, global_batch_size: int) -> Tuple[List[List[RLDataFlowItem]], List[Dict[str, Any] | None]]: """Retrieves a batch of finished sample groups from the buffer. @@ -380,65 +380,56 @@ def get(self, global_batch_size: int) -> Tuple[List[List[RLDataFlowItem]], List[ """ samples = [] multimodal_train_infos = [] - if len(self._returned) < global_batch_size: - self.logger.error("Not enough finished samples in replay buffer") - return [], [] - else: - self.logger.info( - f"Retrieving global_batch_size {global_batch_size} from replay buffer, len of self.returned: {len(self._returned)}" - ) - target_finished_list = self._returned[:global_batch_size] - remain_finished_list = self._returned[global_batch_size:] - for action_id in target_finished_list: - replay_meta = self._actions.pop(action_id) - # todo: add an unified state management - replay_meta.state = RolloutState.ARCHIVED - group_samples = mapping_replaymeta_to_dataitem(replay_meta) - del replay_meta - multimodal_train_info = None - # TODO: 是否需要额外返回不重复的 multimodal_train_infos? - for data_item in group_samples: - if hasattr(data_item.data, "multimodal_train_info"): - multimodal_train_info = data_item.data.multimodal_train_info - del data_item.data.multimodal_train_info - samples.append(group_samples) - multimodal_train_infos.append(multimodal_train_info) - self._returned = remain_finished_list - - return samples, multimodal_train_infos - - def get_finished_samples(self): - """Returns the number of finished sample groups.""" - return len(self._returned) - - def get_unfinished_samples(self): - """Returns the number of unfinished sample groups.""" - return len(self._paused) - - def get_prompt_num(self): - return len(self._action2observations) + target_batch_size = min(global_batch_size, self.completed_samples_count) + self.logger.info(f"Retrieving {target_batch_size} completed samples from the replay buffer.") + for _ in range(target_batch_size): + action_id = self._pop_highest_version_action(self._completed_actions) + if action_id is None: + self.logger.warning("Get action_id None from completed_actions and skip this iteration.") + continue + replay_meta = self._actions.pop(action_id) + group_samples = mapping_replaymeta_to_dataitem(replay_meta) + # 将这条数据彻底清除,不用再记录root_id对应的action_ids了 + self._clear_meta_for_root(replay_meta) + multimodal_train_info = None + # TODO: 是否需要额外返回不重复的 multimodal_train_infos? + for data_item in group_samples: + if hasattr(data_item.data, "multimodal_train_info"): + multimodal_train_info = data_item.data.multimodal_train_info + del data_item.data.multimodal_train_info + if "partial_rollout_input_ids" in data_item.data.extra_info: + del data_item.data.extra_info["partial_rollout_input_ids"] + samples.append(group_samples) + multimodal_train_infos.append(multimodal_train_info) + + # 检查completed_samples中是否还有剩余的数据,并且检查其是否过期 + self.logger.info(f"Remaining completed samples in buffer: {self.completed_samples_count}") + self._check_completed_samples_expired() + return samples, multimodal_train_infos + + def sample(self, sample_from_expired_states) -> List[RLDataFlowItem]: + if sample_from_expired_states and self.expired_samples_count > 0: + self.sample_from_expired_count += 1 + return self._sample_from_expired_storage() + if self.aborted_samples_count > 0: + self.sample_from_aborted_count += 1 + return self._sample_from_aborted_storage() + return [] - def status(self): - return { - "rollout_finished_count": len(self._returned), - "rollout_paused_count": len(self._paused), - "action_count": len(self._actions), - "observation_count": len(self._observations), - } - - def print(self): - rollout_finished_count = len(self._returned) - rollout_paused_count = len(self._paused) - action_count = len(self._actions) - observation_count = len(self._observations) - - log_message = ( - "[ReplayBuffer] ReplayBufferStorage states:\n" - f" - Rollout States: Finished={rollout_finished_count}, Paused={rollout_paused_count}\n" - f" - History Actions: {action_count}\n" - f" - History Observations: {observation_count}" - ) - self.logger.info(log_message) + def clear(self): + attrs_to_clear = [ + "_aborted_actions", + "_completed_actions", + "_expired_actions", + "_actions", + "_root2actions", + "_observations", + "_observations2states", + "_states", + "_action2observations", + ] + for attr in attrs_to_clear: + getattr(self, attr).clear() def dump(self, file_path: str): """Dumps the entire state of the replay buffer storage to a single @@ -485,12 +476,19 @@ def resume(self, file_path: str): replay_meta = mapping_dataitem_to_replaymeta(group_data_items) root_id = replay_meta.root_id action_id = replay_meta.action_id - state_str = replay_meta.state - if state_str == "abort": - self._paused.append(action_id) - elif state_str == "returned": - self._returned.append(action_id) - self._root2actions[root_id].append(action_id) + state = replay_meta.state + version = replay_meta.version + self.logger.info(f"state of replay_meta while resuming: {replay_meta}") + if state == RolloutState.ABORTED: + self._aborted_actions[version].append(action_id) + elif state == RolloutState.EXPIRED: + self._expired_actions.append(action_id) + elif state == RolloutState.COMPLETED: + self._completed_actions[version].append(action_id) + if root_id not in self._root2actions: + self._root2actions[root_id] = [action_id] + else: + self._root2actions[root_id].append(action_id) self._actions[action_id] = replay_meta for observation_id in replay_meta.observation_ids: self._action2observations[action_id].append(observation_id) @@ -500,7 +498,209 @@ def resume(self, file_path: str): self.logger.info(f"ReplayBufferStorage state successfully resumed from {file_path}") - self.print() + @property + def completed_samples_count(self) -> int: + return sum(len(bucket) for bucket in self._completed_actions.values()) + + @property + def aborted_samples_count(self): + return sum(len(bucket) for bucket in self._aborted_actions.values()) + + @property + def expired_samples_count(self): + return len(self._expired_actions) + + def _sample_from_expired_storage(self) -> List[RLDataFlowItem]: + """Samples an item from the expired storage for re-rollout. + + This method takes an action_id from the expired queue, retrieves its + original prompt data, cleans up all its previous rollout outputs, and + prepares it as a new sample group with a fresh action_id and reset + version (0) to be sent for a new generation attempt. + + Returns: + List[RLDataFlowItem]: A list of data items ready for a new rollout. + """ + assert len(self._expired_actions) > 0 + action_id = self._expired_actions.pop() + replay_meta = self._actions.pop(action_id) + group_samples = mapping_replaymeta_to_dataitem(replay_meta) + # 把这条数据上次的记录全部删掉,重新开始rollout,root2actions也要清除 + self._clear_meta_for_root(replay_meta) + + for sample in group_samples: + assert sample.data.input_ids and sample.data.num_tokens, "input_ids or num_tokens is empty!" + del sample.env + sample.env = RLEnvDataItem() # 重置env数据 + sample.uid.action_id = action_id + sample.uid.version = 0 + + self.logger.info( + f"Sampling expired action_id: {action_id} from replay buffer, remain expired samples: {len(self._expired_actions)}" + ) + return group_samples + + def _sample_from_aborted_storage(self) -> List[RLDataFlowItem]: + """Samples an item from the aborted storage for re-rollout. + + This method retrieves an action with the highest version (oldest version) from the + aborted buckets. It then cleans up its previous (aborted) rollout + outputs and prepares it as a new sample group with a fresh action_id. + The original version number is preserved to track its retry history. + + Returns: + List[RLDataFlowItem]: A list of data items ready for a new rollout. + """ + assert self.aborted_samples_count > 0 + action_id = self._pop_highest_version_action(self._aborted_actions) + # 通过self.aborted_samples_count判断过这里不会返回None + replay_meta = self._actions.pop(action_id) # type: ignore[arg-type] + replay_meta_version = replay_meta.version + group_samples = mapping_replaymeta_to_dataitem(replay_meta) + # 把这条数据上次rollout产生的输出的记录都删掉,上次的数据已经记录在了RLEnvDataItem中了 + self._clear_meta_for_actions(replay_meta) + + sample_action_id = uuid4().int + for sample in group_samples: + assert sample.data.input_ids and sample.data.num_tokens, "input_ids or num_tokens is empty!" + if not self.enable_partial_rollout: + # 清除上次的response_ids等env数据 + del sample.env + sample.env = RLEnvDataItem() + else: + # 将异步的逻辑尽量放在replay buffer中处理,尽量不在env/rollout中进行处理 + history_response_ids = list(itertools.chain.from_iterable(sample.env.rollout.versioned_response_ids)) + sample.data.extra_info["partial_rollout_input_ids"] = sample.data.input_ids + history_response_ids + self.logger.debug( + f"Partial rollout enabled, pass response_ids {len(history_response_ids)} to data extra info when sampling." + ) + sample.uid.action_id = int(sample_action_id) + sample.uid.version = replay_meta_version + + self.logger.info( + f"Sampling aborted action_id: {sample_action_id}, root_id: {group_samples[0].uid.root_id} from replay buffer, remain aborted samples: {self.aborted_samples_count}" + ) + return group_samples + + def _pop_highest_version_action(self, buckets: Dict[int, List[int]]) -> Optional[int]: + if not buckets: + return None + + highest_version = sorted(buckets.keys())[-1] + action_list = buckets[highest_version] + action_id = action_list.pop() + if not action_list: + del buckets[highest_version] + + return action_id + + def _check_completed_samples_expired(self): + """Moves samples from completed buckets to the expired list if they are + too old after get target completed samples from replay buffer. + + This method iterates through the `_completed_actions` buckets. If a + bucket's version index is greater than or equal to the configured + `tail_batch_candidate_steps`, all action_ids within that bucket are + moved to the `_expired_actions` list, marking them as expired. + """ + if self.tail_batch_candidate_steps <= 0: + return + + expired_versions = [v for v in self._completed_actions if v >= self.tail_batch_candidate_steps] + + for version in expired_versions: + bucket = self._completed_actions.pop(version) + self._expired_actions.extend(bucket) + self.logger.info( + f"Moved {len(bucket)} completed samples with version {version} to expired samples due to exceeding tail_batch_candidate_steps." + ) + + def _clear_meta_for_actions(self, replay_meta: ReplayMeta): + """Completely removes an action and all its associated data from the + storage. + + This is the single source of truth for deleting an action. + """ + action_id = replay_meta.action_id + + for observation_id in replay_meta.observation_ids: + self._observations.pop(observation_id, None) + state = self._observations2states.pop(observation_id, None) + if state and observation_id in self._states.get(state, []): + self._states[state].remove(observation_id) + + self._action2observations.pop(action_id, None) + del replay_meta + + def _clear_meta_for_root(self, replay_meta: ReplayMeta): + """Clears all actions and associated metadata linked to the same + root_id. + + This function is called after a sample group is successfully retrieved + for training. It ensures that all historical versions of a prompt + (linked by root_id) are purged from the storage to prevent them from + being re-sampled or replayed. + + Args: + replay_meta (ReplayMeta): The metadata of the action that was just + retrieved. The root_id from this object will be used to find + and clear all related actions. + """ + root_id = replay_meta.root_id + if root_id in self._root2actions: + for action_id in self._root2actions[root_id]: + new_replay_meta = self._actions.pop(action_id, None) + if new_replay_meta: + self._clear_meta_for_actions(new_replay_meta) + del self._root2actions[root_id] + del replay_meta + + def _check_rollout_state_and_insert(self, replay_meta: ReplayMeta): + """Checks the rollout state of a ReplayMeta object and inserts its + action_id into the appropriate state bucket. + + This method acts as a router, directing action_ids to different storage + lists (_aborted_actions, _completed_actions, _expired_actions) based on + their final rollout state and version. It also handles the logic for + when an aborted sample becomes expired due to too many retries. + + Args: + replay_meta (ReplayMeta): The metadata object containing the final + state and version of a rollout action. + """ + state = replay_meta.state + root_id = replay_meta.root_id + action_id = replay_meta.action_id + + if state == RolloutState.ABORTED: + if self.tail_batch_candidate_steps == 0: + replay_meta.version = 0 + self._aborted_actions[replay_meta.version].append(action_id) + self.logger.debug( + f"Add aborted sample with action_id: {action_id} version 0 to _aborted_actions because of no tail_batch_candidate_steps." + ) + elif self.tail_batch_candidate_steps > 0 and replay_meta.version < self.tail_batch_candidate_steps: + self._aborted_actions[replay_meta.version].append(action_id) + self.logger.debug( + f"Add aborted sample with action_id: {action_id} version: {replay_meta.version} to _aborted_actions." + ) + elif self.tail_batch_candidate_steps > 0 and replay_meta.version >= self.tail_batch_candidate_steps: + # 过期的数据需要重置状态 + replay_meta.version = 0 + replay_meta.state = RolloutState.EXPIRED + self._expired_actions.append(action_id) + self.logger.debug( + f"Add expired sample with action_id: {action_id} to _expired_actions because version: {replay_meta.version} >= tail_batch_candidate_steps: {self.tail_batch_candidate_steps}." + ) + else: + assert False, ( + f"Unsupported rollout state {state} and rollout version {replay_meta.version} for action_id {action_id} in ReplayBufferStorage." + ) + elif state == RolloutState.COMPLETED: + self._completed_actions[replay_meta.version].append(action_id) + self.logger.debug(f"Add sample with root_id: {root_id}, action_id: {action_id} to finished_actions.") + else: + assert False, f"Unsupported rollout state {state} for action_id {action_id} in ReplayBufferStorage." @ray.remote @@ -518,38 +718,35 @@ def __init__( config (ReplayBufferConfig): The configuration object. """ self.config = config - self.storage = ReplayBufferStorage(config.worker_log_dir) - self.tokenizer = config.tokenizer - if isinstance(self.tokenizer, str): - self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer, trust_remote_code=True) - self.datasets = build_datasets(config.dataset_cfg, self.tokenizer) - - if config.dataloader_cfg is not None: - self.dataloader_cfg = config.dataloader_cfg + self.storage = ReplayBufferStorage(config) + self.sampler = DatasetSampler(config.dataset_cfg, config.dataloader_cfg, config.tokenizer) + self.post_processor_func = config.postprocessor_func + self.sample_from_expired_states = False + self.sample_from_dataset_count = 0 + self.logger = get_logger(log_dir=config.worker_log_dir, tag="ReplayBuffer") + + def get_prerun_state(self, target_batch_size: int): + remain_size = target_batch_size - self.storage.completed_samples_count + if remain_size <= 0: + self.sample_from_expired_states = False else: - self.dataloader_cfg = DataloaderConfig( - collator="fake_collator", - pack_level="none", + expired_threshold = ( + min(remain_size, self.config.tail_batch_trigger_size) + if self.config.tail_batch_trigger_size + else remain_size ) - self.dataloader = build_dataloader( - dataloader_config=self.dataloader_cfg, - datasets=self.datasets, - global_batch_size=1, - micro_batch_size=1, - seed=1, - ) - - self.sampler = Sampler( - self.datasets, - self.dataloader, - self.tokenizer, - self.storage, - ) - self.post_processor_func = config.postprocessor_func + if self.storage.expired_samples_count > expired_threshold: + self.sample_from_expired_states = True + self.logger.info( + f"Enable sampling from expired states. Expired samples: {self.storage.expired_samples_count}, threshold: {expired_threshold}, remain_size: {remain_size}" + ) + else: + self.sample_from_expired_states = False + return self.sample_from_expired_states, self.storage.completed_samples_count def get_train_dataset_length(self): """Returns the length of the training dataloader.""" - return len(self.dataloader) + return len(self.sampler.dataloader) def post_processor(self, group_samples): """Applies a post-processing function to a group of samples. @@ -565,7 +762,7 @@ def post_processor(self, group_samples): return group_samples return group_samples - def sample(self, env, enable_partial_rollout: int, prompt_repeat_k: int) -> List[RLDataFlowItem]: + def sample(self, env, prompt_repeat_k) -> List[RLDataFlowItem]: """Samples a batch of experiences from the replay buffer. Args: @@ -576,8 +773,12 @@ def sample(self, env, enable_partial_rollout: int, prompt_repeat_k: int) -> List Returns: A list of sampled data items. """ - - return self.sampler.sample(env, enable_partial_rollout, prompt_repeat_k) + storage_samples = self.storage.sample(self.sample_from_expired_states) + if storage_samples: + return storage_samples + else: + self.sample_from_dataset_count += 1 + return self.sampler.sample(env, prompt_repeat_k) def get_samples( self, @@ -602,9 +803,15 @@ def add(self, grouped_dataitem: List[RLDataFlowItem]): """ self.storage.add(grouped_dataitem) - def print(self): - """Prints the current state of the replay buffer storage.""" - self.storage.print() + def status(self): + return { + "remain_completed_samples_count": self.storage.completed_samples_count, + "remain_aborted_samples_count": self.storage.aborted_samples_count, + "remain_expired_samples_count": self.storage.expired_samples_count, + "sample_from_dataset_count": self.sample_from_dataset_count, + "sample_from_aborted_count": self.storage.sample_from_aborted_count, + "sample_from_expired_count": self.storage.sample_from_expired_count, + } def dump(self, file_path: str): """Dumps the replay buffer's storage to a file. @@ -614,9 +821,6 @@ def dump(self, file_path: str): """ self.storage.dump(file_path) - def status(self): - return self.storage.status() - def resume(self, file_path: str): """Resumes the replay buffer's storage from a file. @@ -625,16 +829,17 @@ def resume(self, file_path: str): state. """ self.storage.resume(file_path) - num = self.storage.get_prompt_num() - self.sampler.resume(num) + status = self.status() + prompt_num = status["prompt_count"] + self.sampler.resume(prompt_num) - def get_finished_samples(self): - """Returns the number of finished sample groups in the storage.""" - return self.storage.get_finished_samples() + def get_completed_samples_count(self) -> int: + """Returns the count of completed samples in the replay buffer. - def get_unfinished_samples(self): - """Returns the number of unfinished sample groups in the storage.""" - return self.storage.get_unfinished_samples() + Returns: + int: The number of completed samples. + """ + return self.storage.completed_samples_count def clear(self): """Clears the replay buffer storage."""