Skip to content

Commit 20d0f98

Browse files
committed
simpify the implementation of main_ppo in recipe/transfer_queue
1 parent 77c9a0e commit 20d0f98

File tree

8 files changed

+45
-238
lines changed

8 files changed

+45
-238
lines changed

recipe/transfer_queue/config/transfer_queue_ppo_trainer.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,7 @@ hydra:
55
defaults:
66
- ppo_trainer
77
- _self_
8+
9+
# config for TransferQueue
10+
transfer_queue:
11+
enable: True

recipe/transfer_queue/main_ppo.py

Lines changed: 17 additions & 230 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,18 @@
2222
import ray
2323
from omegaconf import OmegaConf
2424

25-
from verl.experimental.dataset.sampler import AbstractSampler
2625
from verl.trainer.constants_ppo import get_ppo_ray_runtime_env
26+
from verl.trainer.main_ppo import (
27+
TaskRunner as MainTaskRunner,
28+
)
29+
from verl.trainer.main_ppo import (
30+
create_rl_dataset,
31+
create_rl_sampler,
32+
)
2733
from verl.trainer.ppo.reward import load_reward_manager
2834
from verl.trainer.ppo.utils import need_critic, need_reference_policy
2935
from verl.utils.config import validate_config
3036
from verl.utils.device import is_cuda_available
31-
from verl.utils.import_utils import load_extern_type
3237

3338
from .ray_trainer import RayPPOTrainer
3439

@@ -44,13 +49,14 @@ def main(config):
4449

4550

4651
# Define a function to run the PPO-like training process
47-
def run_ppo(config) -> None:
52+
def run_ppo(config, task_runner_class=None) -> None:
4853
"""Initialize Ray cluster and run distributed PPO training process.
4954
5055
Args:
5156
config: Training configuration object containing all necessary parameters
5257
for distributed PPO training including Ray initialization settings,
5358
model paths, and training hyperparameters.
59+
task_runner_class: For recipe to change TaskRunner.
5460
"""
5561
# Check if Ray is not initialized
5662
if not ray.is_initialized():
@@ -63,9 +69,14 @@ def run_ppo(config) -> None:
6369
runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {})
6470
runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)
6571
ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env})
72+
if config.transfer_queue.enable:
73+
ray_init_kwargs["TRANSFER_QUEUE_ENABLE"] = "1"
6674
print(f"ray init kwargs: {ray_init_kwargs}")
6775
ray.init(**OmegaConf.to_container(ray_init_kwargs))
6876

77+
if task_runner_class is None:
78+
task_runner_class = ray.remote(num_cpus=1)(TaskRunner) # please make sure main_task is not scheduled on head
79+
6980
# Create a remote instance of the TaskRunner class, and
7081
# Execute the `run` method of the TaskRunner instance remotely and wait for it to complete
7182
if (
@@ -80,9 +91,9 @@ def run_ppo(config) -> None:
8091
nsight_options = OmegaConf.to_container(
8192
config.global_profiler.global_tool_config.nsys.controller_nsight_options
8293
)
83-
runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote()
94+
runner = task_runner_class.options(runtime_env={"nsight": nsight_options}).remote()
8495
else:
85-
runner = TaskRunner.remote()
96+
runner = task_runner_class.remote()
8697
ray.get(runner.run.remote(config))
8798

8899
# [Optional] get the path of the timeline trace file from the configuration, default to None
@@ -92,137 +103,7 @@ def run_ppo(config) -> None:
92103
ray.timeline(filename=timeline_json_file)
93104

94105

95-
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
96-
class TaskRunner:
97-
"""Ray remote class for executing distributed PPO training tasks.
98-
99-
This class encapsulates the main training logic and runs as a Ray remote actor
100-
to enable distributed execution across multiple nodes and GPUs.
101-
102-
Attributes:
103-
role_worker_mapping: Dictionary mapping Role enums to Ray remote worker classes
104-
mapping: Dictionary mapping Role enums to resource pool IDs for GPU allocation
105-
"""
106-
107-
def __init__(self):
108-
self.role_worker_mapping = {}
109-
self.mapping = {}
110-
111-
def add_actor_rollout_worker(self, config):
112-
"""Add actor rollout worker based on the actor strategy."""
113-
from verl.single_controller.ray import RayWorkerGroup
114-
115-
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
116-
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker
117-
118-
actor_rollout_cls = (
119-
AsyncActorRolloutRefWorker
120-
if config.actor_rollout_ref.rollout.mode == "async"
121-
else ActorRolloutRefWorker
122-
)
123-
ray_worker_group_cls = RayWorkerGroup
124-
125-
elif config.actor_rollout_ref.actor.strategy == "megatron":
126-
from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker
127-
128-
actor_rollout_cls = (
129-
AsyncActorRolloutRefWorker
130-
if config.actor_rollout_ref.rollout.mode == "async"
131-
else ActorRolloutRefWorker
132-
)
133-
ray_worker_group_cls = RayWorkerGroup
134-
135-
else:
136-
raise NotImplementedError
137-
138-
from verl.trainer.ppo.ray_trainer import Role
139-
140-
self.role_worker_mapping[Role.ActorRollout] = ray.remote(actor_rollout_cls)
141-
142-
return actor_rollout_cls, ray_worker_group_cls
143-
144-
def add_critic_worker(self, config):
145-
"""Add critic worker to role mapping."""
146-
if config.critic.strategy in {"fsdp", "fsdp2"}:
147-
use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
148-
if use_legacy_worker_impl in ["auto", "enable"]:
149-
from verl.workers.fsdp_workers import CriticWorker
150-
elif use_legacy_worker_impl == "disable":
151-
from verl.workers.roles import CriticWorker
152-
153-
print("Using new worker implementation")
154-
else:
155-
raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}")
156-
157-
elif config.critic.strategy == "megatron":
158-
from verl.workers.megatron_workers import CriticWorker
159-
160-
else:
161-
raise NotImplementedError
162-
163-
from verl.trainer.ppo.ray_trainer import Role
164-
165-
self.role_worker_mapping[Role.Critic] = ray.remote(CriticWorker)
166-
167-
def init_resource_pool_mgr(self, config):
168-
"""Initialize resource pool manager."""
169-
from verl.trainer.ppo.ray_trainer import Role
170-
171-
global_pool_id = "global_pool"
172-
resource_pool_spec = {
173-
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
174-
}
175-
# TODO Here you can use the new registration method to support dynamic registration of roles
176-
if config.reward_model.enable_resource_pool:
177-
if config.reward_model.n_gpus_per_node <= 0:
178-
raise ValueError("config.reward_model.n_gpus_per_node must be greater than 0")
179-
if config.reward_model.nnodes <= 0:
180-
raise ValueError("config.reward_model.nnodes must be greater than 0")
181-
182-
reward_pool = [config.reward_model.n_gpus_per_node] * config.reward_model.nnodes
183-
resource_pool_spec["reward_pool"] = reward_pool
184-
185-
self.mapping[Role.ActorRollout] = global_pool_id
186-
self.mapping[Role.Critic] = global_pool_id
187-
from verl.trainer.ppo.ray_trainer import ResourcePoolManager
188-
189-
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=self.mapping)
190-
return resource_pool_manager
191-
192-
def add_reward_model_worker(self, config):
193-
"""Add reward model worker if enabled."""
194-
from verl.trainer.ppo.ray_trainer import Role
195-
196-
if config.reward_model.enable:
197-
use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
198-
if use_legacy_worker_impl in ["auto", "enable"]:
199-
if config.reward_model.strategy in {"fsdp", "fsdp2"}:
200-
from verl.workers.fsdp_workers import RewardModelWorker
201-
elif config.reward_model.strategy == "megatron":
202-
from verl.workers.megatron_workers import RewardModelWorker
203-
else:
204-
raise NotImplementedError
205-
elif use_legacy_worker_impl == "disable":
206-
from verl.workers.roles import RewardModelWorker
207-
208-
print("Using new worker implementation")
209-
else:
210-
raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}")
211-
212-
self.role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
213-
if config.reward_model.enable_resource_pool:
214-
self.mapping[Role.RewardModel] = "reward_pool"
215-
else:
216-
self.mapping[Role.RewardModel] = "global_pool"
217-
218-
def add_ref_policy_worker(self, config, ref_policy_cls):
219-
"""Add reference policy worker if KL loss or KL reward is used."""
220-
from verl.trainer.ppo.ray_trainer import Role
221-
222-
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
223-
self.role_worker_mapping[Role.RefPolicy] = ray.remote(ref_policy_cls)
224-
self.mapping[Role.RefPolicy] = "global_pool"
225-
106+
class TaskRunner(MainTaskRunner):
226107
def run(self, config):
227108
"""Execute the main PPO training workflow.
228109
@@ -236,8 +117,6 @@ def run(self, config):
236117
# Print the initial configuration. `resolve=True` will evaluate symbolic values.
237118
from pprint import pprint
238119

239-
from omegaconf import OmegaConf
240-
241120
from verl.utils.fs import copy_to_local
242121

243122
print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}")
@@ -317,97 +196,5 @@ def run(self, config):
317196
trainer.fit()
318197

319198

320-
def create_rl_dataset(data_paths, data_config, tokenizer, processor, is_train=True):
321-
"""Create a dataset.
322-
323-
Arguments:
324-
data_paths: List of paths to data files.
325-
data_config: The data config.
326-
tokenizer (Tokenizer): The tokenizer.
327-
processor (Processor): The processor.
328-
329-
Returns:
330-
dataset (Dataset): The dataset.
331-
"""
332-
from torch.utils.data import Dataset
333-
334-
from verl.utils.dataset.rl_dataset import RLHFDataset
335-
336-
# Check if a custom dataset class is specified in the data configuration
337-
# and if the path to the custom class is provided
338-
if "custom_cls" in data_config and data_config.custom_cls.get("path", None) is not None:
339-
# Dynamically load the custom dataset class
340-
dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name)
341-
# Verify that the custom dataset class inherits from torch.utils.data.Dataset
342-
if not issubclass(dataset_cls, Dataset):
343-
raise TypeError(
344-
f"The custom dataset class '{data_config.custom_cls.name}' from "
345-
f"'{data_config.custom_cls.path}' must inherit from torch.utils.data.Dataset"
346-
)
347-
elif "datagen" in data_config and data_config.datagen.get("path", None) is not None and is_train:
348-
# If a data generation strategy is specified, use the DynamicGenDataset class
349-
from verl.utils.dataset.dynamicgen_dataset import DynamicGenDataset
350-
351-
dataset_cls = DynamicGenDataset
352-
print("Using DynamicGenDataset for data generation.")
353-
354-
else:
355-
# Use the default RLHFDataset class if no custom class is specified
356-
dataset_cls = RLHFDataset
357-
print(f"Using dataset class: {dataset_cls.__name__}")
358-
359-
# Instantiate the dataset using the determined dataset class
360-
dataset = dataset_cls(
361-
data_files=data_paths,
362-
tokenizer=tokenizer,
363-
processor=processor,
364-
config=data_config,
365-
)
366-
367-
return dataset
368-
369-
370-
def create_rl_sampler(data_config, dataset):
371-
"""Create a sampler for the dataset.
372-
373-
Arguments:
374-
data_config: The data config.
375-
dataset (Dataset): The dataset.
376-
377-
Returns:
378-
sampler (Sampler): The sampler.
379-
"""
380-
import torch
381-
from torch.utils.data import RandomSampler, SequentialSampler
382-
383-
if data_config.sampler is not None and data_config.sampler.get("class_path", None) is not None:
384-
curriculum_class = load_extern_type(
385-
data_config.sampler.class_path,
386-
data_config.sampler.class_name,
387-
)
388-
sampler = curriculum_class(
389-
data_source=dataset,
390-
data_config=data_config,
391-
)
392-
assert isinstance(sampler, AbstractSampler)
393-
assert data_config.get("dataloader_num_workers", 8) == 0, (
394-
"If using curriculum, num_workers must be 0 to prevent data caching. "
395-
"If the dataloader caches data before the batch is done the "
396-
"curriculum sampler won't have the opportunity to reorder it. "
397-
)
398-
399-
# Use a sampler to facilitate checkpoint resumption.
400-
# If shuffling is enabled in the data configuration, create a random sampler.
401-
elif data_config.shuffle:
402-
train_dataloader_generator = torch.Generator()
403-
train_dataloader_generator.manual_seed(data_config.get("seed", 1))
404-
sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator)
405-
else:
406-
# If shuffling is disabled, use a sequential sampler to iterate through the dataset in order.
407-
sampler = SequentialSampler(data_source=dataset)
408-
409-
return sampler
410-
411-
412199
if __name__ == "__main__":
413200
main()

verl/trainer/config/_generated_ppo_megatron_trainer.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,8 @@ global_profiler:
542542
context: all
543543
stacks: all
544544
kw_args: {}
545+
transfer_queue:
546+
enable: false
545547
ray_kwargs:
546548
ray_init:
547549
num_cpus: null

verl/trainer/config/_generated_ppo_trainer.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,8 @@ global_profiler:
526526
context: all
527527
stacks: all
528528
kw_args: {}
529+
transfer_queue:
530+
enable: false
529531
ray_kwargs:
530532
ray_init:
531533
num_cpus: null

verl/trainer/config/ppo_megatron_trainer.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,12 @@ global_profiler:
189189
# devices, record_context etc.
190190
kw_args: {}
191191

192+
# configs for TransferQueue
193+
transfer_queue:
194+
195+
# Whether to enable transfer queue
196+
enable: False
197+
192198
ray_kwargs:
193199
ray_init:
194200
num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then.

verl/trainer/config/ppo_trainer.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,12 @@ global_profiler:
317317
# devices, record_context etc.
318318
kw_args: {}
319319

320+
# configs for TransferQueue
321+
transfer_queue:
322+
323+
# Whether to enable transfer queue
324+
enable: False
325+
320326
# configs related to ray
321327
ray_kwargs:
322328

verl/trainer/main_ppo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,13 @@ def run_ppo(config, task_runner_class=None) -> None:
6363
runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {})
6464
runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)
6565
ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env})
66+
if config.transfer_queue.enable:
67+
ray_init_kwargs["TRANSFER_QUEUE_ENABLE"] = "1"
6668
print(f"ray init kwargs: {ray_init_kwargs}")
6769
ray.init(**OmegaConf.to_container(ray_init_kwargs))
6870

6971
if task_runner_class is None:
70-
task_runner_class = ray.remote(TaskRunner).options(
71-
num_cpus=1
72-
) # please make sure main_task is not scheduled on head
72+
task_runner_class = ray.remote(num_cpus=1)(TaskRunner) # please make sure main_task is not scheduled on head
7373

7474
# Create a remote instance of the TaskRunner class, and
7575
# Execute the `run` method of the TaskRunner instance remotely and wait for it to complete

0 commit comments

Comments
 (0)