diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index c98c836fa..fed4dfcbb 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -6,7 +6,6 @@ import logging import os -import shutil import time from collections.abc import Mapping @@ -53,45 +52,6 @@ logger.setLevel(logging.DEBUG) -def cleanup_old_weight_versions( - state_dict_key: str, - delim: str, - current_policy_version: int, -) -> None: - """Delete old weight versions, keeping only current and N-1 versions. - - TODO - issues/194: provide a more robust way to handle eviction. - - Args: - state_dict_key: The base key for state dict storage - delim: The delimiter used between key and version - current_policy_version: The current policy version to keep - """ - if current_policy_version <= 1: - return # No cleanup needed for versions 0 or 1 - - prefix = f"{state_dict_key}{delim}" - current_weights = f"{prefix}{current_policy_version}" - previous_weights = f"{prefix}{current_policy_version - 1}" - - # Find all weight directories that match our pattern - parent_dir = os.path.dirname(prefix) or "." - if os.path.exists(parent_dir): - for item in os.listdir(parent_dir): - item_path = os.path.join(parent_dir, item) - if ( - item.startswith(os.path.basename(prefix)) - and item != os.path.basename(current_weights) - and item != os.path.basename(previous_weights) - and os.path.isdir(item_path) - ): - try: - shutil.rmtree(item_path, ignore_errors=True) - logger.debug(f"Removed old weights at {item_path}") - except OSError as e: - logger.debug(f"Error deleting {item_path}: {e}") - - @dataclass class RLTrainer(ForgeActor): """A reinforcement learning trainer actor for policy optimization training. @@ -135,19 +95,10 @@ class RLTrainer(ForgeActor): dcp_path: str = "forge_dcp_tmp" def __post_init__(self): - """Initializes config types and env variables. - - torchrun normally hands env variables, but we need to do it ourselves - in monarch for now. - - """ super().__init__() - if self.use_dcp: - # DCP specific optimization torch.serialization.set_crc32_options(False) - # Instantiate dict fields for f in fields(self): attr = getattr(self, f.name) if isinstance(attr, Mapping): @@ -184,73 +135,23 @@ def forward_backward( ) -> Tensor: model_parts = self.engine.model_parts parallel_dims = self.engine.parallel_dims - - # apply context parallelism if cp is enabled - # ensure CP handles the separate freqs_cis buffer for each pp stage - # if getattr(self.engine.model_args, "use_flex_attn", False): - # cp_mesh = ( - # parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None - # ) - # init_attention_mask( - # inputs, self.engine.tokenizer.base_tokenizer.eos_id, cp_mesh - # ) - - # optional_context_parallel_ctx = ( - # dist_utils.create_context_parallel_ctx( - # cp_mesh=parallel_dims.world_mesh["cp"], - # cp_buffers=[inputs, targets] + [m.freqs_cis for m in model_parts], - # cp_seq_dims=[1, 1] + [0 for _ in model_parts], - # cp_no_restore_buffers={inputs, targets}, - # cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, - # ) - # if parallel_dims.cp_enabled - # else None - # ) optional_context_parallel_ctx = None - if parallel_dims.pp_enabled: raise NotImplementedError("PP not implemented yet") - # TODO implement PP - # # Pipeline Parallel forward / backward inside step() call - # with self.train_context(optional_context_parallel_ctx): - # targets, losses = ( - # (labels, []) if self.pp_has_last_stage else (None, None) - # ) - # if self.pp_has_first_stage: - # self.pp_schedule.step( - # inputs, target=targets, losses=losses, input_batch=inputs - # ) - # else: - # self.pp_schedule.step( - # target=targets, losses=losses, input_batch=inputs - # ) - # - # # accumulate losses across pipeline microbatches - # # TODO: PP+FSDP unexpectedly puts the loss back to the CPU - # loss = ( - # torch.mean(torch.stack(losses)).to(self.device) - # if self.pp_has_last_stage - # else torch.tensor([-1.0], device=self.device) - # ) else: - # Non-PP forward / backward with self.engine.train_context(optional_context_parallel_ctx): assert len(model_parts) == 1 with self.engine.maybe_enable_amp: logits = model_parts[0](**inputs) loss = self.loss(logits, **targets) - # need to free to before bwd to avoid peaking memory - del logits + del logits # Free to before bwd to avoid peaking memory loss.backward() - return loss @endpoint async def train_step( self, inputs: list[dict[str, Tensor]], targets: list[dict[str, Tensor]] ) -> float: - - # Log timesteps t = Tracer("rl_trainer_perf/step", timer="gpu", track_memory=True) t.start() @@ -259,18 +160,12 @@ async def train_step( local_targets = targets[self.engine.dp_rank] batch_to_device(local_inputs, self.engine.device) batch_to_device(local_targets, self.engine.device) - # compute policy logprobs - # TODO implement gradient accumulation - # with GradientAccumulation( - # self.gradient_accumulation_steps, - # self.model, - # self.data_parallel_size, - # ) as grad_acc: + loss = self.forward_backward(local_inputs, local_targets) torch.distributed.all_reduce(loss) + t.step("forward_backward") - # Get learning rate from scheduler current_lr = ( self.engine.lr_schedulers.get_last_lr()[0] if hasattr(self.engine.lr_schedulers, "get_last_lr") @@ -283,13 +178,11 @@ async def train_step( self.engine.lr_schedulers.step() t.step("optimizer_step") - # Record training metrics # TODO: delete item() to avoid cpu-gpu sync - loss = loss.detach().cpu().item() + loss = loss.detach().item() record_metric("rl_trainer/count_training_steps", 1, Reduce.SUM) record_metric("rl_trainer/avg_grpo_loss", loss, Reduce.MEAN) - # TODO: Extract actual KL divergence and policy entropy from the loss computation # These are placeholder values until the loss function exposes these metrics # record_metric("rl_trainer/step/avg_kl_divergence", 0.0, Reduce.MEAN) # record_metric("rl_trainer/step/std_kl_divergence", 0.0, Reduce.STD) @@ -351,109 +244,3 @@ async def push_weights(self, policy_version: int) -> None: async def cleanup(self) -> None: if self.engine.checkpointer: self.engine.checkpointer.close() - - -def _shard_and_concat(sources: list[torch.Tensor], dim: int, tp: int) -> torch.Tensor: - """Shard and concatenate tensors along a given dimension. - - Args: - source (list[torch.Tensor]): List of tensors to shard and concatenate. - dim (int): Dimension along which to shard and concatenate. - tp (int): Number of tensor parallel groups. - - Returns: - torch.Tensor: Concatenated tensor. - """ - sharded_sources = [] - for source in sources: - sharded_sources.append(torch.chunk(source, tp, dim=dim)) - - combined_shards = [] - for shard_idx in range(tp): - combined = torch.cat([s[shard_idx] for s in sharded_sources], dim=dim) - combined_shards.append(combined) - return torch.cat(combined_shards, dim=dim) - - -def _qwen3_hf_to_vllm( - sd: dict[str, torch.Tensor], num_layers: int, vllm_tp: int -) -> dict[str, torch.Tensor]: - """Convert transformers state dict to vLLM format. Specifically, this fuses - QKV projection and MLP gate_up_proj layers. - - Args: - sd (dict): State dict from HF model. - num_layers (int): Number of layers in the model. - - Returns: - dict: State dict in vLLM format. - """ - load_sd = {} - - def unwrap(t): - """Unwrap a DTensor to a Tensor.""" - return t.full_tensor() if isinstance(t, torch.distributed.tensor.DTensor) else t - - for key in sd.keys(): - sd[key] = unwrap(sd[key]).cpu() - - # Copy over directly mapped keys - for k in sd: - if any( - x in k - for x in [ - "down_proj", - "input_layernorm", - "post_attention_layernorm", - "o_proj", - "norm.weight", - "embed_tokens.weight", - "lm_head.weight", - ] - ): - load_sd[k] = sd[k] - - for i in range(num_layers): - prefix = f"model.layers.{i}." - # QKV fusion - q = sd[prefix + "self_attn.q_proj.weight"] - k = sd[prefix + "self_attn.k_proj.weight"] - v = sd[prefix + "self_attn.v_proj.weight"] - - load_sd[prefix + "self_attn.qkv_proj.weight"] = _shard_and_concat( - [q, k, v], dim=0, tp=vllm_tp - ) - - # Untested: QKV fusion - handle bias if present - q_bias_key = prefix + "self_attn.q_proj.bias" - k_bias_key = prefix + "self_attn.k_proj.bias" - v_bias_key = prefix + "self_attn.v_proj.bias" - - if all(key in sd for key in [q_bias_key, k_bias_key, v_bias_key]): - q_bias = sd[q_bias_key] - k_bias = sd[k_bias_key] - v_bias = sd[v_bias_key] - load_sd[prefix + "self_attn.qkv_proj.bias"] = _shard_and_concat( - [q_bias, k_bias, v_bias], dim=0, tp=vllm_tp - ) - - # MLP gate_up_proj fusion - gate = sd[prefix + "mlp.gate_proj.weight"] - up = sd[prefix + "mlp.up_proj.weight"] - load_sd[prefix + "mlp.gate_up_proj.weight"] = _shard_and_concat( - [gate, up], dim=0, tp=vllm_tp - ) - - # Untested: MLP gate_up_proj fusion - handle bias if present - gate_bias_key = prefix + "mlp.gate_proj.bias" - up_bias_key = prefix + "mlp.up_proj.bias" - - if all(key in sd for key in [gate_bias_key, up_bias_key]): - gate_bias = sd[gate_bias_key] - up_bias = sd[up_bias_key] - # Same sharding has to happen here - load_sd[prefix + "mlp.gate_up_proj.bias"] = _shard_and_concat( - [gate_bias, up_bias], dim=0, tp=vllm_tp - ) - - return load_sd diff --git a/tests/unit_tests/test_trainer.py b/tests/unit_tests/test_trainer.py deleted file mode 100644 index a5a6f290e..000000000 --- a/tests/unit_tests/test_trainer.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import os -import shutil -import tempfile -import unittest - -from forge.actors.trainer import cleanup_old_weight_versions - - -class TestTrainerUtilities(unittest.TestCase): - def setUp(self): - """Set up test environment with temporary directory.""" - self.test_dir = tempfile.mkdtemp() - self.addCleanup(shutil.rmtree, self.test_dir) - - def test_cleanup_old_weight_versions_basic(self): - """Test basic cleanup functionality - keeps current and N-1 versions.""" - # Create test directory structure - state_dict_key = os.path.join(self.test_dir, "model") - delim = "__" - - # Create some mock weight directories - old_version_1 = f"{state_dict_key}{delim}1" - previous_version = f"{state_dict_key}{delim}2" # N-1 version - current_version = f"{state_dict_key}{delim}3" # Current version - unrelated_dir = os.path.join(self.test_dir, "other_model__1") - - for dir_path in [ - old_version_1, - previous_version, - current_version, - unrelated_dir, - ]: - os.makedirs(dir_path) - - # Run cleanup for version 3 - cleanup_old_weight_versions( - state_dict_key=state_dict_key, - delim=delim, - current_policy_version=3, - ) - - # Check that only very old versions were deleted (version 1) - self.assertFalse(os.path.exists(old_version_1)) - - # Check that current and previous versions still exist - self.assertTrue(os.path.exists(previous_version)) # N-1 version should remain - self.assertTrue( - os.path.exists(current_version) - ) # Current version should remain - self.assertTrue(os.path.exists(unrelated_dir)) # Unrelated dirs should remain - - def test_cleanup_old_weight_versions_no_cleanup_version_1(self): - """Test that no cleanup happens when current_policy_version <= 1.""" - # Create test directory structure - state_dict_key = os.path.join(self.test_dir, "model") - delim = "__" - - version_1 = f"{state_dict_key}{delim}1" - os.makedirs(version_1) - - # Run cleanup for version 1 - should do nothing - cleanup_old_weight_versions( - state_dict_key=state_dict_key, - delim=delim, - current_policy_version=1, - ) - - # Version 1 should still exist - self.assertTrue(os.path.exists(version_1)) - - def test_cleanup_old_weight_versions_version_2(self): - """Test cleanup with version 2 as current - should keep versions 1 and 2.""" - # Create test directory structure - state_dict_key = os.path.join(self.test_dir, "model") - delim = "__" - - version_1 = f"{state_dict_key}{delim}1" # N-1 version - version_2 = f"{state_dict_key}{delim}2" # Current version - - for dir_path in [version_1, version_2]: - os.makedirs(dir_path) - - # Run cleanup for version 2 - cleanup_old_weight_versions( - state_dict_key=state_dict_key, - delim=delim, - current_policy_version=2, - ) - - # Both versions should still exist (no deletion for version 2) - self.assertTrue(os.path.exists(version_1)) - self.assertTrue(os.path.exists(version_2)) - - -if __name__ == "__main__": - unittest.main()