diff --git a/recipe/transfer_queue/agent_loop.py b/recipe/transfer_queue/agent_loop.py index 871ae8025c0..7f936e6730e 100644 --- a/recipe/transfer_queue/agent_loop.py +++ b/recipe/transfer_queue/agent_loop.py @@ -67,10 +67,7 @@ def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: Data return timing - def create_transferqueue_client(self, controller_infos, storage_infos, role): + def create_transferqueue_client(self, controller_info, config): ray.get( - [ - worker.create_transferqueue_client.remote(controller_infos, storage_infos, role) - for worker in self.agent_loop_workers - ] + [worker.create_transferqueue_client.remote(controller_info, config) for worker in self.agent_loop_workers] ) diff --git a/recipe/transfer_queue/ray_trainer.py b/recipe/transfer_queue/ray_trainer.py index d6adbddb676..83b10d8c467 100644 --- a/recipe/transfer_queue/ray_trainer.py +++ b/recipe/transfer_queue/ray_trainer.py @@ -41,8 +41,8 @@ from tqdm import tqdm from transfer_queue import ( BatchMeta, + SimpleStorageUnit, TransferQueueController, - TransferQueueStorageSimpleUnit, get_placement_group, process_zmq_server_info, ) @@ -89,7 +89,6 @@ from verl.utils.transferqueue_utils import ( create_transferqueue_client, get_transferqueue_client, - get_val_transferqueue_client, tqbridge, ) @@ -412,111 +411,56 @@ def __init__( self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) - self.data_system_client = self._initialize_train_data_system( - self.config.data.train_batch_size, self.config.actor_rollout_ref.rollout.n + self.data_system_client = self._initialize_data_system() + + def _initialize_data_system(self): + # 1. initialize TransferQueueStorage + train_data_size = ( + self.config.data.train_batch_size + * self.config.trainer.num_global_batch + * self.config.actor_rollout_ref.rollout.n ) - self.val_data_system_client = self._initialize_val_data_system( - self.val_batch_size, self.config.actor_rollout_ref.rollout.val_kwargs.n + val_data_size = ( + self.val_batch_size + * self.config.trainer.num_global_batch + * self.config.actor_rollout_ref.rollout.val_kwargs.n ) - def _initialize_train_data_system(self, global_batch_size, num_n_samples, role="train"): - # 1. initialize TransferQueueStorage - total_storage_size = global_batch_size * self.config.trainer.num_global_batch * num_n_samples + total_storage_size = train_data_size + val_data_size self.data_system_storage_units = {} storage_placement_group = get_placement_group(self.config.trainer.num_data_storage_units, num_cpus_per_actor=1) for storage_unit_rank in range(self.config.trainer.num_data_storage_units): - storage_node = TransferQueueStorageSimpleUnit.options( + storage_node = SimpleStorageUnit.options( placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank - ).remote(storage_size=math.ceil(total_storage_size / self.config.trainer.num_data_storage_units)) + ).remote(storage_unit_size=math.ceil(total_storage_size / self.config.trainer.num_data_storage_units)) self.data_system_storage_units[storage_unit_rank] = storage_node - logging.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.") + logging.info(f"SimpleStorageUnit #{storage_unit_rank} has been created.") # 2. initialize TransferQueueController - # we support inilialize multiple controller instances for large-scale scenario. Please allocate exactly - # one controller for a single WorkerGroup. - self.data_system_controllers = {} - controller_placement_group = get_placement_group(self.config.trainer.num_data_controllers, num_cpus_per_actor=1) - for controller_rank in range(self.config.trainer.num_data_controllers): - self.data_system_controllers[controller_rank] = TransferQueueController.options( - placement_group=controller_placement_group, placement_group_bundle_index=controller_rank - ).remote( - num_storage_units=self.config.trainer.num_data_storage_units, - global_batch_size=global_batch_size, - num_global_batch=self.config.trainer.num_global_batch, - num_n_samples=num_n_samples, - ) - logging.info(f"TransferQueueController #{controller_rank} has been created.") + self.data_system_controller = TransferQueueController.remote() + logging.info("TransferQueueController has been created.") - # 3. register controller & storage - self.data_system_controller_infos = process_zmq_server_info(self.data_system_controllers) + # 3. register controller & storage and prepare necessary information + self.data_system_controller_info = process_zmq_server_info(self.data_system_controller) self.data_system_storage_unit_infos = process_zmq_server_info(self.data_system_storage_units) - ray.get( - [ - storage_unit.register_controller_info.remote(self.data_system_controller_infos) - for storage_unit in self.data_system_storage_units.values() - ] - ) + # Note: Need to generate a new DictConfig with allow_objects=True to preserve ZMQServerInfo instances + # (which contain socket connection details). Without this flag, OmegaConf would flatten these objects to dicts, + # breaking the transfer queue client initialization. + tq_config = OmegaConf.create({}, flags={"allow_objects": True}) + tq_config.controller_info = self.data_system_controller_info + tq_config.storage_unit_infos = self.data_system_storage_unit_infos + self.config = OmegaConf.merge(tq_config, self.config) # 4. create client - # each client should be allocated to exactly one controller create_transferqueue_client( - client_id="Trainer-" + role, - controller_infos=self.data_system_controller_infos, - storage_infos=self.data_system_storage_unit_infos, + client_id="Trainer", + controller_info=self.data_system_controller_info, + config=self.config, ) data_system_client = get_transferqueue_client() return data_system_client - def _initialize_val_data_system(self, global_batch_size, num_n_samples, role="val"): - # 1. initialize TransferQueueStorage - total_storage_size = global_batch_size * self.config.trainer.num_global_batch * num_n_samples - self.val_data_system_storage_units = {} - storage_placement_group = get_placement_group(self.config.trainer.num_data_storage_units, num_cpus_per_actor=1) - for storage_unit_rank in range(self.config.trainer.num_data_storage_units): - storage_node = TransferQueueStorageSimpleUnit.options( - placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank - ).remote(storage_size=math.ceil(total_storage_size / self.config.trainer.num_data_storage_units)) - self.val_data_system_storage_units[storage_unit_rank] = storage_node - logging.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.") - - # 2. initialize TransferQueueController - # we support inilialize multiple controller instances for large-scale scenario. Please allocate exactly - # one controller for a single WorkerGroup. - self.val_data_system_controllers = {} - controller_placement_group = get_placement_group(self.config.trainer.num_data_controllers, num_cpus_per_actor=1) - for controller_rank in range(self.config.trainer.num_data_controllers): - self.val_data_system_controllers[controller_rank] = TransferQueueController.options( - placement_group=controller_placement_group, placement_group_bundle_index=controller_rank - ).remote( - num_storage_units=self.config.trainer.num_data_storage_units, - global_batch_size=global_batch_size, - num_global_batch=self.config.trainer.num_global_batch, - num_n_samples=num_n_samples, - ) - logging.info(f"TransferQueueController #{controller_rank} has been created.") - - # 3. register controller & storage - self.val_data_system_controller_infos = process_zmq_server_info(self.val_data_system_controllers) - self.val_data_system_storage_unit_infos = process_zmq_server_info(self.val_data_system_storage_units) - - ray.get( - [ - storage_unit.register_controller_info.remote(self.val_data_system_controller_infos) - for storage_unit in self.val_data_system_storage_units.values() - ] - ) - - # 4. create client - # each client should be allocated to exactly one controller - create_transferqueue_client( - client_id="Trainer-" + role, - controller_infos=self.val_data_system_controller_infos, - storage_infos=self.val_data_system_storage_unit_infos, - ) - data_system_client = get_val_transferqueue_client() - return data_system_client - def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]): """ Creates the train and validation dataloaders. @@ -726,19 +670,19 @@ def _validate(self): if self.config.reward_model.enable and test_batch[0]["reward_model"]["style"] == "model": return {} - asyncio.run(self.val_data_system_client.async_put(data=test_batch, global_step=self.global_steps - 1)) + asyncio.run(self.data_system_client.async_put(data=test_batch, partition_id=f"val_{self.global_steps - 1}")) # Store original inputs batch_meta = asyncio.run( - self.val_data_system_client.async_get_meta( + self.data_system_client.async_get_meta( data_fields=["input_ids", "uid", "reward_model"], batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, - global_step=self.global_steps - 1, + partition_id=f"val_{self.global_steps - 1}", get_n_samples=False, task_name="get_data", ) ) - data = asyncio.run(self.val_data_system_client.async_get_data(batch_meta)) + data = asyncio.run(self.data_system_client.async_get_data(batch_meta)) input_ids = data["input_ids"] # TODO: Can we keep special tokens except for padding tokens? input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] @@ -749,10 +693,10 @@ def _validate(self): sample_gts.extend(ground_truths) test_gen_meta = asyncio.run( - self.val_data_system_client.async_get_meta( + self.data_system_client.async_get_meta( data_fields=list(test_batch.keys()), # TODO: (TQ) Get metadata by specified fields batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, - global_step=self.global_steps - 1, # self.global_steps start from 1 + partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1 get_n_samples=False, task_name="generate_sequences", ) @@ -779,15 +723,15 @@ def _validate(self): # Store generated outputs test_response_meta = asyncio.run( - self.val_data_system_client.async_get_meta( + self.data_system_client.async_get_meta( data_fields=["responses"], batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, - global_step=self.global_steps - 1, # self.global_steps start from 1 + partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1 get_n_samples=False, task_name="get_response", ) ) - data = asyncio.run(self.val_data_system_client.async_get_data(test_response_meta)) + data = asyncio.run(self.data_system_client.async_get_data(test_response_meta)) output_ids = data["responses"] output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] sample_outputs.extend(output_texts) @@ -808,10 +752,10 @@ def _validate(self): if "rm_scores" in batch_meta.field_names: compute_reward_fields = ["rm_scores"] val_reward_meta = asyncio.run( - self.val_data_system_client.async_get_meta( + self.data_system_client.async_get_meta( data_fields=compute_reward_fields, batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, - global_step=self.global_steps - 1, + partition_id=f"val_{self.global_steps - 1}", get_n_samples=False, task_name="compute_reward", ) @@ -832,29 +776,29 @@ def _validate(self): # collect num_turns of each prompt if "__num_turns__" in test_batch_meta.field_names: num_turns_meta = asyncio.run( - self.val_data_system_client.async_get_meta( + self.data_system_client.async_get_meta( data_fields=["__num_turns__"], batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, - global_step=self.global_steps - 1, # self.global_steps start from 1 + partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1 get_n_samples=False, task_name="get_num_turns", ) ) - data = asyncio.run(self.val_data_system_client.async_get_data(num_turns_meta)) + data = asyncio.run(self.data_system_client.async_get_data(num_turns_meta)) sample_turns.append(data["__num_turns__"]) data_source = ["unknown"] * reward_tensor.shape[0] if "data_source" in test_batch_meta.field_names: data_source_meta = asyncio.run( - self.val_data_system_client.async_get_meta( + self.data_system_client.async_get_meta( data_fields=["data_source"], batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, - global_step=self.global_steps - 1, # self.global_steps start from 1 + partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1 get_n_samples=False, task_name="get_data_source", ) ) - data = asyncio.run(self.val_data_system_client.async_get_data(data_source_meta)) + data = asyncio.run(self.data_system_client.async_get_data(data_source_meta)) data_source = data["data_source"] data_source_lst.append(data_source) @@ -902,7 +846,7 @@ def _validate(self): metric_dict["val-aux/num_turns/max"] = sample_turns.max() metric_dict["val-aux/num_turns/mean"] = sample_turns.mean() - asyncio.run(self.val_data_system_client.async_clear(self.global_steps - 1)) + asyncio.run(self.data_system_client.async_clear(partition_id=f"val_{self.global_steps - 1}")) return metric_dict def init_workers(self): @@ -1003,12 +947,7 @@ def init_workers(self): # set transferqueue server info for each worker for _, wg in all_wg.items(): - wg.create_transferqueue_client( - self.data_system_controller_infos, self.data_system_storage_unit_infos, role="train" - ) - wg.create_transferqueue_client( - self.val_data_system_controller_infos, self.val_data_system_storage_unit_infos, role="val" - ) + wg.create_transferqueue_client(self.data_system_controller_info, self.config) # create async rollout manager and request scheduler self.async_rollout_mode = False @@ -1020,12 +959,7 @@ def init_workers(self): config=self.config, worker_group=self.actor_rollout_wg, rm_wg=self.rm_wg ) - self.async_rollout_manager.create_transferqueue_client( - self.data_system_controller_infos, self.data_system_storage_unit_infos, role="train" - ) - self.async_rollout_manager.create_transferqueue_client( - self.val_data_system_controller_infos, self.val_data_system_storage_unit_infos, role="val" - ) + self.async_rollout_manager.create_transferqueue_client(self.data_system_controller_info, self.config) def _save_checkpoint(self): from verl.utils.fs import local_mkdir_safe @@ -1313,7 +1247,7 @@ def fit(self): timing_raw = {} base_get_meta_kwargs = dict( batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, - global_step=self.global_steps - 1, # self.global_steps starts from 1 + partition_id=f"train_{self.global_steps - 1}", # self.global_steps starts from 1 get_n_samples=False, ) @@ -1333,7 +1267,9 @@ def fit(self): batch_dict, repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True ) batch: TensorDict = self.dict_to_tensordict(repeated_batch_dict) - asyncio.run(self.data_system_client.async_put(data=batch, global_step=self.global_steps - 1)) + asyncio.run( + self.data_system_client.async_put(data=batch, partition_id=f"train_{self.global_steps - 1}") + ) gen_meta = asyncio.run( self.data_system_client.async_get_meta( @@ -1709,7 +1645,7 @@ def fit(self): ], batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, - global_step=self.global_steps - 1, + partition_id=f"train_{self.global_steps - 1}", get_n_samples=False, task_name="update_actor", ) @@ -1735,7 +1671,7 @@ def fit(self): self.data_system_client.async_get_meta( data_fields=data_fields, batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, - global_step=self.global_steps - 1, + partition_id=f"train_{self.global_steps - 1}", get_n_samples=False, task_name="log_rollout", ) @@ -1857,7 +1793,7 @@ def fit(self): # TODO: (TQ) support transfer queue self.train_dataloader.sampler.update(batch=batch) - asyncio.run(self.data_system_client.async_clear(self.global_steps - 1)) + asyncio.run(self.data_system_client.async_clear(partition_id=f"train_{self.global_steps - 1}")) # TODO: make a canonical logger that supports various backend logger.log(data=metrics, step=self.global_steps) diff --git a/requirements_transferqueue.txt b/requirements_transferqueue.txt index 8479d27bb21..387a61a456f 100644 --- a/requirements_transferqueue.txt +++ b/requirements_transferqueue.txt @@ -1,2 +1,2 @@ # requirements.txt records the full set of dependencies for development -git+https://github.com/TransferQueue/TransferQueue.git@68c04e7 +git+https://github.com/TransferQueue/TransferQueue.git@862b74a diff --git a/verl/single_controller/base/worker.py b/verl/single_controller/base/worker.py index 2513c57f99c..399ac75a063 100644 --- a/verl/single_controller/base/worker.py +++ b/verl/single_controller/base/worker.py @@ -131,13 +131,13 @@ def _query_collect_info(self, mesh_name: str): return self.__collect_dp_rank[mesh_name] @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=True) - def create_transferqueue_client(self, controller_infos, storage_infos, role="train"): + def create_transferqueue_client(self, controller_info, config): from verl.utils.transferqueue_utils import create_transferqueue_client create_transferqueue_client( - client_id=f"{role}_worker_{self.rank}", - controller_infos=controller_infos, - storage_infos=storage_infos, + client_id=f"worker_{self.rank}", + controller_info=controller_info, + config=config, ) @classmethod diff --git a/verl/utils/transferqueue_utils.py b/verl/utils/transferqueue_utils.py index 27160571ef3..c692578e3a0 100644 --- a/verl/utils/transferqueue_utils.py +++ b/verl/utils/transferqueue_utils.py @@ -38,32 +38,24 @@ class BatchMeta: from verl.protocol import DataProto _TRANSFER_QUEUE_CLIENT = None -_VAL_TRANSFER_QUEUE_CLIENT = None is_transferqueue_enabled = os.environ.get("TRANSFER_QUEUE_ENABLE", False) def create_transferqueue_client( client_id: str, - controller_infos: dict[Any, "ZMQServerInfo"], - storage_infos: dict[Any, "ZMQServerInfo"], + controller_info: dict[Any, "ZMQServerInfo"], + config, ) -> None: global _TRANSFER_QUEUE_CLIENT - global _VAL_TRANSFER_QUEUE_CLIENT - if "val" in client_id: - _VAL_TRANSFER_QUEUE_CLIENT = AsyncTransferQueueClient(client_id, controller_infos, storage_infos) - else: - _TRANSFER_QUEUE_CLIENT = AsyncTransferQueueClient(client_id, controller_infos, storage_infos) + _TRANSFER_QUEUE_CLIENT = AsyncTransferQueueClient(client_id, controller_info) + _TRANSFER_QUEUE_CLIENT.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config) def get_transferqueue_client() -> "AsyncTransferQueueClient": return _TRANSFER_QUEUE_CLIENT -def get_val_transferqueue_client() -> "AsyncTransferQueueClient": - return _VAL_TRANSFER_QUEUE_CLIENT - - def _run_async_in_temp_loop(async_func: Callable[..., Any], *args, **kwargs) -> Any: # Use a temporary event loop in a new thread because event # loop may already exist in server mode @@ -109,10 +101,7 @@ async def _async_batchmeta_to_dataproto(batchmeta: "BatchMeta") -> DataProto: meta_info=batchmeta.extra_info.copy(), ) - if batchmeta.extra_info.get("validate", False): - tensordict = await _VAL_TRANSFER_QUEUE_CLIENT.async_get_data(batchmeta) - else: - tensordict = await _TRANSFER_QUEUE_CLIENT.async_get_data(batchmeta) + tensordict = await _TRANSFER_QUEUE_CLIENT.async_get_data(batchmeta) return DataProto.from_tensordict(tensordict, meta_info=batchmeta.extra_info.copy()) @@ -130,10 +119,7 @@ async def _async_update_batchmeta_with_output(output: DataProto, batchmeta: "Bat for key in output.meta_info.keys(): tensordict.pop(key) batchmeta.add_fields(tensordict) - if batchmeta.extra_info.get("validate", False): - await _VAL_TRANSFER_QUEUE_CLIENT.async_put(data=tensordict, metadata=batchmeta) - else: - await _TRANSFER_QUEUE_CLIENT.async_put(data=tensordict, metadata=batchmeta) + await _TRANSFER_QUEUE_CLIENT.async_put(data=tensordict, metadata=batchmeta) def _update_batchmeta_with_output(output: DataProto, batchmeta: "BatchMeta") -> None: