2626import torch
2727import torch .distributed
2828import torch .distributed as dist
29- import vllm # noqa: F401 ; import vllm to avoid "Cuda failure 1 'invalid argument'"
29+ import vllm # noqa: F401 ; import vllm to set NCCL_CUMEM_ENABLE automatically.
3030from codetiming import Timer
3131from omegaconf import DictConfig , open_dict
3232from peft import LoraConfig , TaskType , get_peft_model
@@ -126,7 +126,6 @@ def __init__(self, config: DictConfig, role: str):
126126 assert self .role in ["actor" , "rollout" , "ref" , "actor_rollout" , "actor_rollout_ref" ]
127127
128128 self ._is_actor = self .role in ["actor" , "actor_rollout" , "actor_rollout_ref" ]
129- self ._is_rollout = self .role in ["rollout" , "actor_rollout" , "actor_rollout_ref" ]
130129 self ._is_ref = self .role in ["ref" , "actor_rollout_ref" ]
131130
132131 self ._is_offload_param = False
@@ -170,14 +169,6 @@ def __init__(self, config: DictConfig, role: str):
170169 > 0
171170 ), f"normalized ppo_mini_batch_size { self .config .actor .ppo_mini_batch_size } should be larger than ppo_micro_batch_size_per_gpu { self .config .actor .ppo_micro_batch_size_per_gpu } "
172171
173- # normalize rollout config
174- if self ._is_rollout and self .config .rollout .log_prob_micro_batch_size is not None :
175- self .config .rollout .log_prob_micro_batch_size //= (
176- self .device_mesh .size () // self .ulysses_sequence_parallel_size
177- )
178- self .config .rollout .log_prob_micro_batch_size_per_gpu = (
179- self .config .rollout .log_prob_micro_batch_size
180- )
181172 # normalize ref config
182173 if self ._is_ref and self .config .ref .log_prob_micro_batch_size is not None :
183174 self .config .ref .log_prob_micro_batch_size //= (
@@ -339,10 +330,6 @@ def _build_model_optimizer( # noqa: C901
339330 is_lora = self .config .model .get ("lora_rank" , 0 ) > 0 ,
340331 )
341332
342- if self ._is_rollout and self .config .rollout .name == "hf" :
343- # TODO(zhangchi.usc1992, shengguangming) fix me. Current, auto_wrap_policy causes HFRollout to hang in Gemma
344- auto_wrap_policy = None
345-
346333 if self .rank == 0 :
347334 print (f"wrap_policy: { auto_wrap_policy } " )
348335
@@ -450,136 +437,6 @@ def _build_model_optimizer( # noqa: C901
450437
451438 return actor_module_fsdp , actor_optimizer , actor_lr_scheduler , actor_model_config
452439
453- def _build_rollout (self , trust_remote_code = False ):
454- from torch .distributed .device_mesh import init_device_mesh
455-
456- # TODO(sgm): support FSDP hybrid shard for larger model
457- infer_tp = self .config .rollout .tensor_model_parallel_size
458- dp = self .world_size // infer_tp
459- assert (
460- self .world_size % infer_tp == 0
461- ), f"rollout world_size: { self .world_size } is not divisible by infer_tp: { infer_tp } "
462- rollout_device_mesh = init_device_mesh (
463- device_name , mesh_shape = (dp , infer_tp ), mesh_dim_names = ["dp" , "infer_tp" ]
464- )
465- rollout_name = self .config .rollout .name
466- if rollout_name == "hf" :
467- from verl .workers .rollout import HFRollout
468- from verl .workers .sharding_manager .base import BaseShardingManager
469-
470- rollout = HFRollout (module = self .actor_module_fsdp , config = self .config .rollout )
471- rollout_sharding_manager = BaseShardingManager ()
472- # TODO: a sharding manager that do nothing?
473-
474- elif rollout_name == "vllm" :
475- from verl .workers .rollout .vllm_rollout import vllm_mode , vLLMRollout
476- from verl .workers .sharding_manager .fsdp_vllm import FSDPVLLMShardingManager
477-
478- log_gpu_memory_usage (f"Before building { rollout_name } rollout" , logger = logger )
479- local_path = copy_to_local (
480- self .config .model .path , use_shm = self .config .model .get ("use_shm" , False )
481- )
482- lora_kwargs = (
483- {
484- "lora_kwargs" : {
485- "enable_lora" : True ,
486- "max_loras" : 1 ,
487- "max_lora_rank" : self ._lora_rank ,
488- }
489- }
490- if self ._is_lora
491- else {}
492- )
493- # lora_kwargs = {}
494- if vllm_mode == "customized" :
495- rollout = vLLMRollout (
496- actor_module = self .actor_module_fsdp ,
497- config = self .config .rollout ,
498- tokenizer = self .tokenizer ,
499- model_hf_config = self .actor_model_config ,
500- trust_remote_code = trust_remote_code ,
501- ** lora_kwargs ,
502- )
503- elif vllm_mode == "spmd" :
504- from verl .workers .rollout .vllm_rollout import vLLMAsyncRollout
505-
506- vllm_rollout_cls = (
507- vLLMRollout if self .config .rollout .mode == "sync" else vLLMAsyncRollout
508- )
509- rollout = vllm_rollout_cls (
510- model_path = local_path ,
511- config = self .config .rollout ,
512- tokenizer = self .tokenizer ,
513- model_hf_config = self .actor_model_config ,
514- device_mesh = rollout_device_mesh ,
515- trust_remote_code = trust_remote_code ,
516- ** lora_kwargs ,
517- )
518- else :
519- raise NotImplementedError ("vllm_mode must be 'customized' or 'spmd'" )
520-
521- log_gpu_memory_usage (f"After building { rollout_name } rollout" , logger = logger )
522- full_params = torch .distributed .get_world_size () == 1
523- rollout_sharding_manager = FSDPVLLMShardingManager (
524- module = self .actor_module_fsdp ,
525- inference_engine = rollout .inference_engine ,
526- model_config = self .actor_model_config ,
527- full_params = full_params ,
528- device_mesh = rollout_device_mesh ,
529- offload_param = self ._is_offload_param ,
530- load_format = self .config .rollout .load_format ,
531- layered_summon = self .config .rollout .get ("layered_summon" , False ),
532- )
533- log_gpu_memory_usage ("After building sharding manager" , logger = logger )
534-
535- elif rollout_name in ["sglang" , "sglang_async" ]:
536- if rollout_name == "sglang_async" :
537- warnings .warn (
538- "'sglang_async' has been deprecated and merged into 'sglang'. Please use 'sglang' going forward." ,
539- DeprecationWarning ,
540- stacklevel = 2 ,
541- )
542- from verl .workers .rollout .sglang_rollout import SGLangRollout
543-
544- # NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to
545- # SGLang's model_runner would check CUDA device capability. However, due to verl's setting,
546- # the main process of ray can not find any CUDA device, which would potentially lead to:
547- # "RuntimeError: No CUDA GPUs are available".
548- # For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and
549- # we import it here use the abs path.
550- # check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76
551- from verl .workers .sharding_manager .fsdp_sglang import (
552- FSDPSGLangShardingManager ,
553- )
554-
555- local_path = copy_to_local (self .config .model .path )
556- log_gpu_memory_usage (f"Before building { rollout_name } rollout" , logger = logger )
557- rollout = SGLangRollout (
558- actor_module = local_path ,
559- config = self .config .rollout ,
560- tokenizer = self .tokenizer ,
561- model_hf_config = self .actor_model_config ,
562- trust_remote_code = trust_remote_code ,
563- )
564- log_gpu_memory_usage (f"After building { rollout_name } rollout" , logger = logger )
565-
566- if torch .distributed .get_world_size () == 1 :
567- self .config .rollout .load_format = "dummy_hf"
568- rollout_sharding_manager = FSDPSGLangShardingManager (
569- module = self .actor_module_fsdp ,
570- inference_engine = rollout ._engine ,
571- model_config = self .actor_model_config ,
572- full_params = "hf" in self .config .rollout .load_format ,
573- device_mesh = rollout_device_mesh ,
574- offload_param = self ._is_offload_param ,
575- )
576- log_gpu_memory_usage ("After building sharding manager" , logger = logger )
577-
578- else :
579- raise NotImplementedError (f"Rollout name: { self .config .rollout .name } is not supported" )
580-
581- return rollout , rollout_sharding_manager
582-
583440 @register (dispatch_mode = Dispatch .ONE_TO_ALL )
584441 def init_model (self ):
585442 from trinity .trainer .verl .dp_actor import DataParallelPPOActor
@@ -597,14 +454,10 @@ def init_model(self):
597454 use_shm = self .config .model .get ("use_shm" , False )
598455 use_fused_kernels = self .config .model .get ("use_fused_kernels" , False )
599456
600- if self ._is_actor or self . _is_rollout :
457+ if self ._is_actor :
601458 # we need the model for actor and rollout
602- if self ._is_actor :
603- optim_config = self .config .actor .optim
604- fsdp_config = self .config .actor .fsdp_config
605- else :
606- optim_config = None
607- fsdp_config = OmegaConf .create ()
459+ optim_config = self .config .actor .optim
460+ fsdp_config = self .config .actor .fsdp_config
608461
609462 local_path = copy_to_local (self .config .model .path , use_shm = use_shm )
610463 (
@@ -651,11 +504,6 @@ def init_model(self):
651504 actor_optimizer = self .actor_optimizer ,
652505 )
653506
654- if self ._is_rollout :
655- self .rollout , self .rollout_sharding_manager = self ._build_rollout (
656- trust_remote_code = self .config .model .get ("trust_remote_code" , False )
657- )
658-
659507 if self ._is_ref :
660508 local_path = copy_to_local (self .config .model .path , use_shm = use_shm )
661509 self .ref_module_fsdp = self ._build_model_optimizer (
@@ -713,7 +561,9 @@ def setup_weight_sync_group(self):
713561 realname = (
714562 name_prefix [len (FSDP_PREFIX ) :] + "." + name if name_prefix else name
715563 )
716- self .state_dict_meta .append ((realname , param .dtype , param .shape ))
564+ self .state_dict_meta .append (
565+ (realname , str (param .dtype ), tuple (param .shape ))
566+ )
717567 param = None
718568 torch .cuda .empty_cache ()
719569
@@ -815,38 +665,6 @@ def update_actor(self, data: DataProto):
815665
816666 return output
817667
818- @register (dispatch_mode = Dispatch .DP_COMPUTE_PROTO )
819- def generate_sequences (self , prompts : DataProto ):
820- # Support all hardwares
821- prompts = prompts .to (get_torch_device ().current_device ())
822-
823- assert self ._is_rollout
824-
825- meta_info = {
826- "eos_token_id" : self .generation_config .eos_token_id
827- if self .generation_config is not None
828- else self .tokenizer .eos_token_id ,
829- "pad_token_id" : self .generation_config .pad_token_id
830- if self .generation_config is not None
831- else self .tokenizer .pad_token_id ,
832- }
833- prompts .meta_info .update (meta_info )
834- with self .rollout_sharding_manager :
835- log_gpu_memory_usage ("After entering rollout sharding manager" , logger = logger )
836-
837- prompts = self .rollout_sharding_manager .preprocess_data (prompts )
838- output = self .rollout .generate_sequences (prompts = prompts )
839-
840- log_gpu_memory_usage ("After rollout generation" , logger = logger )
841-
842- output = self .rollout_sharding_manager .postprocess_data (output )
843-
844- output = output .to ("cpu" )
845-
846- # clear kv cache
847- get_torch_device ().empty_cache ()
848- return output
849-
850668 @register (dispatch_mode = Dispatch .DP_COMPUTE_PROTO )
851669 def compute_log_prob (self , data : DataProto ):
852670 # when is_lora is True, we use the actor without lora applied to calculate the log_prob
0 commit comments