Skip to content

Commit 2a36e0e

Browse files
committed
clean code
1 parent aedfa53 commit 2a36e0e

File tree

3 files changed

+11
-192
lines changed

3 files changed

+11
-192
lines changed

trinity/common/models/vllm_worker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,16 @@ def init_process_group(
5757
)
5858
self._explorer_actor = None
5959

60-
def update_weight(self, name, dtype, shape, empty_cache=False):
60+
def update_weight(self, name: str, dtype_str: str, shape: tuple, empty_cache=False):
6161
"""Broadcast weight to all vllm workers from source rank 0 (actor model)"""
6262
if self._weight_update_rank == 0:
6363
if self._explorer_actor is None:
6464
self._explorer_actor = ray.get_actor(name="explorer")
6565
weight = ray.get(self._explorer_actor.get_weight.remote(name))
6666
weight = weight.to(self.device)
6767
else:
68-
weight = torch.empty(shape, dtype=dtype, device="cuda")
69-
68+
dtype = getattr(torch, dtype_str.split(".")[-1])
69+
weight = torch.empty(shape, dtype=dtype, device=self.device)
7070
torch.distributed.broadcast(weight, 0, group=self._model_update_group)
7171
weight = weight.type(self.model_config.dtype)
7272

trinity/explorer/explorer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def setup_weight_sync_group(
8585
f"world_size={world_size}, rank_offset={base_offset}"
8686
)
8787
self.state_dict_meta = state_dict_meta
88+
# TODO: save state_dict in models
8889
refs = [
8990
model.init_process_group.remote(
9091
master_address=master_address,

trinity/trainer/verl/fsdp_workers.py

Lines changed: 7 additions & 189 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import torch
2727
import torch.distributed
2828
import torch.distributed as dist
29-
import vllm # noqa: F401 ; import vllm to avoid "Cuda failure 1 'invalid argument'"
29+
import vllm # noqa: F401 ; import vllm to set NCCL_CUMEM_ENABLE automatically.
3030
from codetiming import Timer
3131
from omegaconf import DictConfig, open_dict
3232
from peft import LoraConfig, TaskType, get_peft_model
@@ -126,7 +126,6 @@ def __init__(self, config: DictConfig, role: str):
126126
assert self.role in ["actor", "rollout", "ref", "actor_rollout", "actor_rollout_ref"]
127127

128128
self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"]
129-
self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"]
130129
self._is_ref = self.role in ["ref", "actor_rollout_ref"]
131130

132131
self._is_offload_param = False
@@ -170,14 +169,6 @@ def __init__(self, config: DictConfig, role: str):
170169
> 0
171170
), f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}"
172171

173-
# normalize rollout config
174-
if self._is_rollout and self.config.rollout.log_prob_micro_batch_size is not None:
175-
self.config.rollout.log_prob_micro_batch_size //= (
176-
self.device_mesh.size() // self.ulysses_sequence_parallel_size
177-
)
178-
self.config.rollout.log_prob_micro_batch_size_per_gpu = (
179-
self.config.rollout.log_prob_micro_batch_size
180-
)
181172
# normalize ref config
182173
if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None:
183174
self.config.ref.log_prob_micro_batch_size //= (
@@ -339,10 +330,6 @@ def _build_model_optimizer( # noqa: C901
339330
is_lora=self.config.model.get("lora_rank", 0) > 0,
340331
)
341332

342-
if self._is_rollout and self.config.rollout.name == "hf":
343-
# TODO(zhangchi.usc1992, shengguangming) fix me. Current, auto_wrap_policy causes HFRollout to hang in Gemma
344-
auto_wrap_policy = None
345-
346333
if self.rank == 0:
347334
print(f"wrap_policy: {auto_wrap_policy}")
348335

@@ -450,136 +437,6 @@ def _build_model_optimizer( # noqa: C901
450437

451438
return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config
452439

453-
def _build_rollout(self, trust_remote_code=False):
454-
from torch.distributed.device_mesh import init_device_mesh
455-
456-
# TODO(sgm): support FSDP hybrid shard for larger model
457-
infer_tp = self.config.rollout.tensor_model_parallel_size
458-
dp = self.world_size // infer_tp
459-
assert (
460-
self.world_size % infer_tp == 0
461-
), f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}"
462-
rollout_device_mesh = init_device_mesh(
463-
device_name, mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]
464-
)
465-
rollout_name = self.config.rollout.name
466-
if rollout_name == "hf":
467-
from verl.workers.rollout import HFRollout
468-
from verl.workers.sharding_manager.base import BaseShardingManager
469-
470-
rollout = HFRollout(module=self.actor_module_fsdp, config=self.config.rollout)
471-
rollout_sharding_manager = BaseShardingManager()
472-
# TODO: a sharding manager that do nothing?
473-
474-
elif rollout_name == "vllm":
475-
from verl.workers.rollout.vllm_rollout import vllm_mode, vLLMRollout
476-
from verl.workers.sharding_manager.fsdp_vllm import FSDPVLLMShardingManager
477-
478-
log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger)
479-
local_path = copy_to_local(
480-
self.config.model.path, use_shm=self.config.model.get("use_shm", False)
481-
)
482-
lora_kwargs = (
483-
{
484-
"lora_kwargs": {
485-
"enable_lora": True,
486-
"max_loras": 1,
487-
"max_lora_rank": self._lora_rank,
488-
}
489-
}
490-
if self._is_lora
491-
else {}
492-
)
493-
# lora_kwargs = {}
494-
if vllm_mode == "customized":
495-
rollout = vLLMRollout(
496-
actor_module=self.actor_module_fsdp,
497-
config=self.config.rollout,
498-
tokenizer=self.tokenizer,
499-
model_hf_config=self.actor_model_config,
500-
trust_remote_code=trust_remote_code,
501-
**lora_kwargs,
502-
)
503-
elif vllm_mode == "spmd":
504-
from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout
505-
506-
vllm_rollout_cls = (
507-
vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout
508-
)
509-
rollout = vllm_rollout_cls(
510-
model_path=local_path,
511-
config=self.config.rollout,
512-
tokenizer=self.tokenizer,
513-
model_hf_config=self.actor_model_config,
514-
device_mesh=rollout_device_mesh,
515-
trust_remote_code=trust_remote_code,
516-
**lora_kwargs,
517-
)
518-
else:
519-
raise NotImplementedError("vllm_mode must be 'customized' or 'spmd'")
520-
521-
log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger)
522-
full_params = torch.distributed.get_world_size() == 1
523-
rollout_sharding_manager = FSDPVLLMShardingManager(
524-
module=self.actor_module_fsdp,
525-
inference_engine=rollout.inference_engine,
526-
model_config=self.actor_model_config,
527-
full_params=full_params,
528-
device_mesh=rollout_device_mesh,
529-
offload_param=self._is_offload_param,
530-
load_format=self.config.rollout.load_format,
531-
layered_summon=self.config.rollout.get("layered_summon", False),
532-
)
533-
log_gpu_memory_usage("After building sharding manager", logger=logger)
534-
535-
elif rollout_name in ["sglang", "sglang_async"]:
536-
if rollout_name == "sglang_async":
537-
warnings.warn(
538-
"'sglang_async' has been deprecated and merged into 'sglang'. Please use 'sglang' going forward.",
539-
DeprecationWarning,
540-
stacklevel=2,
541-
)
542-
from verl.workers.rollout.sglang_rollout import SGLangRollout
543-
544-
# NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to
545-
# SGLang's model_runner would check CUDA device capability. However, due to verl's setting,
546-
# the main process of ray can not find any CUDA device, which would potentially lead to:
547-
# "RuntimeError: No CUDA GPUs are available".
548-
# For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and
549-
# we import it here use the abs path.
550-
# check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76
551-
from verl.workers.sharding_manager.fsdp_sglang import (
552-
FSDPSGLangShardingManager,
553-
)
554-
555-
local_path = copy_to_local(self.config.model.path)
556-
log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger)
557-
rollout = SGLangRollout(
558-
actor_module=local_path,
559-
config=self.config.rollout,
560-
tokenizer=self.tokenizer,
561-
model_hf_config=self.actor_model_config,
562-
trust_remote_code=trust_remote_code,
563-
)
564-
log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger)
565-
566-
if torch.distributed.get_world_size() == 1:
567-
self.config.rollout.load_format = "dummy_hf"
568-
rollout_sharding_manager = FSDPSGLangShardingManager(
569-
module=self.actor_module_fsdp,
570-
inference_engine=rollout._engine,
571-
model_config=self.actor_model_config,
572-
full_params="hf" in self.config.rollout.load_format,
573-
device_mesh=rollout_device_mesh,
574-
offload_param=self._is_offload_param,
575-
)
576-
log_gpu_memory_usage("After building sharding manager", logger=logger)
577-
578-
else:
579-
raise NotImplementedError(f"Rollout name: {self.config.rollout.name} is not supported")
580-
581-
return rollout, rollout_sharding_manager
582-
583440
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
584441
def init_model(self):
585442
from trinity.trainer.verl.dp_actor import DataParallelPPOActor
@@ -597,14 +454,10 @@ def init_model(self):
597454
use_shm = self.config.model.get("use_shm", False)
598455
use_fused_kernels = self.config.model.get("use_fused_kernels", False)
599456

600-
if self._is_actor or self._is_rollout:
457+
if self._is_actor:
601458
# we need the model for actor and rollout
602-
if self._is_actor:
603-
optim_config = self.config.actor.optim
604-
fsdp_config = self.config.actor.fsdp_config
605-
else:
606-
optim_config = None
607-
fsdp_config = OmegaConf.create()
459+
optim_config = self.config.actor.optim
460+
fsdp_config = self.config.actor.fsdp_config
608461

609462
local_path = copy_to_local(self.config.model.path, use_shm=use_shm)
610463
(
@@ -651,11 +504,6 @@ def init_model(self):
651504
actor_optimizer=self.actor_optimizer,
652505
)
653506

654-
if self._is_rollout:
655-
self.rollout, self.rollout_sharding_manager = self._build_rollout(
656-
trust_remote_code=self.config.model.get("trust_remote_code", False)
657-
)
658-
659507
if self._is_ref:
660508
local_path = copy_to_local(self.config.model.path, use_shm=use_shm)
661509
self.ref_module_fsdp = self._build_model_optimizer(
@@ -713,7 +561,9 @@ def setup_weight_sync_group(self):
713561
realname = (
714562
name_prefix[len(FSDP_PREFIX) :] + "." + name if name_prefix else name
715563
)
716-
self.state_dict_meta.append((realname, param.dtype, param.shape))
564+
self.state_dict_meta.append(
565+
(realname, str(param.dtype), tuple(param.shape))
566+
)
717567
param = None
718568
torch.cuda.empty_cache()
719569

@@ -815,38 +665,6 @@ def update_actor(self, data: DataProto):
815665

816666
return output
817667

818-
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
819-
def generate_sequences(self, prompts: DataProto):
820-
# Support all hardwares
821-
prompts = prompts.to(get_torch_device().current_device())
822-
823-
assert self._is_rollout
824-
825-
meta_info = {
826-
"eos_token_id": self.generation_config.eos_token_id
827-
if self.generation_config is not None
828-
else self.tokenizer.eos_token_id,
829-
"pad_token_id": self.generation_config.pad_token_id
830-
if self.generation_config is not None
831-
else self.tokenizer.pad_token_id,
832-
}
833-
prompts.meta_info.update(meta_info)
834-
with self.rollout_sharding_manager:
835-
log_gpu_memory_usage("After entering rollout sharding manager", logger=logger)
836-
837-
prompts = self.rollout_sharding_manager.preprocess_data(prompts)
838-
output = self.rollout.generate_sequences(prompts=prompts)
839-
840-
log_gpu_memory_usage("After rollout generation", logger=logger)
841-
842-
output = self.rollout_sharding_manager.postprocess_data(output)
843-
844-
output = output.to("cpu")
845-
846-
# clear kv cache
847-
get_torch_device().empty_cache()
848-
return output
849-
850668
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
851669
def compute_log_prob(self, data: DataProto):
852670
# when is_lora is True, we use the actor without lora applied to calculate the log_prob

0 commit comments

Comments
 (0)