diff --git a/examples/dpo_humanlike/train_dpo.yaml b/examples/dpo_humanlike/train_dpo.yaml index 8ffc68b397..028c997e06 100644 --- a/examples/dpo_humanlike/train_dpo.yaml +++ b/examples/dpo_humanlike/train_dpo.yaml @@ -26,8 +26,7 @@ actor_rollout_ref: min_lr_ratio: 0.1 # only useful for warmup with cosine warmup_style: cosine # select from constant/cosine total_training_steps: 783 # - beta1: 0.9 - beta2: 0.95 + betas: [0.9, 0.95] fsdp_config: wrap_policy: # transformer_layer_cls_to_wrap: None diff --git a/examples/opmd_gsm8k/train_opmd_gsm8k.yaml b/examples/opmd_gsm8k/train_opmd_gsm8k.yaml index 326904d987..44a0111d64 100644 --- a/examples/opmd_gsm8k/train_opmd_gsm8k.yaml +++ b/examples/opmd_gsm8k/train_opmd_gsm8k.yaml @@ -15,8 +15,8 @@ # entropy_coeff: default to 0.0 for now # # optimizer: -# beta1, beta2: 0.0, 0.95 # smaller than default values (0.9, 0.999), as a remedy for abrupt distribution shift -# lr: set smaller to account for beta1 = 0.0 +# betas: [0.0, 0.95] # smaller than default values (0.9, 0.999), as a remedy for abrupt distribution shift +# lr: set smaller to account for betas[0] = 0.0 # # misc: # adv_estimator: grpo # merely to disable critic model, doesn't affect adv compute when algorithm_type is opmd @@ -50,8 +50,7 @@ actor_rollout_ref: # min_lr_ratio: null # only useful for warmup with cosine warmup_style: constant # select from constant/cosine total_training_steps: -1 # must be override by program - beta1: 0.0 # set to smaller value for scenarios with abrupt distribution shift (e.g., large sync_interval) - beta2: 0.95 # set to smaller value for scenarios with abrupt distribution shift (e.g., large sync_interval) + betas: [0.0, 0.95] # set to smaller value for scenarios with abrupt distribution shift (e.g., large sync_interval) fsdp_config: wrap_policy: # transformer_layer_cls_to_wrap: None diff --git a/pyproject.toml b/pyproject.toml index 022c9a8ffe..bafa620470 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ classifiers = [ ] requires-python = ">=3.10" dependencies = [ - "verl==0.3.0.post1", + "verl==0.4.0", "ray[default]>=2.45.0", "vllm==0.8.5.post1", "tensordict==0.6.2", diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index 644fe9a8f5..e6b1b9e4e1 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -33,8 +33,7 @@ class Optim: min_lr_ratio: Optional[float] = 0.0 warmup_style: str = "constant" total_training_steps: int = -1 - beta1: float = 0.9 - beta2: float = 0.999 + betas: List[float] = field(default_factory=lambda: [0.9, 0.999]) @dataclass @@ -82,6 +81,7 @@ class Actor: tau: float = 0.001 # strength of regularization w.r.t. old / ref policy opmd_baseline: str = "mean" # mean / logavgexp, applicable to opmd use_uid: bool = False # True / False, applicable to pairwise_opmd + loss_agg_mode: str = "token-mean" # do not set @dataclass @@ -99,12 +99,20 @@ class _ValKwargs: do_sample: bool = False +@dataclass +class _MultiTurn: + enable: bool = False + + @dataclass class Rollout: # do not set val_kwargs: _ValKwargs = field(default_factory=_ValKwargs) + multi_turn: _MultiTurn = field(default_factory=_MultiTurn) temperature: float = 1.0 n: int = 1 # > 1 for grpo + log_prob_micro_batch_size: Optional[int] = None + log_prob_micro_batch_size_per_gpu: int = 1 @dataclass @@ -148,6 +156,7 @@ class Critic: cliprange_value: float = 0.0 checkpoint: Checkpoint = field(default_factory=Checkpoint) rollout_n: int = 1 + loss_agg_mode: str = "token-mean" @dataclass diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py index 595084ac02..d98e6a2993 100644 --- a/trinity/trainer/verl/dp_actor.py +++ b/trinity/trainer/verl/dp_actor.py @@ -1,4 +1,6 @@ # Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,49 +14,42 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Modified from dp_actor.py +Single Process Actor. +Modified from https://github.com/volcengine/verl/blob/0758489422e8d41a89e6c36d4c477714520f0dcc/verl/workers/actor/dp_actor.py """ import itertools -from typing import Tuple +import logging +import os import torch -import verl.utils.torch_functional as verl_F -from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input from torch import nn -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from verl import DataProto +from verl.utils.debug import GPUMemoryLogger +from verl.utils.device import get_torch_device from verl.utils.py_functional import append_to_dict from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches -from verl.utils.torch_functional import logprobs_from_logits -from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs -from verl.workers.actor import BasePPOActor +from verl.workers.actor.dp_actor import DataParallelPPOActor as DPActor from trinity.algorithm import ENTROPY_LOSS_FN, KL_FN, POLICY_LOSS_FN +from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import DummyEntropyLossFn from trinity.algorithm.kl_fn.kl_fn import DummyKLFn from trinity.algorithm.utils import prefix_metrics from trinity.common.config import AlgorithmConfig __all__ = ["DataParallelPPOActor"] +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) -class DataParallelPPOActor(BasePPOActor): + +class DataParallelPPOActor(DPActor): def __init__( - self, - config, - actor_module: nn.Module, - actor_optimizer: torch.optim.Optimizer = None, + self, config, actor_module: nn.Module, actor_optimizer: torch.optim.Optimizer = None ): """When optimizer is None, it is Reference Policy""" - super().__init__(config) - self.actor_module = actor_module - self.actor_optimizer = actor_optimizer - self.use_remove_padding = self.config.get("use_remove_padding", False) - print(f"Actor use_remove_padding={self.use_remove_padding}") - self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size - self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1 - - self.compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True) + super().__init__(config, actor_module, actor_optimizer) + self.policy_loss_fn = None self.kl_loss_fn = None self.entropy_loss_fn = None @@ -68,150 +63,8 @@ def set_algorithm(self, algorithm_config: AlgorithmConfig): **algorithm_config.entropy_loss_fn_args ) - def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Returns: - entropy: # (bs, response_len) - log_probs: # (bs, response_len) - """ - response_length = micro_batch["responses"].size(-1) - multi_modal_inputs = {} - if "multi_modal_inputs" in micro_batch: - for key in micro_batch["multi_modal_inputs"][0].keys(): - multi_modal_inputs[key] = torch.cat( - [inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0 - ) - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - input_ids = micro_batch["input_ids"] - batch_size, seqlen = input_ids.shape - attention_mask = micro_batch["attention_mask"] - position_ids = micro_batch["position_ids"] - if position_ids.dim() == 3: # qwen2vl mrope - position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen) - - if self.use_remove_padding: - input_ids_rmpad, indices, *_ = unpad_input( - input_ids.unsqueeze(-1), attention_mask - ) # input_ids_rmpad (total_nnz, ...) - input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) - - # unpad the position_ids to align the rotary - if position_ids.dim() == 3: - position_ids_rmpad = ( - index_first_axis( - rearrange(position_ids, "c b s ... -> (b s) c ..."), indices - ) - .transpose(0, 1) - .unsqueeze(1) - ) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen) - else: - position_ids_rmpad = index_first_axis( - rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices - ).transpose(0, 1) - - # for compute the log_prob - input_ids_rmpad_rolled = torch.roll( - input_ids_rmpad, shifts=-1, dims=1 - ) # (1, total_nnz) - - # pad and slice the inputs if sp > 1 - if self.use_ulysses_sp: - input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( - input_ids_rmpad, - position_ids_rmpad, - sp_size=self.ulysses_sequence_parallel_size, - ) - input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( - input_ids_rmpad_rolled, None, self.ulysses_sequence_parallel_size - ) - - input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze( - 0 - ) # ((total_nnz / sp) + pad) - - # only pass input_ids and position_ids to enable flash_attn_varlen - output = self.actor_module( - input_ids=input_ids_rmpad, - attention_mask=None, - position_ids=position_ids_rmpad, - **multi_modal_inputs, - use_cache=False, - ) # prevent model thinks we are generating - logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size) - - logits_rmpad.div_(temperature) - - # compute entropy - entropy_rmpad = self.compute_entropy_from_logits( - logits_rmpad - ) # ((total_nnz / sp) + pad) - - # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) - log_probs = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) - - # gather log_prob if sp > 1 - if self.use_ulysses_sp: - # gather and unpad for the ulysses sp - log_probs = gather_outpus_and_unpad( - log_probs, gather_dim=0, unpad_dim=0, padding_size=pad_size - ) - entropy_rmpad = gather_outpus_and_unpad( - entropy_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size - ) - # pad back to (bsz, seqlen) - full_entropy = pad_input( - hidden_states=entropy_rmpad.unsqueeze(-1), - indices=indices, - batch=batch_size, - seqlen=seqlen, - ) - full_log_probs = pad_input( - hidden_states=log_probs.unsqueeze(-1), - indices=indices, - batch=batch_size, - seqlen=seqlen, - ) - - # only return response part: - entropy = full_entropy.squeeze(-1)[ - :, -response_length - 1 : -1 - ] # (bsz, response_length) - log_probs = full_log_probs.squeeze(-1)[ - :, -response_length - 1 : -1 - ] # (bsz, response_length) - - else: # not using rmpad and no ulysses sp - output = self.actor_module( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - **multi_modal_inputs, - use_cache=False, - ) # prevent model thinks we are generating - logits = output.logits - logits.div_(temperature) - logits = logits[ - :, -response_length - 1 : -1, : - ] # (bsz, response_length, vocab_size) - log_probs = logprobs_from_logits(logits, micro_batch["responses"]) - entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length) - - return entropy, log_probs - - def _optimizer_step(self): - assert self.config.grad_clip is not None - - if isinstance(self.actor_module, FSDP): - grad_norm = self.actor_module.clip_grad_norm_(max_norm=self.config.grad_clip) - else: - grad_norm = torch.nn.utils.clip_grad_norm_( - self.actor_module.parameters(), max_norm=self.config.grad_clip - ) - self.actor_optimizer.step() - return grad_norm - - def compute_log_prob(self, data: DataProto) -> torch.Tensor: + @GPUMemoryLogger(role="dp actor", logger=logger) + def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Tensor: """Compute the log probability of the responses given input_ids, attention_mask and position_ids Args: @@ -235,7 +88,7 @@ def compute_log_prob(self, data: DataProto) -> torch.Tensor: micro_batch_size = data.meta_info["micro_batch_size"] temperature = data.meta_info[ "temperature" - ] # temperature must be in the data.meta_info to avoid slient error + ] # temperature must be in the data.meta_info to avoid silent error use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] @@ -258,30 +111,40 @@ def compute_log_prob(self, data: DataProto) -> torch.Tensor: micro_batches = batch.split(micro_batch_size) log_probs_lst = [] + entropy_lst = [] for micro_batch in micro_batches: if isinstance(micro_batch, DataProto): micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch} - with torch.no_grad(): - _, log_probs = self._forward_micro_batch(micro_batch, temperature=temperature) + entropy, log_probs = self._forward_micro_batch( + micro_batch, temperature=temperature, calculate_entropy=calculate_entropy + ) log_probs_lst.append(log_probs) - log_probs = torch.concat(log_probs_lst, dim=0) + if calculate_entropy: + entropy_lst.append(entropy) + log_probs = torch.concat(log_probs_lst, dim=0) + entropys = None + if calculate_entropy: + entropys = torch.concat(entropy_lst, dim=0) if use_dynamic_bsz: indices = list(itertools.chain.from_iterable(indices)) assert len(indices) == log_probs.size(0), f"{len(indices)} vs. {log_probs.size()}" revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) log_probs = log_probs[revert_indices] + if calculate_entropy: + entropys = entropys[revert_indices] # type: ignore - return log_probs + return log_probs, entropys - def update_policy(self, data: DataProto): # noqa: C901 + @GPUMemoryLogger(role="dp actor", logger=logger) + def update_policy(self, data: DataProto): # make sure we are in training mode self.actor_module.train() temperature = data.meta_info[ "temperature" - ] # temperature must be in the data.meta_info to avoid slient error + ] # temperature must be in the data.meta_info to avoid silent error select_keys = [ "input_ids", "position_ids", @@ -356,12 +219,12 @@ def update_policy(self, data: DataProto): # noqa: C901 # Support all hardwares if isinstance(data, DataProto): data = { - **data.batch.to(torch.cuda.current_device()), + **data.batch.to(get_torch_device().current_device()), **data.non_tensor_batch, } else: data = data.to( - torch.cuda.current_device() + get_torch_device().current_device() ) # actor device is cpu when using offload responses = data["responses"] response_length = responses.size(1) @@ -370,8 +233,11 @@ def update_policy(self, data: DataProto): # noqa: C901 assert response_mask.shape == attention_mask[:, -response_length:].shape # all return: (bsz, response_length) + calculate_entropy = self.entropy_loss_fn != DummyEntropyLossFn entropy, log_prob = self._forward_micro_batch( - micro_batch=data, temperature=temperature + micro_batch=data, + temperature=temperature, + calculate_entropy=calculate_entropy, ) kwargs = { diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index c0af427b4a..66d055feeb 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -12,74 +12,70 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -The main entry point to run the PPO algorithm +The main entry point to run the PPO algorithm. +Modified from https://github.com/volcengine/verl/blob/0758489422e8d41a89e6c36d4c477714520f0dcc/verl/workers/fsdp_workers.py """ +import json import logging import os import warnings +from dataclasses import asdict import psutil import torch import torch.distributed -import verl.utils.torch_functional as verl_F +import torch.distributed as dist +import vllm # noqa: F401 ; import vllm to avoid "Cuda failure 1 'invalid argument'" from codetiming import Timer from omegaconf import DictConfig, open_dict +from peft import LoraConfig, TaskType, get_peft_model +from safetensors.torch import save_file from torch.distributed.device_mesh import init_device_mesh from torch.distributed.fsdp import FlatParameter from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import FSDP_PREFIX from verl import DataProto +from verl.models.transformers.monkey_patch import apply_monkey_patch from verl.single_controller.base import Worker from verl.single_controller.base.decorator import Dispatch, register from verl.utils import hf_processor, hf_tokenizer +from verl.utils.activation_offload import enable_activation_offloading from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager from verl.utils.debug import log_gpu_memory_usage +from verl.utils.device import get_torch_device, is_cuda_available from verl.utils.flops_counter import FlopsCounter from verl.utils.fs import copy_to_local from verl.utils.fsdp_utils import ( + CPUOffloadPolicy, + MixedPrecisionPolicy, + apply_fsdp2, + fsdp2_load_full_state_dict, + fsdp_version, get_fsdp_wrap_policy, get_init_weight_context_manager, init_fn, + layered_summon_lora_params, load_fsdp_model_to_gpu, load_fsdp_optimizer, offload_fsdp_model_to_cpu, offload_fsdp_optimizer, ) from verl.utils.import_utils import import_external_libs -from verl.utils.model import compute_position_id_with_mask +from verl.utils.py_functional import convert_to_regular_types +from verl.workers.fsdp_workers import ( + create_device_mesh, + device_name, + get_sharding_strategy, +) from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager from trinity.common.config import AlgorithmConfig -from trinity.common.constants import SyncMethod +from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME, SyncMethod from trinity.utils.distributed import init_process_group, is_ipv6_address logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN")) - - -def create_device_mesh(world_size, fsdp_size): - if fsdp_size < 0 or fsdp_size >= world_size: - device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) - else: - device_mesh = init_device_mesh( - "cuda", mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"] - ) - return device_mesh - - -def get_sharding_strategy(device_mesh): - from torch.distributed.fsdp import ShardingStrategy - - if device_mesh.ndim == 1: - sharding_strategy = ShardingStrategy.FULL_SHARD - elif device_mesh.ndim == 2: - sharding_strategy = ShardingStrategy.HYBRID_SHARD - else: - raise NotImplementedError( - f"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2" - ) - return sharding_strategy +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) class ActorRolloutRefWorker(Worker): @@ -94,7 +90,13 @@ def __init__(self, config: DictConfig, role: str): import torch.distributed if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl") + rank = int(os.environ.get("RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + torch.distributed.init_process_group( + backend="cpu:gloo,cuda:nccl" if is_cuda_available else "cpu:gloo,npu:hccl", + rank=rank, + world_size=world_size, + ) # build device mesh for FSDP world_size = torch.distributed.get_world_size() @@ -111,12 +113,14 @@ def __init__(self, config: DictConfig, role: str): dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: self.ulysses_device_mesh = init_device_mesh( - "cuda", + device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"], ) self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) + self._lora_rank = self.config.model.get("lora_rank", 0) + self._is_lora = self._lora_rank > 0 self.role = role assert self.role in ["actor", "rollout", "ref", "actor_rollout", "actor_rollout_ref"] @@ -153,6 +157,8 @@ def __init__(self, config: DictConfig, role: str): self.config.actor.ppo_micro_batch_size_per_gpu = ( self.config.actor.ppo_micro_batch_size ) + + if self.config.actor.ppo_micro_batch_size_per_gpu is not None: assert ( self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu @@ -181,22 +187,22 @@ def __init__(self, config: DictConfig, role: str): self.config.ref.log_prob_micro_batch_size ) - def _build_model_optimizer( + def _build_model_optimizer( # noqa: C901 self, model_path, fsdp_config, optim_config, override_model_config, use_remove_padding=False, + use_fused_kernels=False, enable_gradient_checkpointing=False, trust_remote_code=False, use_liger=False, role="actor", + enable_activation_offload=False, ): from torch import optim - from torch.distributed.fsdp import CPUOffload - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - from torch.distributed.fsdp import MixedPrecision + from torch.distributed.fsdp import CPUOffload, MixedPrecision from transformers import ( AutoConfig, AutoModelForCausalLM, @@ -211,8 +217,8 @@ def _build_model_optimizer( assert role in ["actor", "ref"] - log_gpu_memory_usage("Before init from HF AutoModel", logger=logger) - local_path = copy_to_local(model_path) + log_gpu_memory_usage(f"Before init {role} from HF AutoModel", logger=logger) + local_path = model_path # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect # TODO(zhangchi.usc1992): 1. support create from random initialized model. 2. Support init with FSDP directly @@ -227,9 +233,13 @@ def _build_model_optimizer( # override model kwargs actor_model_config = AutoConfig.from_pretrained( - local_path, trust_remote_code=trust_remote_code + local_path, trust_remote_code=trust_remote_code, attn_implementation="flash_attention_2" ) + # patch for kimi-vl + if getattr(actor_model_config, "model_type", None) == "kimi_vl": + actor_model_config.text_config.topk_method = "greedy" + self.generation_config = get_generation_config( local_path, trust_remote_code=trust_remote_code ) @@ -260,17 +270,9 @@ def _build_model_optimizer( pretrained_model_name_or_path=local_path, torch_dtype=torch_dtype, config=actor_model_config, - attn_implementation="flash_attention_2", trust_remote_code=trust_remote_code, ) - if use_remove_padding or self.ulysses_sequence_parallel_size > 1: - from verl.models.transformers.monkey_patch import apply_monkey_patch - - apply_monkey_patch( - model=actor_module, ulysses_sp_size=self.ulysses_sequence_parallel_size - ) - # Apply Liger kernel to the model if use_liger is set to True if use_liger: from liger_kernel.transformers.monkey_patch import ( @@ -279,6 +281,13 @@ def _build_model_optimizer( _apply_liger_kernel_to_instance(model=actor_module) + apply_monkey_patch( + model=actor_module, + use_remove_padding=use_remove_padding, + ulysses_sp_size=self.ulysses_sequence_parallel_size, + use_fused_kernels=use_fused_kernels, + ) + # some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2 actor_module.to(torch_dtype) @@ -286,12 +295,24 @@ def _build_model_optimizer( actor_module.gradient_checkpointing_enable( gradient_checkpointing_kwargs={"use_reentrant": False} ) + if self._is_lora: + print("Applying LoRA to actor module") + actor_module.enable_input_require_grads() + # Convert config to regular Python types before creating PEFT model + lora_config = { + "task_type": TaskType.CAUSAL_LM, + "r": self.config.model.lora_rank, + "lora_alpha": self.config.model.lora_alpha, + "target_modules": convert_to_regular_types(self.config.model.target_modules), + "bias": "none", + } + actor_module = get_peft_model(actor_module, LoraConfig(**lora_config)) torch.distributed.barrier() if self.rank == 0: print_model_size(actor_module) - log_gpu_memory_usage("After init from HF AutoModel", logger=logger) + log_gpu_memory_usage(f"After init {role} from HF AutoModel", logger=logger) # We wrap FSDP for rollout as well mixed_precision_config = fsdp_config.get("mixed_precision", None) @@ -313,14 +334,17 @@ def _build_model_optimizer( ) auto_wrap_policy = get_fsdp_wrap_policy( - module=actor_module, config=fsdp_config.get("wrap_policy", None) + module=actor_module, + config=fsdp_config.get("wrap_policy", None), + is_lora=self.config.model.get("lora_rank", 0) > 0, ) if self._is_rollout and self.config.rollout.name == "hf": # TODO(zhangchi.usc1992, shengguangming) fix me. Current, auto_wrap_policy causes HFRollout to hang in Gemma auto_wrap_policy = None - print(f"wrap_policy: {auto_wrap_policy}") + if self.rank == 0: + print(f"wrap_policy: {auto_wrap_policy}") fsdp_mesh = self.device_mesh sharding_strategy = get_sharding_strategy(fsdp_mesh) @@ -329,74 +353,104 @@ def _build_model_optimizer( # We force reference policy to use CPUOffload to save memory. # We force turn off CPUOffload for actor because it causes incorrect results when using grad accumulation cpu_offload = None if role == "actor" else CPUOffload(offload_params=True) - actor_module_fsdp = FSDP( - actor_module, - cpu_offload=cpu_offload, - param_init_fn=init_fn, - use_orig_params=False, - auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device(), - sharding_strategy=sharding_strategy, # zero3 - mixed_precision=mixed_precision, - sync_module_states=True, - device_mesh=self.device_mesh, - forward_prefetch=False, - ) + fsdp_strategy = self.config.actor.strategy + if fsdp_strategy == "fsdp": + actor_module_fsdp = FSDP( + actor_module, + cpu_offload=cpu_offload, + param_init_fn=init_fn, + use_orig_params=False, + auto_wrap_policy=auto_wrap_policy, + device_id=get_torch_device().current_device(), + sharding_strategy=sharding_strategy, # zero3 + mixed_precision=mixed_precision, + sync_module_states=True, + device_mesh=self.device_mesh, + forward_prefetch=False, + ) + elif fsdp_strategy == "fsdp2": + assert ( + CPUOffloadPolicy is not None + ), "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" + mp_policy = MixedPrecisionPolicy( + param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True + ) + if role == "actor" and fsdp_config.offload_policy: + cpu_offload = CPUOffloadPolicy(pin_memory=True) + self._is_offload_param = False + self._is_offload_optimizer = False + else: + cpu_offload = None if role == "actor" else CPUOffloadPolicy(pin_memory=True) + + fsdp_kwargs = { + "mesh": fsdp_mesh, + "mp_policy": mp_policy, + "offload_policy": cpu_offload, + "reshard_after_forward": fsdp_config.reshard_after_forward, + } + full_state = actor_module.state_dict() + apply_fsdp2(actor_module, fsdp_kwargs, fsdp_config) + fsdp2_load_full_state_dict(actor_module, full_state, fsdp_mesh, cpu_offload) + actor_module_fsdp = actor_module + else: + raise NotImplementedError(f"not implement {fsdp_strategy}") - log_gpu_memory_usage("After Actor FSDP init", logger=logger) + if enable_activation_offload: + enable_activation_offloading( + actor_module_fsdp, fsdp_strategy, enable_gradient_checkpointing + ) + + log_gpu_memory_usage(f"After {role} FSDP init", logger=logger) # TODO: add more optimizer args into config if role == "actor" and optim_config is not None: - beta1 = optim_config.get("beta1", 0.9) - beta2 = optim_config.get("beta2", 0.999) + from verl.utils.torch_functional import ( + get_constant_schedule_with_warmup, + get_cosine_schedule_with_warmup, + ) + actor_optimizer = optim.AdamW( actor_module_fsdp.parameters(), lr=optim_config.lr, - betas=(beta1, beta2), + betas=optim_config.get("betas", (0.9, 0.999)), weight_decay=optim_config.get("weight_decay", 1e-2), ) total_steps = optim_config.get("total_training_steps", 0) num_warmup_steps = int(optim_config.get("lr_warmup_steps", -1)) + warmup_style = optim_config.get("warmup_style", "constant") + min_lr_ratio = optim_config.get("min_lr_ratio", 0.0) + num_cycles = optim_config.get("num_cycles", 0.5) if num_warmup_steps < 0: num_warmup_steps_ratio = optim_config.get("lr_warmup_steps_ratio", 0.0) num_warmup_steps = int(num_warmup_steps_ratio * total_steps) - print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") - - if optim_config.warmup_style == "constant": - from verl.utils.torch_functional import ( - get_constant_schedule_with_warmup, - ) + if self.rank == 0: + print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") + if warmup_style == "constant": actor_lr_scheduler = get_constant_schedule_with_warmup( optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps ) - elif optim_config.warmup_style == "cosine": - from verl.utils.torch_functional import get_cosine_schedule_with_warmup - - assert ( - total_steps > 0 - ), "Cosine scheduler of actor requires total_training_steps > 0" + elif warmup_style == "cosine": actor_lr_scheduler = get_cosine_schedule_with_warmup( optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps, - min_lr_ratio=optim_config.min_lr_ratio, + min_lr_ratio=min_lr_ratio, + num_cycles=num_cycles, ) else: - raise NotImplementedError( - f"Lr scheduler style {optim_config.warmup_style} is not supported" - ) + raise NotImplementedError(f"Warmup style {warmup_style} is not supported") + + log_gpu_memory_usage(f"After {role} optimizer init", logger=logger) else: actor_optimizer = None actor_lr_scheduler = None - log_gpu_memory_usage("After actor optimizer init", logger=logger) - return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config - def _build_rollout(self): + def _build_rollout(self, trust_remote_code=False): from torch.distributed.device_mesh import init_device_mesh # TODO(sgm): support FSDP hybrid shard for larger model @@ -406,62 +460,129 @@ def _build_rollout(self): self.world_size % infer_tp == 0 ), f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}" rollout_device_mesh = init_device_mesh( - "cuda", mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"] + device_name, mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"] ) - - if self.config.rollout.name == "hf": + rollout_name = self.config.rollout.name + if rollout_name == "hf": from verl.workers.rollout import HFRollout - from verl.workers.sharding_manager import BaseShardingManager + from verl.workers.sharding_manager.base import BaseShardingManager rollout = HFRollout(module=self.actor_module_fsdp, config=self.config.rollout) rollout_sharding_manager = BaseShardingManager() # TODO: a sharding manager that do nothing? - elif self.config.rollout.name == "vllm": - if self.config.rollout.use_fire_sampling: - from verl.workers.rollout.vllm_rollout import ( - FIREvLLMRollout as vLLMRollout, - ) - from verl.workers.rollout.vllm_rollout import vllm_mode - else: - from verl.workers.rollout.vllm_rollout import vLLMRollout, vllm_mode - from verl.workers.sharding_manager import FSDPVLLMShardingManager - log_gpu_memory_usage("Before building vllm rollout", logger=None) - local_path = copy_to_local(self.config.model.path) + elif rollout_name == "vllm": + from verl.workers.rollout.vllm_rollout import vllm_mode, vLLMRollout + from verl.workers.sharding_manager.fsdp_vllm import FSDPVLLMShardingManager + + log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger) + local_path = copy_to_local( + self.config.model.path, use_shm=self.config.model.get("use_shm", False) + ) + lora_kwargs = ( + { + "lora_kwargs": { + "enable_lora": True, + "max_loras": 1, + "max_lora_rank": self._lora_rank, + } + } + if self._is_lora + else {} + ) + # lora_kwargs = {} if vllm_mode == "customized": rollout = vLLMRollout( actor_module=self.actor_module_fsdp, config=self.config.rollout, tokenizer=self.tokenizer, model_hf_config=self.actor_model_config, + trust_remote_code=trust_remote_code, + **lora_kwargs, ) elif vllm_mode == "spmd": - rollout = vLLMRollout( + from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout + + vllm_rollout_cls = ( + vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout + ) + rollout = vllm_rollout_cls( model_path=local_path, config=self.config.rollout, tokenizer=self.tokenizer, model_hf_config=self.actor_model_config, device_mesh=rollout_device_mesh, + trust_remote_code=trust_remote_code, + **lora_kwargs, ) else: raise NotImplementedError("vllm_mode must be 'customized' or 'spmd'") - log_gpu_memory_usage("After building vllm rollout", logger=None) - if torch.distributed.get_world_size() == 1: - self.config.rollout.load_format = "dummy_hf" + + log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger) + full_params = torch.distributed.get_world_size() == 1 rollout_sharding_manager = FSDPVLLMShardingManager( module=self.actor_module_fsdp, inference_engine=rollout.inference_engine, model_config=self.actor_model_config, + full_params=full_params, + device_mesh=rollout_device_mesh, + offload_param=self._is_offload_param, + load_format=self.config.rollout.load_format, + layered_summon=self.config.rollout.get("layered_summon", False), + ) + log_gpu_memory_usage("After building sharding manager", logger=logger) + + elif rollout_name in ["sglang", "sglang_async"]: + if rollout_name == "sglang_async": + warnings.warn( + "'sglang_async' has been deprecated and merged into 'sglang'. Please use 'sglang' going forward.", + DeprecationWarning, + stacklevel=2, + ) + from verl.workers.rollout.sglang_rollout import SGLangRollout + + # NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to + # SGLang's model_runner would check CUDA device capability. However, due to verl's setting, + # the main process of ray can not find any CUDA device, which would potentially lead to: + # "RuntimeError: No CUDA GPUs are available". + # For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and + # we import it here use the abs path. + # check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76 + from verl.workers.sharding_manager.fsdp_sglang import ( + FSDPSGLangShardingManager, + ) + + local_path = copy_to_local(self.config.model.path) + log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger) + rollout = SGLangRollout( + actor_module=local_path, + config=self.config.rollout, + tokenizer=self.tokenizer, + model_hf_config=self.actor_model_config, + trust_remote_code=trust_remote_code, + ) + log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger) + + if torch.distributed.get_world_size() == 1: + self.config.rollout.load_format = "dummy_hf" + rollout_sharding_manager = FSDPSGLangShardingManager( + module=self.actor_module_fsdp, + inference_engine=rollout._engine, + model_config=self.actor_model_config, full_params="hf" in self.config.rollout.load_format, device_mesh=rollout_device_mesh, + offload_param=self._is_offload_param, ) - log_gpu_memory_usage("After building sharding manager", logger=None) + log_gpu_memory_usage("After building sharding manager", logger=logger) + + else: + raise NotImplementedError(f"Rollout name: {self.config.rollout.name} is not supported") return rollout, rollout_sharding_manager @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): - from .dp_actor import DataParallelPPOActor + from trinity.trainer.verl.dp_actor import DataParallelPPOActor # This is used to import external_lib into the huggingface systems import_external_libs(self.config.model.get("external_lib", None)) @@ -473,6 +594,8 @@ def init_model(self): ) use_remove_padding = self.config.model.get("use_remove_padding", False) + use_shm = self.config.model.get("use_shm", False) + use_fused_kernels = self.config.model.get("use_fused_kernels", False) if self._is_actor or self._is_rollout: # we need the model for actor and rollout @@ -482,27 +605,36 @@ def init_model(self): else: optim_config = None fsdp_config = OmegaConf.create() + + local_path = copy_to_local(self.config.model.path, use_shm=use_shm) ( self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config, ) = self._build_model_optimizer( - model_path=self.config.model.path, + model_path=local_path, fsdp_config=fsdp_config, optim_config=optim_config, override_model_config=override_model_config, use_remove_padding=use_remove_padding, + use_fused_kernels=use_fused_kernels, enable_gradient_checkpointing=self.config.model.get( "enable_gradient_checkpointing", False ), trust_remote_code=self.config.model.get("trust_remote_code", False), use_liger=self.config.model.get("use_liger", False), role="actor", + enable_activation_offload=self.config.model.get("enable_activation_offload", False), ) # get the original unwrapped module - self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module + if fsdp_version(self.actor_module_fsdp) == 1: + self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module + + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.actor_module_fsdp) + log_gpu_memory_usage("After offload actor model during init", logger=logger) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.actor_optimizer) @@ -512,6 +644,7 @@ def init_model(self): OmegaConf.set_struct(self.config.actor, True) with open_dict(self.config.actor): self.config.actor.use_remove_padding = use_remove_padding + self.config.actor.use_fused_kernels = use_fused_kernels self.actor = DataParallelPPOActor( config=self.config.actor, actor_module=self.actor_module_fsdp, @@ -519,15 +652,19 @@ def init_model(self): ) if self._is_rollout: - self.rollout, self.rollout_sharding_manager = self._build_rollout() + self.rollout, self.rollout_sharding_manager = self._build_rollout( + trust_remote_code=self.config.model.get("trust_remote_code", False) + ) if self._is_ref: + local_path = copy_to_local(self.config.model.path, use_shm=use_shm) self.ref_module_fsdp = self._build_model_optimizer( - model_path=self.config.model.path, + model_path=local_path, fsdp_config=self.config.ref.fsdp_config, optim_config=None, override_model_config=override_model_config, use_remove_padding=use_remove_padding, + use_fused_kernels=use_fused_kernels, trust_remote_code=self.config.model.get("trust_remote_code", False), use_liger=self.config.model.get("use_liger", False), role="ref", @@ -535,6 +672,7 @@ def init_model(self): OmegaConf.set_struct(self.config.ref, True) with open_dict(self.config.ref): self.config.ref.use_remove_padding = use_remove_padding + self.config.ref.use_fused_kernels = use_fused_kernels self.ref_policy = DataParallelPPOActor( config=self.config.ref, actor_module=self.ref_module_fsdp ) @@ -555,8 +693,6 @@ def init_model(self): checkpoint_contents=self.config.actor.checkpoint.contents, ) - torch.cuda.empty_cache() - @register(dispatch_mode=Dispatch.ONE_TO_ALL) def setup_weight_sync_group(self): if ( @@ -588,7 +724,6 @@ def setup_weight_sync_group(self): world_size = self.config.synchronizer.explorer_world_size + 1 print(f"Trainer init_process_group {master_address}:{master_port} ({world_size}).") explorer = ray.get_actor("explorer") - group_name = "rollout_weight_sync" setup_ref = explorer.setup_weight_sync_group.remote( master_address, master_port, self.state_dict_meta ) @@ -605,7 +740,7 @@ def setup_weight_sync_group(self): timeout=timeout, world_size=world_size, rank=0, - group_name=group_name, + group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME, ) ray.get(setup_ref) @@ -630,18 +765,16 @@ def set_algorithm(self, algo_config: AlgorithmConfig): @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_actor(self, data: DataProto): # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_torch_device().current_device()) assert self._is_actor if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) if self._is_offload_optimizer: load_fsdp_optimizer( - optimizer=self.actor_optimizer, device_id=torch.cuda.current_device() + optimizer=self.actor_optimizer, device_id=get_torch_device().current_device() ) - log_gpu_memory_usage("Before update policy", logger=logger) - with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data=data) # perform training @@ -655,17 +788,17 @@ def update_actor(self, data: DataProto): metrics["perf/mfu/actor"] = ( estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size ) - metrics["perf/max_memory_allocated_gb"] = torch.cuda.max_memory_allocated() / ( + metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / ( + 1024**3 + ) + metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / ( 1024**3 ) - metrics["perf/max_memory_reserved_gb"] = torch.cuda.max_memory_reserved() / (1024**3) metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3) - self.actor_lr_scheduler.step() lr = self.actor_lr_scheduler.get_last_lr()[0] metrics["actor/lr"] = lr - - log_gpu_memory_usage("After update policy", logger=logger) + self.actor_lr_scheduler.step() # TODO: here, we should return all metrics output = DataProto(meta_info={"metrics": metrics}) @@ -675,19 +808,19 @@ def update_actor(self, data: DataProto): if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) + log_gpu_memory_usage("After offload actor model during update_actor", logger=logger) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.actor_optimizer) + log_gpu_memory_usage("After offload actor optimizer during update_actor", logger=logger) return output @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def generate_sequences(self, prompts: DataProto): # Support all hardwares - prompts = prompts.to(torch.cuda.current_device()) + prompts = prompts.to(get_torch_device().current_device()) assert self._is_rollout - if self._is_offload_param: - load_fsdp_model_to_gpu(self.actor_module_fsdp) meta_info = { "eos_token_id": self.generation_config.eos_token_id @@ -699,12 +832,6 @@ def generate_sequences(self, prompts: DataProto): } prompts.meta_info.update(meta_info) with self.rollout_sharding_manager: - # after parameters sync with rollout, offload actor model to CPU - if self._is_offload_param: - offload_fsdp_model_to_cpu(self.actor_module_fsdp) - if self._is_offload_optimizer: - offload_fsdp_optimizer(optimizer=self.actor_optimizer) - log_gpu_memory_usage("After entering rollout sharding manager", logger=logger) prompts = self.rollout_sharding_manager.preprocess_data(prompts) @@ -717,18 +844,23 @@ def generate_sequences(self, prompts: DataProto): output = output.to("cpu") # clear kv cache - torch.cuda.empty_cache() - log_gpu_memory_usage("After recompute log prob", logger=logger) + get_torch_device().empty_cache() return output @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_log_prob(self, data: DataProto): + # when is_lora is True, we use the actor without lora applied to calculate the log_prob + # which is mostly used for ref log_prob calculation assert self._is_actor if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) # Support all hardwares - data = data.to(torch.cuda.current_device()) + from contextlib import nullcontext + + is_lora = data.meta_info.pop("is_lora", False) + adapter_ctx = self.actor.actor_module.disable_adapter() if is_lora else nullcontext() + data = data.to(get_torch_device().current_device()) # we should always recompute old_log_probs when it is HybridEngine data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu @@ -737,9 +869,10 @@ def compute_log_prob(self, data: DataProto): # perform recompute log_prob with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data) - output = self.actor.compute_log_prob(data=data) + with adapter_ctx: + output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True) output = DataProto.from_dict( - tensors={"old_log_probs": output}, + tensors={"old_log_probs": output, "entropys": entropys}, meta_info={"temperature": self.config.rollout.temperature}, ) output = self.ulysses_sharding_manager.postprocess_data(output) @@ -748,21 +881,29 @@ def compute_log_prob(self, data: DataProto): # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # unshard the root FSDP module - if self.world_size > 1: + if self.world_size > 1 and fsdp_version(self.actor.actor_module) == 1: self.actor.actor_module._handle.reshard(True) if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) + log_gpu_memory_usage("After offload actor model during compute_log_prob", logger=logger) - log_gpu_memory_usage("After compute_log_prob", logger=logger) return output @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_ref_log_prob(self, data: DataProto): + if self._is_lora: + # if _is_lora, actor without lora applied is the ref + data.meta_info["is_lora"] = True + data = self.compute_log_prob(data) + # this old_log_probs is in fact ref_log_prob + data = DataProto.from_dict(tensors={"ref_log_prob": data.batch["old_log_probs"]}) + return data assert self._is_ref - + # else: + # otherwise, the class have a standalone ref model # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_torch_device().current_device()) micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu data.meta_info["micro_batch_size"] = micro_batch_size @@ -771,7 +912,7 @@ def compute_ref_log_prob(self, data: DataProto): data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data) - output = self.ref_policy.compute_log_prob(data=data) + output, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False) output = DataProto.from_dict(tensors={"ref_log_prob": output}) output = self.ulysses_sharding_manager.postprocess_data(output) @@ -779,17 +920,15 @@ def compute_ref_log_prob(self, data: DataProto): # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # unshard the root FSDP module - if self.world_size > 1: + if self.world_size > 1 and fsdp_version(self.ref_policy.actor_module) == 1: self.ref_policy.actor_module._handle.reshard(True) - torch.cuda.empty_cache() return output @register(dispatch_mode=Dispatch.ONE_TO_ALL) def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): # only support save and load ckpt for actor assert self._is_actor - import torch if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) @@ -800,8 +939,42 @@ def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep, ) + dist.barrier() + + if self._is_lora and hasattr( + getattr(self, "actor_module", self.actor_module_fsdp), "peft_config" + ): + lora_save_path = os.path.join(local_path, "lora_adapter") + peft_model = getattr(self, "actor_module", self.actor_module_fsdp) + peft_config = {} + if dist.get_rank() == 0: + os.makedirs(lora_save_path, exist_ok=True) + peft_config = asdict(peft_model.peft_config.get("default", {})) + peft_config["task_type"] = peft_config["task_type"].value + peft_config["peft_type"] = peft_config["peft_type"].value + peft_config["target_modules"] = list(peft_config["target_modules"]) + try: + if fsdp_version(self.actor_module_fsdp) > 0: + self.actor_module_fsdp = self.actor_module_fsdp.cuda() + lora_params = layered_summon_lora_params(self.actor_module_fsdp) + if dist.get_rank() == 0: + save_file( + lora_params, os.path.join(lora_save_path, "adapter_model.safetensors") + ) + with open( + os.path.join(lora_save_path, "adapter_config.json"), + "w", + encoding="utf-8", + ) as f: + json.dump(peft_config, f, ensure_ascii=False, indent=4) + except Exception as e: + if dist.get_rank() == 0: + print(f"[rank-{self.rank}]: Save LoRA Adapter Error ({e})") + + dist.barrier() + if dist.get_rank() == 0: + print(f"[rank-{self.rank}]: Saved LoRA adapter to: {lora_save_path}") - torch.distributed.barrier() if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) @@ -839,7 +1012,7 @@ def __init__(self, config): import torch.distributed if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl") + torch.distributed.init_process_group(backend="nccl" if is_cuda_available else "hccl") self.config = config # build device mesh for Ulysses Sequence Parallel @@ -854,7 +1027,7 @@ def __init__(self, config): dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: self.ulysses_device_mesh = init_device_mesh( - "cuda", + device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"], ) @@ -879,26 +1052,29 @@ def __init__(self, config): ) self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size + + if self.config.ppo_micro_batch_size_per_gpu is not None: assert ( self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0 ), f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" assert ( self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu > 0 ), f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" + self._is_lora = self.config.model.get("lora_rank", 0) > 0 def _build_critic_model_optimizer(self, config): # the following line is necessary from torch import optim - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import MixedPrecision from verl.utils.model import print_model_size from verl.utils.torch_dtypes import PrecisionType - local_path = copy_to_local(config.model.path) + use_shm = config.model.get("use_shm", False) + local_path = copy_to_local(config.model.path, use_shm=use_shm) # note that the tokenizer between actor and critic may be different. So override tokenizer info with actor info # using random initialized model from any architecture. May not be the same as Actor. - tokenizer_path = copy_to_local(config.model.tokenizer_path) + tokenizer_path = copy_to_local(config.model.tokenizer_path, use_shm=use_shm) self.tokenizer = hf_tokenizer( tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False) ) @@ -925,11 +1101,15 @@ def _build_critic_model_optimizer(self, config): from transformers import AutoConfig, AutoModelForTokenClassification - trust_remote_code = False critic_model_config = AutoConfig.from_pretrained( - local_path, trust_remote_code=trust_remote_code + local_path, + attn_implementation="flash_attention_2", + trust_remote_code=config.model.get("trust_remote_code", False), ) critic_model_config.num_labels = 1 + # patch for kimi-vl + if getattr(critic_model_config, "model_type", None) == "kimi_vl": + critic_model_config.text_config.topk_method = "greedy" init_context = get_init_weight_context_manager( use_meta_tensor=not critic_model_config.tie_word_embeddings, mesh=self.device_mesh @@ -937,23 +1117,22 @@ def _build_critic_model_optimizer(self, config): with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") - setattr(critic_model_config, "classifier_dropout", 0.0) - setattr(critic_model_config, "hidden_dropout", "0") + critic_model_config.classifier_dropout = 0.0 + critic_model_config.hidden_dropout = "0" critic_module = AutoModelForTokenClassification.from_pretrained( pretrained_model_name_or_path=local_path, torch_dtype=torch_dtype, config=critic_model_config, - attn_implementation="flash_attention_2", - trust_remote_code=trust_remote_code, + trust_remote_code=config.model.get("trust_remote_code", False), ) use_remove_padding = config.model.get("use_remove_padding", False) - if use_remove_padding or self.ulysses_sequence_parallel_size > 1: - from verl.models.transformers.monkey_patch import apply_monkey_patch - apply_monkey_patch( - model=critic_module, ulysses_sp_size=self.ulysses_sequence_parallel_size - ) + apply_monkey_patch( + model=critic_module, + use_remove_padding=use_remove_padding, + ulysses_sp_size=self.ulysses_sequence_parallel_size, + ) # some parameters may not in torch_dtype critic_module.to(torch_dtype) @@ -962,6 +1141,20 @@ def _build_critic_model_optimizer(self, config): critic_module.gradient_checkpointing_enable( gradient_checkpointing_kwargs={"use_reentrant": False} ) + + if self._is_lora: + print("Applying LoRA to critic module") + critic_module.enable_input_require_grads() + # Convert config to regular Python types before creating PEFT model + lora_config = { + "task_type": TaskType.CAUSAL_LM, + "r": self.config.model.lora_rank, + "lora_alpha": self.config.model.lora_alpha, + "target_modules": convert_to_regular_types(self.config.model.target_modules), + "bias": "none", + } + critic_module = get_peft_model(critic_module, LoraConfig(**lora_config)) + if self.rank == 0: print_model_size(critic_module) @@ -987,7 +1180,9 @@ def _build_critic_model_optimizer(self, config): ) auto_wrap_policy = get_fsdp_wrap_policy( - module=critic_module, config=self.config.model.fsdp_config.wrap_policy + module=critic_module, + config=self.config.model.fsdp_config.wrap_policy, + is_lora=self.config.model.get("lora_rank", 0) > 0, ) log_gpu_memory_usage("Before critic FSDP", logger=None) @@ -996,59 +1191,87 @@ def _build_critic_model_optimizer(self, config): sharding_strategy = get_sharding_strategy(fsdp_mesh) # Note: We force turn off CPUOffload for critic because it causes incorrect results when using grad accumulation - critic_module = FSDP( - critic_module, - param_init_fn=init_fn, - use_orig_params=False, - auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device(), - sharding_strategy=sharding_strategy, - mixed_precision=mixed_precision, - sync_module_states=True, - forward_prefetch=False, - device_mesh=self.device_mesh, - cpu_offload=None, - ) + if config.strategy == "fsdp": + critic_module = FSDP( + critic_module, + param_init_fn=init_fn, + use_orig_params=False, + auto_wrap_policy=auto_wrap_policy, + device_id=get_torch_device().current_device(), + sharding_strategy=sharding_strategy, + mixed_precision=mixed_precision, + sync_module_states=True, + forward_prefetch=False, + device_mesh=self.device_mesh, + cpu_offload=None, + ) + elif config.strategy == "fsdp2": + assert ( + CPUOffloadPolicy is not None + ), "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" + mp_policy = MixedPrecisionPolicy( + param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True + ) + offload_policy = None + if fsdp_config.offload_policy: + self._is_offload_param = False + self._is_offload_optimizer = False + offload_policy = CPUOffloadPolicy(pin_memory=True) + + fsdp_kwargs = { + "mesh": fsdp_mesh, + "mp_policy": mp_policy, + "offload_policy": offload_policy, + "reshard_after_forward": fsdp_config.reshard_after_forward, + } + full_state = critic_module.state_dict() + apply_fsdp2(critic_module, fsdp_kwargs, fsdp_config) + fsdp2_load_full_state_dict(critic_module, full_state, fsdp_mesh, offload_policy) + else: + raise NotImplementedError(f"Unknown strategy {config.strategy}") + + if config.model.get("enable_activation_offload", False): + enable_gradient_checkpointing = config.model.get("enable_gradient_checkpointing", False) + enable_activation_offloading( + critic_module, config.strategy, enable_gradient_checkpointing + ) log_gpu_memory_usage("After critic FSDP", logger=None) - beta1 = config.optim.get("beta1", 0.9) - beta2 = config.optim.get("beta2", 0.999) critic_optimizer = optim.AdamW( critic_module.parameters(), lr=config.optim.lr, - betas=(beta1, beta2), + betas=config.optim.get("betas", (0.9, 0.999)), weight_decay=config.optim.get("weight_decay", 1e-2), ) total_steps = config.optim.get("total_training_steps", 0) num_warmup_steps = int(config.optim.get("lr_warmup_steps", -1)) + warmup_style = config.optim.get("warmup_style", "constant") if num_warmup_steps < 0: num_warmup_steps_ratio = config.optim.get("lr_warmup_steps_ratio", 0.0) num_warmup_steps = int(num_warmup_steps_ratio * total_steps) - print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") + if self.rank == 0: + print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") - if config.optim.warmup_style == "constant": - from verl.utils.torch_functional import get_constant_schedule_with_warmup + from verl.utils.torch_functional import ( + get_constant_schedule_with_warmup, + get_cosine_schedule_with_warmup, + ) + if warmup_style == "constant": critic_lr_scheduler = get_constant_schedule_with_warmup( optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps ) - elif config.optim.warmup_style == "cosine": - from verl.utils.torch_functional import get_cosine_schedule_with_warmup - - assert total_steps > 0, "Cosine scheduler of critic requires total_training_steps > 0" + elif warmup_style == "cosine": critic_lr_scheduler = get_cosine_schedule_with_warmup( optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps, - min_lr_ratio=config.optim.min_lr_ratio, ) else: - raise NotImplementedError( - f"Lr scheduler style {config.optim.warmup_style} is not supported" - ) + raise NotImplementedError(f"Warmup style {warmup_style} is not supported") return critic_module, critic_optimizer, critic_lr_scheduler @@ -1067,8 +1290,10 @@ def init_model(self): if self._is_offload_param: offload_fsdp_model_to_cpu(self.critic_module) + log_gpu_memory_usage("After offload critic model during init", logger=logger) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.critic_optimizer) + log_gpu_memory_usage("After offload critic optimizer during init", logger=logger) self.critic = DataParallelPPOCritic( config=self.config, @@ -1088,7 +1313,7 @@ def init_model(self): @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_values(self, data: DataProto): # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_torch_device().current_device()) if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) @@ -1111,12 +1336,12 @@ def compute_values(self, data: DataProto): @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_critic(self, data: DataProto): # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_torch_device().current_device()) if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) if self._is_offload_optimizer: load_fsdp_optimizer( - optimizer=self.critic_optimizer, device_id=torch.cuda.current_device() + optimizer=self.critic_optimizer, device_id=get_torch_device().current_device() ) # perform forward computation @@ -1197,327 +1422,3 @@ def clear_optimizer_state(self): self.critic_optimizer.zero_grad() if self._is_offload_optimizer: offload_fsdp_optimizer(self.critic_optimizer) - - -# TODO(sgm): we may need to extract it to dp_reward_model.py -class RewardModelWorker(Worker): - """ - Note that we only implement the reward model that is subclass of AutoModelForTokenClassification. - """ - - def __init__(self, config): - super().__init__() - import torch.distributed - - if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl") - self.config = config - - # build device mesh for Ulysses Sequence Parallel - world_size = torch.distributed.get_world_size() - from torch.distributed.device_mesh import init_device_mesh - - fsdp_size = self.config.model.fsdp_config.fsdp_size - self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) - - self.ulysses_device_mesh = None - self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) - dp = world_size // self.ulysses_sequence_parallel_size - if self.ulysses_sequence_parallel_size > 1: - self.ulysses_device_mesh = init_device_mesh( - "cuda", - mesh_shape=(dp, self.ulysses_sequence_parallel_size), - mesh_dim_names=["dp", "sp"], - ) - - self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) - - self.use_remove_padding = self.config.model.get("use_remove_padding", False) - - # normalize config - if self.config.micro_batch_size is not None: - self.config.micro_batch_size //= torch.distributed.get_world_size() - self.config.micro_batch_size_per_gpu = self.config.micro_batch_size - - def _build_model(self, config): - # the following line is necessary - from torch.distributed.fsdp import CPUOffload - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - from transformers import AutoConfig, AutoModelForTokenClassification - - # download the checkpoint from hdfs - local_path = copy_to_local(config.model.path) - - if self.config.model.input_tokenizer is None: - self._do_switch_chat_template = False - else: - self._do_switch_chat_template = True - input_tokenizer_local_path = copy_to_local(config.model.input_tokenizer) - self.input_tokenizer = hf_tokenizer( - input_tokenizer_local_path, - trust_remote_code=config.model.get("trust_remote_code", False), - ) - self.tokenizer = hf_tokenizer( - local_path, trust_remote_code=config.model.get("trust_remote_code", False) - ) - - trust_remote_code = config.model.get("trust_remote_code", False) - model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code) - model_config.num_labels = 1 - - # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect - init_context = get_init_weight_context_manager( - use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh - ) - - with init_context(), warnings.catch_warnings(): - warnings.simplefilter("ignore") - setattr(model_config, "classifier_dropout", 0.0) - reward_module = AutoModelForTokenClassification.from_pretrained( - pretrained_model_name_or_path=local_path, - config=model_config, - torch_dtype=torch.bfloat16, - attn_implementation="flash_attention_2", - trust_remote_code=trust_remote_code, - ) - - if ( - config.model.get("use_remove_padding", False) - or self.ulysses_sequence_parallel_size > 1 - ): - from verl.models.transformers.monkey_patch import apply_monkey_patch - - apply_monkey_patch( - model=reward_module, ulysses_sp_size=self.ulysses_sequence_parallel_size - ) - - reward_module.to(torch.bfloat16) - auto_wrap_policy = get_fsdp_wrap_policy( - module=reward_module, config=self.config.model.fsdp_config - ) - - fsdp_mesh = self.device_mesh - sharding_strategy = get_sharding_strategy(fsdp_mesh) - - reward_module = FSDP( - reward_module, - param_init_fn=init_fn, - use_orig_params=False, - auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device(), - sharding_strategy=sharding_strategy, # zero3 - sync_module_states=True, - cpu_offload=CPUOffload(offload_params=True), - forward_prefetch=False, - device_mesh=self.device_mesh, - ) - - return reward_module - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def init_model(self): - # This is used to import external_lib into the huggingface systems - import_external_libs(self.config.model.get("external_lib", None)) - self.reward_module = self._build_model(config=self.config) - - def _forward_micro_batch(self, micro_batch): - from flash_attn.bert_padding import ( - index_first_axis, - pad_input, - rearrange, - unpad_input, - ) - from verl.utils.ulysses import ( - gather_outpus_and_unpad, - ulysses_pad_and_slice_inputs, - ) - - with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): - input_ids = micro_batch["input_ids"] - batch_size, seqlen = input_ids.shape - attention_mask = micro_batch["attention_mask"] - position_ids = micro_batch["position_ids"] - - if self.use_remove_padding: - input_ids_rmpad, indices, *_ = unpad_input( - input_ids.unsqueeze(-1), attention_mask - ) # input_ids_rmpad (total_nnz, ...) - input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) - - # unpad the position_ids to align the rotary - position_ids_rmpad = index_first_axis( - rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices - ).transpose(0, 1) - - # pad and slice the inputs if sp > 1 - if self.ulysses_sequence_parallel_size > 1: - input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( - input_ids_rmpad, - position_ids_rmpad, - sp_size=self.ulysses_sequence_parallel_size, - ) - - # only pass input_ids and position_ids to enable flash_attn_varlen - output = self.reward_module( - input_ids=input_ids_rmpad, - attention_mask=None, - position_ids=position_ids_rmpad, - use_cache=False, - ) # prevent model thinks we are generating - reward_rmpad = output.logits - reward_rmpad = reward_rmpad.squeeze(0) # (total_nnz) - - # gather output if sp > 1 - if self.ulysses_sequence_parallel_size > 1: - reward_rmpad = gather_outpus_and_unpad( - reward_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size - ) - - # pad it back - rm_score = pad_input( - reward_rmpad, indices=indices, batch=batch_size, seqlen=seqlen - ).squeeze(-1) - else: - output = self.reward_module( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - use_cache=False, - ) - rm_score = output.logits # (batch_size, seq_len, 1) - rm_score = rm_score.squeeze(-1) - - # extract the result of the last valid token - eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,) - rm_score = rm_score[torch.arange(batch_size), eos_mask_idx] - return rm_score - - def _expand_to_token_level(self, data: DataProto, scores: torch.Tensor): - batch_size = data.batch.batch_size[0] - # expand as token_level_reward - attention_mask = data.batch["attention_mask"] - position_ids = data.batch["position_ids"] - response_length = data.batch["responses"].shape[-1] - eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,) - token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype) # (bsz, seqlen) - token_level_scores[torch.arange(batch_size), eos_mask_idx] = scores - - # select the response part - token_level_scores = token_level_scores[:, -response_length:] - - return token_level_scores - - def _switch_chat_template(self, data: DataProto): - src_max_length = data.batch["attention_mask"].shape[-1] - - src_tokenizer = self.input_tokenizer - target_tokenizer = self.tokenizer - - rm_input_ids = [] - rm_attention_mask = [] - - for i in range(data.batch.batch_size[0]): - # extract raw prompt - chat: list = data.non_tensor_batch["raw_prompt"][i].tolist() - - # extract response - response_ids = data.batch["responses"][i] - response_length = response_ids.shape[-1] - valid_response_length = data.batch["attention_mask"][i][-response_length:].sum() - valid_response_ids = response_ids[:valid_response_length] - - # decode - response = src_tokenizer.decode(valid_response_ids) - # remove bos and eos - response = response.replace(src_tokenizer.eos_token, "") - - chat.append({"role": "assistant", "content": response}) - - prompt_with_chat_template = target_tokenizer.apply_chat_template( - chat, add_generation_prompt=False, tokenize=False - ) - if self.rank == 0 and i == 0: - # for debugging purpose - print(f"Switch template. chat: {prompt_with_chat_template}") - - # the maximum length is actually determined by the reward model itself - max_length = self.config.get("max_length", src_max_length) - if max_length is None: - max_length = src_max_length - input_ids, attention_mask = verl_F.tokenize_and_postprocess_data( - prompt=prompt_with_chat_template, - tokenizer=target_tokenizer, - max_length=max_length, - pad_token_id=target_tokenizer.pad_token_id, - left_pad=False, # right padding - truncation=self.config.get("truncation", "right"), - ) # truncate from the right - - rm_input_ids.append(input_ids) - rm_attention_mask.append(attention_mask) - - rm_input_ids = torch.cat(rm_input_ids, dim=0) - rm_attention_mask = torch.cat(rm_attention_mask, dim=0) - - rm_position_ids = compute_position_id_with_mask(rm_attention_mask) - - rm_inputs = { - "input_ids": rm_input_ids, - "attention_mask": rm_attention_mask, - "position_ids": rm_position_ids, - } - - return DataProto.from_dict(rm_inputs) - - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - def compute_rm_score(self, data: DataProto): - import itertools - - from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches - - # Support all hardwares - data = data.to(torch.cuda.current_device()) - if self._do_switch_chat_template: - rm_data = self._switch_chat_template(data) - - # Support all hardwares - rm_data.batch = rm_data.batch.to(torch.cuda.current_device()) - - # perform forward computation - with self.ulysses_sharding_manager: - rm_data = self.ulysses_sharding_manager.preprocess_data(data=rm_data) - data = self.ulysses_sharding_manager.preprocess_data(data=data) - - use_dynamic_bsz = self.config.use_dynamic_bsz - if use_dynamic_bsz: - max_token_len = ( - self.config.forward_max_token_len_per_gpu * self.ulysses_sequence_parallel_size - ) - micro_batches, indices = rearrange_micro_batches( - batch=rm_data.batch, max_token_len=max_token_len - ) - else: - micro_batches = rm_data.batch.split(self.config.micro_batch_size_per_gpu) - output = [] - for micro_batch in micro_batches: - rm_score = self._forward_micro_batch(micro_batch) - output.append(rm_score) - scores = torch.cat(output, dim=0) # (batch_size) - - if use_dynamic_bsz: - indices = list(itertools.chain.from_iterable(indices)) - assert len(indices) == scores.size(0), f"{len(indices)} vs. {scores.size()}" - revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) - scores = scores[revert_indices] - - token_level_scores = self._expand_to_token_level(data, scores) - # Note that this is only the scores, may not be the final rewards used to train RL - output = DataProto.from_dict(tensors={"rm_scores": token_level_scores}) - output = self.ulysses_sharding_manager.postprocess_data(data=output) - - # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes - # unshard the root FSDP module - self.reward_module._handle.reshard(True) - - output = output.to("cpu") - return output diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index d040c329dd..a149853b43 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -159,7 +159,16 @@ def _validate_config(self): # TODO super()._validate_config() def init_workers(self): - """Init resource pool and worker group""" + """Initialize distributed training workers using Ray backend. + + + Creates: + + 1. Ray resource pools from configuration + + 2. Worker groups for each role (actor, critic, etc.) + + """ self.resource_pool_manager.create_resource_pool() self.resource_pool_to_cls = { @@ -207,25 +216,31 @@ def init_workers(self): # initialize WorkerGroup # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, - # you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to different worker groups. + # you should not use `create_colocated_worker_cls`. + # Instead, directly pass different resource pool to different worker groups. # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. all_wg = {} - self.wg_dicts = [] + wg_kwargs = {} # Setting up kwargs for RayWorkerGroup + if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: + wg_kwargs[ + "ray_wait_register_center_timeout" + ] = self.config.trainer.ray_wait_register_center_timeout for resource_pool, class_dict in self.resource_pool_to_cls.items(): worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) wg_dict = self.ray_worker_group_cls( - resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls + resource_pool=resource_pool, + ray_cls_with_init=worker_dict_cls, + device_name=self.device_name, + **wg_kwargs, ) spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) all_wg.update(spawn_wg) - # keep the referece of WorkerDict to support ray >= 2.31. Ref: https://github.com/ray-project/ray/pull/45699 - self.wg_dicts.append(wg_dict) if self.use_critic: self.critic_wg = all_wg["critic"] self.critic_wg.init_model() - if self.use_reference_policy: + if self.use_reference_policy and not self.ref_in_actor: self.ref_policy_wg = all_wg["ref"] self.ref_policy_wg.init_model() @@ -266,7 +281,7 @@ def prepare(self): if self.config.trainer.get("val_only", False): return - def _create_dataloader(self): + def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler): self.train_dataloader = _InternalDataLoader(self.config) # TODO: compute total training steps self.total_training_steps = self.config.trainer.total_training_steps or sys.maxsize