4141from tqdm import tqdm
4242from transfer_queue import (
4343 BatchMeta ,
44+ SimpleStorageUnit ,
4445 TransferQueueController ,
45- TransferQueueStorageSimpleUnit ,
4646 get_placement_group ,
4747 process_zmq_server_info ,
4848)
8989from verl .utils .transferqueue_utils import (
9090 create_transferqueue_client ,
9191 get_transferqueue_client ,
92- get_val_transferqueue_client ,
9392 tqbridge ,
9493)
9594
@@ -412,111 +411,56 @@ def __init__(
412411
413412 self ._create_dataloader (train_dataset , val_dataset , collate_fn , train_sampler )
414413
415- self .data_system_client = self ._initialize_train_data_system (
416- self .config .data .train_batch_size , self .config .actor_rollout_ref .rollout .n
414+ self .data_system_client = self ._initialize_data_system ()
415+
416+ def _initialize_data_system (self ):
417+ # 1. initialize TransferQueueStorage
418+ train_data_size = (
419+ self .config .data .train_batch_size
420+ * self .config .trainer .num_global_batch
421+ * self .config .actor_rollout_ref .rollout .n
417422 )
418- self .val_data_system_client = self ._initialize_val_data_system (
419- self .val_batch_size , self .config .actor_rollout_ref .rollout .val_kwargs .n
423+ val_data_size = (
424+ self .val_batch_size
425+ * self .config .trainer .num_global_batch
426+ * self .config .actor_rollout_ref .rollout .val_kwargs .n
420427 )
421428
422- def _initialize_train_data_system (self , global_batch_size , num_n_samples , role = "train" ):
423- # 1. initialize TransferQueueStorage
424- total_storage_size = global_batch_size * self .config .trainer .num_global_batch * num_n_samples
429+ total_storage_size = train_data_size + val_data_size
425430 self .data_system_storage_units = {}
426431 storage_placement_group = get_placement_group (self .config .trainer .num_data_storage_units , num_cpus_per_actor = 1 )
427432 for storage_unit_rank in range (self .config .trainer .num_data_storage_units ):
428- storage_node = TransferQueueStorageSimpleUnit .options (
433+ storage_node = SimpleStorageUnit .options (
429434 placement_group = storage_placement_group , placement_group_bundle_index = storage_unit_rank
430- ).remote (storage_size = math .ceil (total_storage_size / self .config .trainer .num_data_storage_units ))
435+ ).remote (storage_unit_size = math .ceil (total_storage_size / self .config .trainer .num_data_storage_units ))
431436 self .data_system_storage_units [storage_unit_rank ] = storage_node
432- logging .info (f"TransferQueueStorageSimpleUnit #{ storage_unit_rank } has been created." )
437+ logging .info (f"SimpleStorageUnit #{ storage_unit_rank } has been created." )
433438
434439 # 2. initialize TransferQueueController
435- # we support inilialize multiple controller instances for large-scale scenario. Please allocate exactly
436- # one controller for a single WorkerGroup.
437- self .data_system_controllers = {}
438- controller_placement_group = get_placement_group (self .config .trainer .num_data_controllers , num_cpus_per_actor = 1 )
439- for controller_rank in range (self .config .trainer .num_data_controllers ):
440- self .data_system_controllers [controller_rank ] = TransferQueueController .options (
441- placement_group = controller_placement_group , placement_group_bundle_index = controller_rank
442- ).remote (
443- num_storage_units = self .config .trainer .num_data_storage_units ,
444- global_batch_size = global_batch_size ,
445- num_global_batch = self .config .trainer .num_global_batch ,
446- num_n_samples = num_n_samples ,
447- )
448- logging .info (f"TransferQueueController #{ controller_rank } has been created." )
440+ self .data_system_controller = TransferQueueController .remote ()
441+ logging .info ("TransferQueueController has been created." )
449442
450- # 3. register controller & storage
451- self .data_system_controller_infos = process_zmq_server_info (self .data_system_controllers )
443+ # 3. register controller & storage and prepare necessary information
444+ self .data_system_controller_info = process_zmq_server_info (self .data_system_controller )
452445 self .data_system_storage_unit_infos = process_zmq_server_info (self .data_system_storage_units )
453446
454- ray .get (
455- [
456- storage_unit .register_controller_info .remote (self .data_system_controller_infos )
457- for storage_unit in self .data_system_storage_units .values ()
458- ]
459- )
447+ # Note: Need to generate a new DictConfig with allow_objects=True to preserve ZMQServerInfo instances
448+ # (which contain socket connection details). Without this flag, OmegaConf would flatten these objects to dicts,
449+ # breaking the transfer queue client initialization.
450+ tq_config = OmegaConf .create ({}, flags = {"allow_objects" : True })
451+ tq_config .controller_info = self .data_system_controller_info
452+ tq_config .storage_unit_infos = self .data_system_storage_unit_infos
453+ self .config = OmegaConf .merge (tq_config , self .config )
460454
461455 # 4. create client
462- # each client should be allocated to exactly one controller
463456 create_transferqueue_client (
464- client_id = "Trainer-" + role ,
465- controller_infos = self .data_system_controller_infos ,
466- storage_infos = self .data_system_storage_unit_infos ,
457+ client_id = "Trainer" ,
458+ controller_info = self .data_system_controller_info ,
459+ config = self .config ,
467460 )
468461 data_system_client = get_transferqueue_client ()
469462 return data_system_client
470463
471- def _initialize_val_data_system (self , global_batch_size , num_n_samples , role = "val" ):
472- # 1. initialize TransferQueueStorage
473- total_storage_size = global_batch_size * self .config .trainer .num_global_batch * num_n_samples
474- self .val_data_system_storage_units = {}
475- storage_placement_group = get_placement_group (self .config .trainer .num_data_storage_units , num_cpus_per_actor = 1 )
476- for storage_unit_rank in range (self .config .trainer .num_data_storage_units ):
477- storage_node = TransferQueueStorageSimpleUnit .options (
478- placement_group = storage_placement_group , placement_group_bundle_index = storage_unit_rank
479- ).remote (storage_size = math .ceil (total_storage_size / self .config .trainer .num_data_storage_units ))
480- self .val_data_system_storage_units [storage_unit_rank ] = storage_node
481- logging .info (f"TransferQueueStorageSimpleUnit #{ storage_unit_rank } has been created." )
482-
483- # 2. initialize TransferQueueController
484- # we support inilialize multiple controller instances for large-scale scenario. Please allocate exactly
485- # one controller for a single WorkerGroup.
486- self .val_data_system_controllers = {}
487- controller_placement_group = get_placement_group (self .config .trainer .num_data_controllers , num_cpus_per_actor = 1 )
488- for controller_rank in range (self .config .trainer .num_data_controllers ):
489- self .val_data_system_controllers [controller_rank ] = TransferQueueController .options (
490- placement_group = controller_placement_group , placement_group_bundle_index = controller_rank
491- ).remote (
492- num_storage_units = self .config .trainer .num_data_storage_units ,
493- global_batch_size = global_batch_size ,
494- num_global_batch = self .config .trainer .num_global_batch ,
495- num_n_samples = num_n_samples ,
496- )
497- logging .info (f"TransferQueueController #{ controller_rank } has been created." )
498-
499- # 3. register controller & storage
500- self .val_data_system_controller_infos = process_zmq_server_info (self .val_data_system_controllers )
501- self .val_data_system_storage_unit_infos = process_zmq_server_info (self .val_data_system_storage_units )
502-
503- ray .get (
504- [
505- storage_unit .register_controller_info .remote (self .val_data_system_controller_infos )
506- for storage_unit in self .val_data_system_storage_units .values ()
507- ]
508- )
509-
510- # 4. create client
511- # each client should be allocated to exactly one controller
512- create_transferqueue_client (
513- client_id = "Trainer-" + role ,
514- controller_infos = self .val_data_system_controller_infos ,
515- storage_infos = self .val_data_system_storage_unit_infos ,
516- )
517- data_system_client = get_val_transferqueue_client ()
518- return data_system_client
519-
520464 def _create_dataloader (self , train_dataset , val_dataset , collate_fn , train_sampler : Optional [Sampler ]):
521465 """
522466 Creates the train and validation dataloaders.
@@ -726,19 +670,19 @@ def _validate(self):
726670 if self .config .reward_model .enable and test_batch [0 ]["reward_model" ]["style" ] == "model" :
727671 return {}
728672
729- asyncio .run (self .val_data_system_client .async_put (data = test_batch , global_step = self .global_steps - 1 ))
673+ asyncio .run (self .data_system_client .async_put (data = test_batch , partition_id = f"val_ { self .global_steps - 1 } " ))
730674
731675 # Store original inputs
732676 batch_meta = asyncio .run (
733- self .val_data_system_client .async_get_meta (
677+ self .data_system_client .async_get_meta (
734678 data_fields = ["input_ids" , "uid" , "reward_model" ],
735679 batch_size = self .val_batch_size * self .config .actor_rollout_ref .rollout .val_kwargs .n ,
736- global_step = self .global_steps - 1 ,
680+ partition_id = f"val_ { self .global_steps - 1 } " ,
737681 get_n_samples = False ,
738682 task_name = "get_data" ,
739683 )
740684 )
741- data = asyncio .run (self .val_data_system_client .async_get_data (batch_meta ))
685+ data = asyncio .run (self .data_system_client .async_get_data (batch_meta ))
742686 input_ids = data ["input_ids" ]
743687 # TODO: Can we keep special tokens except for padding tokens?
744688 input_texts = [self .tokenizer .decode (ids , skip_special_tokens = True ) for ids in input_ids ]
@@ -749,10 +693,10 @@ def _validate(self):
749693 sample_gts .extend (ground_truths )
750694
751695 test_gen_meta = asyncio .run (
752- self .val_data_system_client .async_get_meta (
696+ self .data_system_client .async_get_meta (
753697 data_fields = list (test_batch .keys ()), # TODO: (TQ) Get metadata by specified fields
754698 batch_size = self .val_batch_size * self .config .actor_rollout_ref .rollout .val_kwargs .n ,
755- global_step = self .global_steps - 1 , # self.global_steps start from 1
699+ partition_id = f"val_ { self .global_steps - 1 } " , # self.global_steps start from 1
756700 get_n_samples = False ,
757701 task_name = "generate_sequences" ,
758702 )
@@ -779,15 +723,15 @@ def _validate(self):
779723
780724 # Store generated outputs
781725 test_response_meta = asyncio .run (
782- self .val_data_system_client .async_get_meta (
726+ self .data_system_client .async_get_meta (
783727 data_fields = ["responses" ],
784728 batch_size = self .val_batch_size * self .config .actor_rollout_ref .rollout .val_kwargs .n ,
785- global_step = self .global_steps - 1 , # self.global_steps start from 1
729+ partition_id = f"val_ { self .global_steps - 1 } " , # self.global_steps start from 1
786730 get_n_samples = False ,
787731 task_name = "get_response" ,
788732 )
789733 )
790- data = asyncio .run (self .val_data_system_client .async_get_data (test_response_meta ))
734+ data = asyncio .run (self .data_system_client .async_get_data (test_response_meta ))
791735 output_ids = data ["responses" ]
792736 output_texts = [self .tokenizer .decode (ids , skip_special_tokens = True ) for ids in output_ids ]
793737 sample_outputs .extend (output_texts )
@@ -808,10 +752,10 @@ def _validate(self):
808752 if "rm_scores" in batch_meta .field_names :
809753 compute_reward_fields = ["rm_scores" ]
810754 val_reward_meta = asyncio .run (
811- self .val_data_system_client .async_get_meta (
755+ self .data_system_client .async_get_meta (
812756 data_fields = compute_reward_fields ,
813757 batch_size = self .val_batch_size * self .config .actor_rollout_ref .rollout .val_kwargs .n ,
814- global_step = self .global_steps - 1 ,
758+ partition_id = f"val_ { self .global_steps - 1 } " ,
815759 get_n_samples = False ,
816760 task_name = "compute_reward" ,
817761 )
@@ -832,29 +776,29 @@ def _validate(self):
832776 # collect num_turns of each prompt
833777 if "__num_turns__" in test_batch_meta .field_names :
834778 num_turns_meta = asyncio .run (
835- self .val_data_system_client .async_get_meta (
779+ self .data_system_client .async_get_meta (
836780 data_fields = ["__num_turns__" ],
837781 batch_size = self .val_batch_size * self .config .actor_rollout_ref .rollout .val_kwargs .n ,
838- global_step = self .global_steps - 1 , # self.global_steps start from 1
782+ partition_id = f"val_ { self .global_steps - 1 } " , # self.global_steps start from 1
839783 get_n_samples = False ,
840784 task_name = "get_num_turns" ,
841785 )
842786 )
843- data = asyncio .run (self .val_data_system_client .async_get_data (num_turns_meta ))
787+ data = asyncio .run (self .data_system_client .async_get_data (num_turns_meta ))
844788 sample_turns .append (data ["__num_turns__" ])
845789
846790 data_source = ["unknown" ] * reward_tensor .shape [0 ]
847791 if "data_source" in test_batch_meta .field_names :
848792 data_source_meta = asyncio .run (
849- self .val_data_system_client .async_get_meta (
793+ self .data_system_client .async_get_meta (
850794 data_fields = ["data_source" ],
851795 batch_size = self .val_batch_size * self .config .actor_rollout_ref .rollout .val_kwargs .n ,
852- global_step = self .global_steps - 1 , # self.global_steps start from 1
796+ partition_id = f"val_ { self .global_steps - 1 } " , # self.global_steps start from 1
853797 get_n_samples = False ,
854798 task_name = "get_data_source" ,
855799 )
856800 )
857- data = asyncio .run (self .val_data_system_client .async_get_data (data_source_meta ))
801+ data = asyncio .run (self .data_system_client .async_get_data (data_source_meta ))
858802 data_source = data ["data_source" ]
859803
860804 data_source_lst .append (data_source )
@@ -902,7 +846,7 @@ def _validate(self):
902846 metric_dict ["val-aux/num_turns/max" ] = sample_turns .max ()
903847 metric_dict ["val-aux/num_turns/mean" ] = sample_turns .mean ()
904848
905- asyncio .run (self .val_data_system_client .async_clear (self .global_steps - 1 ))
849+ asyncio .run (self .data_system_client .async_clear (partition_id = f"val_ { self .global_steps - 1 } " ))
906850 return metric_dict
907851
908852 def init_workers (self ):
@@ -1003,12 +947,7 @@ def init_workers(self):
1003947
1004948 # set transferqueue server info for each worker
1005949 for _ , wg in all_wg .items ():
1006- wg .create_transferqueue_client (
1007- self .data_system_controller_infos , self .data_system_storage_unit_infos , role = "train"
1008- )
1009- wg .create_transferqueue_client (
1010- self .val_data_system_controller_infos , self .val_data_system_storage_unit_infos , role = "val"
1011- )
950+ wg .create_transferqueue_client (self .data_system_controller_info , self .config )
1012951
1013952 # create async rollout manager and request scheduler
1014953 self .async_rollout_mode = False
@@ -1020,12 +959,7 @@ def init_workers(self):
1020959 config = self .config , worker_group = self .actor_rollout_wg , rm_wg = self .rm_wg
1021960 )
1022961
1023- self .async_rollout_manager .create_transferqueue_client (
1024- self .data_system_controller_infos , self .data_system_storage_unit_infos , role = "train"
1025- )
1026- self .async_rollout_manager .create_transferqueue_client (
1027- self .val_data_system_controller_infos , self .val_data_system_storage_unit_infos , role = "val"
1028- )
962+ self .async_rollout_manager .create_transferqueue_client (self .data_system_controller_info , self .config )
1029963
1030964 def _save_checkpoint (self ):
1031965 from verl .utils .fs import local_mkdir_safe
@@ -1313,7 +1247,7 @@ def fit(self):
13131247 timing_raw = {}
13141248 base_get_meta_kwargs = dict (
13151249 batch_size = self .config .data .train_batch_size * self .config .actor_rollout_ref .rollout .n ,
1316- global_step = self .global_steps - 1 , # self.global_steps starts from 1
1250+ partition_id = f"train_ { self .global_steps - 1 } " , # self.global_steps starts from 1
13171251 get_n_samples = False ,
13181252 )
13191253
@@ -1333,7 +1267,9 @@ def fit(self):
13331267 batch_dict , repeat_times = self .config .actor_rollout_ref .rollout .n , interleave = True
13341268 )
13351269 batch : TensorDict = self .dict_to_tensordict (repeated_batch_dict )
1336- asyncio .run (self .data_system_client .async_put (data = batch , global_step = self .global_steps - 1 ))
1270+ asyncio .run (
1271+ self .data_system_client .async_put (data = batch , partition_id = f"train_{ self .global_steps - 1 } " )
1272+ )
13371273
13381274 gen_meta = asyncio .run (
13391275 self .data_system_client .async_get_meta (
@@ -1709,7 +1645,7 @@ def fit(self):
17091645 ],
17101646 batch_size = self .config .data .train_batch_size
17111647 * self .config .actor_rollout_ref .rollout .n ,
1712- global_step = self .global_steps - 1 ,
1648+ partition_id = f"train_ { self .global_steps - 1 } " ,
17131649 get_n_samples = False ,
17141650 task_name = "update_actor" ,
17151651 )
@@ -1735,7 +1671,7 @@ def fit(self):
17351671 self .data_system_client .async_get_meta (
17361672 data_fields = data_fields ,
17371673 batch_size = self .config .data .train_batch_size * self .config .actor_rollout_ref .rollout .n ,
1738- global_step = self .global_steps - 1 ,
1674+ partition_id = f"train_ { self .global_steps - 1 } " ,
17391675 get_n_samples = False ,
17401676 task_name = "log_rollout" ,
17411677 )
@@ -1857,7 +1793,7 @@ def fit(self):
18571793 # TODO: (TQ) support transfer queue
18581794 self .train_dataloader .sampler .update (batch = batch )
18591795
1860- asyncio .run (self .data_system_client .async_clear (self .global_steps - 1 ))
1796+ asyncio .run (self .data_system_client .async_clear (partition_id = f"train_ { self .global_steps - 1 } " ))
18611797 # TODO: make a canonical logger that supports various backend
18621798 logger .log (data = metrics , step = self .global_steps )
18631799
0 commit comments