Skip to content

Commit d6d12a0

Browse files
author
liuximeng
committed
[data] feat: TransferQueue - Support managing multiple data partitions for Train/Val/Test in controller
1 parent 20d0f98 commit d6d12a0

File tree

5 files changed

+69
-146
lines changed

5 files changed

+69
-146
lines changed

recipe/transfer_queue/agent_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,10 @@ def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: Data
6767

6868
return timing
6969

70-
def create_transferqueue_client(self, controller_infos, storage_infos, role):
70+
def create_transferqueue_client(self, controller_info, config):
7171
ray.get(
7272
[
73-
worker.create_transferqueue_client.remote(controller_infos, storage_infos, role)
73+
worker.create_transferqueue_client.remote(controller_info, config)
7474
for worker in self.agent_loop_workers
7575
]
7676
)

recipe/transfer_queue/ray_trainer.py

Lines changed: 56 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from transfer_queue import (
4343
BatchMeta,
4444
TransferQueueController,
45-
TransferQueueStorageSimpleUnit,
45+
SimpleStorageUnit,
4646
get_placement_group,
4747
process_zmq_server_info,
4848
)
@@ -89,7 +89,6 @@
8989
from verl.utils.transferqueue_utils import (
9090
create_transferqueue_client,
9191
get_transferqueue_client,
92-
get_val_transferqueue_client,
9392
tqbridge,
9493
)
9594

@@ -412,111 +411,53 @@ def __init__(
412411

413412
self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
414413

415-
self.data_system_client = self._initialize_train_data_system(
416-
self.config.data.train_batch_size, self.config.actor_rollout_ref.rollout.n
414+
self.data_system_client = self._initialize_data_system()
415+
416+
def _initialize_data_system(self):
417+
# 1. initialize TransferQueueStorage
418+
train_data_size = (
419+
self.config.data.train_batch_size * self.config.trainer.num_global_batch *
420+
self.config.actor_rollout_ref.rollout.n
417421
)
418-
self.val_data_system_client = self._initialize_val_data_system(
419-
self.val_batch_size, self.config.actor_rollout_ref.rollout.val_kwargs.n
422+
val_data_size = (
423+
self.val_batch_size * self.config.trainer.num_global_batch *
424+
self.config.actor_rollout_ref.rollout.val_kwargs.n
420425
)
421426

422-
def _initialize_train_data_system(self, global_batch_size, num_n_samples, role="train"):
423-
# 1. initialize TransferQueueStorage
424-
total_storage_size = global_batch_size * self.config.trainer.num_global_batch * num_n_samples
427+
total_storage_size = train_data_size + val_data_size
425428
self.data_system_storage_units = {}
426429
storage_placement_group = get_placement_group(self.config.trainer.num_data_storage_units, num_cpus_per_actor=1)
427430
for storage_unit_rank in range(self.config.trainer.num_data_storage_units):
428-
storage_node = TransferQueueStorageSimpleUnit.options(
431+
storage_node = SimpleStorageUnit.options(
429432
placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank
430-
).remote(storage_size=math.ceil(total_storage_size / self.config.trainer.num_data_storage_units))
433+
).remote(storage_unit_size=math.ceil(total_storage_size / self.config.trainer.num_data_storage_units))
431434
self.data_system_storage_units[storage_unit_rank] = storage_node
432-
logging.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.")
435+
logging.info(f"SimpleStorageUnit #{storage_unit_rank} has been created.")
433436

434437
# 2. initialize TransferQueueController
435-
# we support inilialize multiple controller instances for large-scale scenario. Please allocate exactly
436-
# one controller for a single WorkerGroup.
437-
self.data_system_controllers = {}
438-
controller_placement_group = get_placement_group(self.config.trainer.num_data_controllers, num_cpus_per_actor=1)
439-
for controller_rank in range(self.config.trainer.num_data_controllers):
440-
self.data_system_controllers[controller_rank] = TransferQueueController.options(
441-
placement_group=controller_placement_group, placement_group_bundle_index=controller_rank
442-
).remote(
443-
num_storage_units=self.config.trainer.num_data_storage_units,
444-
global_batch_size=global_batch_size,
445-
num_global_batch=self.config.trainer.num_global_batch,
446-
num_n_samples=num_n_samples,
447-
)
448-
logging.info(f"TransferQueueController #{controller_rank} has been created.")
438+
self.data_system_controller = TransferQueueController.remote()
439+
logging.info("TransferQueueController has been created.")
449440

450-
# 3. register controller & storage
451-
self.data_system_controller_infos = process_zmq_server_info(self.data_system_controllers)
441+
# 3. register controller & storage and prepare necessary information
442+
self.data_system_controller_info = process_zmq_server_info(self.data_system_controller)
452443
self.data_system_storage_unit_infos = process_zmq_server_info(self.data_system_storage_units)
453444

454-
ray.get(
455-
[
456-
storage_unit.register_controller_info.remote(self.data_system_controller_infos)
457-
for storage_unit in self.data_system_storage_units.values()
458-
]
459-
)
445+
tq_config = OmegaConf.create({}, flags={"allow_objects": True}) # Note: Need to generate a new DictConfig
446+
# with allow_objects=True to maintain ZMQServerInfo instance. Otherwise it will be flattened to dict
447+
tq_config.controller_info = self.data_system_controller_info
448+
tq_config.storage_unit_infos = self.data_system_storage_unit_infos
449+
self.config = OmegaConf.merge(tq_config, self.config)
460450

461451
# 4. create client
462452
# each client should be allocated to exactly one controller
463453
create_transferqueue_client(
464-
client_id="Trainer-" + role,
465-
controller_infos=self.data_system_controller_infos,
466-
storage_infos=self.data_system_storage_unit_infos,
454+
client_id="Trainer",
455+
controller_info=self.data_system_controller_info,
456+
config=self.config,
467457
)
468458
data_system_client = get_transferqueue_client()
469459
return data_system_client
470460

471-
def _initialize_val_data_system(self, global_batch_size, num_n_samples, role="val"):
472-
# 1. initialize TransferQueueStorage
473-
total_storage_size = global_batch_size * self.config.trainer.num_global_batch * num_n_samples
474-
self.val_data_system_storage_units = {}
475-
storage_placement_group = get_placement_group(self.config.trainer.num_data_storage_units, num_cpus_per_actor=1)
476-
for storage_unit_rank in range(self.config.trainer.num_data_storage_units):
477-
storage_node = TransferQueueStorageSimpleUnit.options(
478-
placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank
479-
).remote(storage_size=math.ceil(total_storage_size / self.config.trainer.num_data_storage_units))
480-
self.val_data_system_storage_units[storage_unit_rank] = storage_node
481-
logging.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.")
482-
483-
# 2. initialize TransferQueueController
484-
# we support inilialize multiple controller instances for large-scale scenario. Please allocate exactly
485-
# one controller for a single WorkerGroup.
486-
self.val_data_system_controllers = {}
487-
controller_placement_group = get_placement_group(self.config.trainer.num_data_controllers, num_cpus_per_actor=1)
488-
for controller_rank in range(self.config.trainer.num_data_controllers):
489-
self.val_data_system_controllers[controller_rank] = TransferQueueController.options(
490-
placement_group=controller_placement_group, placement_group_bundle_index=controller_rank
491-
).remote(
492-
num_storage_units=self.config.trainer.num_data_storage_units,
493-
global_batch_size=global_batch_size,
494-
num_global_batch=self.config.trainer.num_global_batch,
495-
num_n_samples=num_n_samples,
496-
)
497-
logging.info(f"TransferQueueController #{controller_rank} has been created.")
498-
499-
# 3. register controller & storage
500-
self.val_data_system_controller_infos = process_zmq_server_info(self.val_data_system_controllers)
501-
self.val_data_system_storage_unit_infos = process_zmq_server_info(self.val_data_system_storage_units)
502-
503-
ray.get(
504-
[
505-
storage_unit.register_controller_info.remote(self.val_data_system_controller_infos)
506-
for storage_unit in self.val_data_system_storage_units.values()
507-
]
508-
)
509-
510-
# 4. create client
511-
# each client should be allocated to exactly one controller
512-
create_transferqueue_client(
513-
client_id="Trainer-" + role,
514-
controller_infos=self.val_data_system_controller_infos,
515-
storage_infos=self.val_data_system_storage_unit_infos,
516-
)
517-
data_system_client = get_val_transferqueue_client()
518-
return data_system_client
519-
520461
def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]):
521462
"""
522463
Creates the train and validation dataloaders.
@@ -726,19 +667,21 @@ def _validate(self):
726667
if self.config.reward_model.enable and test_batch[0]["reward_model"]["style"] == "model":
727668
return {}
728669

729-
asyncio.run(self.val_data_system_client.async_put(data=test_batch, global_step=self.global_steps - 1))
670+
asyncio.run(
671+
self.data_system_client.async_put(data=test_batch, partition_id=f"val_{self.global_steps - 1}")
672+
)
730673

731674
# Store original inputs
732675
batch_meta = asyncio.run(
733-
self.val_data_system_client.async_get_meta(
676+
self.data_system_client.async_get_meta(
734677
data_fields=["input_ids", "uid", "reward_model"],
735678
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
736-
global_step=self.global_steps - 1,
679+
partition_id=f"val_{self.global_steps - 1}",
737680
get_n_samples=False,
738681
task_name="get_data",
739682
)
740683
)
741-
data = asyncio.run(self.val_data_system_client.async_get_data(batch_meta))
684+
data = asyncio.run(self.data_system_client.async_get_data(batch_meta))
742685
input_ids = data["input_ids"]
743686
# TODO: Can we keep special tokens except for padding tokens?
744687
input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
@@ -749,10 +692,10 @@ def _validate(self):
749692
sample_gts.extend(ground_truths)
750693

751694
test_gen_meta = asyncio.run(
752-
self.val_data_system_client.async_get_meta(
695+
self.data_system_client.async_get_meta(
753696
data_fields=list(test_batch.keys()), # TODO: (TQ) Get metadata by specified fields
754697
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
755-
global_step=self.global_steps - 1, # self.global_steps start from 1
698+
partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1
756699
get_n_samples=False,
757700
task_name="generate_sequences",
758701
)
@@ -779,15 +722,15 @@ def _validate(self):
779722

780723
# Store generated outputs
781724
test_response_meta = asyncio.run(
782-
self.val_data_system_client.async_get_meta(
725+
self.data_system_client.async_get_meta(
783726
data_fields=["responses"],
784727
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
785-
global_step=self.global_steps - 1, # self.global_steps start from 1
728+
partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1
786729
get_n_samples=False,
787730
task_name="get_response",
788731
)
789732
)
790-
data = asyncio.run(self.val_data_system_client.async_get_data(test_response_meta))
733+
data = asyncio.run(self.data_system_client.async_get_data(test_response_meta))
791734
output_ids = data["responses"]
792735
output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
793736
sample_outputs.extend(output_texts)
@@ -808,10 +751,10 @@ def _validate(self):
808751
if "rm_scores" in batch_meta.field_names:
809752
compute_reward_fields = ["rm_scores"]
810753
val_reward_meta = asyncio.run(
811-
self.val_data_system_client.async_get_meta(
754+
self.data_system_client.async_get_meta(
812755
data_fields=compute_reward_fields,
813756
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
814-
global_step=self.global_steps - 1,
757+
partition_id=f"val_{self.global_steps - 1}",
815758
get_n_samples=False,
816759
task_name="compute_reward",
817760
)
@@ -832,29 +775,29 @@ def _validate(self):
832775
# collect num_turns of each prompt
833776
if "__num_turns__" in test_batch_meta.field_names:
834777
num_turns_meta = asyncio.run(
835-
self.val_data_system_client.async_get_meta(
778+
self.data_system_client.async_get_meta(
836779
data_fields=["__num_turns__"],
837780
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
838-
global_step=self.global_steps - 1, # self.global_steps start from 1
781+
partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1
839782
get_n_samples=False,
840783
task_name="get_num_turns",
841784
)
842785
)
843-
data = asyncio.run(self.val_data_system_client.async_get_data(num_turns_meta))
786+
data = asyncio.run(self.data_system_client.async_get_data(num_turns_meta))
844787
sample_turns.append(data["__num_turns__"])
845788

846789
data_source = ["unknown"] * reward_tensor.shape[0]
847790
if "data_source" in test_batch_meta.field_names:
848791
data_source_meta = asyncio.run(
849-
self.val_data_system_client.async_get_meta(
792+
self.data_system_client.async_get_meta(
850793
data_fields=["data_source"],
851794
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
852-
global_step=self.global_steps - 1, # self.global_steps start from 1
795+
partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1
853796
get_n_samples=False,
854797
task_name="get_data_source",
855798
)
856799
)
857-
data = asyncio.run(self.val_data_system_client.async_get_data(data_source_meta))
800+
data = asyncio.run(self.data_system_client.async_get_data(data_source_meta))
858801
data_source = data["data_source"]
859802

860803
data_source_lst.append(data_source)
@@ -902,7 +845,7 @@ def _validate(self):
902845
metric_dict["val-aux/num_turns/max"] = sample_turns.max()
903846
metric_dict["val-aux/num_turns/mean"] = sample_turns.mean()
904847

905-
asyncio.run(self.val_data_system_client.async_clear(self.global_steps - 1))
848+
asyncio.run(self.data_system_client.async_clear(partition_id=f"val_{self.global_steps - 1}"))
906849
return metric_dict
907850

908851
def init_workers(self):
@@ -1003,12 +946,7 @@ def init_workers(self):
1003946

1004947
# set transferqueue server info for each worker
1005948
for _, wg in all_wg.items():
1006-
wg.create_transferqueue_client(
1007-
self.data_system_controller_infos, self.data_system_storage_unit_infos, role="train"
1008-
)
1009-
wg.create_transferqueue_client(
1010-
self.val_data_system_controller_infos, self.val_data_system_storage_unit_infos, role="val"
1011-
)
949+
wg.create_transferqueue_client(self.data_system_controller_info, self.config)
1012950

1013951
# create async rollout manager and request scheduler
1014952
self.async_rollout_mode = False
@@ -1021,10 +959,7 @@ def init_workers(self):
1021959
)
1022960

1023961
self.async_rollout_manager.create_transferqueue_client(
1024-
self.data_system_controller_infos, self.data_system_storage_unit_infos, role="train"
1025-
)
1026-
self.async_rollout_manager.create_transferqueue_client(
1027-
self.val_data_system_controller_infos, self.val_data_system_storage_unit_infos, role="val"
962+
self.data_system_controller_info, self.config
1028963
)
1029964

1030965
def _save_checkpoint(self):
@@ -1313,7 +1248,7 @@ def fit(self):
13131248
timing_raw = {}
13141249
base_get_meta_kwargs = dict(
13151250
batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n,
1316-
global_step=self.global_steps - 1, # self.global_steps starts from 1
1251+
partition_id=f"train_{self.global_steps - 1}", # self.global_steps starts from 1
13171252
get_n_samples=False,
13181253
)
13191254

@@ -1333,7 +1268,9 @@ def fit(self):
13331268
batch_dict, repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True
13341269
)
13351270
batch: TensorDict = self.dict_to_tensordict(repeated_batch_dict)
1336-
asyncio.run(self.data_system_client.async_put(data=batch, global_step=self.global_steps - 1))
1271+
asyncio.run(
1272+
self.data_system_client.async_put(data=batch, partition_id=f"train_{self.global_steps - 1}")
1273+
)
13371274

13381275
gen_meta = asyncio.run(
13391276
self.data_system_client.async_get_meta(
@@ -1709,7 +1646,7 @@ def fit(self):
17091646
],
17101647
batch_size=self.config.data.train_batch_size
17111648
* self.config.actor_rollout_ref.rollout.n,
1712-
global_step=self.global_steps - 1,
1649+
partition_id=f"train_{self.global_steps - 1}",
17131650
get_n_samples=False,
17141651
task_name="update_actor",
17151652
)
@@ -1735,7 +1672,7 @@ def fit(self):
17351672
self.data_system_client.async_get_meta(
17361673
data_fields=data_fields,
17371674
batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n,
1738-
global_step=self.global_steps - 1,
1675+
partition_id=f"train_{self.global_steps - 1}",
17391676
get_n_samples=False,
17401677
task_name="log_rollout",
17411678
)
@@ -1857,7 +1794,7 @@ def fit(self):
18571794
# TODO: (TQ) support transfer queue
18581795
self.train_dataloader.sampler.update(batch=batch)
18591796

1860-
asyncio.run(self.data_system_client.async_clear(self.global_steps - 1))
1797+
asyncio.run(self.data_system_client.async_clear(partition_id=f"train_{self.global_steps - 1}"))
18611798
# TODO: make a canonical logger that supports various backend
18621799
logger.log(data=metrics, step=self.global_steps)
18631800

requirements_transferqueue.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# requirements.txt records the full set of dependencies for development
2-
git+https://github.com/TransferQueue/TransferQueue.git@68c04e7
2+
git+https://github.com/TransferQueue/TransferQueue.git@862b74a

0 commit comments

Comments
 (0)