diff --git a/pyproject.toml b/pyproject.toml index 365053f2ed..9bb03e46fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ classifiers = [ ] requires-python = ">=3.10,<3.13" dependencies = [ - "verl==0.5.0", + "verl==0.7.0", "ray[default]>=2.50.0", "tensordict", "wandb", diff --git a/trinity/common/config.py b/trinity/common/config.py index df5286cb41..3307024015 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -116,6 +116,8 @@ class LoRAConfig: lora_alpha: int = 32 lora_dtype: str = "auto" target_modules: str = "all-linear" + exclude_modules: Optional[str] = None + is_dummy: bool = False # DO NOT SET, automatically set @Experimental @@ -1356,16 +1358,19 @@ def _check_explorer(self) -> None: self.explorer.rollout_model.enable_lora = True if len(self.model.lora_configs) > 1: raise ValueError("Only one lora adapter is supported for now.") - if self.model.lora_configs[0].path is None: + lora_config = self.model.lora_configs[0] + if lora_config.path is None: logger.info("Creating dummy lora, since no lora_path is provided.") lora_path = create_dummy_lora( model_path=self.model.model_path, checkpoint_job_dir=self.checkpoint_job_dir, - lora_rank=self.model.lora_configs[0].lora_rank, - lora_alpha=self.model.lora_configs[0].lora_alpha, - target_modules=self.model.lora_configs[0].target_modules, + lora_rank=lora_config.lora_rank, + lora_alpha=lora_config.lora_alpha, + target_modules=lora_config.target_modules, + exclude_modules=lora_config.exclude_modules, ) - self.model.lora_configs[0].path = lora_path + lora_config.path = lora_path + lora_config.is_dummy = True self.explorer.rollout_model.lora_modules = [ { "lora_int_id": i + 1, diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index 5db971245d..5aa791240b 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -4,7 +4,9 @@ from typing import Any, Dict, List, Optional from omegaconf import OmegaConf +from verl.workers.config import PolicyLossConfig, RouterReplayConfig +from trinity.algorithm import ALGORITHM_TYPE from trinity.common.config import Config, SynchronizerConfig, set_if_none from trinity.common.constants import EXPLORER_NAME from trinity.utils.log import get_logger @@ -41,6 +43,8 @@ class ActorModel: lora_rank: int = 0 # The rank of the LoRA model, default to 0. If lora_rank > 0, LoRA module is enabled in trainer lora_alpha: int = 32 target_modules: Optional[str] = "all-linear" + exclude_modules: Optional[str] = None + lora_adapter_path: Optional[str] = None # rope configs rope_scaling: Optional[dict] = None @@ -51,6 +55,8 @@ class ActorModel: class Optim: # For actor, most fields are set in algorithm.optimizer # For critic, you can set trainer_config.critic.optim + optimizer: str = "AdamW" + optimizer_impl: str = "torch.optim" lr: float = 1e-6 lr_warmup_steps: int = -1 lr_warmup_steps_ratio: float = 0.0 @@ -58,7 +64,6 @@ class Optim: warmup_style: str = "constant" total_training_steps: int = -1 # ! DO NOT SET, use trainer.total_steps betas: List[float] = field(default_factory=lambda: [0.9, 0.999]) - optimizer: str = "adam" clip_grad: float = 1.0 lr_warmup_init: float = 0.0 lr_decay_steps: Optional[int] = None @@ -69,6 +74,7 @@ class Optim: lr_wsd_decay_style: str = "exponential" lr_wsd_decay_steps: Optional[int] = None use_checkpoint_opt_param_scheduler: bool = False + override_optimizer_config: Optional[dict] = None @dataclass @@ -78,6 +84,7 @@ class WrapPolicy: @dataclass class FSDPConfig: + _target_: str = "verl.workers.config.FSDPEngineConfig" # DO NOT SET param_offload: bool = False optimizer_offload: bool = False offload_policy: bool = False @@ -92,7 +99,7 @@ class FSDPConfig: class Checkpoint: load_contents: List[str] = field(default_factory=lambda: ["model", "optimizer", "extra"]) save_contents: List[str] = field(default_factory=lambda: ["model", "optimizer", "extra"]) - async_save: bool = False # do not set, async save has bug in verl megatron training + async_save: bool = False # TODO: testing async save @dataclass @@ -124,6 +131,8 @@ class MegatronConfig: default_factory=OverrideTransformerConfig ) use_mbridge: bool = False + dtype: str = "bfloat16" + use_remove_padding: bool = True @dataclass @@ -157,6 +166,9 @@ class Actor: profile: ProfileConfig = field(default_factory=ProfileConfig) data_loader_seed: Optional[int] = None load_weight: bool = True + policy_loss: PolicyLossConfig = field(default_factory=PolicyLossConfig) + profiler: dict = field(default_factory=dict) + router_replay: RouterReplayConfig = field(default_factory=RouterReplayConfig) # do not set loss_agg_mode: str = "token-mean" clip_ratio: float = 0.2 @@ -182,6 +194,8 @@ class Ref: megatron: MegatronConfig = field(default_factory=MegatronConfig) profile: ProfileConfig = field(default_factory=ProfileConfig) load_weight: bool = True + profiler: dict = field(default_factory=dict) + router_replay: RouterReplayConfig = field(default_factory=RouterReplayConfig) @dataclass @@ -214,6 +228,7 @@ class ActorRolloutRef: actor: Actor = field(default_factory=Actor) ref: Ref = field(default_factory=Ref) rollout: Rollout = field(default_factory=Rollout) + nccl_timeout: float = 600 # ! DO NOT SET, it will be set by `config.synchronizer.sync_timeout` synchronizer: Optional[SynchronizerConfig] = None explorer_name: str = EXPLORER_NAME @@ -232,6 +247,7 @@ class CriticModel: @dataclass class Critic: + enable: bool = False strategy: Optional[str] = None optim: Optim = field(default_factory=Optim) model: CriticModel = field(default_factory=CriticModel) @@ -255,7 +271,9 @@ class Critic: profile: ProfileConfig = field(default_factory=ProfileConfig) data_loader_seed: Optional[int] = None load_weight: bool = True + nccl_timeout: float = 600 # ! DO NOT SET, it will be set by `config.synchronizer.sync_timeout` ray_namespace: str = "" # automatically generated + profiler: dict = field(default_factory=dict) @dataclass @@ -278,6 +296,7 @@ class RewardModel: use_dynamic_bsz: bool = False forward_max_token_len_per_gpu: int = 0 reward_manager: str = "naive" + use_reward_loop: bool = True @dataclass @@ -294,8 +313,24 @@ class KL_Ctrl: target_kl: float = 0.1 +@dataclass +class RolloutCorrection: + rollout_is: Optional[str] = None + rollout_is_threshold: float = 2.0 + rollout_rs: Optional[str] = None + rollout_rs_threshold: Optional[float] = None + rollout_rs_threshold_lower: Optional[float] = None + rollout_token_veto_threshold: Optional[float] = None + # Because rollout and training in Trinity runs separately, + # rollout_is_batch_normalize is default to True + bypass_mode: bool = True + loss_type: str = "ppo_clip" + rollout_is_batch_normalize: bool = False + + @dataclass class Algorithm: + rollout_correction: RolloutCorrection = field(default_factory=RolloutCorrection) # ! DO NOT SET gamma or lam below; they are kept here merely for compatibility with verl, # and their values will be overwritten by those in AlgorithmConfig.advantage_fn_args # if they are really needed (e.g., for GAE advantage/returns computation) @@ -349,6 +384,7 @@ class veRLConfig: custom_reward_function: CustomRewardFunction = field(default_factory=CustomRewardFunction) algorithm: Algorithm = field(default_factory=Algorithm) trainer: Trainer = field(default_factory=Trainer) + global_profiler: dict = field(default_factory=dict) synchronizer: Optional[SynchronizerConfig] = None enable_preview: bool = True @@ -423,8 +459,12 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 ) # kept to pass RayPPOTrainer._validate_config self.synchronizer = config.synchronizer + self.actor_rollout_ref.nccl_timeout = config.synchronizer.sync_timeout self.actor_rollout_ref.synchronizer = config.synchronizer self.actor_rollout_ref.explorer_name = config.explorer.name + algorithm = ALGORITHM_TYPE.get(config.algorithm.algorithm_type) + self.critic.enable = algorithm.use_critic + self.critic.nccl_timeout = config.synchronizer.sync_timeout self.critic.ray_namespace = config.synchronizer.ray_namespace # Actor / Rollout Config @@ -539,11 +579,12 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 # LoRA related config if config.model.lora_configs is not None: - self.actor_rollout_ref.model.lora_rank = config.model.lora_configs[0].lora_rank - self.actor_rollout_ref.model.lora_alpha = config.model.lora_configs[0].lora_alpha - self.actor_rollout_ref.model.target_modules = config.model.lora_configs[ - 0 - ].target_modules + lora_config = config.model.lora_configs[0] + actor_model_config = self.actor_rollout_ref.model + for attr in ["lora_rank", "lora_alpha", "target_modules", "exclude_modules"]: + setattr(actor_model_config, attr, getattr(lora_config, attr)) + if not lora_config.is_dummy: + actor_model_config.lora_adapter_path = lora_config.path if self.actor_rollout_ref.actor.strategy not in ["fsdp", "fsdp2"]: logger.warning( f"Lora is only supported for fsdp and fsdp2, but got {self.actor_rollout_ref.actor.strategy} instead, changed to fsdp." @@ -559,6 +600,13 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 for field_name in config.algorithm.optimizer.__dataclass_fields__: field_value = getattr(config.algorithm.optimizer, field_name) if field_name == "optimizer_type": + if config.trainer.trainer_strategy.startswith("fsdp"): + optim_map = { + "adam": "AdamW", + "adamw": "AdamW", + "sgd": "SGD", + } + field_value = optim_map.get(field_value, field_value) setattr(self.actor_rollout_ref.actor.optim, "optimizer", field_value) elif hasattr(self.actor_rollout_ref.actor.optim, field_name): setattr(self.actor_rollout_ref.actor.optim, field_name, field_value) diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py index 2d1258b9cb..80bc4a16c8 100644 --- a/trinity/trainer/verl/dp_actor.py +++ b/trinity/trainer/verl/dp_actor.py @@ -15,7 +15,7 @@ # limitations under the License. """ Single Process Actor. -Modified from https://github.com/volcengine/verl/blob/v0.5.0/verl/workers/actor/dp_actor.py +Modified from https://github.com/volcengine/verl/blob/v0.7.0/verl/workers/actor/dp_actor.py """ import logging @@ -67,9 +67,8 @@ def update_policy(self, data: DataProto): # noqa: C901 # make sure we are in training mode self.actor_module.train() - temperature = data.meta_info[ - "temperature" - ] # temperature must be in the data.meta_info to avoid silent error + # temperature must be in the data.meta_info to avoid silent error + temperature = data.meta_info["temperature"] select_keys = [ "input_ids", "position_ids", @@ -80,6 +79,8 @@ def update_policy(self, data: DataProto): # noqa: C901 select_keys.extend(self.policy_loss_fn.select_keys) if not isinstance(self.kl_loss_fn, DummyKLFn): select_keys.append("ref_log_prob") + # rollout_is_weights will be used in policy loss + # rollout_log_probs is equal to old_log_prob in Trinity select_keys = list(set(select_keys)) has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() @@ -87,6 +88,8 @@ def update_policy(self, data: DataProto): # noqa: C901 data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) + # Split to make minibatch iterator for updating the actor + # See PPO paper for details. https://arxiv.org/abs/1707.06347 mini_batches = data.split(self.config.ppo_mini_batch_size) # EXPERIMENTAL: apply loss scale fix @@ -119,12 +122,11 @@ def update_policy(self, data: DataProto): # noqa: C901 self.actor_optimizer.zero_grad() for micro_batch in micro_batches: + micro_batch = micro_batch.to(get_device_id()) micro_batch_metrics = {} - model_inputs = { - **micro_batch.batch.to(get_device_id()), - **micro_batch.non_tensor_batch, - } + model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} response_mask = model_inputs["response_mask"] + loss_mode = self.config.policy_loss.get("loss_mode", "vanilla") # all return: (bsz, response_length) calculate_entropy = self.entropy_loss_fn != DummyEntropyLossFn @@ -141,6 +143,23 @@ def update_policy(self, data: DataProto): # noqa: C901 src_metrics=pg_loss_metrics, prefix="actor", dst_metrics=micro_batch_metrics ) + # TODO: to be check + # Skip if using bypass_mode loss (metrics already computed in pg_metrics) + rollout_log_prob = model_inputs.get("rollout_log_probs", None) + if loss_mode != "bypass_mode" and rollout_log_prob is not None: + # Compute metrics using CURRENT policy π_θ vs π_rollout + # Tracks evolving off-policy gap as π_θ updates during mini-batch training + from verl.trainer.ppo.rollout_corr_helper import ( + compute_rollout_corr_metrics_from_logprobs, + ) + + rollout_corr_metrics = compute_rollout_corr_metrics_from_logprobs( + log_prob=log_prob, + rollout_log_prob=rollout_log_prob, + response_mask=response_mask, + ) + micro_batch_metrics.update(rollout_corr_metrics) + # compute entropy loss from entropy entropy_loss, entropy_loss_metrics = self.entropy_loss_fn( # type: ignore entropy=entropy, @@ -185,7 +204,15 @@ def update_policy(self, data: DataProto): # noqa: C901 loss = policy_loss * loss_scale micro_batch_metrics["actor/final_loss"] = loss.detach().item() - loss.backward() + if "actor/kl_loss" in micro_batch_metrics: + micro_batch_metrics["actor/kl_loss"] *= loss_scale + if "actor/pg_loss" in micro_batch_metrics: + micro_batch_metrics["actor/pg_loss"] *= loss_scale + + if self.scaler is not None: + self.scaler.scale(loss).backward() + else: + loss.backward() append_to_dict(metrics, micro_batch_metrics) diff --git a/trinity/trainer/verl/fsdp_checkpoint_manager.py b/trinity/trainer/verl/fsdp_checkpoint_manager.py index 4c3a6aba57..839b83d327 100644 --- a/trinity/trainer/verl/fsdp_checkpoint_manager.py +++ b/trinity/trainer/verl/fsdp_checkpoint_manager.py @@ -13,7 +13,7 @@ # limitations under the License. """ FSDP Checkpoint Manager. -Modified from https://github.com/volcengine/verl/blob/v0.5.0/verl/utils/checkpoint/fsdp_checkpoint_manager.py +Modified from https://github.com/volcengine/verl/blob/v0.7.0/verl/utils/checkpoint/fsdp_checkpoint_manager.py """ import json @@ -33,6 +33,7 @@ StateDictType, ) from transformers import GenerationConfig +from transformers.dynamic_module_utils import custom_object_save from verl.utils.checkpoint.fsdp_checkpoint_manager import ( FSDPCheckpointManager as OldFSDPCheckpointManager, ) @@ -209,7 +210,7 @@ def _save_extra_state(self, local_path, global_step) -> bool: self.latest_extra_state_save_step = global_step return True - def _get_model_config(self): + def _get_unwrap_model_and_config(self): if fsdp_version(self.model) == 1: unwrap_model = self.model._fsdp_wrapped_module else: @@ -221,13 +222,21 @@ def _get_model_config(self): and hasattr(model_config, "name_or_path") and model_config.name_or_path ): - # Some model's name_or_path is empty if not initialized from pretrained, - # in this cases, we don't save generation config. - generation_config = GenerationConfig.from_pretrained(model_config.name_or_path) + try: + # Some model's name_or_path is empty if not initialized from pretrained, + # in this cases, we don't save generation config. + generation_config = GenerationConfig.from_pretrained(model_config.name_or_path) + except Exception: + # if the generation config isn't available, we don't save it + generation_config = None else: generation_config = None - return model_config, generation_config + if hasattr(model_config, "auto_map") and None in model_config.auto_map: + model_config.auto_map = { + k: v for k, v in model_config.auto_map.items() if k is not None + } + return unwrap_model, model_config, generation_config def _save_tokenizer(self, local_path, global_step): """ @@ -251,12 +260,12 @@ def _save_tokenizer(self, local_path, global_step): hf_config_tokenizer_path = os.path.join(local_path, "huggingface") local_mkdir_safe(hf_config_tokenizer_path) - model_config, generation_config = self._get_model_config() + unwrap_model, model_config, generation_config = self._get_unwrap_model_and_config() model_config.save_pretrained(hf_config_tokenizer_path) if generation_config is not None: generation_config.save_pretrained(hf_config_tokenizer_path) - - self.processing_class.save_pretrained(hf_config_tokenizer_path) + if self.processing_class is not None: + self.processing_class.save_pretrained(hf_config_tokenizer_path) log_with_rank( f"Saved model config and tokenizer class to {os.path.abspath(hf_config_tokenizer_path)}", rank=self.rank, @@ -264,6 +273,11 @@ def _save_tokenizer(self, local_path, global_step): log_only_rank_0=True, ) + # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be + # loaded from the Hub. + if hasattr(model_config, "auto_map"): + custom_object_save(unwrap_model, hf_config_tokenizer_path, config=model_config) + # Also save runtime FSDP config fsdp_config_path = os.path.join(local_path, "fsdp_config.json") fsdp_config = FSDPConfig( @@ -303,7 +317,7 @@ def _save_hf_model(self, local_path, global_step) -> bool: hf_local_path = os.path.join(local_path, "huggingface") os.makedirs(hf_local_path, exist_ok=True) - model_config, generation_config = self._get_model_config() + _, model_config, generation_config = self._get_unwrap_model_and_config() if "ForTokenClassification" in model_config.architectures[0]: from transformers import AutoModelForTokenClassification @@ -314,9 +328,20 @@ def _save_hf_model(self, local_path, global_step) -> bool: auto_model_cls = AutoModelForCausalLM elif "ForConditionalGeneration" in model_config.architectures[0]: - from transformers import AutoModelForVision2Seq + # Handle different transformers versions for Vision2Seq models + import transformers + from packaging import version + + if version.parse(transformers.__version__) >= version.parse("4.54.0"): + # transformers >= 4.54.0 uses AutoModelForImageTextToText + from transformers import AutoModelForImageTextToText + + auto_model_cls = AutoModelForImageTextToText + else: + # transformers < 4.54.0 uses AutoModelForVision2Seq + from transformers import AutoModelForVision2Seq - auto_model_cls = AutoModelForVision2Seq + auto_model_cls = AutoModelForVision2Seq else: raise NotImplementedError(f"Unknown architecture {model_config['architectures']}") diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index dd977206c8..ddb8bdf854 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -13,16 +13,16 @@ # limitations under the License. """ The main entry point to run the PPO algorithm. -Modified from https://github.com/volcengine/verl/blob/v0.5.0/verl/workers/fsdp_workers.py +Modified from https://github.com/volcengine/verl/blob/v0.7.0/verl/workers/fsdp_workers.py """ +import datetime import json import logging import os import warnings from contextlib import contextmanager from dataclasses import asdict -from datetime import timedelta import psutil import torch @@ -40,15 +40,20 @@ from verl import DataProto from verl.models.transformers.monkey_patch import apply_monkey_patch from verl.single_controller.base import Worker -from verl.single_controller.base.decorator import Dispatch, register +from verl.single_controller.base.decorator import ( + Dispatch, + make_nd_compute_dataproto_dispatch_fn, + register, +) from verl.utils import hf_processor, hf_tokenizer from verl.utils.activation_offload import enable_activation_offloading -from verl.utils.debug import log_gpu_memory_usage +from verl.utils.config import omega_conf_to_dataclass from verl.utils.device import ( get_device_id, get_device_name, get_nccl_backend, get_torch_device, + set_expandable_segments, ) from verl.utils.flops_counter import FlopsCounter from verl.utils.fs import copy_to_local @@ -60,6 +65,7 @@ fsdp_version, get_fsdp_wrap_policy, get_init_weight_context_manager, + get_shard_placement_fn, init_fn, layered_summon_lora_params, load_fsdp_model_to_gpu, @@ -69,11 +75,21 @@ ) from verl.utils.import_utils import import_external_libs from verl.utils.logger import log_with_rank +from verl.utils.memory_utils import aggressive_empty_cache +from verl.utils.profiler import ( + DistProfiler, + DistProfilerExtension, + ProfilerConfig, + log_gpu_memory_usage, +) from verl.utils.py_functional import convert_to_regular_types +from verl.workers.config import FSDPCriticConfig, FSDPEngineConfig +from verl.workers.config.optimizer import build_optimizer from verl.workers.fsdp_workers import ( create_device_mesh, device_name, get_sharding_strategy, + get_vl_model_vision_tower, ) from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager @@ -87,14 +103,15 @@ logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) -class ActorRolloutRefWorker(Worker): +class ActorRolloutRefWorker(Worker, DistProfilerExtension): """ This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy or a hybrid engine based on the config.rollout """ - def __init__(self, config: DictConfig, role: str): - super().__init__() + def __init__(self, config: DictConfig, role: str, **kwargs): + Worker.__init__(self) + self.config = config import torch.distributed @@ -105,8 +122,8 @@ def __init__(self, config: DictConfig, role: str): backend=f"cpu:gloo,{get_device_name()}:{get_nccl_backend()}", rank=rank, world_size=world_size, + timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), init_method=os.environ.get("DIST_INIT_METHOD", None), - timeout=timedelta(seconds=self.config.synchronizer.sync_timeout), ) # build device mesh for FSDP @@ -129,15 +146,59 @@ def __init__(self, config: DictConfig, role: str): mesh_dim_names=["dp", "sp"], ) + # create training dispatch + if self.ulysses_device_mesh is not None: + is_collect = self.ulysses_device_mesh["sp"].get_local_rank() == 0 + self._register_dispatch_collect_info( + "actor", + dp_rank=self.ulysses_device_mesh["dp"].get_local_rank(), + is_collect=is_collect, + ) + else: + self._register_dispatch_collect_info("actor", dp_rank=self.rank, is_collect=True) + self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) self._lora_rank = self.config.model.get("lora_rank", 0) - self._is_lora = self._lora_rank > 0 + self._is_lora = ( + self.config.model.get("lora_adapter_path") is not None or self._lora_rank > 0 + ) self.role = role assert self.role in ["actor", "rollout", "ref", "actor_rollout", "actor_rollout_ref"] self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"] self._is_ref = self.role in ["ref", "actor_rollout_ref"] + self.use_orig_params = self.config.actor.fsdp_config.get("use_orig_params", False) + + # TODO(haibin.lin): + # As of now the type of config is DictConfig, if we assign config.profiler with ProfilerConfig, + # it will actually convert the ProfilerConfig dataclass back to a DictConfig. + # We can still use ProfilerConfig for testing purpose (tests/utils/test_nvtx_profile.py) + # as they provides DictConfig-like interface + # The benefit of creating the dataclass config is to perform validation during __post_init__ + if self._is_actor: + omega_profiler_config = config.actor.get("profiler", {}) + elif self._is_ref: + omega_profiler_config = config.ref.get("profiler", {}) + else: + raise ValueError( + f"Invalid role {self.role}, should be one of " + "['actor', 'rollout', 'ref', 'actor_rollout', 'actor_rollout_ref']" + ) + # omega_profiler_config is DictConfig + # profiler_config is a ProfilerConfig dataclass + profiler_config = omega_conf_to_dataclass( + omega_profiler_config, dataclass_type=ProfilerConfig + ) + if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]: + tool_config = omega_conf_to_dataclass( + omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool")) + ) + else: + tool_config = None + DistProfilerExtension.__init__( + self, DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config) + ) self._is_offload_param = False self._is_offload_optimizer = False @@ -156,9 +217,10 @@ def __init__(self, config: DictConfig, role: str): self.config.actor.ppo_mini_batch_size //= ( self.device_mesh.size() // self.ulysses_sequence_parallel_size ) - assert ( - self.config.actor.ppo_mini_batch_size > 0 - ), f"ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than 0 after normalization" + assert self.config.actor.ppo_mini_batch_size > 0, ( + f"ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than 0 after " + f"normalization" + ) # micro bsz if self.config.actor.ppo_micro_batch_size is not None: self.config.actor.ppo_micro_batch_size //= ( @@ -176,12 +238,18 @@ def __init__(self, config: DictConfig, role: str): self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0 - ), f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" + ), ( + f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be divisible by " + f"ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" + ) assert ( self.config.actor.ppo_mini_batch_size // self.config.actor.ppo_micro_batch_size_per_gpu > 0 - ), f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" + ), ( + f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than " + f"ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" + ) # normalize ref config if ( @@ -212,7 +280,7 @@ def _fsdp_offload_context(self): def _build_model_optimizer( # noqa: C901 self, model_path, - fsdp_config, + fsdp_config: FSDPEngineConfig, optim_config, override_model_config, use_remove_padding=False, @@ -222,12 +290,15 @@ def _build_model_optimizer( # noqa: C901 use_liger=False, role="actor", enable_activation_offload=False, + use_tiled_mlp=False, + tiled_mlp_shards=4, ): - from torch import optim from torch.distributed.fsdp import CPUOffload, MixedPrecision from transformers import ( AutoConfig, + AutoModel, AutoModelForCausalLM, + AutoModelForImageTextToText, AutoModelForVision2Seq, ) from verl.utils.model import ( @@ -239,6 +310,12 @@ def _build_model_optimizer( # noqa: C901 assert role in ["actor", "ref"] + # TiledMLP requires FSDP2 for correct gradient computation + if use_tiled_mlp and self.config.actor.strategy == "fsdp": + raise ValueError( + "TiledMLP requires FSDP2. Set `actor_rollout_ref.actor.strategy=fsdp2`." + ) + log_gpu_memory_usage(f"Before init {role} from HF AutoModel", logger=logger) local_path = model_path @@ -260,9 +337,24 @@ def _build_model_optimizer( # noqa: C901 torch_dtype = PrecisionType.to_dtype(torch_dtype) # override model kwargs + attn_implementation = override_model_config.get("attn_implementation", "flash_attention_2") actor_model_config = AutoConfig.from_pretrained( - local_path, trust_remote_code=trust_remote_code, attn_implementation="flash_attention_2" + local_path, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation ) + # TODO: VL models use VisionAttention, which directly uses flash_attention in transformers>=4.53 + # which will be patched by _ulysses_flash_attention_forward, but errorly misses position_ids + # Maybe support Ulysses in VisionAttention in the future and remove this patch + if self.ulysses_sequence_parallel_size > 1 and hasattr(actor_model_config, "vision_config"): + actor_model_config.vision_config._attn_implementation = "eager" + + # patch for qwen2.5-vl: when using flash_attention_3, set vision tower to use flash_attention_2 + # because the vision tower does not support flash_attention_3 + if ( + getattr(actor_model_config, "model_type", None) == "qwen2_5_vl" + and attn_implementation == "flash_attention_3" + and hasattr(actor_model_config, "vision_config") + ): + actor_model_config.vision_config._attn_implementation = "flash_attention_2" # patch for rope if self.config.model.rope_scaling is not None: @@ -295,16 +387,41 @@ def _build_model_optimizer( # noqa: C901 with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") - if type(actor_model_config) in AutoModelForVision2Seq._model_mapping.keys(): - actor_module_class = AutoModelForVision2Seq + has_remote_code = hasattr(actor_model_config, "auto_map") and any( + actor_model_config.architectures[0] in val + for val in actor_model_config.auto_map.values() + ) + if has_remote_code: + auto_class = next( + k + for k, v in actor_model_config.auto_map.items() + if actor_model_config.architectures[0] in v + ) + match auto_class: + case "AutoModelForVision2Seq": + actor_module_class = AutoModelForVision2Seq + case "AutoModelForCausalLM": + actor_module_class = AutoModelForCausalLM + case "AutoModelForImageTextToText": + actor_module_class = AutoModelForImageTextToText + case _: + actor_module_class = AutoModel else: - actor_module_class = AutoModelForCausalLM + if type(actor_model_config) in AutoModelForVision2Seq._model_mapping.keys(): + actor_module_class = AutoModelForVision2Seq + elif type(actor_model_config) in AutoModelForCausalLM._model_mapping.keys(): + actor_module_class = AutoModelForCausalLM + elif type(actor_model_config) in AutoModelForImageTextToText._model_mapping.keys(): + actor_module_class = AutoModelForImageTextToText + else: + actor_module_class = AutoModel actor_module = actor_module_class.from_pretrained( pretrained_model_name_or_path=local_path, torch_dtype=torch_dtype, config=actor_model_config, trust_remote_code=trust_remote_code, + attn_implementation=attn_implementation, ) # Apply Liger kernel to the model if use_liger is set to True @@ -331,24 +448,62 @@ def _build_model_optimizer( # noqa: C901 ulysses_sp_size=self.ulysses_sequence_parallel_size, use_fused_kernels=use_fused_kernels, fused_kernels_backend=fused_kernels_backend, + use_tiled_mlp=use_tiled_mlp, + tiled_mlp_shards=tiled_mlp_shards, ) if enable_gradient_checkpointing: actor_module.gradient_checkpointing_enable( gradient_checkpointing_kwargs={"use_reentrant": False} ) - if self._is_lora: - print("Applying LoRA to actor module") - actor_module.enable_input_require_grads() + + if self._is_lora: + print("Applying LoRA to actor module") + actor_module.enable_input_require_grads() + + lora_adapter_path = self.config.model.get("lora_adapter_path") + if lora_adapter_path is not None: + from peft import PeftModel + + print(f"Loading pre-trained LoRA adapter to {role} from: {lora_adapter_path}") + + # Copy adapter to local if needed + local_adapter_path = copy_to_local( + lora_adapter_path, use_shm=self.config.model.get("use_shm", False) + ) + + actor_module = PeftModel.from_pretrained( + actor_module, local_adapter_path, is_trainable=True + ) + peft_config = actor_module.peft_config["default"] + # Ensure task_type is TaskType enum, not string + if isinstance(peft_config.task_type, str): + peft_config.task_type = TaskType.CAUSAL_LM + + else: # Convert config to regular Python types before creating PEFT model lora_config = { "task_type": TaskType.CAUSAL_LM, "r": self.config.model.lora_rank, "lora_alpha": self.config.model.lora_alpha, "target_modules": convert_to_regular_types(self.config.model.target_modules), + "exclude_modules": convert_to_regular_types(self.config.model.exclude_modules), "bias": "none", } actor_module = get_peft_model(actor_module, LoraConfig(**lora_config)) + + self.use_orig_params = fsdp_config.get("use_orig_params", False) + if self.config.actor.get("freeze_vision_tower", False): + vision_tower = get_vl_model_vision_tower(actor_module) + if vision_tower is not None: + vision_tower.requires_grad_(False) + self.use_orig_params = True + if self.rank == 0: + print("[actor model] Vision tower is set to not trainable.") + else: + if self.rank == 0: + print("[actor model] No vision tower found.") + torch.distributed.barrier() if self.rank == 0: @@ -367,7 +522,7 @@ def _build_model_optimizer( # noqa: C901 mixed_precision_config.get("buffer_dtype", "fp32") ) else: - param_dtype = torch.bfloat16 + param_dtype = PrecisionType.to_dtype(fsdp_config.dtype) reduce_dtype = torch.float32 buffer_dtype = torch.float32 @@ -378,14 +533,15 @@ def _build_model_optimizer( # noqa: C901 auto_wrap_policy = get_fsdp_wrap_policy( module=actor_module, config=fsdp_config.get("wrap_policy", None), - is_lora=self.config.model.get("lora_rank", 0) > 0, + is_lora=self._is_lora, ) if self.rank == 0: print(f"wrap_policy: {auto_wrap_policy}") fsdp_mesh = self.device_mesh - sharding_strategy = get_sharding_strategy(fsdp_mesh) + fsdp_enable_zero3 = fsdp_config.reshard_after_forward + sharding_strategy = get_sharding_strategy(fsdp_mesh, fsdp_enable_zero3) # TODO: add transformer policy # We force reference policy to use CPUOffload to save memory. @@ -397,13 +553,13 @@ def _build_model_optimizer( # noqa: C901 actor_module, cpu_offload=cpu_offload, param_init_fn=init_fn, - use_orig_params=False, auto_wrap_policy=auto_wrap_policy, device_id=get_device_id(), sharding_strategy=sharding_strategy, # zero3 mixed_precision=mixed_precision, sync_module_states=True, device_mesh=self.device_mesh, + use_orig_params=self.use_orig_params, forward_prefetch=self.config.actor.fsdp_config.forward_prefetch, ) elif fsdp_strategy == "fsdp2": @@ -425,6 +581,7 @@ def _build_model_optimizer( # noqa: C901 "mp_policy": mp_policy, "offload_policy": cpu_offload, "reshard_after_forward": fsdp_config.reshard_after_forward, + "shard_placement_fn": get_shard_placement_fn(fsdp_size=self.device_mesh.shape[-1]), } full_state = actor_module.state_dict() apply_fsdp2(actor_module, fsdp_kwargs, fsdp_config) @@ -447,16 +604,11 @@ def _build_model_optimizer( # noqa: C901 get_cosine_schedule_with_warmup, ) - actor_optimizer = optim.AdamW( - actor_module_fsdp.parameters(), - lr=optim_config.lr, - betas=optim_config.get("betas", (0.9, 0.999)), - weight_decay=optim_config.get("weight_decay", 1e-2), - ) + actor_optimizer = build_optimizer(actor_module_fsdp.parameters(), optim_config) total_steps = optim_config.get("total_training_steps", 0) num_warmup_steps = int(optim_config.get("lr_warmup_steps", -1)) - warmup_style = optim_config.get("warmup_style", "constant") + lr_scheduler_type = optim_config.get("lr_scheduler_type", "constant") min_lr_ratio = optim_config.get("min_lr_ratio", 0.0) num_cycles = optim_config.get("num_cycles", 0.5) if num_warmup_steps < 0: @@ -464,13 +616,13 @@ def _build_model_optimizer( # noqa: C901 num_warmup_steps = int(num_warmup_steps_ratio * total_steps) if self.rank == 0: - print(f"num_warmup_steps: {num_warmup_steps}") + print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") - if warmup_style == "constant": + if lr_scheduler_type == "constant": actor_lr_scheduler = get_constant_schedule_with_warmup( optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps ) - elif warmup_style == "cosine": + elif lr_scheduler_type == "cosine": actor_lr_scheduler = get_cosine_schedule_with_warmup( optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps, @@ -479,7 +631,7 @@ def _build_model_optimizer( # noqa: C901 num_cycles=num_cycles, ) else: - raise NotImplementedError(f"Warmup style {warmup_style} is not supported") + raise NotImplementedError(f"LR scheduler type {lr_scheduler_type} is not supported") log_gpu_memory_usage(f"After {role} optimizer init", logger=logger) else: @@ -488,6 +640,24 @@ def _build_model_optimizer( # noqa: C901 return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config + async def trainer_mode(self): # TODO: check this + """Context switch hybridengine to trainer mode.""" + # if self.config.rollout.free_cache_engine: + # log_gpu_memory_usage("Before rollout offload", logger=logger) + # await self.rollout.release() + # log_gpu_memory_usage("After rollout offload", logger=logger) + + self.actor_module_fsdp.train() + + # add empty cache after each compute + aggressive_empty_cache(force_sync=True) + + set_expandable_segments(True) + + # restore random states + self.gen_random_states = get_torch_device().get_rng_state() + get_torch_device().set_rng_state(self.torch_random_states) + @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): from trinity.trainer.verl.dp_actor import DataParallelPPOActor @@ -496,19 +666,23 @@ def init_model(self): import_external_libs(self.config.model.get("external_lib", None)) override_model_config = OmegaConf.to_container( - self.config.model.get("override_config", OmegaConf.create()) + OmegaConf.create(self.config.model.get("override_config", {})) ) - use_remove_padding = self.config.model.get("use_remove_padding", False) use_shm = self.config.model.get("use_shm", False) use_fused_kernels = self.config.model.get("use_fused_kernels", False) if self._is_actor: - # we need the model for actor and rollout + # we need the model for actor optim_config = self.config.actor.optim - fsdp_config = self.config.actor.fsdp_config + fsdp_config = omega_conf_to_dataclass(self.config.actor.fsdp_config) local_path = copy_to_local(self.config.model.path, use_shm=use_shm) + # TiledMLP configuration for memory-efficient MLP computation + tiled_mlp_config = self.config.model.get("tiled_mlp", {}) + use_tiled_mlp = tiled_mlp_config.get("enabled", False) + tiled_mlp_shards = tiled_mlp_config.get("num_shards", 4) + ( self.actor_module_fsdp, self.actor_optimizer, @@ -528,6 +702,8 @@ def init_model(self): use_liger=self.config.model.get("use_liger", False), role="actor", enable_activation_offload=self.config.model.get("enable_activation_offload", False), + use_tiled_mlp=use_tiled_mlp, + tiled_mlp_shards=tiled_mlp_shards, ) # get the original unwrapped module @@ -554,10 +730,25 @@ def init_model(self): ) if self._is_ref: - local_path = copy_to_local(self.config.model.path, use_shm=use_shm) + ref_model_path = self.config.model.path + ref_model = self.config.ref.get("model", None) + if ref_model is not None: + ref_model_path = ref_model.get("path", self.config.model.path) + + if self.rank == 0: + print("reference model:", ref_model_path) + local_path = copy_to_local(ref_model_path, use_shm=use_shm) + + # TiledMLP for ref model: use ref config if specified, otherwise use actor config + ref_tiled_mlp_config = self.config.ref.get("tiled_mlp", None) + if ref_tiled_mlp_config is None: + ref_tiled_mlp_config = self.config.model.get("tiled_mlp", {}) + ref_use_tiled_mlp = ref_tiled_mlp_config.get("enabled", False) + ref_tiled_mlp_shards = ref_tiled_mlp_config.get("num_shards", 4) + self.ref_module_fsdp = self._build_model_optimizer( model_path=local_path, - fsdp_config=self.config.ref.fsdp_config, + fsdp_config=omega_conf_to_dataclass(self.config.ref.fsdp_config), optim_config=None, override_model_config=override_model_config, use_remove_padding=use_remove_padding, @@ -565,6 +756,8 @@ def init_model(self): trust_remote_code=self.config.model.get("trust_remote_code", False), use_liger=self.config.model.get("use_liger", False), role="ref", + use_tiled_mlp=ref_use_tiled_mlp, + tiled_mlp_shards=ref_tiled_mlp_shards, )[0] OmegaConf.set_struct(self.config.ref, True) with open_dict(self.config.ref): @@ -683,11 +876,9 @@ def upload_state_dict(self, trainer_step: int): def set_algorithm(self, algo_config: AlgorithmConfig): self.actor.set_algorithm(algo_config) - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) + @DistProfiler.annotate(color="red", role="actor_update") def update_actor(self, data: DataProto): - # Support all hardwares - data = data.to("cpu") # data will to device with each micro batch on actor.update_policy - assert self._is_actor if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) @@ -695,7 +886,10 @@ def update_actor(self, data: DataProto): load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=get_device_id()) with self.ulysses_sharding_manager: - data = self.ulysses_sharding_manager.preprocess_data(data=data) + data = data.to( + "cpu" + ) # data will to device with each micro batch on actor.update_policy + # perform training with Timer(name="update_policy", logger=None) as timer: metrics = self.actor.update_policy(data=data) @@ -716,13 +910,12 @@ def update_actor(self, data: DataProto): metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3) lr = self.actor_lr_scheduler.get_last_lr()[0] - metrics["actor/lr"] = lr + metrics["actor/lr"] = lr.item() if torch.is_tensor(lr) else lr self.actor_lr_scheduler.step() # TODO: here, we should return all metrics output = DataProto(meta_info={"metrics": metrics}) - output = self.ulysses_sharding_manager.postprocess_data(data=output) output = output.to("cpu") if self._is_offload_param: @@ -734,7 +927,8 @@ def update_actor(self, data: DataProto): return output - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) + @DistProfiler.annotate(color="blue", role="actor_compute_log_prob") def compute_log_prob(self, data: DataProto): # when is_lora is True, we use the actor without lora applied to calculate the log_prob # which is mostly used for ref log_prob calculation @@ -747,22 +941,25 @@ def compute_log_prob(self, data: DataProto): is_lora = data.meta_info.pop("is_lora", False) adapter_ctx = self.actor.actor_module.disable_adapter() if is_lora else nullcontext() - data = data.to(get_device_id()) # we should always recompute old_log_probs when it is HybridEngine - data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu - data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu - data.meta_info["use_dynamic_bsz"] = self.config.rollout.log_prob_use_dynamic_bsz + config_source = self.config.ref if is_lora else self.config.rollout + data.meta_info["micro_batch_size"] = config_source.log_prob_micro_batch_size_per_gpu + data.meta_info["max_token_len"] = config_source.log_prob_max_token_len_per_gpu + data.meta_info["use_dynamic_bsz"] = config_source.log_prob_use_dynamic_bsz data.meta_info["temperature"] = self.config.rollout.temperature # perform recompute log_prob with self.ulysses_sharding_manager: - data = self.ulysses_sharding_manager.preprocess_data(data) with adapter_ctx: - output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True) + output, entropys = self.actor.compute_log_prob( + data=data, calculate_entropy=not is_lora + ) + tensors = {"ref_log_prob": output} if is_lora else {"old_log_probs": output} + if not is_lora: + tensors["entropys"] = entropys output = DataProto.from_dict( - tensors={"old_log_probs": output, "entropys": entropys}, + tensors=tensors, meta_info={"temperature": self.config.rollout.temperature}, ) - output = self.ulysses_sharding_manager.postprocess_data(output) output = output.to("cpu") @@ -777,20 +974,16 @@ def compute_log_prob(self, data: DataProto): return output - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) + @DistProfiler.annotate(color="olive", role="ref_compute_log_prob") def compute_ref_log_prob(self, data: DataProto): if self._is_lora: # if _is_lora, actor without lora applied is the ref data.meta_info["is_lora"] = True - data = self.compute_log_prob(data) - # this old_log_probs is in fact ref_log_prob - data = DataProto.from_dict(tensors={"ref_log_prob": data.batch["old_log_probs"]}) - return data + return self.compute_log_prob(data) assert self._is_ref # else: # otherwise, the class have a standalone ref model - # Support all hardwares - data = data.to(get_device_id()) micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu data.meta_info["micro_batch_size"] = micro_batch_size @@ -798,17 +991,21 @@ def compute_ref_log_prob(self, data: DataProto): data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz with self.ulysses_sharding_manager: - data = self.ulysses_sharding_manager.preprocess_data(data) + data = data.to( + "cpu" + ) # data will to device with each micro batch on ref.compute_log_prob output, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False) output = DataProto.from_dict(tensors={"ref_log_prob": output}) - output = self.ulysses_sharding_manager.postprocess_data(output) output = output.to("cpu") # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # unshard the root FSDP module - if self.world_size > 1 and fsdp_version(self.ref_policy.actor_module) == 1: - self.ref_policy.actor_module._handle.reshard(True) + if self.world_size > 1: + if fsdp_version(self.ref_policy.actor_module) == 1: + self.ref_policy.actor_module._handle.reshard(True) + elif fsdp_version(self.ref_policy.actor_module) == 2: + self.ref_policy.actor_module.reshard() return output @@ -896,46 +1093,89 @@ def save_checkpoint( @register(dispatch_mode=Dispatch.ONE_TO_ALL) def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False): - if self._is_actor and self._is_offload_param: + assert self._is_actor or (not self._is_actor and self._is_rollout), ( + f"Checkpoint loading is only supported for Actor or standalone Rollout Workers, but got " + f"{self._is_actor} and {self._is_rollout}" + ) + + # No checkpoint to load, just offload the model and optimizer to CPU + if local_path is None: + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.actor_module_fsdp) + if self._is_offload_optimizer: + offload_fsdp_optimizer(self.actor_optimizer) + return + + if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) self.checkpoint_manager.load_checkpoint( local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load ) - if self._is_actor and self._is_offload_param: + if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) - if self._is_actor and self._is_offload_optimizer: + if self._is_offload_optimizer: offload_fsdp_optimizer(self.actor_optimizer) @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def clear_optimizer_state(self): - print("Clear actor optimizer state") - if self._is_offload_optimizer: - load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=get_device_id()) - self.actor_optimizer.state.clear() - self.actor_optimizer.zero_grad() - if self._is_offload_optimizer: - offload_fsdp_optimizer(self.actor_optimizer) + def start_profile(self, **kwargs) -> None: + """Start profiling for the current rank in the current training step.""" + self.profiler.start(**kwargs) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def stop_profile(self) -> None: + """Stop profiling for the current rank in the current training step.""" + self.profiler.stop() + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def dump_memory_snapshot(self, tag: str = "manual", sub_dir: str = None) -> None: + """Manually trigger a CUDA memory snapshot dump on all ranks.""" + # Memory snapshot is now handled by the profiler system + # This method is kept for backward compatibility but delegates to profiler + if hasattr(self, "profiler") and hasattr(self.profiler, "_impl"): + try: + # Try to use the profiler's memory snapshot functionality + if hasattr(self.profiler._impl, "sampler"): + out_dir = OmegaConf.select(self.config, "actor.profiler.save_path") or "." + self.profiler._impl.sampler.dump_memory_snapshot( + out_dir=out_dir, tag=tag, sub_dir=sub_dir + ) + except Exception: + # silently ignore if profiler doesn't support memory snapshots + pass @register(dispatch_mode=Dispatch.ONE_TO_ALL) def wait_on_save_thread(self) -> None: self.checkpoint_manager.wait_on_save_thread() -class CriticWorker(Worker): - def __init__(self, config): - super().__init__() +class CriticWorker(Worker, DistProfilerExtension): + def __init__(self, config: FSDPCriticConfig): + Worker.__init__(self) + omega_profiler_config = config.get("profiler", {}) + profiler_config = omega_conf_to_dataclass( + omega_profiler_config, dataclass_type=ProfilerConfig + ) + if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]: + tool_config = omega_conf_to_dataclass( + omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool")) + ) + else: + tool_config = None + DistProfilerExtension.__init__( + self, DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config) + ) import torch.distributed if not torch.distributed.is_initialized(): torch.distributed.init_process_group( backend=get_nccl_backend(), + timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), init_method=os.environ.get("DIST_INIT_METHOD", None), - timeout=timedelta(seconds=self.config.synchronizer.sync_timeout), ) - self.config = config + self.config: FSDPCriticConfig = config # build device mesh for Ulysses Sequence Parallel world_size = torch.distributed.get_world_size() @@ -954,6 +1194,17 @@ def __init__(self, config): mesh_dim_names=["dp", "sp"], ) + # create training dispatch + if self.ulysses_device_mesh is not None: + is_collect = self.ulysses_device_mesh["sp"].get_local_rank() == 0 + self._register_dispatch_collect_info( + "critic", + dp_rank=self.ulysses_device_mesh["dp"].get_local_rank(), + is_collect=is_collect, + ) + else: + self._register_dispatch_collect_info("critic", dp_rank=self.rank, is_collect=True) + self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) # set FSDP offload params @@ -978,15 +1229,24 @@ def __init__(self, config): if not self.config.use_dynamic_bsz and self.config.ppo_micro_batch_size_per_gpu is not None: assert ( self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0 - ), f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" + ), ( + f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by " + f"ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" + ) assert ( self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu > 0 - ), f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" - self._is_lora = self.config.model.get("lora_rank", 0) > 0 + ), ( + f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be larger than " + f"ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" + ) + self._is_lora = ( + self.config.model.get("lora_adapter_path") is not None + or self.config.model.get("lora_rank", 0) > 0 + ) + self.use_orig_params = self.config.model.fsdp_config.get("use_orig_params", False) def _build_critic_model_optimizer(self, config): # noqa: C901 # the following line is necessary - from torch import optim from torch.distributed.fsdp import MixedPrecision from verl.utils.model import load_valuehead_model, print_model_size from verl.utils.torch_dtypes import PrecisionType @@ -1013,7 +1273,7 @@ def _build_critic_model_optimizer(self, config): # noqa: C901 from omegaconf import OmegaConf override_config = OmegaConf.to_container( - self.config.model.get("override_config", OmegaConf.create()) + OmegaConf.create(self.config.model.get("override_config", {})) ) override_config_kwargs = { "bos_token_id": self.tokenizer.bos_token_id, @@ -1029,11 +1289,21 @@ def _build_critic_model_optimizer(self, config): # noqa: C901 from transformers import AutoConfig + # override model kwargs + attn_implementation = override_config.get("attn_implementation", "flash_attention_2") critic_model_config = AutoConfig.from_pretrained( local_path, - attn_implementation="flash_attention_2", + attn_implementation=attn_implementation, trust_remote_code=config.model.get("trust_remote_code", False), ) + # TODO: VL models use VisionAttention, which directly uses flash_attention in transformers>=4.53 + # which will be patched by _ulysses_flash_attention_forward, but errorly misses position_ids + # Maybe support Ulysses in VisionAttention in the future and remove this patch + if self.ulysses_sequence_parallel_size > 1 and hasattr( + critic_model_config, "vision_config" + ): + critic_model_config.vision_config._attn_implementation = "eager" + critic_model_config.num_labels = 1 # patch for kimi-vl if getattr(critic_model_config, "model_type", None) == "kimi_vl": @@ -1043,6 +1313,15 @@ def _build_critic_model_optimizer(self, config): # noqa: C901 use_meta_tensor=not critic_model_config.tie_word_embeddings, mesh=self.device_mesh ) + # TiledMLP configuration for memory-efficient MLP computation + tiled_mlp_config = config.model.get("tiled_mlp", {}) + use_tiled_mlp = tiled_mlp_config.get("enabled", False) + tiled_mlp_shards = tiled_mlp_config.get("num_shards", 4) + + # TiledMLP requires FSDP2 for correct gradient computation + if use_tiled_mlp and config.strategy == "fsdp": + raise ValueError("TiledMLP requires FSDP2. Set `critic.strategy=fsdp2`.") + with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") critic_model_config.classifier_dropout = 0.0 @@ -1062,6 +1341,8 @@ def _build_critic_model_optimizer(self, config): # noqa: C901 model=critic_module, use_remove_padding=use_remove_padding, ulysses_sp_size=self.ulysses_sequence_parallel_size, + use_tiled_mlp=use_tiled_mlp, + tiled_mlp_shards=tiled_mlp_shards, ) if config.model.get("enable_gradient_checkpointing", False): @@ -1072,15 +1353,39 @@ def _build_critic_model_optimizer(self, config): # noqa: C901 if self._is_lora: print("Applying LoRA to critic module") critic_module.enable_input_require_grads() - # Convert config to regular Python types before creating PEFT model - lora_config = { - "task_type": TaskType.CAUSAL_LM, - "r": self.config.model.lora_rank, - "lora_alpha": self.config.model.lora_alpha, - "target_modules": convert_to_regular_types(self.config.model.target_modules), - "bias": "none", - } - critic_module = get_peft_model(critic_module, LoraConfig(**lora_config)) + + # Check if we should load a pre-trained LoRA adapter + lora_adapter_path = self.config.model.get("lora_adapter_path") + if lora_adapter_path is not None: + from peft import PeftModel + + print(f"Loading pre-trained LoRA adapter to critic from: {lora_adapter_path}") + + # Copy adapter to local if needed + local_adapter_path = copy_to_local( + lora_adapter_path, use_shm=self.config.model.get("use_shm", False) + ) + + critic_module = PeftModel.from_pretrained( + critic_module, local_adapter_path, is_trainable=True + ) + peft_config = critic_module.peft_config["default"] + # Ensure task_type is TaskType enum, not string + # Use TOKEN_CLS for Critic since it's loaded as AutoModelForTokenClassification + if isinstance(peft_config.task_type, str): + peft_config.task_type = TaskType.TOKEN_CLS + + else: + # Convert config to regular Python types before creating PEFT model + # Use TOKEN_CLS for Critic since it's loaded as AutoModelForTokenClassification + lora_config = { + "task_type": TaskType.TOKEN_CLS, + "r": self.config.model.lora_rank, + "lora_alpha": self.config.model.lora_alpha, + "target_modules": convert_to_regular_types(self.config.model.target_modules), + "bias": "none", + } + critic_module = get_peft_model(critic_module, LoraConfig(**lora_config)) if self.rank == 0: print_model_size(critic_module) @@ -1109,7 +1414,7 @@ def _build_critic_model_optimizer(self, config): # noqa: C901 auto_wrap_policy = get_fsdp_wrap_policy( module=critic_module, config=self.config.model.fsdp_config.wrap_policy, - is_lora=self.config.model.get("lora_rank", 0) > 0, + is_lora=self._is_lora, ) log_gpu_memory_usage("Before critic FSDP", logger=None) @@ -1117,12 +1422,24 @@ def _build_critic_model_optimizer(self, config): # noqa: C901 fsdp_mesh = self.device_mesh sharding_strategy = get_sharding_strategy(fsdp_mesh) + self.use_orig_params = fsdp_config.get("use_orig_params", False) + if self.config.model.get("freeze_vision_tower", False): + vision_tower = get_vl_model_vision_tower(critic_module) + if vision_tower is not None: + vision_tower.requires_grad_(False) + self.use_orig_params = True + if self.rank == 0: + print("[critic model] Vision tower is set to not trainable.") + else: + if self.rank == 0: + print("[critic model] No vision tower found.") + # Note: We force turn off CPUOffload for critic because it causes incorrect results when using grad accumulation if config.strategy == "fsdp": critic_module = FSDP( critic_module, param_init_fn=init_fn, - use_orig_params=False, + use_orig_params=self.use_orig_params, auto_wrap_policy=auto_wrap_policy, device_id=get_device_id(), sharding_strategy=sharding_strategy, @@ -1150,6 +1467,7 @@ def _build_critic_model_optimizer(self, config): # noqa: C901 "mp_policy": mp_policy, "offload_policy": offload_policy, "reshard_after_forward": fsdp_config.reshard_after_forward, + "shard_placement_fn": get_shard_placement_fn(fsdp_size=self.device_mesh.shape[-1]), } full_state = critic_module.state_dict() apply_fsdp2(critic_module, fsdp_kwargs, fsdp_config) @@ -1165,40 +1483,40 @@ def _build_critic_model_optimizer(self, config): # noqa: C901 log_gpu_memory_usage("After critic FSDP", logger=None) - critic_optimizer = optim.AdamW( - critic_module.parameters(), - lr=config.optim.lr, - betas=config.optim.get("betas", (0.9, 0.999)), - weight_decay=config.optim.get("weight_decay", 1e-2), - ) + critic_optimizer = build_optimizer(critic_module.parameters(), config.optim) total_steps = config.optim.get("total_training_steps", 0) num_warmup_steps = int(config.optim.get("lr_warmup_steps", -1)) - warmup_style = config.optim.get("warmup_style", "constant") + + lr_scheduler_type = config.optim.get("lr_scheduler_type", "constant") if num_warmup_steps < 0: num_warmup_steps_ratio = config.optim.get("lr_warmup_steps_ratio", 0.0) num_warmup_steps = int(num_warmup_steps_ratio * total_steps) if self.rank == 0: - print(f"num_warmup_steps: {num_warmup_steps}") + print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") from verl.utils.torch_functional import ( get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup, ) - if warmup_style == "constant": + if lr_scheduler_type == "constant": critic_lr_scheduler = get_constant_schedule_with_warmup( optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps ) - elif warmup_style == "cosine": + elif lr_scheduler_type == "cosine": + min_lr_ratio = config.optim.get("min_lr_ratio", 0.0) + num_cycles = config.optim.get("num_cycles", 0.5) critic_lr_scheduler = get_cosine_schedule_with_warmup( optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps, + min_lr_ratio=min_lr_ratio, + num_cycles=num_cycles, ) else: - raise NotImplementedError(f"Warmup style {warmup_style} is not supported") + raise NotImplementedError(f"LR scheduler type {lr_scheduler_type} is not supported") return critic_module, critic_optimizer, critic_lr_scheduler @@ -1238,11 +1556,9 @@ def init_model(self): ray_namespace=self.config.ray_namespace, ) - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic")) + @DistProfiler.annotate(color="cyan", role="compute_values") def compute_values(self, data: DataProto): - # Support all hardwares - data = data.to(get_device_id()) - if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) micro_batch_size = self.config.forward_micro_batch_size_per_gpu @@ -1251,20 +1567,20 @@ def compute_values(self, data: DataProto): data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz # perform forward computation with self.ulysses_sharding_manager: - data = self.ulysses_sharding_manager.preprocess_data(data=data) + data = data.to( + "cpu" + ) # data will to device with each micro batch on critic.compute_values values = self.critic.compute_values(data=data) output = DataProto.from_dict(tensors={"values": values}) - output = self.ulysses_sharding_manager.postprocess_data(data=output) output = output.to("cpu") if self._is_offload_param: offload_fsdp_model_to_cpu(self.critic_module) return output - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic")) + @DistProfiler.annotate(color="pink", role="critic_update") def update_critic(self, data: DataProto): - # Support all hardwares - data = data.to(get_device_id()) if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) if self._is_offload_optimizer: @@ -1272,8 +1588,9 @@ def update_critic(self, data: DataProto): # perform forward computation with self.ulysses_sharding_manager: - data = self.ulysses_sharding_manager.preprocess_data(data=data) - + data = data.to( + "cpu" + ) # data will to device with each micro batch on critic.update_critic with Timer(name="update_critic", logger=None) as timer: metrics = self.critic.update_critic(data=data) delta_time = timer.last @@ -1286,12 +1603,11 @@ def update_critic(self, data: DataProto): estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size ) - self.critic_lr_scheduler.step() lr = self.critic_lr_scheduler.get_last_lr()[0] metrics["critic/lr"] = lr + self.critic_lr_scheduler.step() output = DataProto(batch=None, meta_info={"metrics": metrics}) - output = self.ulysses_sharding_manager.postprocess_data(data=output) if self._is_offload_param: offload_fsdp_model_to_cpu(self.critic_module) @@ -1343,16 +1659,6 @@ def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True) if self._is_offload_optimizer: offload_fsdp_optimizer(self.critic_optimizer) - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def clear_optimizer_state(self): - print("Clear critic optimizer state") - if self._is_offload_optimizer: - load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=get_device_id()) - self.critic_optimizer.state.clear() - self.critic_optimizer.zero_grad() - if self._is_offload_optimizer: - offload_fsdp_optimizer(self.critic_optimizer) - @register(dispatch_mode=Dispatch.ONE_TO_ALL) def wait_on_save_thread(self) -> None: self.checkpoint_manager.wait_on_save_thread() diff --git a/trinity/trainer/verl/megatron_actor.py b/trinity/trainer/verl/megatron_actor.py index 105cc5456c..b771b9e901 100644 --- a/trinity/trainer/verl/megatron_actor.py +++ b/trinity/trainer/verl/megatron_actor.py @@ -18,7 +18,7 @@ Note that our model doesn't have to be `MegatronModule` because we don't share embedding in the last layer -Modified from https://github.com/volcengine/verl/blob/v0.5.0/verl/workers/actor/megatron_actor.py +Modified from https://github.com/volcengine/verl/blob/v0.7.0/verl/workers/actor/megatron_actor.py """ from functools import partial @@ -31,10 +31,18 @@ from verl import DataProto from verl.utils.device import get_device_id, get_torch_device from verl.utils.megatron.pipeline_parallel import make_batch_generator +from verl.utils.megatron.router_replay_patch import RouterReplay, RouterReplayAction +from verl.utils.megatron.router_replay_utils import ( + RouterReplayHelper, + merge_router_topk_indices, + reorder_and_merge_vpp_layers, + set_router_replay_data, +) from verl.utils.megatron.tensor_parallel import ( vocab_parallel_entropy, vocab_parallel_log_probs_from_logits, ) +from verl.utils.megatron_utils import unwrap_model from verl.utils.profiler import GPUMemoryLogger from verl.utils.py_functional import append_to_dict from verl.utils.seqlen_balancing import rearrange_micro_batches @@ -87,12 +95,15 @@ def forward_backward_batch( # noqa: C901 """ # broadcast from last pp rank to all other pp ranks # TODO: actually, we just need to control the sampling order. + data.to(get_device_id()) + data.batch = data.batch.contiguous() mini_batch = data broadcast_dict_tensor( mini_batch.batch, src=mpu.get_pipeline_model_parallel_last_rank(), group=mpu.get_pipeline_model_parallel_group(), ) + mini_batch.to("cpu") # split into micro-batches mini_batch.batch["attention_mask"] = mini_batch.batch["attention_mask"].to(bool) self.has_multi_modal_inputs = "multi_modal_inputs" in mini_batch.non_tensor_batch.keys() @@ -110,6 +121,7 @@ def forward_backward_batch( # noqa: C901 ] # mcore patch recompute qwen2vl's pos ids during forward indices = None + temperature = data.meta_info["temperature"] if use_dynamic_bsz: assert ( max_token_len is not None @@ -150,7 +162,17 @@ def forward_backward_batch( # noqa: C901 def loss_func(output, data, meta_info): # For memory efficiency # We move calculation of entropy to compute_log_probs, forward_only == True - device = output["log_probs"].device + log_probs = None + entropy = None + if isinstance(output, dict): + log_probs = output["log_probs"] + if "entropy" in output: + entropy = output["entropy"] + else: + assert isinstance(output, torch.Tensor) + log_probs = output + + device = log_probs.device metrics = {} if forward_only: if post_process_fn is None: @@ -167,14 +189,34 @@ def loss_func(output, data, meta_info): response_mask = data["response_mask"].to(bool) # compute policy loss - log_prob = output["log_probs"][:, -response_length - 1 : -1].contiguous() + log_prob = log_probs[:, -response_length - 1 : -1].contiguous() ret_entropy = None stats = {} if not forward_only: + loss_mode = self.config.policy_loss.get("loss_mode", "vanilla") + pg_loss, pg_loss_metrics = self.policy_loss_fn( # type: ignore logprob=log_prob, **data ) prefix_metrics(src_metrics=pg_loss_metrics, prefix="actor", dst_metrics=stats) + + # TODO: to be check + # Skip if using bypass_mode loss (metrics already computed in pg_metrics) + rollout_log_prob = data.get("rollout_log_probs", None) + if loss_mode != "bypass_mode" and rollout_log_prob is not None: + # Compute metrics using CURRENT policy π_θ vs π_rollout + # Tracks evolving off-policy gap as π_θ updates during mini-batch training + from verl.trainer.ppo.rollout_corr_helper import ( + compute_rollout_corr_metrics_from_logprobs, + ) + + rollout_corr_metrics = compute_rollout_corr_metrics_from_logprobs( + log_prob=log_prob, + rollout_log_prob=rollout_log_prob, + response_mask=response_mask, + ) + stats.update(rollout_corr_metrics) + policy_loss = pg_loss if calculate_entropy: @@ -220,20 +262,54 @@ def loss_func(output, data, meta_info): append_to_dict(metrics, stats) return policy_loss, [metrics, ret_entropy] - def forward_step(batch_iter, model): + def forward_step(batch_iter, model, return_schedule_plan: bool = False): + """ + Args: + batch_iter: the batch iterator + model: the model + return_schedule_plan: whether to return the schedule plan, for 1f1b overlap + """ + if return_schedule_plan: + assert ( + self.tf_config.overlap_moe_expert_parallel_comm + ), "overlap_moe_expert_parallel_comm must be enabled to return the schedule plan" + # TODO: Fix this + assert ( + not calculate_entropy + ), "calculate_entropy must be disabled to return the schedule plan" + from megatron.core.models.gpt.gpt_model import GPTModel + + assert isinstance(model, GPTModel), "model must be a GPTModel" + assert ( + self.use_fused_kernels + ), "use_fused_kernels must be enabled to return the schedule plan" + # TODO: support VLM with MoE + from verl.models.mcore.model_forward_1f1b_overlap import ( + gptmodel_forward_1f1b_overlap, + ) + batch = next(batch_iter) + batch = batch.to(get_device_id()) + batch = batch.contiguous() + input_ids = batch["input_ids"] attention_mask = batch["attention_mask"].to(bool) position_ids = batch["position_ids"] + unwrapped_model = unwrap_model(model) + if hasattr(unwrapped_model, "vp_stage"): + vp_rank = unwrapped_model.vp_stage + else: + vp_rank = 0 + multi_modal_inputs = {} if "multi_modal_inputs" in batch: - for key in batch["multi_modal_inputs"][0].keys(): - idxs = batch["multi_modal_inputs_idx"] - mmi = batch["multi_modal_inputs"] - multi_modal_inputs[key] = torch.cat( - [mmi[idx].get(key) for idx in idxs if mmi[idx].get(key) is not None], dim=0 - ) + from verl.utils.model import extract_multi_modal_inputs + + indices = batch.get("multi_modal_inputs_idx", None) + multi_modal_inputs = extract_multi_modal_inputs( + batch["multi_modal_inputs"], indices + ) responses = batch["responses"] response_length = responses.size(1) label = position_ids.clone() @@ -242,6 +318,17 @@ def forward_step(batch_iter, model): label_mask[:, : -response_length - 1] = False label_mask[:, -1] = False + if RouterReplayHelper.is_replay_backward_action(self.tf_config, vp_rank): + router_instance_list = RouterReplayHelper.get_micro_batch_router_list( + self.tf_config, vp_rank + ) + for router in router_instance_list: + router.set_router_replay_action(RouterReplayAction.REPLAY_FORWARD) + + if RouterReplayHelper.is_replay_forward_action(self.tf_config, vp_rank): + layers_topk_idx = batch["routed_experts"] + set_router_replay_data(layers_topk_idx, attention_mask, self.tf_config, vp_rank) + from verl.models.mcore import ( get_mcore_forward_fn, get_mcore_forward_fused_fn, @@ -249,16 +336,18 @@ def forward_step(batch_iter, model): if self.use_fused_kernels: forward_fn = get_mcore_forward_fused_fn(self.hf_config) + if return_schedule_plan: + forward_fn = gptmodel_forward_1f1b_overlap # return dict of [logits, entropy] output = forward_fn( - model, - input_ids, - position_ids, - attention_mask, - sequence_parallel=self.tf_config.sequence_parallel, - multi_modal_inputs=multi_modal_inputs, + model=model, + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, labels=label, labels_mask=label_mask, + temperature=temperature, + multi_modal_inputs=multi_modal_inputs, ) else: forward_fn = get_mcore_forward_fn(self.hf_config) @@ -266,25 +355,35 @@ def forward_step(batch_iter, model): def logits_processor(logits, label, label_mask): assert logits.shape[:2] == label.shape[:2] assert label.shape == label_mask.shape + logits.div_(temperature) ret = {} if calculate_entropy: + logits_bak = logits.clone() + # # disable the hint until the fused_kernel is optimized for triton>=3.3 + # logger.warning_once( + # "For memory-efficient computation, enable fused kernels via " + # "`actor_rollout_ref.model.use_fused_kernels=True`. " + # "The current `clone()` operation ensures correctness but increases memory usage." + # ) entropy = vocab_parallel_entropy(logits) ret["entropy"] = entropy - log_probs = vocab_parallel_log_probs_from_logits(logits, label) + else: + logits_bak = logits + log_probs = vocab_parallel_log_probs_from_logits(logits_bak, label) log_probs = log_probs.masked_fill(~label_mask, 0.0) ret["log_probs"] = log_probs return ret logits_processor_args = {"label": label, "label_mask": label_mask} output = forward_fn( - model, - input_ids, - attention_mask, - position_ids, - sequence_parallel=self.tf_config.sequence_parallel, + model=model, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, multi_modal_inputs=multi_modal_inputs, logits_processor=logits_processor, logits_processor_args=logits_processor_args, + data_format="thd" if self.config.megatron.use_remove_padding else "bshd", ) if forward_only: @@ -296,6 +395,23 @@ def logits_processor(logits, label, label_mask): "entropy_coeff": self.config.entropy_coeff, "clip_ratio_c": clip_ratio_c, } + + if RouterReplayHelper.is_r2_record_action(self.tf_config, vp_rank): + merge_router_topk_indices( + attention_mask, + input_ids, + self.mini_layer_topk_idx_list, # type: ignore + self.tf_config, + vp_rank, + ) + + if RouterReplayHelper.is_replay_forward_action(self.tf_config, vp_rank): + router_instance_list = RouterReplayHelper.get_micro_batch_router_list( + self.tf_config, vp_rank + ) + for router in router_instance_list: + router.set_router_replay_action(RouterReplayAction.REPLAY_BACKWARD) + return output, partial(loss_func, data=batch, meta_info=meta_info) # batch should be a list of batches inside micro-batches @@ -333,6 +449,23 @@ def logits_processor(logits, label, label_mask): losses_reduced = {"output": losses_reduced} if use_dynamic_bsz: losses_reduced["indices"] = indices + if RouterReplayHelper.is_r2_record_action(self.tf_config): + if self.tf_config.virtual_pipeline_model_parallel_size is not None: + # config = self.actor_module[0].module.module.config + vp_size = len(self.actor_module) + microbatch_group_size_per_vp_stage = ( + self.tf_config.microbatch_group_size_per_vp_stage + ) + bs = n_micro_batch + losses_reduced["mini_layer_topk_idx_tensor"] = reorder_and_merge_vpp_layers( + self.mini_layer_topk_idx_list, bs, vp_size, microbatch_group_size_per_vp_stage # type: ignore + ) + else: + losses_reduced["mini_layer_topk_idx_tensor"] = torch.cat( + self.mini_layer_topk_idx_list, dim=0 # type: ignore + ) + self.mini_layer_topk_idx_list = [] + return losses_reduced @GPUMemoryLogger(role="megatron actor", logger=logger) @@ -349,9 +482,11 @@ def update_policy(self, dataloader: Iterable[DataProto]) -> dict: """ metrics = {} - self.prof.start() + if self.use_torch_profiler and self.prof and self.prof.enable: + self.prof.start() for data in dataloader: - data.to(get_device_id()) + if self.config.router_replay.mode in ["R2", "R3"]: + RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) self.actor_optimizer.zero_grad() # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm for chunk in self.actor_module: @@ -393,10 +528,17 @@ def update_policy(self, dataloader: Iterable[DataProto]) -> dict: pass else: raise NotImplementedError - self.prof.step() + if self.use_torch_profiler and self.prof and self.prof.enable: + self.prof.step() + + if self.config.router_replay.mode in ["R2", "R3"]: + RouterReplay.clear_global_router_replay_action() + RouterReplay.clear_global_indices() + # add empty cache after each compute - self.prof.stop_and_save() - self.prof.stop_trace() + if self.use_torch_profiler and self.prof and self.prof.enable: + self.prof.stop_and_save() + self.prof.stop_trace() get_torch_device().empty_cache() return metrics diff --git a/trinity/trainer/verl/megatron_checkpoint_manager.py b/trinity/trainer/verl/megatron_checkpoint_manager.py index 4fadcae477..0cec19db00 100644 --- a/trinity/trainer/verl/megatron_checkpoint_manager.py +++ b/trinity/trainer/verl/megatron_checkpoint_manager.py @@ -13,10 +13,11 @@ # limitations under the License. """ Megatron Checkpoint Manager. -Modified from https://github.com/volcengine/verl/blob/v0.5.0/verl/utils/checkpoint/megatron_checkpoint_manager.py +Modified from https://github.com/volcengine/verl/blob/v0.7.0/verl/utils/checkpoint/megatron_checkpoint_manager.py """ import json +import os from collections.abc import Callable from dataclasses import asdict @@ -44,7 +45,7 @@ class MegatronCheckpointManager(OldMegatronCheckpointManager): """ - An enhanced version of the original FSDP checkpoint manager that: + An enhanced version of the original Megatron checkpoint manager that: 1. Uploads model state dicts to a remote Synchronizer actor (either directly or via checkpoints). """ @@ -86,9 +87,13 @@ def _save_state_dict(self, local_path, global_step) -> bool: dist_checkpoint_path = get_dist_checkpoint_path(local_path) hf_ckpt_path = get_hf_model_checkpoint_path(local_path) + # Note that model weights, optimizer states, and extra states are generated + # together in a state dict, we save them in one time if self.use_dist_checkpointing: # Generate state dict for saving - state_dict = self.generate_state_dict() + state_dict = self.generate_state_dict( + self.should_save_model, self.should_save_optimizer, self.should_save_extra + ) # log_with_rank(f"Generated state dict for saving: {state_dict.keys()}", rank=self.rank, logger=logger) # for vpp_rank, model in enumerate(self.model): # if len(self.model) > 1: @@ -105,10 +110,6 @@ def _save_state_dict(self, local_path, global_step) -> bool: async_save=self.checkpoint_config.async_save, ) - if self.rank == 0: - # Save huggingface config - self.hf_config.save_pretrained(hf_ckpt_path) - # Synchronize all async save requests if not self.checkpoint_config.async_save: assert ( @@ -118,27 +119,66 @@ def _save_state_dict(self, local_path, global_step) -> bool: else: assert ( self.use_hf_checkpoint - ), "use_hf_checkpoint should be True when not using dist checkpointing" - log_with_rank( - f"Saving HF model checkpoint to {local_path} with bridge", - rank=self.rank, - logger=logger, + ), "When not using distributed checkpointing, use_hf_checkpoint should be True." + # Generate optimizer and exra state dicts + state_dict = self.generate_state_dict( + generate_model=False, + generate_optimizer=self.should_save_optimizer, + generate_extra=self.should_save_extra, ) - self.bridge.save_weights(self.model, hf_ckpt_path) - log_with_rank( - f"Saved bridge checkpoint to {hf_ckpt_path}", rank=self.rank, logger=logger + # Save optimizer and extra states to local path + # Start Async save if enabled + async_save_request = save_dist_checkpointing( + sharded_state_dict=state_dict, + ckpt_path=dist_checkpoint_path, + async_save=self.checkpoint_config.async_save, ) - if self.rank == 0: - if hasattr(self.hf_config, "name_or_path") and self.hf_config.name_or_path: - try: - generation_config = GenerationConfig.from_pretrained( - self.hf_config.name_or_path + # Synchronize all async save requests + if not self.checkpoint_config.async_save: + assert ( + async_save_request is None + ), "Async save request should be None when not using async save." + torch.distributed.barrier() + + if self.should_save_model: + # Save adapter-only checkpoint if PEFT is enabled + if self.peft_cls is not None: + from verl.utils.megatron_peft_utils import save_adapter_checkpoint + + adapter_ckpt_path = os.path.join(local_path, "adapter_checkpoint") + + # Save adapter weights only (much smaller than full model) + save_adapter_checkpoint( + self.model, + adapter_ckpt_path, + self.rank, + ) + + log_with_rank( + f"Saved adapter-only checkpoint to {adapter_ckpt_path}", + rank=self.rank, + logger=logger, + log_only_rank_0=True, + ) + if self.use_hf_checkpoint: + # Use mbridge to save HF model checkpoint + log_with_rank( + f"Saving HF model checkpoint to {local_path} with bridge", + rank=self.rank, + logger=logger, + ) + hf_ckpt_path = get_hf_model_checkpoint_path(local_path) + if self.vanilla_bridge: + self.bridge.save_weights( + self.model, hf_ckpt_path, distributed_filesystem=True, memory_efficient=True ) - generation_config.save_pretrained(hf_ckpt_path) - except Exception: - # if the generation config isn't available, we don't save it - pass + else: + self.bridge.save_hf_weights(self.model, hf_ckpt_path) + + log_with_rank( + f"Saved bridge checkpoint to {hf_ckpt_path}", rank=self.rank, logger=logger + ) def finalize_save_fn(): # Rank 0 uploads checkpoint to HDFS if hdfs_path is provided @@ -158,6 +198,9 @@ def finalize_save_fn(): async_save_request is not None ), "Async save request should not be None when using async save." async_save_request.add_finalize_fn(finalize_save_fn) + from megatron.core.dist_checkpointing.strategies.base import async_calls + + async_calls.schedule_async_request(async_save_request) else: finalize_save_fn() @@ -179,18 +222,31 @@ def _save_tokenizer(self, local_path, global_step) -> bool: if self.latest_tokenizer_save_step == global_step: return False - # Only rank 0 saves the hf config and tokenizer to huggingface path - # No matter whether we save hf model or not - if self.rank == 0: - # Save tokenizer - hf_config_tokenizer_path = get_hf_model_checkpoint_path(local_path) - self.processing_class.save_pretrained(hf_config_tokenizer_path) - log_with_rank( - f"Saved Huggingface tokenizer to {hf_config_tokenizer_path}", - rank=self.rank, - logger=logger, - log_only_rank_0=True, - ) + if self.should_save_model: + # Only rank 0 saves the hf config and tokenizer to huggingface path + # No matter whether we save hf model or not + if self.rank == 0: + # Save tokenizer + hf_config_tokenizer_path = get_hf_model_checkpoint_path(local_path) + if self.processing_class is not None: + self.processing_class.save_pretrained(hf_config_tokenizer_path) + # Save huggingface config + self.hf_config.save_pretrained(hf_config_tokenizer_path) + if hasattr(self.hf_config, "name_or_path") and self.hf_config.name_or_path: + try: + generation_config = GenerationConfig.from_pretrained( + self.hf_config.name_or_path + ) + generation_config.save_pretrained(hf_config_tokenizer_path) + except Exception: + # if the generation config isn't available, we don't save it + pass + log_with_rank( + f"Saved Huggingface config and tokenizer to {hf_config_tokenizer_path}", + rank=self.rank, + logger=logger, + log_only_rank_0=True, + ) self.latest_tokenizer_save_step = global_step return self.rank == 0 @@ -212,10 +268,23 @@ def _save_extra_state(self, local_path, global_step) -> bool: if self.rank == 0: # Save transformer config - log_with_rank( - f"Transformer config: {self.transformer_config}", rank=self.rank, logger=logger - ) + print(self.transformer_config) + bypass_keys = [ + "finalize_model_grads_func", + "grad_scale_func", + "no_sync_func", + "grad_sync_func", + "param_sync_func", + "generation_config", + ] + backup = {} + for k in bypass_keys: + if hasattr(self.transformer_config, k): + backup[k] = getattr(self.transformer_config, k, None) + delattr(self.transformer_config, k) transformer_config_dict = asdict(self.transformer_config) + for k in backup: + setattr(self.transformer_config, k, backup[k]) to_convert_types = {torch.dtype: str, AttnBackend: str} ignore_types = [Callable] pop_keys = [] @@ -251,45 +320,56 @@ def _save_hf_model(self, local_path, global_step) -> bool: try: # wait for everyone to dump to local - state_dict = self.weight_saver( - self.model, - self.hf_config, - dtype=self.param_dtype, - is_value_model=self.is_value_model, - tie_word_embeddings=self.share_embeddings_and_output_weights, - ) - - torch.distributed.barrier() - if self.rank == 0: - # TODO: async save or use mbridge to save hf model + if self.bridge is not None: hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path) - import warnings - - from accelerate import init_empty_weights - - with init_empty_weights(), warnings.catch_warnings(): - warnings.simplefilter("ignore") - if "mistral7b-rm" in self.config.model.path: - from transformers import MistralForSequenceClassification + if self.vanilla_bridge: + self.bridge.save_weights( + self.model, + hf_model_ckpt_path, + distributed_filesystem=True, + memory_efficient=True, + ) + else: + self.bridge.save_hf_weights(self.model, hf_model_ckpt_path) + else: + state_dict = self.weight_saver( + self.model, + self.hf_config, + dtype=self.param_dtype, + is_value_model=self.is_value_model, + tie_word_embeddings=self.share_embeddings_and_output_weights, + ) - model = MistralForSequenceClassification.from_pretrained( - self.config.model.path, torch_dtype=torch.bfloat16 - ) # use score head instead of lm_head - state_dict["score.weight"] = state_dict["score.weight"] - else: - from transformers import AutoModelForCausalLM + torch.distributed.barrier() + if self.rank == 0: + hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path) + import warnings + + from accelerate import init_empty_weights + + with init_empty_weights(), warnings.catch_warnings(): + warnings.simplefilter("ignore") + if "mistral7b-rm" in self.config.model.path: + from transformers import MistralForSequenceClassification + + model = MistralForSequenceClassification.from_pretrained( + self.config.model.path + ) # use score head instead of lm_head + state_dict["score.weight"] = state_dict["score.weight"] + else: + from transformers import AutoModelForCausalLM + + model = AutoModelForCausalLM.from_pretrained( + self.config.model.path, torch_dtype="auto" + ) + model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict) + log_with_rank( + f"Saved Huggingface config and tokenizer to {hf_model_ckpt_path}", + rank=self.rank, + logger=logger, + log_only_rank_0=True, + ) - model = AutoModelForCausalLM.from_pretrained( - self.config.model.path, torch_dtype=torch.bfloat16 - ) - state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()} - model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict) - log_with_rank( - f"Saved Huggingface config and tokenizer to {hf_model_ckpt_path}", - rank=self.rank, - logger=logger, - log_only_rank_0=True, - ) except Exception: logger.error( f"Failed to save Huggingface model to {local_path}, you can try to set `use_mbridge=true` to save it.", @@ -333,7 +413,8 @@ def save_checkpoint( # remove previous local_path if ( - max_ckpt_to_keep + not self.checkpoint_config.async_save + and max_ckpt_to_keep and isinstance(max_ckpt_to_keep, int) and max_ckpt_to_keep > 0 and len(self.previous_saved_paths) >= max_ckpt_to_keep # type: ignore @@ -346,16 +427,15 @@ def save_checkpoint( torch.distributed.barrier() state_dict_thread_count = 0 - if self.should_save_model: - if self._save_state_dict(local_path, global_step): - state_dict_thread_count += 1 + if self._save_state_dict(local_path, global_step): + state_dict_thread_count += 1 self._save_tokenizer(local_path, global_step) if self.should_save_extra: self._save_extra_state(local_path, global_step) - if self.should_save_hf_model or save_as_hf: + if (self.should_save_hf_model or save_as_hf) and not self.use_hf_checkpoint: self._save_hf_model(local_path, global_step) ray.get( diff --git a/trinity/trainer/verl/megatron_workers.py b/trinity/trainer/verl/megatron_workers.py index 0735a477e5..e7ef662fcc 100644 --- a/trinity/trainer/verl/megatron_workers.py +++ b/trinity/trainer/verl/megatron_workers.py @@ -13,7 +13,7 @@ # limitations under the License. """ The main entry point to run the PPO algorithm. -Modified from https://github.com/volcengine/verl/blob/v0.5.0/verl/workers/megatron_workers.py +Modified from https://github.com/volcengine/verl/blob/v0.7.0/verl/workers/megatron_workers.py """ import datetime @@ -28,23 +28,43 @@ from codetiming import Timer from megatron.core import parallel_state as mpu from omegaconf import DictConfig, OmegaConf, open_dict + +try: + from verl.workers.engine.mindspeed.transformer_impl import repatch +except ImportError: + repatch = None from verl import DataProto -from verl.single_controller.base.decorator import Dispatch, register -from verl.single_controller.base.megatron.worker import MegatronWorker +from verl.models.mcore import get_mcore_weight_converter +from verl.single_controller.base import Worker +from verl.single_controller.base.decorator import ( + Dispatch, + make_nd_compute_dataproto_dispatch_fn, + register, +) from verl.utils.config import omega_conf_to_dataclass from verl.utils.device import ( get_device_id, - get_device_name, get_nccl_backend, get_torch_device, + set_expandable_segments, ) +from verl.utils.distributed import set_numa_affinity from verl.utils.flops_counter import FlopsCounter +from verl.utils.fs import copy_to_local +from verl.utils.megatron.router_replay_patch import ( + RouterReplay, + RouterReplayAction, + apply_router_replay_patch, +) from verl.utils.megatron_utils import ( load_megatron_model_to_gpu, load_megatron_optimizer, offload_megatron_model_to_cpu, offload_megatron_optimizer, + per_tensor_generator, + register_megatron_training_hooks, ) +from verl.utils.memory_utils import aggressive_empty_cache from verl.utils.model import ( get_hf_model_path, load_mcore_dist_weights, @@ -54,8 +74,10 @@ DistProfiler, DistProfilerExtension, GPUMemoryLogger, + ProfilerConfig, log_gpu_memory_usage, ) +from verl.workers.config import McoreCriticConfig from verl.workers.critic.megatron_critic import MegatronPPOCritic from verl.workers.megatron_workers import logger, set_random_seed @@ -67,6 +89,145 @@ from trinity.utils.distributed import init_process_group +class MegatronWorker(Worker): + def _init_hf_config_and_tf_config( + self, + model_path, + tokenizer_or_path, + dtype, + override_model_config, + override_transformer_config, + trust_remote_code=False, + megatron_config=None, + ): + from transformers import AutoConfig + from verl.models.mcore import hf_to_mcore_config + from verl.utils import hf_processor, hf_tokenizer + from verl.utils.model import update_model_config + + # Step 1: initialize the tokenizer + self.local_path = copy_to_local(model_path) + if tokenizer_or_path is None: + self.tokenizer = hf_tokenizer(self.local_path, trust_remote_code=trust_remote_code) + self.processor = hf_processor(self.local_path, trust_remote_code=trust_remote_code) + elif isinstance(tokenizer_or_path, str): + self.tokenizer = hf_tokenizer( + copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code + ) + self.processor = hf_processor( + copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code + ) + else: + self.tokenizer = tokenizer_or_path + self.processor = tokenizer_or_path + + if self.config.model.get("custom_chat_template", None) is not None: + if self.processor is not None: + self.processor.chat_template = self.config.model.custom_chat_template + else: + self.tokenizer.chat_template = self.config.model.custom_chat_template + + # Step 2: get the hf + hf_config = AutoConfig.from_pretrained(self.local_path, trust_remote_code=trust_remote_code) + + # Step 3: override the hf config + override_config_kwargs = { + "bos_token_id": self.tokenizer.bos_token_id, + "eos_token_id": self.tokenizer.eos_token_id, + "pad_token_id": self.tokenizer.pad_token_id, + } + override_config_kwargs.update(override_model_config.get("model_config", {})) + + # patch for rope + if self.config.model.rope_scaling is not None: + hf_config.rope_scaling = OmegaConf.to_container(self.config.model.rope_scaling) + if self.config.model.rope_theta is not None: + hf_config.rope_theta = self.config.model.rope_theta + + self.share_embeddings_and_output_weights = getattr(hf_config, "tie_word_embeddings", False) + update_model_config(hf_config, override_config_kwargs=override_config_kwargs) + self.architectures = getattr(hf_config, "architectures", None) + if self.rank == 0: + print(f"Model config after override: {hf_config}") + + from verl.models.mcore.config_converter import mapping_string_to_attn_backend + + # todo: remove this line after mcore adopt mbridge 0.15, now for compatibility + override_transformer_config = mapping_string_to_attn_backend(override_transformer_config) + fp16 = dtype == torch.float16 + bf16 = dtype == torch.bfloat16 + if fp16: + assert megatron_config.use_mbridge, "fp16 mode requires use_mbridge to be True" + + self.provider = None + self.vanilla_bridge = megatron_config.get("vanilla_mbridge", True) + if megatron_config.use_mbridge: + if self.vanilla_bridge: + from verl.models.mcore.mbridge import AutoBridge + + bridge = AutoBridge.from_config(hf_config, dtype=dtype) + bridge.set_extra_args(**override_transformer_config) + tf_config = bridge.config + tf_config.fp16 = fp16 + tf_config.bf16 = bf16 + else: + from verl.models.mcore.bridge import AutoBridge + + # Use Megatron-Bridge to convert HF config to Megatron config + bridge = AutoBridge.from_hf_pretrained( + self.local_path, trust_remote_code=trust_remote_code + ) + # Get Megatron provider and configure it + provider = bridge.to_megatron_provider(load_weights=False) + + # In case of invalid overrides, we need to make sure some critical params are set correctly + provider.params_dtype = dtype + + # Pass distributed info + provider.tensor_model_parallel_size = megatron_config.tensor_model_parallel_size + provider.pipeline_model_parallel_size = megatron_config.pipeline_model_parallel_size + provider.expert_model_parallel_size = megatron_config.expert_model_parallel_size + provider.expert_tensor_parallel_size = megatron_config.expert_tensor_parallel_size + provider.virtual_pipeline_model_parallel_size = ( + megatron_config.virtual_pipeline_model_parallel_size + ) + provider.context_parallel_size = megatron_config.context_parallel_size + provider.sequence_parallel = megatron_config.sequence_parallel + + # Match verl implementation (need variable_seq_lengths) + from megatron.core.transformer.enums import AttnBackend + + provider.attention_backend = AttnBackend.flash + provider.variable_seq_lengths = True + provider.moe_token_dispatcher_type = "alltoall" + provider.moe_router_load_balancing_type = "none" + + # Apply transformer config overrides + for key, value in override_transformer_config.items(): + setattr(provider, key, value) + + provider.finalize() + self.provider = provider + tf_config = None # Will be set after model creation + self.bridge = bridge + else: + tf_config = hf_to_mcore_config(hf_config, dtype, **override_transformer_config) + self.bridge = None + + if torch.distributed.get_rank() == 0: + if tf_config is not None: + print(f"TF config: {tf_config}") + self.hf_config = hf_config + self.tf_config = tf_config + + # Get PEFT config from model.lora if specified + from verl.workers.config.megatron_peft import get_peft_cls + + self.peft_cls = get_peft_cls( + model_config=self.config.model, bridge=self.bridge, provider=self.provider, dtype=dtype + ) + + class ActorRolloutRefWorker(MegatronWorker, DistProfilerExtension): """ This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy @@ -74,8 +235,17 @@ class ActorRolloutRefWorker(MegatronWorker, DistProfilerExtension): """ def __init__(self, config: DictConfig, role: str, **kwargs): - MegatronWorker.__init__(self) + Worker.__init__(self) self.config = config + if repatch is not None: + # NPU MindSpeed patch, will be refactored with MindSpeedEngine. + repatch(self.config.actor.megatron.get("override_transformer_config", {})) + + self.role = role + assert self.role in ["actor", "rollout", "ref", "actor_rollout", "actor_rollout_ref"] + + self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"] + self._is_ref = self.role in ["ref", "actor_rollout_ref"] # NOTE(sgm): We utilize colocate WorkerGroup by default. # As a result, Workers for different model share the same process. @@ -84,36 +254,71 @@ def __init__(self, config: DictConfig, role: str, **kwargs): # 1, users should disable WorkerDict; 2.assign different ResourcePool to different models, # 3. and apply the following patch in ray==2.10, https://github.com/ray-project/ray/pull/44385 if not torch.distributed.is_initialized(): + set_numa_affinity() rank = int(os.environ["LOCAL_RANK"]) torch.distributed.init_process_group( backend=get_nccl_backend(), - timeout=datetime.timedelta(seconds=self.config.synchronizer.sync_timeout), + timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), init_method=os.environ.get("DIST_INIT_METHOD", None), ) get_torch_device().set_device(rank) - mpu.initialize_model_parallel( - tensor_model_parallel_size=self.config.actor.megatron.tensor_model_parallel_size, - pipeline_model_parallel_size=self.config.actor.megatron.pipeline_model_parallel_size, - virtual_pipeline_model_parallel_size=self.config.actor.megatron.virtual_pipeline_model_parallel_size, - pipeline_model_parallel_split_rank=None, - use_sharp=False, - context_parallel_size=self.config.actor.megatron.context_parallel_size, - expert_model_parallel_size=self.config.actor.megatron.expert_model_parallel_size, - expert_tensor_parallel_size=self.config.actor.megatron.expert_tensor_parallel_size, - nccl_communicator_config_path=None, + if self._is_actor or self._is_ref: + mpu.initialize_model_parallel( + tensor_model_parallel_size=self.config.actor.megatron.tensor_model_parallel_size, + pipeline_model_parallel_size=self.config.actor.megatron.pipeline_model_parallel_size, + virtual_pipeline_model_parallel_size=self.config.actor.megatron.virtual_pipeline_model_parallel_size, + use_sharp=False, + context_parallel_size=self.config.actor.megatron.context_parallel_size, + expert_model_parallel_size=self.config.actor.megatron.expert_model_parallel_size, + expert_tensor_parallel_size=self.config.actor.megatron.expert_tensor_parallel_size, + nccl_communicator_config_path=None, + ) + + if self._is_actor or self._is_ref: + is_collect = ( + mpu.get_tensor_model_parallel_rank() == 0 + and mpu.get_pipeline_model_parallel_rank() + == mpu.get_pipeline_model_parallel_world_size() - 1 + and mpu.get_context_parallel_rank() == 0 + ) + self._register_dispatch_collect_info( + mesh_name="actor", dp_rank=mpu.get_data_parallel_rank(), is_collect=is_collect ) - set_random_seed(seed=self.config.actor.megatron.seed) + self.enable_routing_replay = False + if self._is_actor: + self.router_replay = self.config.actor.router_replay + self.enable_routing_replay = self.router_replay.mode != "disabled" - self.role = role - assert self.role in ["actor", "rollout", "ref", "actor_rollout", "actor_rollout_ref"] + if self.enable_routing_replay: + apply_router_replay_patch() - self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"] - self._is_ref = self.role in ["ref", "actor_rollout_ref"] + set_random_seed(seed=self.config.actor.megatron.seed) - profiler_config = omega_conf_to_dataclass(config.get("profiler")) - DistProfilerExtension.__init__(self, DistProfiler(rank=self.rank, config=profiler_config)) + if self._is_actor: + omega_profiler_config = config.actor.get("profiler", {}) + elif self._is_ref: + omega_profiler_config = config.ref.get("profiler", {}) + else: + raise ValueError( + f"Invalid role {self.role}, should be one of " + "['actor', 'rollout', 'ref', 'actor_rollout', 'actor_rollout_ref']" + ) + # omega_profiler_config is DictConfig + # profiler_config is a ProfilerConfig dataclass + profiler_config = omega_conf_to_dataclass( + omega_profiler_config, dataclass_type=ProfilerConfig + ) + if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]: + tool_config = omega_conf_to_dataclass( + omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool")) + ) + else: + tool_config = None + DistProfilerExtension.__init__( + self, DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config) + ) # TODO(sgm): Currently, we only support reference model param offload # will support other offload later @@ -151,157 +356,74 @@ def __init__(self, config: DictConfig, role: str, **kwargs): ) self._ref_is_offload_param = self.config.ref.megatron.get("param_offload", False) - def _init_hf_config_and_tf_config( + def _build_model_optimizer( self, model_path, - tokenizer_or_path, - dtype, + optim_config, override_model_config, override_transformer_config, - trust_remote_code=False, - use_mbridge=False, - ): - from transformers import AutoConfig - from verl.models.mcore import hf_to_mcore_config - from verl.utils import hf_processor, hf_tokenizer - from verl.utils.fs import copy_to_local - from verl.utils.model import update_model_config - - # Step 1: initialize the tokenizer - self.local_path = copy_to_local(model_path) - if tokenizer_or_path is None: - self.tokenizer = hf_tokenizer(self.local_path, trust_remote_code=trust_remote_code) - self.processor = hf_processor(self.local_path, trust_remote_code=trust_remote_code) - elif isinstance(tokenizer_or_path, str): - self.tokenizer = hf_tokenizer( - copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code - ) - self.processor = hf_processor( - copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code - ) - else: - self.tokenizer = tokenizer_or_path - self.processor = tokenizer_or_path - - if self.config.model.get("custom_chat_template", None) is not None: - if self.processor is not None: - self.processor.chat_template = self.config.model.custom_chat_template - else: - self.tokenizer.chat_template = self.config.model.custom_chat_template - - # Step 2: get the hf - hf_config = AutoConfig.from_pretrained(self.local_path, trust_remote_code=trust_remote_code) - - # Step 3: override the hf config - override_config_kwargs = { - "bos_token_id": self.tokenizer.bos_token_id, - "eos_token_id": self.tokenizer.eos_token_id, - "pad_token_id": self.tokenizer.pad_token_id, - } - override_config_kwargs.update(override_model_config.get("model_config", {})) - - # patch for rope - if self.config.model.rope_scaling is not None: - hf_config.rope_scaling = OmegaConf.to_container(self.config.model.rope_scaling) - if self.config.model.rope_theta is not None: - hf_config.rope_theta = self.config.model.rope_theta - - self.share_embeddings_and_output_weights = getattr(hf_config, "tie_word_embeddings", False) - update_model_config(hf_config, override_config_kwargs=override_config_kwargs) - self.architectures = getattr(hf_config, "architectures", None) - if self.rank == 0: - print(f"Model config after override: {hf_config}") - tf_config = hf_to_mcore_config(hf_config, dtype, **override_transformer_config) - - if use_mbridge: - from verl.models.mcore.mbridge import AutoBridge - - bridge = AutoBridge.from_config(hf_config) - bridge.set_extra_args(**override_transformer_config) - tf_config = bridge.config - self.bridge = bridge - else: - self.bridge = None - - print(f"TF config: {tf_config}") - self.hf_config = hf_config - self.tf_config = tf_config - - def _build_model_optimizer( - self, model_path, optim_config, override_model_config, override_transformer_config + override_ddp_config=None, ): from verl.utils.megatron.optimizer import ( get_megatron_optimizer, get_megatron_optimizer_param_scheduler, + init_megatron_optim_config, + ) + from verl.utils.megatron_utils import ( + McoreModuleWrapperConfig, + make_megatron_module, ) - from verl.utils.megatron_utils import get_model, init_megatron_optim_config from verl.utils.model import get_generation_config, print_model_size self._init_hf_config_and_tf_config( model_path, - model_path, + self.config.model.get("tokenizer_path") or model_path, self.dtype, override_model_config, override_transformer_config, self.config.model.get("trust_remote_code", False), - self.config.actor.megatron.use_mbridge, + self.config.actor.megatron if not self._is_ref else self.config.ref.megatron, + ) + self.generation_config = get_generation_config( + self.local_path, + self.config.model.get("trust_remote_code", False), ) - self.generation_config = get_generation_config(self.local_path) - - def make_model(wrap_with_ddp=False): - if self.bridge is not None: - from verl.models.mcore.mbridge import freeze_moe_router - - post_model_creation_callbacks = [] - if override_model_config.get("moe_config", {}).get("freeze_moe_router", False): - post_model_creation_callbacks.append(freeze_moe_router) - return self.bridge.get_model( - post_model_creation_callbacks=post_model_creation_callbacks, - wrap_with_ddp=wrap_with_ddp, - ) - else: - - def megatron_actor_model_provider(pre_process, post_process): - from verl.models.mcore import init_mcore_model - - parallel_model = init_mcore_model( - self.tf_config, - self.hf_config, - pre_process, - post_process, - share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, - value=False, - freeze_moe_router=override_model_config.get("moe_config", {}).get( - "freeze_moe_router", False - ), - ) - parallel_model.to(get_device_name()) - return parallel_model - - override_ddp_config = OmegaConf.to_container( - self.config.actor.megatron.get("override_ddp_config", OmegaConf.create()), - resolve=True, - ) - return get_model( - megatron_actor_model_provider, - wrap_with_ddp=wrap_with_ddp, - use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer, - override_ddp_config=override_ddp_config, - ) if self._is_actor: - actor_module = make_model(wrap_with_ddp=True) + wrap_config = McoreModuleWrapperConfig( + is_value_model=False, # actor is not value model + share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, + wrap_with_ddp=True, + use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer, + ) + actor_module, updated_tf_config = make_megatron_module( + wrap_config=wrap_config, + tf_config=self.tf_config, + hf_config=self.hf_config, + bridge=self.bridge, + provider=self.provider, + override_model_config=override_model_config, + override_ddp_config=override_ddp_config, + peft_cls=self.peft_cls, + peft_config=self.config.model.get("lora", None), + ) + self.tf_config = updated_tf_config + print(f"actor_module: {len(actor_module)}") if self.config.actor.load_weight: if self.config.actor.megatron.use_dist_checkpointing: load_mcore_dist_weights( actor_module, self.config.actor.megatron.dist_checkpointing_path, is_value_model=False, + prefix=self.config.actor.megatron.dist_checkpointing_prefix, ) else: if self.bridge is not None: local_model_path = get_hf_model_path(self.config) - self.bridge.load_weights(actor_module, local_model_path) + if self.vanilla_bridge: + self.bridge.load_weights(actor_module, local_model_path) + else: + self.bridge.load_hf_weights(actor_module, local_model_path) else: load_megatron_gptmodel_weights( self.config, @@ -315,8 +437,21 @@ def megatron_actor_model_provider(pre_process, post_process): print_model_size(actor_module[0]) log_gpu_memory_usage("After MegatronPPOActor init", logger=logger) elif self._is_ref: - print(f"self.config.ref.load_weight: {self.config.ref.load_weight}") - ref_module = make_model(wrap_with_ddp=False) + wrap_config = McoreModuleWrapperConfig( + is_value_model=False, # ref is not value model + share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, + wrap_with_ddp=False, + use_distributed_optimizer=self.config.ref.megatron.use_distributed_optimizer, + ) + ref_module, updated_tf_config = make_megatron_module( + wrap_config=wrap_config, + tf_config=self.tf_config, + hf_config=self.hf_config, + bridge=self.bridge, + provider=self.provider, + override_model_config=override_model_config, + ) + self.tf_config = updated_tf_config if self.config.ref.load_weight: # should align with the actor: assert self.config.actor.load_weight == self.config.ref.load_weight print("load ref weight start") @@ -325,11 +460,15 @@ def megatron_actor_model_provider(pre_process, post_process): ref_module, self.config.ref.megatron.dist_checkpointing_path, is_value_model=False, + prefix=self.config.ref.megatron.dist_checkpointing_prefix, ) else: if self.bridge is not None: local_model_path = get_hf_model_path(self.config) - self.bridge.load_weights(ref_module, local_model_path) + if self.vanilla_bridge: + self.bridge.load_weights(ref_module, local_model_path) + else: + self.bridge.load_hf_weights(ref_module, local_model_path) else: load_megatron_gptmodel_weights( self.config, @@ -343,7 +482,11 @@ def megatron_actor_model_provider(pre_process, post_process): # TODO: add more optimizer args into config if self._is_actor: - optim_config_megatron = init_megatron_optim_config(optim_config) + optim_config_megatron = init_megatron_optim_config( + optim_config, + use_distributed_optimizer=wrap_config.use_distributed_optimizer, + fp16=self.dtype == torch.float16, + ) actor_optimizer = get_megatron_optimizer( model=actor_module, config=optim_config_megatron ) @@ -357,6 +500,8 @@ def megatron_actor_model_provider(pre_process, post_process): log_gpu_memory_usage("After actor optimizer init", logger=logger) + register_megatron_training_hooks(actor_module, actor_optimizer) + return ( actor_module, actor_optimizer, @@ -376,21 +521,24 @@ def init_model(self): from verl.utils.torch_dtypes import PrecisionType override_model_config = OmegaConf.to_container( - self.config.model.get("override_config", OmegaConf.create()) + OmegaConf.create(self.config.model.get("override_config", {})) ) if self._is_actor: override_transformer_config = OmegaConf.to_container( - self.config.actor.megatron.get("override_transformer_config", OmegaConf.create()), - resolve=True, + OmegaConf.create(self.config.actor.megatron.get("override_transformer_config", {})) + ) + if self.enable_routing_replay: + override_transformer_config["enable_routing_replay"] = True + override_ddp_config = OmegaConf.to_container( + OmegaConf.create(self.config.actor.megatron.get("override_ddp_config", {})) ) elif self._is_ref: override_transformer_config = OmegaConf.to_container( - self.config.ref.megatron.get("override_transformer_config", OmegaConf.create()), - resolve=True, + OmegaConf.create(self.config.ref.megatron.get("override_transformer_config", {})) ) else: override_transformer_config = {} - self.param_dtype = torch.bfloat16 + self.param_dtype = PrecisionType.to_dtype(self.config.actor.megatron.dtype) log_gpu_memory_usage("Before init actor model and optimizer", logger=logger) self.dtype = PrecisionType.to_dtype(self.param_dtype) if self._is_actor: @@ -407,6 +555,7 @@ def init_model(self): optim_config=optim_config, override_model_config=override_model_config, override_transformer_config=override_transformer_config, + override_ddp_config=override_ddp_config, ) if self._is_offload_param: offload_megatron_model_to_cpu(self.actor_module) @@ -430,6 +579,7 @@ def init_model(self): actor_module=self.actor_module, actor_optimizer=self.actor_optimizer, ) + print(f"routing replay layers: {len(RouterReplay.router_instances)}") log_gpu_memory_usage("After MegatronPPOActor init", logger=logger) if self._is_ref: @@ -471,9 +621,22 @@ def init_model(self): use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer, use_checkpoint_opt_param_scheduler=self.config.actor.optim.use_checkpoint_opt_param_scheduler, bridge=self.bridge, + provider=self.provider, use_dist_checkpointing=self.config.actor.megatron.use_dist_checkpointing, + peft_cls=self.peft_cls, ray_namespace=self.config.synchronizer.ray_namespace, ) + + self.layer_name_mapping = { + "qkv_layer_name": "self_attention.linear_qkv.", + "gate_proj_layer_name": "linear_fc1.", + } + self.weight_converter = None + if not self.config.actor.megatron.use_mbridge: + self.weight_converter = get_mcore_weight_converter( + self.actor_model_config, self.dtype + ) + self.synchronizer = Synchronizer.get_actor(namespace=self.config.synchronizer.ray_namespace) get_torch_device().empty_cache() log_gpu_memory_usage("After init_model finish", logger=logger) @@ -484,29 +647,26 @@ def _get_tensor_generator(self): in `verl.workers.megatron_workers.ActorRolloutRefWorker._build_rollout` and its `__enter__` method. When the version of verl changes, please check the related code. """ - from verl.models.mcore import get_mcore_weight_converter - from verl.utils.megatron_utils import per_tensor_generator - - weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype) - layer_name_mapping = { - "qkv_layer_name": "self_attention.linear_qkv.", - "gate_proj_layer_name": "linear_fc1.", - } if self.bridge is not None: - per_tensor_param = self.bridge.export_weights(self.actor_module) + if self.vanilla_bridge: + per_tensor_param = self.bridge.export_weights(self.actor.actor_module) + else: + per_tensor_param = self.bridge.export_hf_weights(self.actor.actor_module) else: per_tensor_param = per_tensor_generator( - self.actor_module, + self.actor.actor_module, self.actor_model_config, - weight_converter, + self.weight_converter, self.tf_config, - layer_name_mapping, + self.layer_name_mapping, ) return per_tensor_param @register(dispatch_mode=Dispatch.ONE_TO_ALL) def setup_weight_sync_group(self): if self.config.synchronizer.sync_method == SyncMethod.NCCL: + aggressive_empty_cache(force_sync=True) + set_expandable_segments(False) self.state_dict_meta = [] if self._is_offload_param: @@ -545,6 +705,9 @@ def setup_weight_sync_group(self): @register(dispatch_mode=Dispatch.ONE_TO_ALL) def sync_weight(self): + aggressive_empty_cache(force_sync=True) + set_expandable_segments(False) + if self._is_offload_param: load_megatron_model_to_gpu(self.actor_module) for name, weight in self._get_tensor_generator(): @@ -561,6 +724,9 @@ def sync_weight(self): @register(dispatch_mode=Dispatch.ONE_TO_ALL) def upload_state_dict(self, trainer_step: int): + aggressive_empty_cache(force_sync=True) + set_expandable_segments(False) + if self._is_offload_param: load_megatron_model_to_gpu(self.actor_module) state_dict = {} @@ -579,9 +745,31 @@ def upload_state_dict(self, trainer_step: int): def set_algorithm(self, algo_config: AlgorithmConfig): self.actor.set_algorithm(algo_config) - @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO) + async def trainer_mode(self): + """Context switch hybridengine to trainer mode.""" + # if self.config.rollout.free_cache_engine: + # log_gpu_memory_usage("Before rollout offload", logger=logger) + # await self.rollout.release() + # log_gpu_memory_usage("After rollout offload", logger=logger) + + for model in self.actor.actor_module: + model.train() + # add empty cache after each compute + aggressive_empty_cache(force_sync=True) + + # FIXME(@wuxibin): megatron+sglang failed with `expandable_segments:True` in ci, + # can't reproduce it in dev environment, temporary disable it. + # https://github.com/volcengine/verl/actions/runs/17382936845/job/49344264323?pr=3285 + if os.environ.get("MEGATRON_CI_DISABLE_EXPANDABLE_SEGMENTS", "0") == "0": + set_expandable_segments(True) + + # restore random states + self.gen_random_states = get_torch_device().get_rng_state() + get_torch_device().set_rng_state(self.torch_random_states) + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) @GPUMemoryLogger(role="update_actor", logger=logger) - @DistProfiler.annotate(color="red") + @DistProfiler.annotate(color="red", role="actor_update") def update_actor(self, data: DataProto): assert self._is_actor if self._is_offload_param: @@ -592,7 +780,6 @@ def update_actor(self, data: DataProto): if self._is_offload_optimizer: load_megatron_optimizer(self.actor_optimizer) log_gpu_memory_usage("After load actor optimizer during update_actor", logger=logger) - data.batch = data.batch.to(get_device_name()) micro_batch_size = self.config.actor.ppo_micro_batch_size_per_gpu data.meta_info["micro_batch_size"] = micro_batch_size @@ -632,12 +819,12 @@ def update_actor(self, data: DataProto): offload_megatron_optimizer(self.actor_optimizer) log_gpu_memory_usage("After offload actor optimizer during update_actor", logger=logger) - get_torch_device().empty_cache() + aggressive_empty_cache(force_sync=True) return output - @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO) + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) @GPUMemoryLogger(role="compute_ref_log_prob", logger=logger) - @DistProfiler.annotate(color="olive") + @DistProfiler.annotate(color="olive", role="ref_compute_log_prob") def compute_ref_log_prob(self, data: DataProto): assert self._is_ref if self._ref_is_offload_param: @@ -650,8 +837,7 @@ def compute_ref_log_prob(self, data: DataProto): data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz data.meta_info["temperature"] = self.config.rollout.temperature - data = data.to(get_device_id()) - output, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False) + output, _, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False) output = DataProto.from_dict(tensors={"ref_log_prob": output}) output = output.to("cpu") if self._ref_is_offload_param: @@ -659,12 +845,12 @@ def compute_ref_log_prob(self, data: DataProto): log_gpu_memory_usage( "After offload ref params and grad during compute_ref_log_prob", logger=logger ) - get_torch_device().empty_cache() + aggressive_empty_cache(force_sync=True) return output - @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO) + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) @GPUMemoryLogger(role="compute_log_prob", logger=logger) - @DistProfiler.annotate(color="blue") + @DistProfiler.annotate(color="blue", role="actor_compute_log_prob") def compute_log_prob(self, data: DataProto): assert self._is_actor if self._is_offload_param: @@ -677,12 +863,27 @@ def compute_log_prob(self, data: DataProto): data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu data.meta_info["use_dynamic_bsz"] = self.config.rollout.log_prob_use_dynamic_bsz data.meta_info["temperature"] = self.config.rollout.temperature - data = data.to(get_device_id()) - output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True) + + if self.enable_routing_replay and self.config.actor.router_replay.mode == "R2": + RouterReplay.set_global_router_replay_action(RouterReplayAction.RECORD) + + if self.enable_routing_replay and self.config.actor.router_replay.mode == "R3": + RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) + + output, entropys, layers_topk_idx = self.actor.compute_log_prob( + data=data, calculate_entropy=True + ) output = DataProto.from_dict( tensors={"old_log_probs": output, "entropys": entropys}, meta_info={"temperature": self.config.rollout.temperature}, ) + if self.config.actor.router_replay.mode == "R2": + output.batch["routed_experts"] = layers_topk_idx + + if self.config.actor.router_replay.mode in ["R2", "R3"]: + RouterReplay.clear_global_indices() + RouterReplay.clear_global_router_replay_action() + output = output.to("cpu") # clear kv cache if self._is_offload_param: @@ -690,11 +891,22 @@ def compute_log_prob(self, data: DataProto): log_gpu_memory_usage( "After offload actor params and grad during compute_log_prob", logger=logger ) - get_torch_device().empty_cache() + aggressive_empty_cache(force_sync=True) return output @register(dispatch_mode=Dispatch.ONE_TO_ALL) def load_checkpoint(self, checkpoint_path, hdfs_path=None, del_local_after_load=True): + # No checkpoint to load, just offload the model and optimizer to CPU + if checkpoint_path is None: + if self._is_offload_param: + offload_megatron_model_to_cpu(self.actor_module) + if self._is_offload_optimizer: + offload_megatron_optimizer(self.actor_optimizer) + log_gpu_memory_usage( + "After offload actor params and optimizer during load_checkpoint", logger=logger + ) + return + if self._is_offload_param: load_megatron_model_to_gpu(self.actor_module) self.checkpoint_mananager.load_checkpoint( @@ -737,6 +949,8 @@ def save_checkpoint( ): if self._is_offload_param: load_megatron_model_to_gpu(self.actor_module) + if self.checkpoint_mananager.checkpoint_config.async_save and self._is_offload_optimizer: + load_megatron_optimizer(self.actor_optimizer) self.checkpoint_mananager.save_checkpoint( local_path=checkpoint_path, global_step=global_step, @@ -746,31 +960,65 @@ def save_checkpoint( torch.distributed.barrier() if self._is_offload_param: offload_megatron_model_to_cpu(self.actor_module) + if self.checkpoint_mananager.checkpoint_config.async_save and self._is_offload_optimizer: + offload_megatron_optimizer(self.actor_optimizer) @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def clear_optimizer_state(self): - print("Clear actor optimizer state") - if self._is_offload_optimizer: - load_megatron_optimizer(self.actor_optimizer) - self.actor_optimizer.state.clear() - self.actor_optimizer.zero_grad() - if self._is_offload_optimizer: - offload_megatron_optimizer(self.actor_optimizer) + def async_calls_finalize_fn_exec(self, blocking=False): + from megatron.core.dist_checkpointing.strategies.base import async_calls + + async_calls.maybe_finalize_async_calls(blocking=blocking) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def start_profile(self, **kwargs) -> None: + """Start profiling for the current rank in the current training step.""" + self.profiler.start(**kwargs) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def stop_profile(self) -> None: + """Stop profiling for the current rank in the current training step.""" + self.profiler.stop() + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def dump_memory_snapshot(self, tag: str = "manual", sub_dir: str = None) -> None: + """Manually trigger a CUDA memory snapshot dump on all ranks.""" + # Memory snapshot is now handled by the profiler system + # This method is kept for backward compatibility but delegates to profiler + if hasattr(self, "profiler") and hasattr(self.profiler, "_impl"): + try: + # Try to use the profiler's memory snapshot functionality + if hasattr(self.profiler._impl, "sampler"): + out_dir = OmegaConf.select(self.config, "actor.profiler.save_path") or "." + self.profiler._impl.sampler.dump_memory_snapshot( + out_dir=out_dir, tag=tag, sub_dir=sub_dir + ) + except Exception as e: + # Log a warning if memory snapshot fails. This might be expected if the profiler doesn't support it. + logger.warning(f"Failed to dump memory snapshot: {e}") @register(dispatch_mode=Dispatch.ONE_TO_ALL) def wait_on_save_thread(self) -> None: - # currently, we don't need to wait for the save thread because async saving doesn't work. - pass + self.async_calls_finalize_fn_exec(blocking=True) class CriticWorker(MegatronWorker, DistProfilerExtension): - def __init__(self, config): - MegatronWorker.__init__(self) + def __init__(self, config: McoreCriticConfig): + Worker.__init__(self) + + omega_profiler_config = config.get("profiler", {}) + profiler_config = omega_conf_to_dataclass( + omega_profiler_config, dataclass_type=ProfilerConfig + ) + if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]: + tool_config = omega_conf_to_dataclass( + omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool")) + ) + else: + tool_config = None DistProfilerExtension.__init__( - self, - DistProfiler(rank=self.rank, config=omega_conf_to_dataclass(config.get("profiler"))), + self, DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config) ) - self.config = config + self.config: McoreCriticConfig = config # NOTE(sgm): We utilize colocate WorkerGroup by default. # As a result, Workers for different model share the same process. @@ -779,10 +1027,11 @@ def __init__(self, config): # 1, users should disable WorkerDict; 2.assign different ResourcePool to different models, # 3. and apply the following patch in ray==2.10, https://github.com/ray-project/ray/pull/44385 if not torch.distributed.is_initialized(): + set_numa_affinity() rank = int(os.environ["LOCAL_RANK"]) torch.distributed.init_process_group( backend=get_nccl_backend(), - timeout=datetime.timedelta(seconds=self.config.synchronizer.sync_timeout), + timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), init_method=os.environ.get("DIST_INIT_METHOD", None), ) get_torch_device().set_device(rank) @@ -791,7 +1040,6 @@ def __init__(self, config): tensor_model_parallel_size=self.config.megatron.tensor_model_parallel_size, pipeline_model_parallel_size=self.config.megatron.pipeline_model_parallel_size, virtual_pipeline_model_parallel_size=self.config.megatron.virtual_pipeline_model_parallel_size, - pipeline_model_parallel_split_rank=None, use_sharp=False, context_parallel_size=self.config.megatron.context_parallel_size, expert_model_parallel_size=self.config.megatron.expert_model_parallel_size, @@ -799,6 +1047,16 @@ def __init__(self, config): nccl_communicator_config_path=None, ) + is_collect = ( + mpu.get_tensor_model_parallel_rank() == 0 + and mpu.get_pipeline_model_parallel_rank() + == mpu.get_pipeline_model_parallel_world_size() - 1 + and mpu.get_context_parallel_rank() == 0 + ) + self._register_dispatch_collect_info( + mesh_name="critic", dp_rank=mpu.get_data_parallel_rank(), is_collect=is_collect + ) + set_random_seed(seed=self.config.megatron.seed) # set FSDP offload params @@ -815,65 +1073,52 @@ def __init__(self, config): # TODO(sgm): support critic model offload def _build_critic_model_optimizer( - self, model_path, optim_config, override_model_config, override_transformer_config + self, + model_path, + optim_config, + override_model_config, + override_transformer_config, + override_ddp_config, ): - from megatron.core.models.gpt.gpt_model import ModelType from verl.utils.megatron.optimizer import ( get_megatron_optimizer, get_megatron_optimizer_param_scheduler, + init_megatron_optim_config, + ) + from verl.utils.megatron_utils import ( + McoreModuleWrapperConfig, + make_megatron_module, ) - from verl.utils.megatron_utils import get_model, init_megatron_optim_config from verl.utils.model import print_model_size self._init_hf_config_and_tf_config( model_path, - self.config.model.tokenizer_path, + self.config.model.get("tokenizer_path") or model_path, self.dtype, override_model_config, override_transformer_config, self.config.model.get("trust_remote_code", False), - self.config.megatron.use_mbridge, + self.config.megatron, ) - if self.bridge is not None: - from verl.models.mcore.mbridge import freeze_moe_router, make_value_model - - post_model_creation_callbacks = [make_value_model] - if override_model_config.get("moe_config", {}).get("freeze_moe_router", False): - post_model_creation_callbacks.append(freeze_moe_router) - critic_module = self.bridge.get_model( - post_model_creation_callbacks=post_model_creation_callbacks, wrap_with_ddp=True - ) - else: - - def megatron_critic_model_provider(pre_process, post_process): - from verl.models.mcore import init_mcore_model - - parallel_model = init_mcore_model( - self.tf_config, - self.hf_config, - pre_process, - post_process, - share_embeddings_and_output_weights=False, - value=True, - freeze_moe_router=override_model_config.get("moe_config", {}).get( - "freeze_moe_router", False - ), - ) - parallel_model.to(get_device_name()) - return parallel_model - - override_ddp_config = OmegaConf.to_container( - self.config.megatron.get("override_ddp_config", OmegaConf.create()), resolve=True - ) - # Step 3: initialize the megatron model - critic_module = get_model( - model_provider_func=megatron_critic_model_provider, - model_type=ModelType.encoder_or_decoder, - wrap_with_ddp=True, - use_distributed_optimizer=self.config.megatron.use_distributed_optimizer, - override_ddp_config=override_ddp_config, - ) + wrap_config = McoreModuleWrapperConfig( + is_value_model=True, # critic is value model + share_embeddings_and_output_weights=False, + wrap_with_ddp=True, + use_distributed_optimizer=self.config.megatron.use_distributed_optimizer, + ) + critic_module, updated_tf_config = make_megatron_module( + wrap_config=wrap_config, + tf_config=self.tf_config, + hf_config=self.hf_config, + bridge=self.bridge, + provider=self.provider, + override_model_config=override_model_config, + override_ddp_config=override_ddp_config, + peft_cls=self.peft_cls, + peft_config=self.config.model.get("lora", None), + ) + self.tf_config = updated_tf_config # note that here critic_module will be a list to be compatible with the construction of interleaved pp (vpp). # but here, we do not use pp (vpp) yet. For simplicity, we remove the list # critic_module = nn.ModuleList(critic_module) @@ -882,12 +1127,22 @@ def megatron_critic_model_provider(pre_process, post_process): t0 = time.time() if self.config.megatron.use_dist_checkpointing: load_mcore_dist_weights( - critic_module, self.config.megatron.dist_checkpointing_path, is_value_model=True + critic_module, + self.config.megatron.dist_checkpointing_path, + is_value_model=True, + prefix=self.config.megatron.dist_checkpointing_prefix, ) else: if self.bridge is not None: local_model_path = get_hf_model_path(self.config) - self.bridge.load_weights(critic_module, local_model_path) + if self.vanilla_bridge: + self.bridge.load_weights(critic_module, local_model_path) + else: + self.bridge.load_hf_weights( + critic_module, + local_model_path, + allowed_mismatched_params=["output_layer.weight"], + ) else: load_megatron_gptmodel_weights( self.config, @@ -903,12 +1158,19 @@ def megatron_critic_model_provider(pre_process, post_process): print_model_size(critic_module[0]) # TODO: add more optimizer args into config - optim_config_megatron = init_megatron_optim_config(optim_config) + optim_config_megatron = init_megatron_optim_config( + optim_config, + use_distributed_optimizer=wrap_config.use_distributed_optimizer, + fp16=self.dtype == torch.float16, + ) critic_optimizer = get_megatron_optimizer(model=critic_module, config=optim_config_megatron) critic_optimizer_scheduler = get_megatron_optimizer_param_scheduler( optimizer=critic_optimizer, config=optim_config ) get_torch_device().empty_cache() + + register_megatron_training_hooks(critic_module, critic_optimizer) + return ( critic_module, critic_optimizer, @@ -929,13 +1191,15 @@ def init_model(self): importlib.import_module(self.config.model.external_lib) override_model_config = OmegaConf.to_container( - self.config.model.get("override_config", OmegaConf.create()) + OmegaConf.create(self.config.model.get("override_config", {})) ) override_transformer_config = OmegaConf.to_container( - self.config.megatron.get("override_transformer_config", OmegaConf.create()), - resolve=True, + OmegaConf.create(self.config.megatron.get("override_transformer_config", {})) + ) + override_ddp_config = OmegaConf.to_container( + OmegaConf.create(self.config.megatron.get("override_ddp_config", {})) ) - self.param_dtype = torch.bfloat16 + self.param_dtype = PrecisionType.to_dtype(self.config.megatron.dtype) self.dtype = PrecisionType.to_dtype(self.param_dtype) ( self.critic_module, @@ -948,6 +1212,7 @@ def init_model(self): optim_config=self.config.optim, override_model_config=override_model_config, override_transformer_config=override_transformer_config, + override_ddp_config=override_ddp_config, ) if self._is_offload_param: offload_megatron_model_to_cpu(self.critic_module) @@ -981,12 +1246,14 @@ def init_model(self): use_distributed_optimizer=self.config.megatron.use_distributed_optimizer, use_checkpoint_opt_param_scheduler=self.config.optim.use_checkpoint_opt_param_scheduler, bridge=self.bridge, + provider=self.provider, use_dist_checkpointing=self.config.megatron.use_dist_checkpointing, + peft_cls=self.peft_cls, ray_namespace=self.config.ray_namespace, ) - @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO) - @DistProfiler.annotate(color="cyan") + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic")) + @DistProfiler.annotate(color="cyan", role="compute_values") def compute_values(self, data: DataProto): micro_batch_size = self.config.ppo_micro_batch_size_per_gpu data.meta_info["micro_batch_size"] = micro_batch_size @@ -1002,8 +1269,8 @@ def compute_values(self, data: DataProto): offload_megatron_model_to_cpu(self.critic_module) return output - @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO) - @DistProfiler.annotate(color="pink") + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic")) + @DistProfiler.annotate(color="pink", role="critic_update") def update_critic(self, data: DataProto): data = data.to(get_device_id()) @@ -1071,16 +1338,11 @@ def save_checkpoint( offload_megatron_model_to_cpu(self.critic_module) @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def clear_optimizer_state(self): - print("Clear critic optimizer state") - if self._is_offload_optimizer: - load_megatron_optimizer(self.critic_optimizer) - self.critic_optimizer.state.clear() - self.critic_optimizer.zero_grad() - if self._is_offload_optimizer: - offload_megatron_optimizer(self.critic_optimizer) + def async_calls_finalize_fn_exec(self, blocking=False): + from megatron.core.dist_checkpointing.strategies.base import async_calls + + async_calls.maybe_finalize_async_calls(blocking=blocking) @register(dispatch_mode=Dispatch.ONE_TO_ALL) def wait_on_save_thread(self) -> None: - # currently, we don't need to wait for the save thread because async saving doesn't work. - pass + self.async_calls_finalize_fn_exec(blocking=True) diff --git a/trinity/trainer/verl/utils.py b/trinity/trainer/verl/utils.py index 640ee2b748..654fdd8d7a 100644 --- a/trinity/trainer/verl/utils.py +++ b/trinity/trainer/verl/utils.py @@ -72,7 +72,7 @@ def to_data_proto( batch_dict.update( { "token_level_scores": token_level_rewards, - "old_log_probs": gather_response_attrs( + "rollout_log_probs": gather_response_attrs( experiences, "logprobs", max_response_length ), } @@ -99,7 +99,8 @@ def to_data_proto( ) else: raise ValueError("Custom fields are not consistent across experiences.") - return DataProto.from_single_dict(batch_dict) + meta_info = {"model_versions": np.array([exp.info["model_version"] for exp in experiences])} + return DataProto.from_single_dict(batch_dict, meta_info=meta_info) def compute_data_metrics(batch: DataProto) -> dict: diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 8583fd3363..51360afafa 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -13,6 +13,7 @@ import torch from omegaconf import OmegaConf from verl import DataProto +from verl.trainer.ppo.core_algos import agg_loss from verl.trainer.ppo.metric_utils import ( compute_throughout_metrics, compute_timing_metrics, @@ -29,6 +30,7 @@ from verl.utils.debug import marked_timer from verl.utils.fs import copy_local_path_from_hdfs from verl.utils.metric import reduce_metrics +from verl.workers.config import FSDPEngineConfig from trinity.algorithm import ADVANTAGE_FN, ALGORITHM_TYPE, KL_FN from trinity.algorithm.utils import prefix_metrics @@ -198,30 +200,23 @@ def __init__( tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) # processor for multimodal LLM, could be None processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) + from verl.single_controller.ray import RayWorkerGroup + + ray_worker_group_cls = RayWorkerGroup # define worker classes if config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]: - assert config.critic.strategy in ["fsdp", "fsdp2"] - from verl.single_controller.ray import RayWorkerGroup - from trinity.trainer.verl.fsdp_workers import ( ActorRolloutRefWorker, CriticWorker, ) - ray_worker_group_cls = RayWorkerGroup - elif config.actor_rollout_ref.actor.strategy == "megatron": - assert config.actor_rollout_ref.actor.strategy == config.critic.strategy - from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup - from trinity.trainer.verl.megatron_workers import ( ActorRolloutRefWorker, CriticWorker, ) - ray_worker_group_cls = NVMegatronRayWorkerGroup - else: raise NotImplementedError @@ -272,12 +267,7 @@ def __init__( ) self.init_workers() - def _validate_config(self): # TODO - algorithm = ALGORITHM_TYPE.get(self.algorithm_config.algorithm_type) - self.use_critic = algorithm.use_critic - super()._validate_config() - - def init_workers(self): + def init_workers(self): # noqa: C901 """Initialize distributed training workers using Ray backend. @@ -295,42 +285,98 @@ def init_workers(self): } # create actor and rollout + actor_role = ( + Role.ActorRolloutRef + if Role.ActorRolloutRef in self.role_worker_mapping + else Role.ActorRollout + ) if self.hybrid_engine: - resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) + resource_pool = self.resource_pool_manager.get_resource_pool(actor_role) actor_rollout_cls = RayClassWithInitArgs( - cls=self.role_worker_mapping[Role.ActorRollout], + cls=self.role_worker_mapping[actor_role], config=self.config.actor_rollout_ref, - role="actor", + role=str(actor_role), ) - self.resource_pool_to_cls[resource_pool]["actor"] = actor_rollout_cls + self.resource_pool_to_cls[resource_pool][str(actor_role)] = actor_rollout_cls else: raise NotImplementedError # create critic if self.use_critic: resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) + + critic_cfg = self.config.critic + + if self.use_legacy_worker_impl == "disable": + # convert critic_cfg into TrainingWorkerConfig + from verl.workers.engine_workers import TrainingWorkerConfig + + orig_critic_cfg = critic_cfg + if orig_critic_cfg.strategy == "fsdp": + engine_config: FSDPEngineConfig = orig_critic_cfg.model.fsdp_config + engine_config.infer_max_token_len_per_gpu = ( + critic_cfg.ppo_infer_max_token_len_per_gpu + ) + engine_config.max_token_len_per_gpu = critic_cfg.ppo_max_token_len_per_gpu + else: + raise NotImplementedError(f"Unknown strategy {orig_critic_cfg.strategy=}") + + critic_cfg = TrainingWorkerConfig( + model_type="value_model", + model_config=orig_critic_cfg.model_config, + engine_config=engine_config, + optimizer_config=orig_critic_cfg.optim, + checkpoint_config=orig_critic_cfg.checkpoint, + ) + critic_cls = RayClassWithInitArgs( - cls=self.role_worker_mapping[Role.Critic], config=self.config.critic + cls=self.role_worker_mapping[Role.Critic], config=critic_cfg ) - self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls + self.resource_pool_to_cls[resource_pool][str(Role.Critic)] = critic_cls # create reference policy if needed - if self.use_reference_policy: + if self.use_reference_policy and Role.RefPolicy in self.role_worker_mapping: resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) ref_policy_cls = RayClassWithInitArgs( self.role_worker_mapping[Role.RefPolicy], config=self.config.actor_rollout_ref, - role="ref", + role=str(Role.RefPolicy), ) - self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls + self.resource_pool_to_cls[resource_pool][str(Role.RefPolicy)] = ref_policy_cls + # create a reward model if reward_fn is None - if self.use_rm: - # we create a RM here - resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) - rm_cls = RayClassWithInitArgs( - self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model + # for legacy discriminative reward model, we create a reward model worker here + # for reward loop discriminative reward model, we create a reward loop manager here + if not self.use_reward_loop: + # legacy reward model only handle reward-model based scenario + if self.use_rm: + # we create a RM here + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) + rm_cls = RayClassWithInitArgs( + self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model + ) + self.resource_pool_to_cls[resource_pool][str(Role.RewardModel)] = rm_cls + else: + # reward loop handle hybrid reward scenario (rule, disrm, genrm, ...) + # Note: mode is always "async" since sync mode is deprecated + can_reward_loop_parallelize = ( + not self.use_rm or self.config.reward_model.enable_resource_pool ) - self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls + # judge if we can asynchronously parallelize reward model with actor rollout + # two condition that we can parallelize reward model with actor rollout: + # 1. reward model is not enabled (rule-based reward can parallelize) + # 2. reward model is enabled but extra resource pool is enabled + # If we cannot parallelize, we should enable synchronous mode here, and launch a reward loop manager here + # else for parallelize mode, we launch a reward worker for each rollout worker (in agent loop, not here) + if not can_reward_loop_parallelize: + from verl.experimental.reward_loop import RewardLoopManager + + self.config.reward_model.n_gpus_per_node = self.config.trainer.n_gpus_per_node + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) + self.reward_loop_manager = RewardLoopManager( + config=self.config, + rm_resource_pool=resource_pool, + ) # initialize WorkerGroup # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, @@ -343,33 +389,69 @@ def init_workers(self): wg_kwargs[ "ray_wait_register_center_timeout" ] = self.config.trainer.ray_wait_register_center_timeout + if OmegaConf.select(self.config.global_profiler, "steps") is not None: + wg_kwargs["profile_steps"] = OmegaConf.select(self.config.global_profiler, "steps") + # Only require nsight worker options when tool is nsys + if OmegaConf.select(self.config.global_profiler, "tool") == "nsys": + assert ( + OmegaConf.select( + self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options" + ) + is not None + ), "worker_nsight_options must be set when using nsys with profile_steps" + wg_kwargs["worker_nsight_options"] = OmegaConf.to_container( + OmegaConf.select( + self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options" + ) + ) + wg_kwargs["device_name"] = self.device_name + for resource_pool, class_dict in self.resource_pool_to_cls.items(): worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) wg_dict = self.ray_worker_group_cls( resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls, - device_name=self.device_name, **wg_kwargs, ) spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) all_wg.update(spawn_wg) if self.use_critic: - self.critic_wg = all_wg["critic"] - self.critic_wg.init_model() + self.critic_wg = all_wg[str(Role.Critic)] + if self.use_legacy_worker_impl == "disable": + self.critic_wg.reset() + # assign critic loss + from functools import partial - if self.use_reference_policy and not self.ref_in_actor: - self.ref_policy_wg = all_wg["ref"] - self.ref_policy_wg.init_model() + from verl.workers.utils.losses import value_loss + + value_loss_ = partial(value_loss, config=orig_critic_cfg) + self.critic_wg.set_loss_fn(value_loss_) + else: + self.critic_wg.init_model() - if self.use_rm: - self.rm_wg = all_wg["rm"] + if self.use_reference_policy and not self.ref_in_actor: + if str(Role.RefPolicy) in all_wg: + self.ref_policy_wg = all_wg[str(Role.RefPolicy)] + self.ref_policy_wg.init_model() + else: + # Model engine: ActorRolloutRefWorker + assert str(Role.ActorRolloutRef) in all_wg, f"{all_wg.keys()=}" + self.ref_policy_wg = all_wg[str(Role.ActorRolloutRef)] + + self.rm_wg = None + # initalization of rm_wg will be deprecated in the future + if self.use_rm and not self.use_reward_loop: + self.rm_wg = all_wg[str(Role.RewardModel)] self.rm_wg.init_model() # we should create rollout at the end so that vllm can have a better estimation of kv cache memory - self.actor_rollout_wg = all_wg["actor"] + self.actor_rollout_wg = all_wg[str(actor_role)] self.actor_rollout_wg.init_model() + if self.ref_in_actor: + self.ref_policy_wg = self.actor_rollout_wg + @property def train_step_num(self) -> int: return self.global_steps @@ -424,13 +506,57 @@ async def train_step(self, batch_exps: List[Experience]) -> Dict: # noqa C901 batch.batch["attention_mask"], dim=-1 ).tolist() + # Operating Mode Selection: + # - Bypass mode: Sets old_log_probs = rollout_log_probs (2 policies: π_rollout, π_θ) + # - Decoupled mode: Recomputes old_log_probs as proximal anchor (3 policies: π_rollout, π_old, π_θ) + # Note: π_old computed once per data batch, serves as stable reference during mini-batch updates + rollout_corr_config = self.config.algorithm.get("rollout_correction", None) + bypass_recomputing_logprobs = rollout_corr_config and rollout_corr_config.get( + "bypass_mode", False + ) + if bypass_recomputing_logprobs: # Use `rollout_log_probs` + from verl.trainer.ppo.rollout_corr_helper import apply_bypass_mode + + apply_bypass_mode( + batch=batch, + rollout_corr_config=rollout_corr_config, + policy_loss_config=self.config.actor_rollout_ref.actor.policy_loss, + ) + else: # Recompute old_log_probs TODO: to be check + if (batch.meta_info["model_versions"] != self.global_steps - 1).any(): + self.logger.warning( + f"model_versions mismatch: {batch.meta_info['model_versions']} vs {self.global_steps - 1}" + ) + with marked_timer("old_log_prob", timing_raw, color="blue"): + old_log_prob, old_log_prob_mfu = self._compute_old_log_prob(batch) + entropys = old_log_prob.batch["entropys"] + response_masks = batch.batch["response_mask"] + actor_config = self.config.actor_rollout_ref.actor + entropy_agg = agg_loss( + loss_mat=entropys, + loss_mask=response_masks, + loss_agg_mode=actor_config.loss_agg_mode, + loss_scale_factor=actor_config.loss_scale_factor, + ) + old_log_prob_metrics = { + "actor/entropy": entropy_agg.detach().item(), + "perf/mfu/actor_infer": old_log_prob_mfu, + } + metrics.update(old_log_prob_metrics) + old_log_prob.batch.pop("entropys") + batch = batch.union(old_log_prob) + if "rollout_log_probs" in batch.batch.keys(): + # TODO: we may want to add diff of probs too. + from verl.utils.debug.metrics import calculate_debug_metrics + + metrics.update(calculate_debug_metrics(batch)) + + assert "old_log_probs" in batch.batch, f'"old_log_prob" not in {batch.batch.keys()=}' + if self.algorithm.use_reference: # ref_logprob may not be used # compute reference log_prob - with marked_timer("ref", timing_raw): - if not self.ref_in_actor: - ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) - else: - ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch) + with marked_timer(str(Role.RefPolicy), timing_raw, color="olive"): + ref_log_prob = self._compute_ref_log_prob(batch) batch = batch.union(ref_log_prob) if self.algorithm.use_critic: @@ -451,10 +577,30 @@ async def train_step(self, batch_exps: List[Experience]) -> Dict: # noqa C901 assert "token_level_rewards" not in batch.batch.keys() batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] + # TODO: to be check + # Compute rollout correction: IS weights, rejection sampling, and metrics + # Only runs in decoupled mode (computes once per batch using stable π_old) + # In bypass mode, this is skipped - actor computes metrics from evolving π_θ vs π_rollout + if ( + rollout_corr_config is not None + and "rollout_log_probs" in batch.batch + and not bypass_recomputing_logprobs # Only in decoupled mode + ): + from verl.trainer.ppo.rollout_corr_helper import ( + compute_rollout_correction_and_add_to_batch, + ) + + # Compute IS weights, apply rejection sampling, compute metrics + batch, is_metrics = compute_rollout_correction_and_add_to_batch( + batch, rollout_corr_config + ) + # IS and off-policy metrics already have rollout_corr/ prefix + metrics.update(is_metrics) + # update critic if self.algorithm.use_critic: - with marked_timer("update_critic", timing_raw): - critic_output = self.critic_wg.update_critic(batch) + with marked_timer("update_critic", timing_raw, color="pink"): + critic_output = self._update_critic(batch) critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) metrics.update(critic_output_metrics) @@ -464,9 +610,8 @@ async def train_step(self, batch_exps: List[Experience]) -> Dict: # noqa C901 or self.config.trainer.critic_warmup <= self.global_steps ): # update actor - with marked_timer("update_actor", timing_raw): - actor_output = self.actor_rollout_wg.update_actor(batch) - # TODO add send weight explorer + with marked_timer("update_actor", timing_raw, color="red"): + actor_output = self._update_actor(batch) actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) metrics.update(actor_output_metrics) @@ -518,7 +663,7 @@ def _save_checkpoint(self, save_as_hf: bool = False): self.actor_rollout_wg.save_checkpoint( actor_local_path, - self.global_steps, + global_step=self.global_steps, max_ckpt_to_keep=max_actor_ckpt_to_keep, save_as_hf=save_as_hf, ) @@ -527,7 +672,7 @@ def _save_checkpoint(self, save_as_hf: bool = False): critic_local_path = os.path.join(local_global_step_folder, "critic") self.critic_wg.save_checkpoint( critic_local_path, - self.global_steps, + global_step=self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep, ) diff --git a/trinity/utils/lora_utils.py b/trinity/utils/lora_utils.py index bddce2684c..7277eeb562 100644 --- a/trinity/utils/lora_utils.py +++ b/trinity/utils/lora_utils.py @@ -1,9 +1,13 @@ +from typing import Optional + + def create_dummy_lora( model_path: str, checkpoint_job_dir: str, lora_rank: int, lora_alpha: int, target_modules: str, + exclude_modules: Optional[str] = None, ) -> str: import torch from peft import LoraConfig, TaskType, get_peft_model @@ -16,6 +20,7 @@ def create_dummy_lora( "r": lora_rank, "lora_alpha": lora_alpha, "target_modules": target_modules, + "exclude_modules": exclude_modules, "bias": "none", } peft_model = get_peft_model(model, LoraConfig(**lora_config))