Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions recipe/transfer_queue/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,7 @@ def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: Data

return timing

def create_transferqueue_client(self, controller_infos, storage_infos, role):
def create_transferqueue_client(self, controller_info, config):
ray.get(
[
worker.create_transferqueue_client.remote(controller_infos, storage_infos, role)
for worker in self.agent_loop_workers
]
[worker.create_transferqueue_client.remote(controller_info, config) for worker in self.agent_loop_workers]
)
180 changes: 58 additions & 122 deletions recipe/transfer_queue/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
from tqdm import tqdm
from transfer_queue import (
BatchMeta,
SimpleStorageUnit,
TransferQueueController,
TransferQueueStorageSimpleUnit,
get_placement_group,
process_zmq_server_info,
)
Expand Down Expand Up @@ -89,7 +89,6 @@
from verl.utils.transferqueue_utils import (
create_transferqueue_client,
get_transferqueue_client,
get_val_transferqueue_client,
tqbridge,
)

Expand Down Expand Up @@ -412,111 +411,56 @@ def __init__(

self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)

self.data_system_client = self._initialize_train_data_system(
self.config.data.train_batch_size, self.config.actor_rollout_ref.rollout.n
self.data_system_client = self._initialize_data_system()

def _initialize_data_system(self):
# 1. initialize TransferQueueStorage
train_data_size = (
self.config.data.train_batch_size
* self.config.trainer.num_global_batch
* self.config.actor_rollout_ref.rollout.n
)
self.val_data_system_client = self._initialize_val_data_system(
self.val_batch_size, self.config.actor_rollout_ref.rollout.val_kwargs.n
val_data_size = (
self.val_batch_size
* self.config.trainer.num_global_batch
* self.config.actor_rollout_ref.rollout.val_kwargs.n
)

def _initialize_train_data_system(self, global_batch_size, num_n_samples, role="train"):
# 1. initialize TransferQueueStorage
total_storage_size = global_batch_size * self.config.trainer.num_global_batch * num_n_samples
total_storage_size = train_data_size + val_data_size
self.data_system_storage_units = {}
storage_placement_group = get_placement_group(self.config.trainer.num_data_storage_units, num_cpus_per_actor=1)
for storage_unit_rank in range(self.config.trainer.num_data_storage_units):
storage_node = TransferQueueStorageSimpleUnit.options(
storage_node = SimpleStorageUnit.options(
placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank
).remote(storage_size=math.ceil(total_storage_size / self.config.trainer.num_data_storage_units))
).remote(storage_unit_size=math.ceil(total_storage_size / self.config.trainer.num_data_storage_units))
self.data_system_storage_units[storage_unit_rank] = storage_node
logging.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.")
logging.info(f"SimpleStorageUnit #{storage_unit_rank} has been created.")

# 2. initialize TransferQueueController
# we support inilialize multiple controller instances for large-scale scenario. Please allocate exactly
# one controller for a single WorkerGroup.
self.data_system_controllers = {}
controller_placement_group = get_placement_group(self.config.trainer.num_data_controllers, num_cpus_per_actor=1)
for controller_rank in range(self.config.trainer.num_data_controllers):
self.data_system_controllers[controller_rank] = TransferQueueController.options(
placement_group=controller_placement_group, placement_group_bundle_index=controller_rank
).remote(
num_storage_units=self.config.trainer.num_data_storage_units,
global_batch_size=global_batch_size,
num_global_batch=self.config.trainer.num_global_batch,
num_n_samples=num_n_samples,
)
logging.info(f"TransferQueueController #{controller_rank} has been created.")
self.data_system_controller = TransferQueueController.remote()
logging.info("TransferQueueController has been created.")

# 3. register controller & storage
self.data_system_controller_infos = process_zmq_server_info(self.data_system_controllers)
# 3. register controller & storage and prepare necessary information
self.data_system_controller_info = process_zmq_server_info(self.data_system_controller)
self.data_system_storage_unit_infos = process_zmq_server_info(self.data_system_storage_units)

ray.get(
[
storage_unit.register_controller_info.remote(self.data_system_controller_infos)
for storage_unit in self.data_system_storage_units.values()
]
)
# Note: Need to generate a new DictConfig with allow_objects=True to preserve ZMQServerInfo instances
# (which contain socket connection details). Without this flag, OmegaConf would flatten these objects to dicts,
# breaking the transfer queue client initialization.
tq_config = OmegaConf.create({}, flags={"allow_objects": True})
tq_config.controller_info = self.data_system_controller_info
tq_config.storage_unit_infos = self.data_system_storage_unit_infos
self.config = OmegaConf.merge(tq_config, self.config)

# 4. create client
# each client should be allocated to exactly one controller
create_transferqueue_client(
client_id="Trainer-" + role,
controller_infos=self.data_system_controller_infos,
storage_infos=self.data_system_storage_unit_infos,
client_id="Trainer",
controller_info=self.data_system_controller_info,
config=self.config,
)
data_system_client = get_transferqueue_client()
return data_system_client

def _initialize_val_data_system(self, global_batch_size, num_n_samples, role="val"):
# 1. initialize TransferQueueStorage
total_storage_size = global_batch_size * self.config.trainer.num_global_batch * num_n_samples
self.val_data_system_storage_units = {}
storage_placement_group = get_placement_group(self.config.trainer.num_data_storage_units, num_cpus_per_actor=1)
for storage_unit_rank in range(self.config.trainer.num_data_storage_units):
storage_node = TransferQueueStorageSimpleUnit.options(
placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank
).remote(storage_size=math.ceil(total_storage_size / self.config.trainer.num_data_storage_units))
self.val_data_system_storage_units[storage_unit_rank] = storage_node
logging.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.")

# 2. initialize TransferQueueController
# we support inilialize multiple controller instances for large-scale scenario. Please allocate exactly
# one controller for a single WorkerGroup.
self.val_data_system_controllers = {}
controller_placement_group = get_placement_group(self.config.trainer.num_data_controllers, num_cpus_per_actor=1)
for controller_rank in range(self.config.trainer.num_data_controllers):
self.val_data_system_controllers[controller_rank] = TransferQueueController.options(
placement_group=controller_placement_group, placement_group_bundle_index=controller_rank
).remote(
num_storage_units=self.config.trainer.num_data_storage_units,
global_batch_size=global_batch_size,
num_global_batch=self.config.trainer.num_global_batch,
num_n_samples=num_n_samples,
)
logging.info(f"TransferQueueController #{controller_rank} has been created.")

# 3. register controller & storage
self.val_data_system_controller_infos = process_zmq_server_info(self.val_data_system_controllers)
self.val_data_system_storage_unit_infos = process_zmq_server_info(self.val_data_system_storage_units)

ray.get(
[
storage_unit.register_controller_info.remote(self.val_data_system_controller_infos)
for storage_unit in self.val_data_system_storage_units.values()
]
)

# 4. create client
# each client should be allocated to exactly one controller
create_transferqueue_client(
client_id="Trainer-" + role,
controller_infos=self.val_data_system_controller_infos,
storage_infos=self.val_data_system_storage_unit_infos,
)
data_system_client = get_val_transferqueue_client()
return data_system_client

def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]):
"""
Creates the train and validation dataloaders.
Expand Down Expand Up @@ -726,19 +670,19 @@ def _validate(self):
if self.config.reward_model.enable and test_batch[0]["reward_model"]["style"] == "model":
return {}

asyncio.run(self.val_data_system_client.async_put(data=test_batch, global_step=self.global_steps - 1))
asyncio.run(self.data_system_client.async_put(data=test_batch, partition_id=f"val_{self.global_steps - 1}"))

# Store original inputs
batch_meta = asyncio.run(
self.val_data_system_client.async_get_meta(
self.data_system_client.async_get_meta(
data_fields=["input_ids", "uid", "reward_model"],
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
global_step=self.global_steps - 1,
partition_id=f"val_{self.global_steps - 1}",
get_n_samples=False,
task_name="get_data",
)
)
data = asyncio.run(self.val_data_system_client.async_get_data(batch_meta))
data = asyncio.run(self.data_system_client.async_get_data(batch_meta))
input_ids = data["input_ids"]
# TODO: Can we keep special tokens except for padding tokens?
input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
Expand All @@ -749,10 +693,10 @@ def _validate(self):
sample_gts.extend(ground_truths)

test_gen_meta = asyncio.run(
self.val_data_system_client.async_get_meta(
self.data_system_client.async_get_meta(
data_fields=list(test_batch.keys()), # TODO: (TQ) Get metadata by specified fields
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
global_step=self.global_steps - 1, # self.global_steps start from 1
partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1
get_n_samples=False,
task_name="generate_sequences",
)
Expand All @@ -779,15 +723,15 @@ def _validate(self):

# Store generated outputs
test_response_meta = asyncio.run(
self.val_data_system_client.async_get_meta(
self.data_system_client.async_get_meta(
data_fields=["responses"],
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
global_step=self.global_steps - 1, # self.global_steps start from 1
partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1
get_n_samples=False,
task_name="get_response",
)
)
data = asyncio.run(self.val_data_system_client.async_get_data(test_response_meta))
data = asyncio.run(self.data_system_client.async_get_data(test_response_meta))
output_ids = data["responses"]
output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
sample_outputs.extend(output_texts)
Expand All @@ -808,10 +752,10 @@ def _validate(self):
if "rm_scores" in batch_meta.field_names:
compute_reward_fields = ["rm_scores"]
val_reward_meta = asyncio.run(
self.val_data_system_client.async_get_meta(
self.data_system_client.async_get_meta(
data_fields=compute_reward_fields,
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
global_step=self.global_steps - 1,
partition_id=f"val_{self.global_steps - 1}",
get_n_samples=False,
task_name="compute_reward",
)
Expand All @@ -832,29 +776,29 @@ def _validate(self):
# collect num_turns of each prompt
if "__num_turns__" in test_batch_meta.field_names:
num_turns_meta = asyncio.run(
self.val_data_system_client.async_get_meta(
self.data_system_client.async_get_meta(
data_fields=["__num_turns__"],
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
global_step=self.global_steps - 1, # self.global_steps start from 1
partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1
get_n_samples=False,
task_name="get_num_turns",
)
)
data = asyncio.run(self.val_data_system_client.async_get_data(num_turns_meta))
data = asyncio.run(self.data_system_client.async_get_data(num_turns_meta))
sample_turns.append(data["__num_turns__"])

data_source = ["unknown"] * reward_tensor.shape[0]
if "data_source" in test_batch_meta.field_names:
data_source_meta = asyncio.run(
self.val_data_system_client.async_get_meta(
self.data_system_client.async_get_meta(
data_fields=["data_source"],
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
global_step=self.global_steps - 1, # self.global_steps start from 1
partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1
get_n_samples=False,
task_name="get_data_source",
)
)
data = asyncio.run(self.val_data_system_client.async_get_data(data_source_meta))
data = asyncio.run(self.data_system_client.async_get_data(data_source_meta))
data_source = data["data_source"]

data_source_lst.append(data_source)
Expand Down Expand Up @@ -902,7 +846,7 @@ def _validate(self):
metric_dict["val-aux/num_turns/max"] = sample_turns.max()
metric_dict["val-aux/num_turns/mean"] = sample_turns.mean()

asyncio.run(self.val_data_system_client.async_clear(self.global_steps - 1))
asyncio.run(self.data_system_client.async_clear(partition_id=f"val_{self.global_steps - 1}"))
return metric_dict

def init_workers(self):
Expand Down Expand Up @@ -1003,12 +947,7 @@ def init_workers(self):

# set transferqueue server info for each worker
for _, wg in all_wg.items():
wg.create_transferqueue_client(
self.data_system_controller_infos, self.data_system_storage_unit_infos, role="train"
)
wg.create_transferqueue_client(
self.val_data_system_controller_infos, self.val_data_system_storage_unit_infos, role="val"
)
wg.create_transferqueue_client(self.data_system_controller_info, self.config)

# create async rollout manager and request scheduler
self.async_rollout_mode = False
Expand All @@ -1020,12 +959,7 @@ def init_workers(self):
config=self.config, worker_group=self.actor_rollout_wg, rm_wg=self.rm_wg
)

self.async_rollout_manager.create_transferqueue_client(
self.data_system_controller_infos, self.data_system_storage_unit_infos, role="train"
)
self.async_rollout_manager.create_transferqueue_client(
self.val_data_system_controller_infos, self.val_data_system_storage_unit_infos, role="val"
)
self.async_rollout_manager.create_transferqueue_client(self.data_system_controller_info, self.config)

def _save_checkpoint(self):
from verl.utils.fs import local_mkdir_safe
Expand Down Expand Up @@ -1313,7 +1247,7 @@ def fit(self):
timing_raw = {}
base_get_meta_kwargs = dict(
batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n,
global_step=self.global_steps - 1, # self.global_steps starts from 1
partition_id=f"train_{self.global_steps - 1}", # self.global_steps starts from 1
get_n_samples=False,
)

Expand All @@ -1333,7 +1267,9 @@ def fit(self):
batch_dict, repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True
)
batch: TensorDict = self.dict_to_tensordict(repeated_batch_dict)
asyncio.run(self.data_system_client.async_put(data=batch, global_step=self.global_steps - 1))
asyncio.run(
self.data_system_client.async_put(data=batch, partition_id=f"train_{self.global_steps - 1}")
)

gen_meta = asyncio.run(
self.data_system_client.async_get_meta(
Expand Down Expand Up @@ -1709,7 +1645,7 @@ def fit(self):
],
batch_size=self.config.data.train_batch_size
* self.config.actor_rollout_ref.rollout.n,
global_step=self.global_steps - 1,
partition_id=f"train_{self.global_steps - 1}",
get_n_samples=False,
task_name="update_actor",
)
Expand All @@ -1735,7 +1671,7 @@ def fit(self):
self.data_system_client.async_get_meta(
data_fields=data_fields,
batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n,
global_step=self.global_steps - 1,
partition_id=f"train_{self.global_steps - 1}",
get_n_samples=False,
task_name="log_rollout",
)
Expand Down Expand Up @@ -1857,7 +1793,7 @@ def fit(self):
# TODO: (TQ) support transfer queue
self.train_dataloader.sampler.update(batch=batch)

asyncio.run(self.data_system_client.async_clear(self.global_steps - 1))
asyncio.run(self.data_system_client.async_clear(partition_id=f"train_{self.global_steps - 1}"))
# TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_steps)

Expand Down
2 changes: 1 addition & 1 deletion requirements_transferqueue.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# requirements.txt records the full set of dependencies for development
git+https://github.com/TransferQueue/TransferQueue.git@68c04e7
git+https://github.com/TransferQueue/TransferQueue.git@862b74a
Loading