diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 774f494ac..c4bea0889 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -7,7 +7,6 @@ from __future__ import annotations import asyncio - import logging import os import sys @@ -393,18 +392,26 @@ async def update_weights(self, policy_version: int): logger.info(f"Weight update completed (now v{self.policy_version})") @endpoint - async def _get_model_params(self) -> dict[str, torch.Tensor]: - """Get the current model parameters. Only for testing purposes.""" - val_mesh = await self.policy_worker._get_model_params.call() - sharded_state_dicts = {} - for idx, val in val_mesh.items(): - sharded_state_dicts[idx["gpus"]] = val - return sharded_state_dicts + async def get_version(self) -> int: + """Get the current policy version.""" + return self.policy_version @endpoint async def stop(self): self.running = False + @endpoint + async def _test_save_model_params(self): + """Save model parameters before weight update, used for tesing purposes only.""" + logger.info("[Policy] save model parameters for testing.") + await self.policy_worker._test_save_model_params.call() + + @endpoint + async def _test_validate_model_params(self, validate_fn): + """Validate updated model params using validate_fn.""" + logger.info("[Policy] start validating model parameters.") + return await self.policy_worker._test_validate_model_params.call(validate_fn) + def _to_completions(self, request_output: RequestOutput) -> list[Completion]: """Convert a RequestOutput to a list of Completion objects.""" completions = [] @@ -449,6 +456,9 @@ class PolicyWorker(ForgeActor): state_dict_key: str = "model_state_dict" use_dcp: bool = True + # used for tesing purposes only + _test_prev_params = {} + def __post_init__(self): super().__init__() @@ -541,15 +551,23 @@ async def setup_kv_cache(self): return kv_cache_config @endpoint - async def _get_model_params(self) -> dict[str, torch.Tensor]: - model = self.worker.model_runner.model - state_dict = {} + async def _test_save_model_params(self): + """Save model parameters before weight update, used for tesing purposes only.""" + logger.info("[PolicyWorker] save model parameters for testing.") + for name, param in self.worker.model_runner.model.named_parameters(): + self._test_prev_params[name] = param.detach().cpu() + logger.info( + "[PolicyWorker] finished saving model parameters, len = %d", + len(self._test_prev_params), + ) - for name, param in model.named_parameters(): - if "layers.0" not in name: - continue - state_dict[name] = param.cpu().detach() - return state_dict + @endpoint + async def _test_validate_model_params(self, validate_fn): + """Validate updated model params using validate_fn.""" + logger.info("[PolicyWorker] start validating model parameters.") + return validate_fn( + self._test_prev_params, self.worker.model_runner.model, logger + ) def setup_worker(self): """Build and Instantiate vLLM worker""" diff --git a/tests/integration_tests/__init__.py b/tests/integration_tests/__init__.py new file mode 100644 index 000000000..2e41cd717 --- /dev/null +++ b/tests/integration_tests/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/integration_tests/test_policy_update.py b/tests/integration_tests/test_policy_update.py index c737078b7..543956877 100644 --- a/tests/integration_tests/test_policy_update.py +++ b/tests/integration_tests/test_policy_update.py @@ -4,9 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import asyncio import logging from dataclasses import asdict -from typing import Callable import pytest import pytest_asyncio @@ -17,154 +17,23 @@ from forge.actors.trainer import RLTrainer from forge.controller.service import ServiceConfig -from forge.data.sharding import VLLMSharding -from transformers import AutoModelForCausalLM + +from forge.controller.service.service import uuid +from monarch.actor import endpoint + requires_cuda = pytest.mark.skipif( not torch.cuda.is_available(), reason="CUDA not available", ) -from forge.actors.trainer import _qwen3_hf_to_vllm + from huggingface_hub import snapshot_download logger: logging.Logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -# Run tests: pytest tests/integration_tests/test_policy_update.py::TestWeightSync:: - - -def convert_state_dict(saved_sd): - """ - Convert transformers state dict to vLLM format. - - Key conversions: - 1. Copy over directly mapped keys (down_proj, input_layernorm, etc.) - 2. Fuse QKV projections: combine q_proj, k_proj, v_proj into qkv_proj - 3. Fuse MLP projections: combine gate_proj and up_proj into gate_up_proj - """ - load_sd = {} - num_layers = 32 # For Llama-8B-3.1 - - # Copy over directly mapped keys - for k in saved_sd: - if any( - x in k - for x in [ - "down_proj", - "input_layernorm", - "post_attention_layernorm", - "o_proj", - "norm.weight", - "embed_tokens.weight", - "lm_head.weight", - ] - ): - load_sd[k] = saved_sd[k] - - # Fuse QKV and gate_up_proj - for i in range(num_layers): - prefix = f"model.layers.{i}." - - # QKV fusion - q = saved_sd[prefix + "self_attn.q_proj.weight"] - k = saved_sd[prefix + "self_attn.k_proj.weight"] - v = saved_sd[prefix + "self_attn.v_proj.weight"] - load_sd[prefix + "self_attn.qkv_proj.weight"] = torch.cat([q, k, v], dim=0) - - # MLP gate_up_proj fusion - gate = saved_sd[prefix + "mlp.gate_proj.weight"] - up = saved_sd[prefix + "mlp.up_proj.weight"] - load_sd[prefix + "mlp.gate_up_proj.weight"] = torch.cat([gate, up], dim=0) - - return load_sd - - -def calculate_expected_shard( - full_tensor: torch.Tensor, - param_name: str, - tensor_parallel_size: int, - rank: int, -) -> torch.Tensor: - """ - Calculate the expected shard of a full tensor for comparison with loaded tensor. - This is mainly used for validation in tests. - - Args: - full_tensor: The full tensor to shard - param_name: Name of the parameter (used to determine sharding strategy) - expected_shape: Expected shape of the sharded tensor - tensor_parallel_size: Number of tensor parallel ranks - rank: Current rank - - Returns: - torch.Tensor: The expected sharded tensor for this rank - """ - - sharding = VLLMSharding(tensor_parallel_size, rank) - shard_dim, is_sharded = sharding._get_tensor_parallel_sharding_strategy(param_name) - - if not is_sharded: - return full_tensor - - sharded_tensor = sharding._calculate_tensor_shard( - full_tensor, shard_dim, tensor_parallel_size, rank - ) - return sharded_tensor - - -def validate_loaded_tensors_equals_original( - loaded_state_dict: dict[str, torch.Tensor], - original_state_dict: dict[str, torch.Tensor], - tensor_parallel_size: int, - rank: int, -): - """ - Shared validation function to verify that every tensor loaded by the policy - equals the original tensor. - - For tensor parallel cases, instead of gathering sharded tensors, we shard - the original tensor and compare it with the loaded shard. - """ - for param_name, loaded_tensor in loaded_state_dict.items(): - if param_name in original_state_dict: - original_tensor = original_state_dict[param_name] - - if tensor_parallel_size > 1: - original_shard = calculate_expected_shard( - original_tensor, - param_name, - tensor_parallel_size, - rank, - ) - tensor_to_compare = original_shard.cpu().float() - else: - tensor_to_compare = original_tensor.cpu().float() - - # Training trainer emitted and loaded tensors are of type bfloat16, - # where as a HF model loaded(expected) tensor has type float16. - if not torch.allclose( - loaded_tensor.float(), - tensor_to_compare, - rtol=1e-2, - atol=1e-3, - ): - logger.warning( - f"Loaded tensor {param_name} does not equal original.\n" - f"dtype = {loaded_tensor.dtype} vs {original_tensor.dtype}\n" - f"shape= {loaded_tensor.shape} vs {original_tensor.shape}\n," - f"values = {loaded_tensor} vs {original_tensor}" - ) - raise ValueError( - f"Loaded tensor {param_name} does not equal original " - f"(shapes: loaded={loaded_tensor.shape}, expected={tensor_to_compare.shape})" - ) - else: - print(f"Loaded tensor {param_name} correctly validated") - - print( - f"Successfully validated that all {len(loaded_state_dict)} loaded tensors equal original" - ) +# Run tests: pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync:: def get_configs( @@ -188,15 +57,97 @@ def get_configs( return policy_config, service_config +class MockRLTrainer(RLTrainer): + @endpoint + async def zero_out_model_states(self): + """This simply sets all model weights to zero.""" + for model_part in self.engine.model_parts: + sd = model_part.state_dict() + for k in sd.keys(): + if not torch.is_floating_point(sd[k]): + logger.info( + f"[MockRLTrainer] zero_out_model_states(): skipping non-float param {k}" + ) + continue + sd[k] *= 0.0 + + +# exceptions sometimes are not propogated in monarch, do it manually +def validate_fn(prev_params, curr_model, logger) -> Exception | None: + """Validate that current parameters are the same as prev_params.""" + verified = set() + skipped = set() + logger.info( + f"Validating model params, all named_parameters() = {curr_model.named_parameters()}" + ) + errs = [] + for name, param in curr_model.named_parameters(): + if not torch.is_floating_point(param): + logger.info(f"Skipping non-float param {name}") + skipped.add(name) + continue + try: + assert name in prev_params, f"Param {name} not found in prev_params" + assert torch.allclose( + prev_params[name], param.cpu(), atol=1e-3, rtol=1e-2 + ), ( + f"current param {name} does not match expected value; " + f"previous param ({prev_params[name].size()})= {prev_params[name]}; " + f"expected = {prev_params[name]} vs got = {param.cpu().size()} {param.cpu()}" + ) + verified.add(name) + except Exception as e: + # logger.error(f"Validation failed with exception: {e}") + errs.append((name, e)) + logger.info(f"Verified params = {verified}") + logger.info(f"Skipped params = {skipped}") + if errs: + logger.error( + f"Validation failed for the following params: {[e[0] for e in errs]}" + ) + return AssertionError(f"Validation failed: {errs}") + + +# exceptions sometimes are not propogated in monarch, do it manually +def validate_fn_all_zeros(prev_params, curr_model, logger) -> Exception | None: + """Validate all parameters are set to zero. prev_params is actually not used.""" + _ = prev_params + verified = set() + skipped = set() + logger.info( + f"Validating model params, all named_parameters() = {curr_model.named_parameters()}" + ) + errs = [] + for name, param in curr_model.named_parameters(): + if not torch.is_floating_point(param): + logger.info(f"Skipping non-float param {name}") + skipped.add(name) + continue + try: + param = param.cpu() + assert torch.allclose( + torch.zeros_like(param), param, atol=1e-4, rtol=1e-3 + ), "param {name} is not zero." + verified.add(name) + except Exception as e: + # logger.error(f"Validation failed with exception: {e}") + errs.append((name, e)) + logger.info(f"Verified params = {verified}") + logger.info(f"Skipped params = {skipped}") + if errs: + logger.error( + f"Validation failed for the following params: {[e[0] for e in errs]}" + ) + return AssertionError(f"Validation failed: {errs}") + + class TestWeightSync: """Tests for weight sync between trainer and policy. Currently hardcoded to Qwen3-1.7B.""" model = "Qwen/Qwen3-1.7B" - to_vllm_fn: Callable = _qwen3_hf_to_vllm - num_layers = 28 @pytest_asyncio.fixture - def trainer_cfg(self): + async def trainer_cfg(self): cached_dir = snapshot_download(repo_id=self.model) return { "model": { @@ -212,7 +163,7 @@ def trainer_cfg(self): } @pytest_asyncio.fixture - def trainer_cfg_tp(self): + async def trainer_cfg_tp(self): # NB: TP size is set to 2. cached_dir = snapshot_download(repo_id=self.model) return { @@ -229,57 +180,79 @@ def trainer_cfg_tp(self): }, } - @pytest_asyncio.fixture - def expected_sd(self): - model = AutoModelForCausalLM.from_pretrained( - self.model, - dtype=torch.bfloat16, - trust_remote_code=True, - ) - original_state_dict = model.state_dict() - # Hack to access through class without passing in self param - return self.__class__.to_vllm_fn(original_state_dict, self.num_layers) - @pytest.mark.asyncio @requires_cuda - async def test_policy_update_single(self, expected_sd, trainer_cfg): + async def test_policy_update_single(self, trainer_cfg): """ - 1. Loads weights from HF model into in-memory state-dict (source of truth) - 2. Initializes RLTrainer, make the weights available in torchstore. - 3. Initializes Policy, and calls update_weights() to load weights from torchstore. - 4. Validate the policy weights against source of truth. + Test the weight synchronization process between RLTrainer and Policy. + + This test performs the following steps: + - Initialize trainer and push weights v0 (original huggingface ckpt) + - Step the trainer, setting all weights to zero and push weights v1 + - Load weights v0 and check the policy has all zero weights + - Load weights v1 and check the policy has all the weights back """ - worker_size = 1 - # 1. Initialize TS + trainer_worker_size = 1 + policy_worker_size = 1 + tp_size = 1 + await ts.initialize() - # 2. Trainer push - rl_trainer = await RLTrainer.options( - procs=worker_size, with_gpus=True, num_replicas=1 - ).as_service(**trainer_cfg) - await rl_trainer.push_weights.route(policy_version=0) - # 3. Policy pull weights policy_config, service_config = get_configs( - worker_size=worker_size, tp_size=worker_size, model_name=self.model + worker_size=policy_worker_size, tp_size=tp_size, model_name=self.model ) - policy = await Policy.options(**asdict(service_config)).as_service( - **policy_config + policy, rl_trainer = await asyncio.gather( + *[ + Policy.options(**asdict(service_config)).as_service(**policy_config), + MockRLTrainer.options( + procs=trainer_worker_size, with_gpus=True, num_replicas=1 + ).as_service(**trainer_cfg), + ] ) - await policy.update_weights.fanout() - # 4. Validate weights - loaded_state_dict = await policy._get_model_params.route() - validate_loaded_tensors_equals_original( - loaded_state_dict, expected_sd, tensor_parallel_size=1, rank=0 + + v0 = uuid.uuid4().int + v1 = v0 + 1 + + await rl_trainer.push_weights.fanout(policy_version=v0) + # Setting everything to zero + await rl_trainer.zero_out_model_states.fanout() + await rl_trainer.push_weights.fanout(policy_version=v1) + await policy._test_save_model_params.fanout() + + # Sanity check that before update all the tests pass + all_errs = await policy._test_validate_model_params.fanout(validate_fn) + for errs in all_errs: + for _, e in errs.items(): + assert not e, f"Validation failed with exception: {e}" + + await policy.update_weights.fanout(policy_version=v1) + all_errs = await policy._test_validate_model_params.fanout( + validate_fn_all_zeros ) + for errs in all_errs: + for _, e in errs.items(): + assert not e, f"Validation failed with exception: {e}" + + # Reloading v0, getting back original weights + await policy.update_weights.fanout(policy_version=v0) + all_errs = await policy._test_validate_model_params.fanout(validate_fn) + for errs in all_errs: + for _, e in errs.items(): + assert not e, f"Validation failed with exception: {e}" + + await ts.shutdown() @pytest.mark.asyncio @requires_cuda - async def test_policy_update_tp(self, expected_sd, trainer_cfg_tp): + async def test_policy_update_tp(self, trainer_cfg_tp): """ - 1. Init RLTrainer over multiple workers with TP parallelism strategy. - 2. Push weights in to torchstore. - 3. Initializes Policy over multiple workers, and calls update_weights() to load weights from torchstore. - 4. Validate the policy weights against manually loaded origina HF weights. + Test the weight synchronization process between RLTrainer and Policy. + + This test performs the following steps: + - Initialize trainer and push weights v0 (original huggingface ckpt) + - Step the trainer, setting all weights to zero and push weights v1 + - Load weights v0 and check the policy has all zero weights + - Load weights v1 and check the policy has all the weights back """ # test configs/paralleism trainer_worker_size = 2 @@ -290,31 +263,48 @@ async def test_policy_update_tp(self, expected_sd, trainer_cfg_tp): pytest.skip( f"Only {torch.cuda.device_count()} GPU(s) available, need 2+ for tensor parallel" ) - # 1. Initialize TS + await ts.initialize() - # 2. Trainer push - rl_trainer = await RLTrainer.options( - procs=trainer_worker_size, with_gpus=True, num_replicas=1 - ).as_service(**trainer_cfg_tp) - await rl_trainer.push_weights.fanout(policy_version=0) - # 3. Policy pull weights policy_config, service_config = get_configs( worker_size=policy_worker_size, tp_size=tp_size, model_name=self.model ) - policy = await Policy.options(**asdict(service_config)).as_service( - **policy_config + policy, rl_trainer = await asyncio.gather( + *[ + Policy.options(**asdict(service_config)).as_service(**policy_config), + MockRLTrainer.options( + procs=trainer_worker_size, with_gpus=True, num_replicas=1 + ).as_service(**trainer_cfg_tp), + ] ) - await policy.update_weights.fanout() - # validate loaded shard of each worker againt manually calculated shard (expected shard). + v0 = uuid.uuid4().int + v1 = v0 + 1 - # 4. Validate weight shards. We compare vLLM loades shard content with - # Directly loaded HF shard content. - sharded_state_dicts = await policy._get_model_params.fanout() - validate_loaded_tensors_equals_original( - sharded_state_dicts[0][0], expected_sd, tensor_parallel_size=tp_size, rank=0 - ) - validate_loaded_tensors_equals_original( - sharded_state_dicts[0][1], expected_sd, tensor_parallel_size=tp_size, rank=1 + await rl_trainer.push_weights.fanout(policy_version=v0) + # Setting everything to zero + await rl_trainer.zero_out_model_states.fanout() + await rl_trainer.push_weights.fanout(policy_version=v1) + await policy._test_save_model_params.fanout() + + # Sanity check that before update all the tests pass + all_errs = await policy._test_validate_model_params.fanout(validate_fn) + for errs in all_errs: + for _, e in errs.items(): + assert not e, f"Validation failed with exception: {e}" + + await policy.update_weights.fanout(policy_version=v1) + all_errs = await policy._test_validate_model_params.fanout( + validate_fn_all_zeros ) + for errs in all_errs: + for _, e in errs.items(): + assert not e, f"Validation failed with exception: {e}" + # Reloading v0, getting back original weights + await policy.update_weights.fanout(policy_version=v0) + all_errs = await policy._test_validate_model_params.fanout(validate_fn) + for errs in all_errs: + for _, e in errs.items(): + assert not e, f"Validation failed with exception: {e}" + + await ts.shutdown()