4242from transfer_queue import (
4343 BatchMeta ,
4444 TransferQueueController ,
45- TransferQueueStorageSimpleUnit ,
45+ SimpleStorageUnit ,
4646 get_placement_group ,
4747 process_zmq_server_info ,
4848)
8989from verl .utils .transferqueue_utils import (
9090 create_transferqueue_client ,
9191 get_transferqueue_client ,
92- get_val_transferqueue_client ,
9392 tqbridge ,
9493)
9594
@@ -412,111 +411,53 @@ def __init__(
412411
413412 self ._create_dataloader (train_dataset , val_dataset , collate_fn , train_sampler )
414413
415- self .data_system_client = self ._initialize_train_data_system (
416- self .config .data .train_batch_size , self .config .actor_rollout_ref .rollout .n
414+ self .data_system_client = self ._initialize_data_system ()
415+
416+ def _initialize_data_system (self ):
417+ # 1. initialize TransferQueueStorage
418+ train_data_size = (
419+ self .config .data .train_batch_size * self .config .trainer .num_global_batch *
420+ self .config .actor_rollout_ref .rollout .n
417421 )
418- self .val_data_system_client = self ._initialize_val_data_system (
419- self .val_batch_size , self .config .actor_rollout_ref .rollout .val_kwargs .n
422+ val_data_size = (
423+ self .val_batch_size * self .config .trainer .num_global_batch *
424+ self .config .actor_rollout_ref .rollout .val_kwargs .n
420425 )
421426
422- def _initialize_train_data_system (self , global_batch_size , num_n_samples , role = "train" ):
423- # 1. initialize TransferQueueStorage
424- total_storage_size = global_batch_size * self .config .trainer .num_global_batch * num_n_samples
427+ total_storage_size = train_data_size + val_data_size
425428 self .data_system_storage_units = {}
426429 storage_placement_group = get_placement_group (self .config .trainer .num_data_storage_units , num_cpus_per_actor = 1 )
427430 for storage_unit_rank in range (self .config .trainer .num_data_storage_units ):
428- storage_node = TransferQueueStorageSimpleUnit .options (
431+ storage_node = SimpleStorageUnit .options (
429432 placement_group = storage_placement_group , placement_group_bundle_index = storage_unit_rank
430- ).remote (storage_size = math .ceil (total_storage_size / self .config .trainer .num_data_storage_units ))
433+ ).remote (storage_unit_size = math .ceil (total_storage_size / self .config .trainer .num_data_storage_units ))
431434 self .data_system_storage_units [storage_unit_rank ] = storage_node
432- logging .info (f"TransferQueueStorageSimpleUnit #{ storage_unit_rank } has been created." )
435+ logging .info (f"SimpleStorageUnit #{ storage_unit_rank } has been created." )
433436
434437 # 2. initialize TransferQueueController
435- # we support inilialize multiple controller instances for large-scale scenario. Please allocate exactly
436- # one controller for a single WorkerGroup.
437- self .data_system_controllers = {}
438- controller_placement_group = get_placement_group (self .config .trainer .num_data_controllers , num_cpus_per_actor = 1 )
439- for controller_rank in range (self .config .trainer .num_data_controllers ):
440- self .data_system_controllers [controller_rank ] = TransferQueueController .options (
441- placement_group = controller_placement_group , placement_group_bundle_index = controller_rank
442- ).remote (
443- num_storage_units = self .config .trainer .num_data_storage_units ,
444- global_batch_size = global_batch_size ,
445- num_global_batch = self .config .trainer .num_global_batch ,
446- num_n_samples = num_n_samples ,
447- )
448- logging .info (f"TransferQueueController #{ controller_rank } has been created." )
438+ self .data_system_controller = TransferQueueController .remote ()
439+ logging .info ("TransferQueueController has been created." )
449440
450- # 3. register controller & storage
451- self .data_system_controller_infos = process_zmq_server_info (self .data_system_controllers )
441+ # 3. register controller & storage and prepare necessary information
442+ self .data_system_controller_info = process_zmq_server_info (self .data_system_controller )
452443 self .data_system_storage_unit_infos = process_zmq_server_info (self .data_system_storage_units )
453444
454- ray .get (
455- [
456- storage_unit .register_controller_info .remote (self .data_system_controller_infos )
457- for storage_unit in self .data_system_storage_units .values ()
458- ]
459- )
445+ tq_config = OmegaConf .create ({}, flags = {"allow_objects" : True }) # Note: Need to generate a new DictConfig
446+ # with allow_objects=True to maintain ZMQServerInfo instance. Otherwise it will be flattened to dict
447+ tq_config .controller_info = self .data_system_controller_info
448+ tq_config .storage_unit_infos = self .data_system_storage_unit_infos
449+ self .config = OmegaConf .merge (tq_config , self .config )
460450
461451 # 4. create client
462452 # each client should be allocated to exactly one controller
463453 create_transferqueue_client (
464- client_id = "Trainer-" + role ,
465- controller_infos = self .data_system_controller_infos ,
466- storage_infos = self .data_system_storage_unit_infos ,
454+ client_id = "Trainer" ,
455+ controller_info = self .data_system_controller_info ,
456+ config = self .config ,
467457 )
468458 data_system_client = get_transferqueue_client ()
469459 return data_system_client
470460
471- def _initialize_val_data_system (self , global_batch_size , num_n_samples , role = "val" ):
472- # 1. initialize TransferQueueStorage
473- total_storage_size = global_batch_size * self .config .trainer .num_global_batch * num_n_samples
474- self .val_data_system_storage_units = {}
475- storage_placement_group = get_placement_group (self .config .trainer .num_data_storage_units , num_cpus_per_actor = 1 )
476- for storage_unit_rank in range (self .config .trainer .num_data_storage_units ):
477- storage_node = TransferQueueStorageSimpleUnit .options (
478- placement_group = storage_placement_group , placement_group_bundle_index = storage_unit_rank
479- ).remote (storage_size = math .ceil (total_storage_size / self .config .trainer .num_data_storage_units ))
480- self .val_data_system_storage_units [storage_unit_rank ] = storage_node
481- logging .info (f"TransferQueueStorageSimpleUnit #{ storage_unit_rank } has been created." )
482-
483- # 2. initialize TransferQueueController
484- # we support inilialize multiple controller instances for large-scale scenario. Please allocate exactly
485- # one controller for a single WorkerGroup.
486- self .val_data_system_controllers = {}
487- controller_placement_group = get_placement_group (self .config .trainer .num_data_controllers , num_cpus_per_actor = 1 )
488- for controller_rank in range (self .config .trainer .num_data_controllers ):
489- self .val_data_system_controllers [controller_rank ] = TransferQueueController .options (
490- placement_group = controller_placement_group , placement_group_bundle_index = controller_rank
491- ).remote (
492- num_storage_units = self .config .trainer .num_data_storage_units ,
493- global_batch_size = global_batch_size ,
494- num_global_batch = self .config .trainer .num_global_batch ,
495- num_n_samples = num_n_samples ,
496- )
497- logging .info (f"TransferQueueController #{ controller_rank } has been created." )
498-
499- # 3. register controller & storage
500- self .val_data_system_controller_infos = process_zmq_server_info (self .val_data_system_controllers )
501- self .val_data_system_storage_unit_infos = process_zmq_server_info (self .val_data_system_storage_units )
502-
503- ray .get (
504- [
505- storage_unit .register_controller_info .remote (self .val_data_system_controller_infos )
506- for storage_unit in self .val_data_system_storage_units .values ()
507- ]
508- )
509-
510- # 4. create client
511- # each client should be allocated to exactly one controller
512- create_transferqueue_client (
513- client_id = "Trainer-" + role ,
514- controller_infos = self .val_data_system_controller_infos ,
515- storage_infos = self .val_data_system_storage_unit_infos ,
516- )
517- data_system_client = get_val_transferqueue_client ()
518- return data_system_client
519-
520461 def _create_dataloader (self , train_dataset , val_dataset , collate_fn , train_sampler : Optional [Sampler ]):
521462 """
522463 Creates the train and validation dataloaders.
@@ -726,19 +667,21 @@ def _validate(self):
726667 if self .config .reward_model .enable and test_batch [0 ]["reward_model" ]["style" ] == "model" :
727668 return {}
728669
729- asyncio .run (self .val_data_system_client .async_put (data = test_batch , global_step = self .global_steps - 1 ))
670+ asyncio .run (
671+ self .data_system_client .async_put (data = test_batch , partition_id = f"val_{ self .global_steps - 1 } " )
672+ )
730673
731674 # Store original inputs
732675 batch_meta = asyncio .run (
733- self .val_data_system_client .async_get_meta (
676+ self .data_system_client .async_get_meta (
734677 data_fields = ["input_ids" , "uid" , "reward_model" ],
735678 batch_size = self .val_batch_size * self .config .actor_rollout_ref .rollout .val_kwargs .n ,
736- global_step = self .global_steps - 1 ,
679+ partition_id = f"val_ { self .global_steps - 1 } " ,
737680 get_n_samples = False ,
738681 task_name = "get_data" ,
739682 )
740683 )
741- data = asyncio .run (self .val_data_system_client .async_get_data (batch_meta ))
684+ data = asyncio .run (self .data_system_client .async_get_data (batch_meta ))
742685 input_ids = data ["input_ids" ]
743686 # TODO: Can we keep special tokens except for padding tokens?
744687 input_texts = [self .tokenizer .decode (ids , skip_special_tokens = True ) for ids in input_ids ]
@@ -749,10 +692,10 @@ def _validate(self):
749692 sample_gts .extend (ground_truths )
750693
751694 test_gen_meta = asyncio .run (
752- self .val_data_system_client .async_get_meta (
695+ self .data_system_client .async_get_meta (
753696 data_fields = list (test_batch .keys ()), # TODO: (TQ) Get metadata by specified fields
754697 batch_size = self .val_batch_size * self .config .actor_rollout_ref .rollout .val_kwargs .n ,
755- global_step = self .global_steps - 1 , # self.global_steps start from 1
698+ partition_id = f"val_ { self .global_steps - 1 } " , # self.global_steps start from 1
756699 get_n_samples = False ,
757700 task_name = "generate_sequences" ,
758701 )
@@ -779,15 +722,15 @@ def _validate(self):
779722
780723 # Store generated outputs
781724 test_response_meta = asyncio .run (
782- self .val_data_system_client .async_get_meta (
725+ self .data_system_client .async_get_meta (
783726 data_fields = ["responses" ],
784727 batch_size = self .val_batch_size * self .config .actor_rollout_ref .rollout .val_kwargs .n ,
785- global_step = self .global_steps - 1 , # self.global_steps start from 1
728+ partition_id = f"val_ { self .global_steps - 1 } " , # self.global_steps start from 1
786729 get_n_samples = False ,
787730 task_name = "get_response" ,
788731 )
789732 )
790- data = asyncio .run (self .val_data_system_client .async_get_data (test_response_meta ))
733+ data = asyncio .run (self .data_system_client .async_get_data (test_response_meta ))
791734 output_ids = data ["responses" ]
792735 output_texts = [self .tokenizer .decode (ids , skip_special_tokens = True ) for ids in output_ids ]
793736 sample_outputs .extend (output_texts )
@@ -808,10 +751,10 @@ def _validate(self):
808751 if "rm_scores" in batch_meta .field_names :
809752 compute_reward_fields = ["rm_scores" ]
810753 val_reward_meta = asyncio .run (
811- self .val_data_system_client .async_get_meta (
754+ self .data_system_client .async_get_meta (
812755 data_fields = compute_reward_fields ,
813756 batch_size = self .val_batch_size * self .config .actor_rollout_ref .rollout .val_kwargs .n ,
814- global_step = self .global_steps - 1 ,
757+ partition_id = f"val_ { self .global_steps - 1 } " ,
815758 get_n_samples = False ,
816759 task_name = "compute_reward" ,
817760 )
@@ -832,29 +775,29 @@ def _validate(self):
832775 # collect num_turns of each prompt
833776 if "__num_turns__" in test_batch_meta .field_names :
834777 num_turns_meta = asyncio .run (
835- self .val_data_system_client .async_get_meta (
778+ self .data_system_client .async_get_meta (
836779 data_fields = ["__num_turns__" ],
837780 batch_size = self .val_batch_size * self .config .actor_rollout_ref .rollout .val_kwargs .n ,
838- global_step = self .global_steps - 1 , # self.global_steps start from 1
781+ partition_id = f"val_ { self .global_steps - 1 } " , # self.global_steps start from 1
839782 get_n_samples = False ,
840783 task_name = "get_num_turns" ,
841784 )
842785 )
843- data = asyncio .run (self .val_data_system_client .async_get_data (num_turns_meta ))
786+ data = asyncio .run (self .data_system_client .async_get_data (num_turns_meta ))
844787 sample_turns .append (data ["__num_turns__" ])
845788
846789 data_source = ["unknown" ] * reward_tensor .shape [0 ]
847790 if "data_source" in test_batch_meta .field_names :
848791 data_source_meta = asyncio .run (
849- self .val_data_system_client .async_get_meta (
792+ self .data_system_client .async_get_meta (
850793 data_fields = ["data_source" ],
851794 batch_size = self .val_batch_size * self .config .actor_rollout_ref .rollout .val_kwargs .n ,
852- global_step = self .global_steps - 1 , # self.global_steps start from 1
795+ partition_id = f"val_ { self .global_steps - 1 } " , # self.global_steps start from 1
853796 get_n_samples = False ,
854797 task_name = "get_data_source" ,
855798 )
856799 )
857- data = asyncio .run (self .val_data_system_client .async_get_data (data_source_meta ))
800+ data = asyncio .run (self .data_system_client .async_get_data (data_source_meta ))
858801 data_source = data ["data_source" ]
859802
860803 data_source_lst .append (data_source )
@@ -902,7 +845,7 @@ def _validate(self):
902845 metric_dict ["val-aux/num_turns/max" ] = sample_turns .max ()
903846 metric_dict ["val-aux/num_turns/mean" ] = sample_turns .mean ()
904847
905- asyncio .run (self .val_data_system_client .async_clear (self .global_steps - 1 ))
848+ asyncio .run (self .data_system_client .async_clear (partition_id = f"val_ { self .global_steps - 1 } " ))
906849 return metric_dict
907850
908851 def init_workers (self ):
@@ -1003,12 +946,7 @@ def init_workers(self):
1003946
1004947 # set transferqueue server info for each worker
1005948 for _ , wg in all_wg .items ():
1006- wg .create_transferqueue_client (
1007- self .data_system_controller_infos , self .data_system_storage_unit_infos , role = "train"
1008- )
1009- wg .create_transferqueue_client (
1010- self .val_data_system_controller_infos , self .val_data_system_storage_unit_infos , role = "val"
1011- )
949+ wg .create_transferqueue_client (self .data_system_controller_info , self .config )
1012950
1013951 # create async rollout manager and request scheduler
1014952 self .async_rollout_mode = False
@@ -1021,10 +959,7 @@ def init_workers(self):
1021959 )
1022960
1023961 self .async_rollout_manager .create_transferqueue_client (
1024- self .data_system_controller_infos , self .data_system_storage_unit_infos , role = "train"
1025- )
1026- self .async_rollout_manager .create_transferqueue_client (
1027- self .val_data_system_controller_infos , self .val_data_system_storage_unit_infos , role = "val"
962+ self .data_system_controller_info , self .config
1028963 )
1029964
1030965 def _save_checkpoint (self ):
@@ -1313,7 +1248,7 @@ def fit(self):
13131248 timing_raw = {}
13141249 base_get_meta_kwargs = dict (
13151250 batch_size = self .config .data .train_batch_size * self .config .actor_rollout_ref .rollout .n ,
1316- global_step = self .global_steps - 1 , # self.global_steps starts from 1
1251+ partition_id = f"train_ { self .global_steps - 1 } " , # self.global_steps starts from 1
13171252 get_n_samples = False ,
13181253 )
13191254
@@ -1333,7 +1268,9 @@ def fit(self):
13331268 batch_dict , repeat_times = self .config .actor_rollout_ref .rollout .n , interleave = True
13341269 )
13351270 batch : TensorDict = self .dict_to_tensordict (repeated_batch_dict )
1336- asyncio .run (self .data_system_client .async_put (data = batch , global_step = self .global_steps - 1 ))
1271+ asyncio .run (
1272+ self .data_system_client .async_put (data = batch , partition_id = f"train_{ self .global_steps - 1 } " )
1273+ )
13371274
13381275 gen_meta = asyncio .run (
13391276 self .data_system_client .async_get_meta (
@@ -1709,7 +1646,7 @@ def fit(self):
17091646 ],
17101647 batch_size = self .config .data .train_batch_size
17111648 * self .config .actor_rollout_ref .rollout .n ,
1712- global_step = self .global_steps - 1 ,
1649+ partition_id = f"train_ { self .global_steps - 1 } " ,
17131650 get_n_samples = False ,
17141651 task_name = "update_actor" ,
17151652 )
@@ -1735,7 +1672,7 @@ def fit(self):
17351672 self .data_system_client .async_get_meta (
17361673 data_fields = data_fields ,
17371674 batch_size = self .config .data .train_batch_size * self .config .actor_rollout_ref .rollout .n ,
1738- global_step = self .global_steps - 1 ,
1675+ partition_id = f"train_ { self .global_steps - 1 } " ,
17391676 get_n_samples = False ,
17401677 task_name = "log_rollout" ,
17411678 )
@@ -1857,7 +1794,7 @@ def fit(self):
18571794 # TODO: (TQ) support transfer queue
18581795 self .train_dataloader .sampler .update (batch = batch )
18591796
1860- asyncio .run (self .data_system_client .async_clear (self .global_steps - 1 ))
1797+ asyncio .run (self .data_system_client .async_clear (partition_id = f"train_ { self .global_steps - 1 } " ))
18611798 # TODO: make a canonical logger that supports various backend
18621799 logger .log (data = metrics , step = self .global_steps )
18631800
0 commit comments