2222import ray
2323from omegaconf import OmegaConf
2424
25- from verl .experimental .dataset .sampler import AbstractSampler
2625from verl .trainer .constants_ppo import get_ppo_ray_runtime_env
26+ from verl .trainer .main_ppo import (
27+ TaskRunner as MainTaskRunner ,
28+ )
29+ from verl .trainer .main_ppo import (
30+ create_rl_dataset ,
31+ create_rl_sampler ,
32+ )
2733from verl .trainer .ppo .reward import load_reward_manager
2834from verl .trainer .ppo .utils import need_critic , need_reference_policy
2935from verl .utils .config import validate_config
3036from verl .utils .device import is_cuda_available
31- from verl .utils .import_utils import load_extern_type
3237
3338from .ray_trainer import RayPPOTrainer
3439
@@ -44,13 +49,14 @@ def main(config):
4449
4550
4651# Define a function to run the PPO-like training process
47- def run_ppo (config ) -> None :
52+ def run_ppo (config , task_runner_class = None ) -> None :
4853 """Initialize Ray cluster and run distributed PPO training process.
4954
5055 Args:
5156 config: Training configuration object containing all necessary parameters
5257 for distributed PPO training including Ray initialization settings,
5358 model paths, and training hyperparameters.
59+ task_runner_class: For recipe to change TaskRunner.
5460 """
5561 # Check if Ray is not initialized
5662 if not ray .is_initialized ():
@@ -63,9 +69,14 @@ def run_ppo(config) -> None:
6369 runtime_env_kwargs = ray_init_kwargs .get ("runtime_env" , {})
6470 runtime_env = OmegaConf .merge (default_runtime_env , runtime_env_kwargs )
6571 ray_init_kwargs = OmegaConf .create ({** ray_init_kwargs , "runtime_env" : runtime_env })
72+ if config .transfer_queue .enable :
73+ ray_init_kwargs ["TRANSFER_QUEUE_ENABLE" ] = "1"
6674 print (f"ray init kwargs: { ray_init_kwargs } " )
6775 ray .init (** OmegaConf .to_container (ray_init_kwargs ))
6876
77+ if task_runner_class is None :
78+ task_runner_class = ray .remote (num_cpus = 1 )(TaskRunner ) # please make sure main_task is not scheduled on head
79+
6980 # Create a remote instance of the TaskRunner class, and
7081 # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete
7182 if (
@@ -80,9 +91,9 @@ def run_ppo(config) -> None:
8091 nsight_options = OmegaConf .to_container (
8192 config .global_profiler .global_tool_config .nsys .controller_nsight_options
8293 )
83- runner = TaskRunner .options (runtime_env = {"nsight" : nsight_options }).remote ()
94+ runner = task_runner_class .options (runtime_env = {"nsight" : nsight_options }).remote ()
8495 else :
85- runner = TaskRunner .remote ()
96+ runner = task_runner_class .remote ()
8697 ray .get (runner .run .remote (config ))
8798
8899 # [Optional] get the path of the timeline trace file from the configuration, default to None
@@ -92,137 +103,7 @@ def run_ppo(config) -> None:
92103 ray .timeline (filename = timeline_json_file )
93104
94105
95- @ray .remote (num_cpus = 1 ) # please make sure main_task is not scheduled on head
96- class TaskRunner :
97- """Ray remote class for executing distributed PPO training tasks.
98-
99- This class encapsulates the main training logic and runs as a Ray remote actor
100- to enable distributed execution across multiple nodes and GPUs.
101-
102- Attributes:
103- role_worker_mapping: Dictionary mapping Role enums to Ray remote worker classes
104- mapping: Dictionary mapping Role enums to resource pool IDs for GPU allocation
105- """
106-
107- def __init__ (self ):
108- self .role_worker_mapping = {}
109- self .mapping = {}
110-
111- def add_actor_rollout_worker (self , config ):
112- """Add actor rollout worker based on the actor strategy."""
113- from verl .single_controller .ray import RayWorkerGroup
114-
115- if config .actor_rollout_ref .actor .strategy in {"fsdp" , "fsdp2" }:
116- from verl .workers .fsdp_workers import ActorRolloutRefWorker , AsyncActorRolloutRefWorker
117-
118- actor_rollout_cls = (
119- AsyncActorRolloutRefWorker
120- if config .actor_rollout_ref .rollout .mode == "async"
121- else ActorRolloutRefWorker
122- )
123- ray_worker_group_cls = RayWorkerGroup
124-
125- elif config .actor_rollout_ref .actor .strategy == "megatron" :
126- from verl .workers .megatron_workers import ActorRolloutRefWorker , AsyncActorRolloutRefWorker
127-
128- actor_rollout_cls = (
129- AsyncActorRolloutRefWorker
130- if config .actor_rollout_ref .rollout .mode == "async"
131- else ActorRolloutRefWorker
132- )
133- ray_worker_group_cls = RayWorkerGroup
134-
135- else :
136- raise NotImplementedError
137-
138- from verl .trainer .ppo .ray_trainer import Role
139-
140- self .role_worker_mapping [Role .ActorRollout ] = ray .remote (actor_rollout_cls )
141-
142- return actor_rollout_cls , ray_worker_group_cls
143-
144- def add_critic_worker (self , config ):
145- """Add critic worker to role mapping."""
146- if config .critic .strategy in {"fsdp" , "fsdp2" }:
147- use_legacy_worker_impl = config .trainer .get ("use_legacy_worker_impl" , "auto" )
148- if use_legacy_worker_impl in ["auto" , "enable" ]:
149- from verl .workers .fsdp_workers import CriticWorker
150- elif use_legacy_worker_impl == "disable" :
151- from verl .workers .roles import CriticWorker
152-
153- print ("Using new worker implementation" )
154- else :
155- raise ValueError (f"Invalid use_legacy_worker_impl: { use_legacy_worker_impl } " )
156-
157- elif config .critic .strategy == "megatron" :
158- from verl .workers .megatron_workers import CriticWorker
159-
160- else :
161- raise NotImplementedError
162-
163- from verl .trainer .ppo .ray_trainer import Role
164-
165- self .role_worker_mapping [Role .Critic ] = ray .remote (CriticWorker )
166-
167- def init_resource_pool_mgr (self , config ):
168- """Initialize resource pool manager."""
169- from verl .trainer .ppo .ray_trainer import Role
170-
171- global_pool_id = "global_pool"
172- resource_pool_spec = {
173- global_pool_id : [config .trainer .n_gpus_per_node ] * config .trainer .nnodes ,
174- }
175- # TODO Here you can use the new registration method to support dynamic registration of roles
176- if config .reward_model .enable_resource_pool :
177- if config .reward_model .n_gpus_per_node <= 0 :
178- raise ValueError ("config.reward_model.n_gpus_per_node must be greater than 0" )
179- if config .reward_model .nnodes <= 0 :
180- raise ValueError ("config.reward_model.nnodes must be greater than 0" )
181-
182- reward_pool = [config .reward_model .n_gpus_per_node ] * config .reward_model .nnodes
183- resource_pool_spec ["reward_pool" ] = reward_pool
184-
185- self .mapping [Role .ActorRollout ] = global_pool_id
186- self .mapping [Role .Critic ] = global_pool_id
187- from verl .trainer .ppo .ray_trainer import ResourcePoolManager
188-
189- resource_pool_manager = ResourcePoolManager (resource_pool_spec = resource_pool_spec , mapping = self .mapping )
190- return resource_pool_manager
191-
192- def add_reward_model_worker (self , config ):
193- """Add reward model worker if enabled."""
194- from verl .trainer .ppo .ray_trainer import Role
195-
196- if config .reward_model .enable :
197- use_legacy_worker_impl = config .trainer .get ("use_legacy_worker_impl" , "auto" )
198- if use_legacy_worker_impl in ["auto" , "enable" ]:
199- if config .reward_model .strategy in {"fsdp" , "fsdp2" }:
200- from verl .workers .fsdp_workers import RewardModelWorker
201- elif config .reward_model .strategy == "megatron" :
202- from verl .workers .megatron_workers import RewardModelWorker
203- else :
204- raise NotImplementedError
205- elif use_legacy_worker_impl == "disable" :
206- from verl .workers .roles import RewardModelWorker
207-
208- print ("Using new worker implementation" )
209- else :
210- raise ValueError (f"Invalid use_legacy_worker_impl: { use_legacy_worker_impl } " )
211-
212- self .role_worker_mapping [Role .RewardModel ] = ray .remote (RewardModelWorker )
213- if config .reward_model .enable_resource_pool :
214- self .mapping [Role .RewardModel ] = "reward_pool"
215- else :
216- self .mapping [Role .RewardModel ] = "global_pool"
217-
218- def add_ref_policy_worker (self , config , ref_policy_cls ):
219- """Add reference policy worker if KL loss or KL reward is used."""
220- from verl .trainer .ppo .ray_trainer import Role
221-
222- if config .algorithm .use_kl_in_reward or config .actor_rollout_ref .actor .use_kl_loss :
223- self .role_worker_mapping [Role .RefPolicy ] = ray .remote (ref_policy_cls )
224- self .mapping [Role .RefPolicy ] = "global_pool"
225-
106+ class TaskRunner (MainTaskRunner ):
226107 def run (self , config ):
227108 """Execute the main PPO training workflow.
228109
@@ -236,8 +117,6 @@ def run(self, config):
236117 # Print the initial configuration. `resolve=True` will evaluate symbolic values.
237118 from pprint import pprint
238119
239- from omegaconf import OmegaConf
240-
241120 from verl .utils .fs import copy_to_local
242121
243122 print (f"TaskRunner hostname: { socket .gethostname ()} , PID: { os .getpid ()} " )
@@ -317,97 +196,5 @@ def run(self, config):
317196 trainer .fit ()
318197
319198
320- def create_rl_dataset (data_paths , data_config , tokenizer , processor , is_train = True ):
321- """Create a dataset.
322-
323- Arguments:
324- data_paths: List of paths to data files.
325- data_config: The data config.
326- tokenizer (Tokenizer): The tokenizer.
327- processor (Processor): The processor.
328-
329- Returns:
330- dataset (Dataset): The dataset.
331- """
332- from torch .utils .data import Dataset
333-
334- from verl .utils .dataset .rl_dataset import RLHFDataset
335-
336- # Check if a custom dataset class is specified in the data configuration
337- # and if the path to the custom class is provided
338- if "custom_cls" in data_config and data_config .custom_cls .get ("path" , None ) is not None :
339- # Dynamically load the custom dataset class
340- dataset_cls = load_extern_type (data_config .custom_cls .path , data_config .custom_cls .name )
341- # Verify that the custom dataset class inherits from torch.utils.data.Dataset
342- if not issubclass (dataset_cls , Dataset ):
343- raise TypeError (
344- f"The custom dataset class '{ data_config .custom_cls .name } ' from "
345- f"'{ data_config .custom_cls .path } ' must inherit from torch.utils.data.Dataset"
346- )
347- elif "datagen" in data_config and data_config .datagen .get ("path" , None ) is not None and is_train :
348- # If a data generation strategy is specified, use the DynamicGenDataset class
349- from verl .utils .dataset .dynamicgen_dataset import DynamicGenDataset
350-
351- dataset_cls = DynamicGenDataset
352- print ("Using DynamicGenDataset for data generation." )
353-
354- else :
355- # Use the default RLHFDataset class if no custom class is specified
356- dataset_cls = RLHFDataset
357- print (f"Using dataset class: { dataset_cls .__name__ } " )
358-
359- # Instantiate the dataset using the determined dataset class
360- dataset = dataset_cls (
361- data_files = data_paths ,
362- tokenizer = tokenizer ,
363- processor = processor ,
364- config = data_config ,
365- )
366-
367- return dataset
368-
369-
370- def create_rl_sampler (data_config , dataset ):
371- """Create a sampler for the dataset.
372-
373- Arguments:
374- data_config: The data config.
375- dataset (Dataset): The dataset.
376-
377- Returns:
378- sampler (Sampler): The sampler.
379- """
380- import torch
381- from torch .utils .data import RandomSampler , SequentialSampler
382-
383- if data_config .sampler is not None and data_config .sampler .get ("class_path" , None ) is not None :
384- curriculum_class = load_extern_type (
385- data_config .sampler .class_path ,
386- data_config .sampler .class_name ,
387- )
388- sampler = curriculum_class (
389- data_source = dataset ,
390- data_config = data_config ,
391- )
392- assert isinstance (sampler , AbstractSampler )
393- assert data_config .get ("dataloader_num_workers" , 8 ) == 0 , (
394- "If using curriculum, num_workers must be 0 to prevent data caching. "
395- "If the dataloader caches data before the batch is done the "
396- "curriculum sampler won't have the opportunity to reorder it. "
397- )
398-
399- # Use a sampler to facilitate checkpoint resumption.
400- # If shuffling is enabled in the data configuration, create a random sampler.
401- elif data_config .shuffle :
402- train_dataloader_generator = torch .Generator ()
403- train_dataloader_generator .manual_seed (data_config .get ("seed" , 1 ))
404- sampler = RandomSampler (data_source = dataset , generator = train_dataloader_generator )
405- else :
406- # If shuffling is disabled, use a sequential sampler to iterate through the dataset in order.
407- sampler = SequentialSampler (data_source = dataset )
408-
409- return sampler
410-
411-
412199if __name__ == "__main__" :
413200 main ()
0 commit comments