Skip to content

Commit 9f00d21

Browse files
LLLLxmmmliuximeng
andauthored
[data] feat: TransferQueue - Support managing multiple data partitions for Train/Val/Test in controller (#43)
* [data] feat: TransferQueue - Support managing multiple data partitions for Train/Val/Test in controller * fix codecheck * fix comments * fix comments --------- Co-authored-by: liuximeng <13073314+liuximeng18772102439@user.noreply.gitee.com>
1 parent 20d0f98 commit 9f00d21

File tree

5 files changed

+71
-152
lines changed

5 files changed

+71
-152
lines changed

recipe/transfer_queue/agent_loop.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,7 @@ def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: Data
6767

6868
return timing
6969

70-
def create_transferqueue_client(self, controller_infos, storage_infos, role):
70+
def create_transferqueue_client(self, controller_info, config):
7171
ray.get(
72-
[
73-
worker.create_transferqueue_client.remote(controller_infos, storage_infos, role)
74-
for worker in self.agent_loop_workers
75-
]
72+
[worker.create_transferqueue_client.remote(controller_info, config) for worker in self.agent_loop_workers]
7673
)

recipe/transfer_queue/ray_trainer.py

Lines changed: 58 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@
4141
from tqdm import tqdm
4242
from transfer_queue import (
4343
BatchMeta,
44+
SimpleStorageUnit,
4445
TransferQueueController,
45-
TransferQueueStorageSimpleUnit,
4646
get_placement_group,
4747
process_zmq_server_info,
4848
)
@@ -89,7 +89,6 @@
8989
from verl.utils.transferqueue_utils import (
9090
create_transferqueue_client,
9191
get_transferqueue_client,
92-
get_val_transferqueue_client,
9392
tqbridge,
9493
)
9594

@@ -412,111 +411,56 @@ def __init__(
412411

413412
self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
414413

415-
self.data_system_client = self._initialize_train_data_system(
416-
self.config.data.train_batch_size, self.config.actor_rollout_ref.rollout.n
414+
self.data_system_client = self._initialize_data_system()
415+
416+
def _initialize_data_system(self):
417+
# 1. initialize TransferQueueStorage
418+
train_data_size = (
419+
self.config.data.train_batch_size
420+
* self.config.trainer.num_global_batch
421+
* self.config.actor_rollout_ref.rollout.n
417422
)
418-
self.val_data_system_client = self._initialize_val_data_system(
419-
self.val_batch_size, self.config.actor_rollout_ref.rollout.val_kwargs.n
423+
val_data_size = (
424+
self.val_batch_size
425+
* self.config.trainer.num_global_batch
426+
* self.config.actor_rollout_ref.rollout.val_kwargs.n
420427
)
421428

422-
def _initialize_train_data_system(self, global_batch_size, num_n_samples, role="train"):
423-
# 1. initialize TransferQueueStorage
424-
total_storage_size = global_batch_size * self.config.trainer.num_global_batch * num_n_samples
429+
total_storage_size = train_data_size + val_data_size
425430
self.data_system_storage_units = {}
426431
storage_placement_group = get_placement_group(self.config.trainer.num_data_storage_units, num_cpus_per_actor=1)
427432
for storage_unit_rank in range(self.config.trainer.num_data_storage_units):
428-
storage_node = TransferQueueStorageSimpleUnit.options(
433+
storage_node = SimpleStorageUnit.options(
429434
placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank
430-
).remote(storage_size=math.ceil(total_storage_size / self.config.trainer.num_data_storage_units))
435+
).remote(storage_unit_size=math.ceil(total_storage_size / self.config.trainer.num_data_storage_units))
431436
self.data_system_storage_units[storage_unit_rank] = storage_node
432-
logging.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.")
437+
logging.info(f"SimpleStorageUnit #{storage_unit_rank} has been created.")
433438

434439
# 2. initialize TransferQueueController
435-
# we support inilialize multiple controller instances for large-scale scenario. Please allocate exactly
436-
# one controller for a single WorkerGroup.
437-
self.data_system_controllers = {}
438-
controller_placement_group = get_placement_group(self.config.trainer.num_data_controllers, num_cpus_per_actor=1)
439-
for controller_rank in range(self.config.trainer.num_data_controllers):
440-
self.data_system_controllers[controller_rank] = TransferQueueController.options(
441-
placement_group=controller_placement_group, placement_group_bundle_index=controller_rank
442-
).remote(
443-
num_storage_units=self.config.trainer.num_data_storage_units,
444-
global_batch_size=global_batch_size,
445-
num_global_batch=self.config.trainer.num_global_batch,
446-
num_n_samples=num_n_samples,
447-
)
448-
logging.info(f"TransferQueueController #{controller_rank} has been created.")
440+
self.data_system_controller = TransferQueueController.remote()
441+
logging.info("TransferQueueController has been created.")
449442

450-
# 3. register controller & storage
451-
self.data_system_controller_infos = process_zmq_server_info(self.data_system_controllers)
443+
# 3. register controller & storage and prepare necessary information
444+
self.data_system_controller_info = process_zmq_server_info(self.data_system_controller)
452445
self.data_system_storage_unit_infos = process_zmq_server_info(self.data_system_storage_units)
453446

454-
ray.get(
455-
[
456-
storage_unit.register_controller_info.remote(self.data_system_controller_infos)
457-
for storage_unit in self.data_system_storage_units.values()
458-
]
459-
)
447+
# Note: Need to generate a new DictConfig with allow_objects=True to preserve ZMQServerInfo instances
448+
# (which contain socket connection details). Without this flag, OmegaConf would flatten these objects to dicts,
449+
# breaking the transfer queue client initialization.
450+
tq_config = OmegaConf.create({}, flags={"allow_objects": True})
451+
tq_config.controller_info = self.data_system_controller_info
452+
tq_config.storage_unit_infos = self.data_system_storage_unit_infos
453+
self.config = OmegaConf.merge(tq_config, self.config)
460454

461455
# 4. create client
462-
# each client should be allocated to exactly one controller
463456
create_transferqueue_client(
464-
client_id="Trainer-" + role,
465-
controller_infos=self.data_system_controller_infos,
466-
storage_infos=self.data_system_storage_unit_infos,
457+
client_id="Trainer",
458+
controller_info=self.data_system_controller_info,
459+
config=self.config,
467460
)
468461
data_system_client = get_transferqueue_client()
469462
return data_system_client
470463

471-
def _initialize_val_data_system(self, global_batch_size, num_n_samples, role="val"):
472-
# 1. initialize TransferQueueStorage
473-
total_storage_size = global_batch_size * self.config.trainer.num_global_batch * num_n_samples
474-
self.val_data_system_storage_units = {}
475-
storage_placement_group = get_placement_group(self.config.trainer.num_data_storage_units, num_cpus_per_actor=1)
476-
for storage_unit_rank in range(self.config.trainer.num_data_storage_units):
477-
storage_node = TransferQueueStorageSimpleUnit.options(
478-
placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank
479-
).remote(storage_size=math.ceil(total_storage_size / self.config.trainer.num_data_storage_units))
480-
self.val_data_system_storage_units[storage_unit_rank] = storage_node
481-
logging.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.")
482-
483-
# 2. initialize TransferQueueController
484-
# we support inilialize multiple controller instances for large-scale scenario. Please allocate exactly
485-
# one controller for a single WorkerGroup.
486-
self.val_data_system_controllers = {}
487-
controller_placement_group = get_placement_group(self.config.trainer.num_data_controllers, num_cpus_per_actor=1)
488-
for controller_rank in range(self.config.trainer.num_data_controllers):
489-
self.val_data_system_controllers[controller_rank] = TransferQueueController.options(
490-
placement_group=controller_placement_group, placement_group_bundle_index=controller_rank
491-
).remote(
492-
num_storage_units=self.config.trainer.num_data_storage_units,
493-
global_batch_size=global_batch_size,
494-
num_global_batch=self.config.trainer.num_global_batch,
495-
num_n_samples=num_n_samples,
496-
)
497-
logging.info(f"TransferQueueController #{controller_rank} has been created.")
498-
499-
# 3. register controller & storage
500-
self.val_data_system_controller_infos = process_zmq_server_info(self.val_data_system_controllers)
501-
self.val_data_system_storage_unit_infos = process_zmq_server_info(self.val_data_system_storage_units)
502-
503-
ray.get(
504-
[
505-
storage_unit.register_controller_info.remote(self.val_data_system_controller_infos)
506-
for storage_unit in self.val_data_system_storage_units.values()
507-
]
508-
)
509-
510-
# 4. create client
511-
# each client should be allocated to exactly one controller
512-
create_transferqueue_client(
513-
client_id="Trainer-" + role,
514-
controller_infos=self.val_data_system_controller_infos,
515-
storage_infos=self.val_data_system_storage_unit_infos,
516-
)
517-
data_system_client = get_val_transferqueue_client()
518-
return data_system_client
519-
520464
def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]):
521465
"""
522466
Creates the train and validation dataloaders.
@@ -726,19 +670,19 @@ def _validate(self):
726670
if self.config.reward_model.enable and test_batch[0]["reward_model"]["style"] == "model":
727671
return {}
728672

729-
asyncio.run(self.val_data_system_client.async_put(data=test_batch, global_step=self.global_steps - 1))
673+
asyncio.run(self.data_system_client.async_put(data=test_batch, partition_id=f"val_{self.global_steps - 1}"))
730674

731675
# Store original inputs
732676
batch_meta = asyncio.run(
733-
self.val_data_system_client.async_get_meta(
677+
self.data_system_client.async_get_meta(
734678
data_fields=["input_ids", "uid", "reward_model"],
735679
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
736-
global_step=self.global_steps - 1,
680+
partition_id=f"val_{self.global_steps - 1}",
737681
get_n_samples=False,
738682
task_name="get_data",
739683
)
740684
)
741-
data = asyncio.run(self.val_data_system_client.async_get_data(batch_meta))
685+
data = asyncio.run(self.data_system_client.async_get_data(batch_meta))
742686
input_ids = data["input_ids"]
743687
# TODO: Can we keep special tokens except for padding tokens?
744688
input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
@@ -749,10 +693,10 @@ def _validate(self):
749693
sample_gts.extend(ground_truths)
750694

751695
test_gen_meta = asyncio.run(
752-
self.val_data_system_client.async_get_meta(
696+
self.data_system_client.async_get_meta(
753697
data_fields=list(test_batch.keys()), # TODO: (TQ) Get metadata by specified fields
754698
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
755-
global_step=self.global_steps - 1, # self.global_steps start from 1
699+
partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1
756700
get_n_samples=False,
757701
task_name="generate_sequences",
758702
)
@@ -779,15 +723,15 @@ def _validate(self):
779723

780724
# Store generated outputs
781725
test_response_meta = asyncio.run(
782-
self.val_data_system_client.async_get_meta(
726+
self.data_system_client.async_get_meta(
783727
data_fields=["responses"],
784728
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
785-
global_step=self.global_steps - 1, # self.global_steps start from 1
729+
partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1
786730
get_n_samples=False,
787731
task_name="get_response",
788732
)
789733
)
790-
data = asyncio.run(self.val_data_system_client.async_get_data(test_response_meta))
734+
data = asyncio.run(self.data_system_client.async_get_data(test_response_meta))
791735
output_ids = data["responses"]
792736
output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
793737
sample_outputs.extend(output_texts)
@@ -808,10 +752,10 @@ def _validate(self):
808752
if "rm_scores" in batch_meta.field_names:
809753
compute_reward_fields = ["rm_scores"]
810754
val_reward_meta = asyncio.run(
811-
self.val_data_system_client.async_get_meta(
755+
self.data_system_client.async_get_meta(
812756
data_fields=compute_reward_fields,
813757
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
814-
global_step=self.global_steps - 1,
758+
partition_id=f"val_{self.global_steps - 1}",
815759
get_n_samples=False,
816760
task_name="compute_reward",
817761
)
@@ -832,29 +776,29 @@ def _validate(self):
832776
# collect num_turns of each prompt
833777
if "__num_turns__" in test_batch_meta.field_names:
834778
num_turns_meta = asyncio.run(
835-
self.val_data_system_client.async_get_meta(
779+
self.data_system_client.async_get_meta(
836780
data_fields=["__num_turns__"],
837781
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
838-
global_step=self.global_steps - 1, # self.global_steps start from 1
782+
partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1
839783
get_n_samples=False,
840784
task_name="get_num_turns",
841785
)
842786
)
843-
data = asyncio.run(self.val_data_system_client.async_get_data(num_turns_meta))
787+
data = asyncio.run(self.data_system_client.async_get_data(num_turns_meta))
844788
sample_turns.append(data["__num_turns__"])
845789

846790
data_source = ["unknown"] * reward_tensor.shape[0]
847791
if "data_source" in test_batch_meta.field_names:
848792
data_source_meta = asyncio.run(
849-
self.val_data_system_client.async_get_meta(
793+
self.data_system_client.async_get_meta(
850794
data_fields=["data_source"],
851795
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
852-
global_step=self.global_steps - 1, # self.global_steps start from 1
796+
partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1
853797
get_n_samples=False,
854798
task_name="get_data_source",
855799
)
856800
)
857-
data = asyncio.run(self.val_data_system_client.async_get_data(data_source_meta))
801+
data = asyncio.run(self.data_system_client.async_get_data(data_source_meta))
858802
data_source = data["data_source"]
859803

860804
data_source_lst.append(data_source)
@@ -902,7 +846,7 @@ def _validate(self):
902846
metric_dict["val-aux/num_turns/max"] = sample_turns.max()
903847
metric_dict["val-aux/num_turns/mean"] = sample_turns.mean()
904848

905-
asyncio.run(self.val_data_system_client.async_clear(self.global_steps - 1))
849+
asyncio.run(self.data_system_client.async_clear(partition_id=f"val_{self.global_steps - 1}"))
906850
return metric_dict
907851

908852
def init_workers(self):
@@ -1003,12 +947,7 @@ def init_workers(self):
1003947

1004948
# set transferqueue server info for each worker
1005949
for _, wg in all_wg.items():
1006-
wg.create_transferqueue_client(
1007-
self.data_system_controller_infos, self.data_system_storage_unit_infos, role="train"
1008-
)
1009-
wg.create_transferqueue_client(
1010-
self.val_data_system_controller_infos, self.val_data_system_storage_unit_infos, role="val"
1011-
)
950+
wg.create_transferqueue_client(self.data_system_controller_info, self.config)
1012951

1013952
# create async rollout manager and request scheduler
1014953
self.async_rollout_mode = False
@@ -1020,12 +959,7 @@ def init_workers(self):
1020959
config=self.config, worker_group=self.actor_rollout_wg, rm_wg=self.rm_wg
1021960
)
1022961

1023-
self.async_rollout_manager.create_transferqueue_client(
1024-
self.data_system_controller_infos, self.data_system_storage_unit_infos, role="train"
1025-
)
1026-
self.async_rollout_manager.create_transferqueue_client(
1027-
self.val_data_system_controller_infos, self.val_data_system_storage_unit_infos, role="val"
1028-
)
962+
self.async_rollout_manager.create_transferqueue_client(self.data_system_controller_info, self.config)
1029963

1030964
def _save_checkpoint(self):
1031965
from verl.utils.fs import local_mkdir_safe
@@ -1313,7 +1247,7 @@ def fit(self):
13131247
timing_raw = {}
13141248
base_get_meta_kwargs = dict(
13151249
batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n,
1316-
global_step=self.global_steps - 1, # self.global_steps starts from 1
1250+
partition_id=f"train_{self.global_steps - 1}", # self.global_steps starts from 1
13171251
get_n_samples=False,
13181252
)
13191253

@@ -1333,7 +1267,9 @@ def fit(self):
13331267
batch_dict, repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True
13341268
)
13351269
batch: TensorDict = self.dict_to_tensordict(repeated_batch_dict)
1336-
asyncio.run(self.data_system_client.async_put(data=batch, global_step=self.global_steps - 1))
1270+
asyncio.run(
1271+
self.data_system_client.async_put(data=batch, partition_id=f"train_{self.global_steps - 1}")
1272+
)
13371273

13381274
gen_meta = asyncio.run(
13391275
self.data_system_client.async_get_meta(
@@ -1709,7 +1645,7 @@ def fit(self):
17091645
],
17101646
batch_size=self.config.data.train_batch_size
17111647
* self.config.actor_rollout_ref.rollout.n,
1712-
global_step=self.global_steps - 1,
1648+
partition_id=f"train_{self.global_steps - 1}",
17131649
get_n_samples=False,
17141650
task_name="update_actor",
17151651
)
@@ -1735,7 +1671,7 @@ def fit(self):
17351671
self.data_system_client.async_get_meta(
17361672
data_fields=data_fields,
17371673
batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n,
1738-
global_step=self.global_steps - 1,
1674+
partition_id=f"train_{self.global_steps - 1}",
17391675
get_n_samples=False,
17401676
task_name="log_rollout",
17411677
)
@@ -1857,7 +1793,7 @@ def fit(self):
18571793
# TODO: (TQ) support transfer queue
18581794
self.train_dataloader.sampler.update(batch=batch)
18591795

1860-
asyncio.run(self.data_system_client.async_clear(self.global_steps - 1))
1796+
asyncio.run(self.data_system_client.async_clear(partition_id=f"train_{self.global_steps - 1}"))
18611797
# TODO: make a canonical logger that supports various backend
18621798
logger.log(data=metrics, step=self.global_steps)
18631799

requirements_transferqueue.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# requirements.txt records the full set of dependencies for development
2-
git+https://github.com/TransferQueue/TransferQueue.git@68c04e7
2+
git+https://github.com/TransferQueue/TransferQueue.git@862b74a

0 commit comments

Comments
 (0)