Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions recipe/transfer_queue/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: Data

return timing

def create_transferqueue_client(self, controller_infos, storage_infos, role):
def create_transferqueue_client(self, controller_info, config):
ray.get(
[
worker.create_transferqueue_client.remote(controller_infos, storage_infos, role)
worker.create_transferqueue_client.remote(controller_info, config)
for worker in self.agent_loop_workers
]
)
175 changes: 56 additions & 119 deletions recipe/transfer_queue/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from transfer_queue import (
BatchMeta,
TransferQueueController,
TransferQueueStorageSimpleUnit,
SimpleStorageUnit,
get_placement_group,
process_zmq_server_info,
)
Expand Down Expand Up @@ -89,7 +89,6 @@
from verl.utils.transferqueue_utils import (
create_transferqueue_client,
get_transferqueue_client,
get_val_transferqueue_client,
tqbridge,
)

Expand Down Expand Up @@ -412,111 +411,53 @@ def __init__(

self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)

self.data_system_client = self._initialize_train_data_system(
self.config.data.train_batch_size, self.config.actor_rollout_ref.rollout.n
self.data_system_client = self._initialize_data_system()

def _initialize_data_system(self):
# 1. initialize TransferQueueStorage
train_data_size = (
self.config.data.train_batch_size * self.config.trainer.num_global_batch *
self.config.actor_rollout_ref.rollout.n
)
self.val_data_system_client = self._initialize_val_data_system(
self.val_batch_size, self.config.actor_rollout_ref.rollout.val_kwargs.n
val_data_size = (
self.val_batch_size * self.config.trainer.num_global_batch *
self.config.actor_rollout_ref.rollout.val_kwargs.n
)

def _initialize_train_data_system(self, global_batch_size, num_n_samples, role="train"):
# 1. initialize TransferQueueStorage
total_storage_size = global_batch_size * self.config.trainer.num_global_batch * num_n_samples
total_storage_size = train_data_size + val_data_size
self.data_system_storage_units = {}
storage_placement_group = get_placement_group(self.config.trainer.num_data_storage_units, num_cpus_per_actor=1)
for storage_unit_rank in range(self.config.trainer.num_data_storage_units):
storage_node = TransferQueueStorageSimpleUnit.options(
storage_node = SimpleStorageUnit.options(
placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank
).remote(storage_size=math.ceil(total_storage_size / self.config.trainer.num_data_storage_units))
).remote(storage_unit_size=math.ceil(total_storage_size / self.config.trainer.num_data_storage_units))
self.data_system_storage_units[storage_unit_rank] = storage_node
logging.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.")
logging.info(f"SimpleStorageUnit #{storage_unit_rank} has been created.")

# 2. initialize TransferQueueController
# we support inilialize multiple controller instances for large-scale scenario. Please allocate exactly
# one controller for a single WorkerGroup.
self.data_system_controllers = {}
controller_placement_group = get_placement_group(self.config.trainer.num_data_controllers, num_cpus_per_actor=1)
for controller_rank in range(self.config.trainer.num_data_controllers):
self.data_system_controllers[controller_rank] = TransferQueueController.options(
placement_group=controller_placement_group, placement_group_bundle_index=controller_rank
).remote(
num_storage_units=self.config.trainer.num_data_storage_units,
global_batch_size=global_batch_size,
num_global_batch=self.config.trainer.num_global_batch,
num_n_samples=num_n_samples,
)
logging.info(f"TransferQueueController #{controller_rank} has been created.")
self.data_system_controller = TransferQueueController.remote()
logging.info("TransferQueueController has been created.")

# 3. register controller & storage
self.data_system_controller_infos = process_zmq_server_info(self.data_system_controllers)
# 3. register controller & storage and prepare necessary information
self.data_system_controller_info = process_zmq_server_info(self.data_system_controller)
self.data_system_storage_unit_infos = process_zmq_server_info(self.data_system_storage_units)

ray.get(
[
storage_unit.register_controller_info.remote(self.data_system_controller_infos)
for storage_unit in self.data_system_storage_units.values()
]
)
tq_config = OmegaConf.create({}, flags={"allow_objects": True}) # Note: Need to generate a new DictConfig
# with allow_objects=True to maintain ZMQServerInfo instance. Otherwise it will be flattened to dict
tq_config.controller_info = self.data_system_controller_info
tq_config.storage_unit_infos = self.data_system_storage_unit_infos
self.config = OmegaConf.merge(tq_config, self.config)

# 4. create client
# each client should be allocated to exactly one controller
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# each client should be allocated to exactly one controller

create_transferqueue_client(
client_id="Trainer-" + role,
controller_infos=self.data_system_controller_infos,
storage_infos=self.data_system_storage_unit_infos,
client_id="Trainer",
controller_info=self.data_system_controller_info,
config=self.config,
)
data_system_client = get_transferqueue_client()
return data_system_client

def _initialize_val_data_system(self, global_batch_size, num_n_samples, role="val"):
# 1. initialize TransferQueueStorage
total_storage_size = global_batch_size * self.config.trainer.num_global_batch * num_n_samples
self.val_data_system_storage_units = {}
storage_placement_group = get_placement_group(self.config.trainer.num_data_storage_units, num_cpus_per_actor=1)
for storage_unit_rank in range(self.config.trainer.num_data_storage_units):
storage_node = TransferQueueStorageSimpleUnit.options(
placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank
).remote(storage_size=math.ceil(total_storage_size / self.config.trainer.num_data_storage_units))
self.val_data_system_storage_units[storage_unit_rank] = storage_node
logging.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.")

# 2. initialize TransferQueueController
# we support inilialize multiple controller instances for large-scale scenario. Please allocate exactly
# one controller for a single WorkerGroup.
self.val_data_system_controllers = {}
controller_placement_group = get_placement_group(self.config.trainer.num_data_controllers, num_cpus_per_actor=1)
for controller_rank in range(self.config.trainer.num_data_controllers):
self.val_data_system_controllers[controller_rank] = TransferQueueController.options(
placement_group=controller_placement_group, placement_group_bundle_index=controller_rank
).remote(
num_storage_units=self.config.trainer.num_data_storage_units,
global_batch_size=global_batch_size,
num_global_batch=self.config.trainer.num_global_batch,
num_n_samples=num_n_samples,
)
logging.info(f"TransferQueueController #{controller_rank} has been created.")

# 3. register controller & storage
self.val_data_system_controller_infos = process_zmq_server_info(self.val_data_system_controllers)
self.val_data_system_storage_unit_infos = process_zmq_server_info(self.val_data_system_storage_units)

ray.get(
[
storage_unit.register_controller_info.remote(self.val_data_system_controller_infos)
for storage_unit in self.val_data_system_storage_units.values()
]
)

# 4. create client
# each client should be allocated to exactly one controller
create_transferqueue_client(
client_id="Trainer-" + role,
controller_infos=self.val_data_system_controller_infos,
storage_infos=self.val_data_system_storage_unit_infos,
)
data_system_client = get_val_transferqueue_client()
return data_system_client

def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]):
"""
Creates the train and validation dataloaders.
Expand Down Expand Up @@ -726,19 +667,21 @@ def _validate(self):
if self.config.reward_model.enable and test_batch[0]["reward_model"]["style"] == "model":
return {}

asyncio.run(self.val_data_system_client.async_put(data=test_batch, global_step=self.global_steps - 1))
asyncio.run(
self.data_system_client.async_put(data=test_batch, partition_id=f"val_{self.global_steps - 1}")
)

# Store original inputs
batch_meta = asyncio.run(
self.val_data_system_client.async_get_meta(
self.data_system_client.async_get_meta(
data_fields=["input_ids", "uid", "reward_model"],
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
global_step=self.global_steps - 1,
partition_id=f"val_{self.global_steps - 1}",
get_n_samples=False,
task_name="get_data",
)
)
data = asyncio.run(self.val_data_system_client.async_get_data(batch_meta))
data = asyncio.run(self.data_system_client.async_get_data(batch_meta))
input_ids = data["input_ids"]
# TODO: Can we keep special tokens except for padding tokens?
input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
Expand All @@ -749,10 +692,10 @@ def _validate(self):
sample_gts.extend(ground_truths)

test_gen_meta = asyncio.run(
self.val_data_system_client.async_get_meta(
self.data_system_client.async_get_meta(
data_fields=list(test_batch.keys()), # TODO: (TQ) Get metadata by specified fields
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
global_step=self.global_steps - 1, # self.global_steps start from 1
partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1
get_n_samples=False,
task_name="generate_sequences",
)
Expand All @@ -779,15 +722,15 @@ def _validate(self):

# Store generated outputs
test_response_meta = asyncio.run(
self.val_data_system_client.async_get_meta(
self.data_system_client.async_get_meta(
data_fields=["responses"],
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
global_step=self.global_steps - 1, # self.global_steps start from 1
partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1
get_n_samples=False,
task_name="get_response",
)
)
data = asyncio.run(self.val_data_system_client.async_get_data(test_response_meta))
data = asyncio.run(self.data_system_client.async_get_data(test_response_meta))
output_ids = data["responses"]
output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
sample_outputs.extend(output_texts)
Expand All @@ -808,10 +751,10 @@ def _validate(self):
if "rm_scores" in batch_meta.field_names:
compute_reward_fields = ["rm_scores"]
val_reward_meta = asyncio.run(
self.val_data_system_client.async_get_meta(
self.data_system_client.async_get_meta(
data_fields=compute_reward_fields,
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
global_step=self.global_steps - 1,
partition_id=f"val_{self.global_steps - 1}",
get_n_samples=False,
task_name="compute_reward",
)
Expand All @@ -832,29 +775,29 @@ def _validate(self):
# collect num_turns of each prompt
if "__num_turns__" in test_batch_meta.field_names:
num_turns_meta = asyncio.run(
self.val_data_system_client.async_get_meta(
self.data_system_client.async_get_meta(
data_fields=["__num_turns__"],
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
global_step=self.global_steps - 1, # self.global_steps start from 1
partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1
get_n_samples=False,
task_name="get_num_turns",
)
)
data = asyncio.run(self.val_data_system_client.async_get_data(num_turns_meta))
data = asyncio.run(self.data_system_client.async_get_data(num_turns_meta))
sample_turns.append(data["__num_turns__"])

data_source = ["unknown"] * reward_tensor.shape[0]
if "data_source" in test_batch_meta.field_names:
data_source_meta = asyncio.run(
self.val_data_system_client.async_get_meta(
self.data_system_client.async_get_meta(
data_fields=["data_source"],
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
global_step=self.global_steps - 1, # self.global_steps start from 1
partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1
get_n_samples=False,
task_name="get_data_source",
)
)
data = asyncio.run(self.val_data_system_client.async_get_data(data_source_meta))
data = asyncio.run(self.data_system_client.async_get_data(data_source_meta))
data_source = data["data_source"]

data_source_lst.append(data_source)
Expand Down Expand Up @@ -902,7 +845,7 @@ def _validate(self):
metric_dict["val-aux/num_turns/max"] = sample_turns.max()
metric_dict["val-aux/num_turns/mean"] = sample_turns.mean()

asyncio.run(self.val_data_system_client.async_clear(self.global_steps - 1))
asyncio.run(self.data_system_client.async_clear(partition_id=f"val_{self.global_steps - 1}"))
return metric_dict

def init_workers(self):
Expand Down Expand Up @@ -1003,12 +946,7 @@ def init_workers(self):

# set transferqueue server info for each worker
for _, wg in all_wg.items():
wg.create_transferqueue_client(
self.data_system_controller_infos, self.data_system_storage_unit_infos, role="train"
)
wg.create_transferqueue_client(
self.val_data_system_controller_infos, self.val_data_system_storage_unit_infos, role="val"
)
wg.create_transferqueue_client(self.data_system_controller_info, self.config)

# create async rollout manager and request scheduler
self.async_rollout_mode = False
Expand All @@ -1021,10 +959,7 @@ def init_workers(self):
)

self.async_rollout_manager.create_transferqueue_client(
self.data_system_controller_infos, self.data_system_storage_unit_infos, role="train"
)
self.async_rollout_manager.create_transferqueue_client(
self.val_data_system_controller_infos, self.val_data_system_storage_unit_infos, role="val"
self.data_system_controller_info, self.config
)

def _save_checkpoint(self):
Expand Down Expand Up @@ -1313,7 +1248,7 @@ def fit(self):
timing_raw = {}
base_get_meta_kwargs = dict(
batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n,
global_step=self.global_steps - 1, # self.global_steps starts from 1
partition_id=f"train_{self.global_steps - 1}", # self.global_steps starts from 1
get_n_samples=False,
)

Expand All @@ -1333,7 +1268,9 @@ def fit(self):
batch_dict, repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True
)
batch: TensorDict = self.dict_to_tensordict(repeated_batch_dict)
asyncio.run(self.data_system_client.async_put(data=batch, global_step=self.global_steps - 1))
asyncio.run(
self.data_system_client.async_put(data=batch, partition_id=f"train_{self.global_steps - 1}")
)

gen_meta = asyncio.run(
self.data_system_client.async_get_meta(
Expand Down Expand Up @@ -1709,7 +1646,7 @@ def fit(self):
],
batch_size=self.config.data.train_batch_size
* self.config.actor_rollout_ref.rollout.n,
global_step=self.global_steps - 1,
partition_id=f"train_{self.global_steps - 1}",
get_n_samples=False,
task_name="update_actor",
)
Expand All @@ -1735,7 +1672,7 @@ def fit(self):
self.data_system_client.async_get_meta(
data_fields=data_fields,
batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n,
global_step=self.global_steps - 1,
partition_id=f"train_{self.global_steps - 1}",
get_n_samples=False,
task_name="log_rollout",
)
Expand Down Expand Up @@ -1857,7 +1794,7 @@ def fit(self):
# TODO: (TQ) support transfer queue
self.train_dataloader.sampler.update(batch=batch)

asyncio.run(self.data_system_client.async_clear(self.global_steps - 1))
asyncio.run(self.data_system_client.async_clear(partition_id=f"train_{self.global_steps - 1}"))
# TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_steps)

Expand Down
2 changes: 1 addition & 1 deletion requirements_transferqueue.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# requirements.txt records the full set of dependencies for development
git+https://github.com/TransferQueue/TransferQueue.git@68c04e7
git+https://github.com/TransferQueue/TransferQueue.git@862b74a
Loading
Loading