diff --git a/.gitignore b/.gitignore index 14e5f66e1..82c22144b 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,9 @@ cython_debug/ slogs/ slurm-* +# DCP checkpoints +model_state_dict/ + # Celery stuff celerybeat-schedule celerybeat.pid diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 15642272c..170c2bb6e 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -18,6 +18,7 @@ from forge.actors.policy import Policy from forge.actors.reference_model import ReferenceModel from forge.actors.replay_buffer import ReplayBuffer +from forge.actors.torchstore_utils import get_param_key from forge.actors.trainer import RLTrainer from forge.cli.config import parse from forge.controller.actor import ForgeActor @@ -155,6 +156,23 @@ def simple_grpo_loss( / (padding_mask.sum(dim=1).clamp(min=1.0)) ).mean() return loss + loss = self.loss(logprobs, ref_logprobs, advantages, mask) + loss.backward() + self.optimizer.step() + self.optimizer.zero_grad(set_to_none=True) + + return loss.item() + + @endpoint + async def push_weights(self, version: int): + """Update policy model weights with trainer's current weights.""" + start_time = time.perf_counter() + hf_state_dict = self.model.state_dict() + for name, param in hf_state_dict.items(): + key = get_param_key(version, name) + await ts.put(key, param) + end_time = time.perf_counter() + self.logger.debug(f"Pushed weights in {end_time - start_time:.2f} seconds") @dataclass @@ -245,7 +263,7 @@ async def main(cfg: DictConfig): mlogger = get_metric_logger( "wandb", freq=1, - project="grpo-training", + project="yuxuanh-grpo-training-debug", ) # ---- Setup services ---- # @@ -351,8 +369,20 @@ async def continuous_training(): loss = await trainer.train_step.choose(inputs, targets) training_step += 1 mlogger.log("loss/training_step", loss, training_step) + start_time = time.perf_counter() await trainer.push_weights.call(training_step) + mlogger.log( + "push_weights_time/training_step", + time.perf_counter() - start_time, + training_step, + ) + start_time = time.perf_counter() await policy.update_weights.call(training_step) + mlogger.log( + "update_weights_time/training_step", + time.perf_counter() - start_time, + training_step, + ) print("Starting GRPO training loops...") # TODO: Start multiple rollouts once all serivces support it diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 38889656d..c63bc5939 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -7,7 +7,6 @@ from __future__ import annotations import asyncio - import logging import os import sys @@ -19,8 +18,21 @@ import torch import torch.distributed.checkpoint as dcp import torchstore as ts + +from forge.actors.torchstore_utils import ( + extract_param_name, + get_param_key, + get_param_prefix, +) + +from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh +from forge.data.sharding import VLLMSharding +from forge.data_models.completion import Completion +from forge.data_models.prompt import to_prompt + +from forge.interfaces import Policy as PolicyInterface +from forge.types import ProcessConfig from monarch.actor import current_rank, endpoint, ProcMesh -from torchstore.state_dict_utils import DELIM from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs @@ -43,15 +55,7 @@ from vllm.v1.structured_output import StructuredOutputManager from vllm.worker.worker_base import WorkerWrapperBase -from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh - -from forge.data.sharding import VLLMSharding -from forge.data_models.completion import Completion -from forge.data_models.prompt import to_prompt - -from forge.interfaces import Policy as PolicyInterface -from forge.types import ProcessConfig - +logger: logging.Logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -388,15 +392,6 @@ async def update_weights(self, policy_version: int): self.policy_version = policy_version logger.info(f"Weight update completed (now v{self.policy_version})") - @endpoint - async def _get_model_params(self) -> dict[str, torch.Tensor]: - """Get the current model parameters. Only for testing purposes.""" - val_mesh = await self.policy_worker._get_model_params.call() - sharded_state_dicts = {} - for idx, val in val_mesh.items(): - sharded_state_dicts[idx["gpus"]] = val - return sharded_state_dicts - @endpoint async def get_version(self) -> int: """Get the current policy version.""" @@ -406,6 +401,18 @@ async def get_version(self) -> int: async def stop(self): self.running = False + @endpoint + async def _test_save_model_params(self): + """Save model parameters before weight update, used for tesing purposes only.""" + logger.info("[Policy] start saving model parameters before update for testing") + await self.policy_worker._test_save_model_params.call() + + @endpoint + async def _test_validate_model_params(self, validate_fn): + """Validate updated model params using validate_fn.""" + logger.info("[Policy] start validating model parameters post update") + return await self.policy_worker._test_validate_model_params.call(validate_fn) + def _to_completions(self, request_output: RequestOutput) -> list[Completion]: """Convert a RequestOutput to a list of Completion objects.""" completions = [] @@ -449,6 +456,9 @@ class PolicyWorker(ForgeActor): state_dict_key: str = "model_state_dict" use_dcp: bool = True + # used for tesing purposes only + _test_prev_params = {} + @endpoint async def setup(self): # TODO: remove ["gpus"] when monarch implements a flat rank @@ -498,12 +508,23 @@ async def _load_tensor_parallel_state_dict( @endpoint async def update(self, version: int): """Update model weights by reading state dict from torchstore""" - key = f"{self.state_dict_key}{DELIM}{version}" model = self.worker.model_runner.model - current_state_dict = model.state_dict() - start = time.time() - await self._load_tensor_parallel_state_dict(current_state_dict, version) - logger.debug(f"Loaded state dict from {key} in {time.time() - start} seconds") + prefix = get_param_prefix(version) + self.logger.debug(f"{prefix=}") + matching_keys = await ts.keys(prefix) + self.logger.debug(f"{matching_keys=}") + # TODO: find a way to save the original huggingface parameter names. + hf_names = [extract_param_name(key) for key in matching_keys] + self.logger.debug(f"{hf_names=}") + loaded_weights = set() + # We can't pass a generator since vllm load_weights is not async. + # Instead, we just call load_weights with one parameter at a time. + for name in hf_names: + param = await ts.get(get_param_key(version, name)) + loaded = model.load_weights([(name, param)]) + del param + loaded_weights.update(loaded) + self.logger.info(f"Updated {len(loaded_weights)} parameters") @endpoint async def setup_kv_cache(self): @@ -536,15 +557,25 @@ async def setup_kv_cache(self): return kv_cache_config @endpoint - async def _get_model_params(self) -> dict[str, torch.Tensor]: - model = self.worker.model_runner.model - state_dict = {} + async def _test_save_model_params(self): + """Save model parameters before weight update, used for tesing purposes only.""" + logger.info( + "[PolicyWorker] start saving model parameters before update for testing" + ) + for name, param in self.worker.model_runner.model.named_parameters(): + self._test_prev_params[name] = param.detach().cpu() + logger.info( + "[PolicyWorker] finished saving model parameters, len = %d", + len(self._test_prev_params), + ) - for name, param in model.named_parameters(): - if "layers.0" not in name: - continue - state_dict[name] = param.cpu().detach() - return state_dict + @endpoint + async def _test_validate_model_params(self, validate_fn): + """Validate updated model params using validate_fn.""" + logger.info("[PolicyWorker] start validating model parameters post update") + return validate_fn( + self._test_prev_params, self.worker.model_runner.model, logger + ) def setup_worker(self): """Build and Instantiate vLLM worker""" diff --git a/src/forge/actors/torchstore_utils.py b/src/forge/actors/torchstore_utils.py new file mode 100644 index 000000000..b8835039b --- /dev/null +++ b/src/forge/actors/torchstore_utils.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +KEY_DELIM = "." + + +def get_param_prefix(policy_version: int) -> str: + return f"policy_ver_{policy_version}" + + +def get_param_key(policy_version: int, name: str) -> str: + return f"policy_ver_{policy_version}{KEY_DELIM}{name}" + + +def extract_param_name(key: str) -> str: + return KEY_DELIM.join(key.split(KEY_DELIM)[1:]) diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index f6ffe3af7..014d94037 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -17,6 +17,15 @@ import torch.distributed.checkpoint as dcp import torchstore as ts +from forge.actors.torchstore_utils import ( + extract_param_name, + get_param_key, + get_param_prefix, +) + +from forge.controller import ForgeActor +from forge.data.utils import batch_to_device + from monarch.actor import current_rank, current_size, endpoint from torch import Tensor from torch.distributed.checkpoint._nested_dict import flatten_state_dict @@ -36,9 +45,6 @@ from torchtitan.experiments.forge.engine import ForgeEngine from torchtitan.experiments.forge.job_config import ForgeJobConfig -from forge.controller import ForgeActor -from forge.data.utils import batch_to_device - logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -290,6 +296,23 @@ async def push_weights(self, policy_version: int) -> None: logger.debug(f"Pushed weights to {key} in {end_time - start_time:.2f} seconds") + @endpoint + async def push_weights_hf_nonsharded(self, policy_version: int) -> None: + """Push weights to torchstore in HF format, non-sharded.""" + if "model" not in self.engine.checkpointer.states: + raise RuntimeError("Model state not found in checkpointer state") + + sd = self.engine.checkpointer.states["model"].state_dict() + flattened_state_dict, _ = flatten_state_dict(sd) + if self.engine.checkpointer.sd_adapter is None: + raise RuntimeError( + "Trying to save checkpoint in HF safetensors format, but sd_adapter is not provided." + ) + hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict) + for name, param in hf_state_dict.items(): + key = get_param_key(policy_version, name) + await ts.put(key, param) + @endpoint async def cleanup(self) -> None: if self.engine.checkpointer: diff --git a/tests/integration_tests/__init__.py b/tests/integration_tests/__init__.py new file mode 100644 index 000000000..2e41cd717 --- /dev/null +++ b/tests/integration_tests/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/integration_tests/test_policy_update.py b/tests/integration_tests/test_policy_update.py index 5fdce0b6a..d38520714 100644 --- a/tests/integration_tests/test_policy_update.py +++ b/tests/integration_tests/test_policy_update.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import asyncio import logging from typing import Callable @@ -16,155 +17,26 @@ from forge.actors.trainer import RLTrainer from forge.controller.service import ServiceConfig -from forge.data.sharding import VLLMSharding -from transformers import AutoModelForCausalLM +from forge.controller.service.service import uuid + +from monarch.actor import current_rank, endpoint +from torch.distributed.checkpoint._nested_dict import flatten_state_dict +from torch.distributed.tensor import DTensor, Replicate + requires_cuda = pytest.mark.skipif( not torch.cuda.is_available(), reason="CUDA not available", ) -from forge.actors.trainer import _qwen3_hf_to_vllm + from huggingface_hub import snapshot_download logger: logging.Logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -# Run tests: pytest tests/integration_tests/test_policy_update.py::TestWeightSync:: - - -def convert_state_dict(saved_sd): - """ - Convert transformers state dict to vLLM format. - - Key conversions: - 1. Copy over directly mapped keys (down_proj, input_layernorm, etc.) - 2. Fuse QKV projections: combine q_proj, k_proj, v_proj into qkv_proj - 3. Fuse MLP projections: combine gate_proj and up_proj into gate_up_proj - """ - load_sd = {} - num_layers = 32 # For Llama-8B-3.1 - - # Copy over directly mapped keys - for k in saved_sd: - if any( - x in k - for x in [ - "down_proj", - "input_layernorm", - "post_attention_layernorm", - "o_proj", - "norm.weight", - "embed_tokens.weight", - "lm_head.weight", - ] - ): - load_sd[k] = saved_sd[k] - - # Fuse QKV and gate_up_proj - for i in range(num_layers): - prefix = f"model.layers.{i}." - - # QKV fusion - q = saved_sd[prefix + "self_attn.q_proj.weight"] - k = saved_sd[prefix + "self_attn.k_proj.weight"] - v = saved_sd[prefix + "self_attn.v_proj.weight"] - load_sd[prefix + "self_attn.qkv_proj.weight"] = torch.cat([q, k, v], dim=0) - - # MLP gate_up_proj fusion - gate = saved_sd[prefix + "mlp.gate_proj.weight"] - up = saved_sd[prefix + "mlp.up_proj.weight"] - load_sd[prefix + "mlp.gate_up_proj.weight"] = torch.cat([gate, up], dim=0) - - return load_sd - - -def calculate_expected_shard( - full_tensor: torch.Tensor, - param_name: str, - tensor_parallel_size: int, - rank: int, -) -> torch.Tensor: - """ - Calculate the expected shard of a full tensor for comparison with loaded tensor. - This is mainly used for validation in tests. - - Args: - full_tensor: The full tensor to shard - param_name: Name of the parameter (used to determine sharding strategy) - expected_shape: Expected shape of the sharded tensor - tensor_parallel_size: Number of tensor parallel ranks - rank: Current rank - - Returns: - torch.Tensor: The expected sharded tensor for this rank - """ - - sharding = VLLMSharding(tensor_parallel_size, rank) - shard_dim, is_sharded = sharding._get_tensor_parallel_sharding_strategy(param_name) - - if not is_sharded: - return full_tensor - - sharded_tensor = sharding._calculate_tensor_shard( - full_tensor, shard_dim, tensor_parallel_size, rank - ) - return sharded_tensor - - -def validate_loaded_tensors_equals_original( - loaded_state_dict: dict[str, torch.Tensor], - original_state_dict: dict[str, torch.Tensor], - tensor_parallel_size: int, - rank: int, -): - """ - Shared validation function to verify that every tensor loaded by the policy - equals the original tensor. - - For tensor parallel cases, instead of gathering sharded tensors, we shard - the original tensor and compare it with the loaded shard. - """ - for param_name, loaded_tensor in loaded_state_dict.items(): - if param_name in original_state_dict: - original_tensor = original_state_dict[param_name] - - if tensor_parallel_size > 1: - original_shard = calculate_expected_shard( - original_tensor, - param_name, - tensor_parallel_size, - rank, - ) - tensor_to_compare = original_shard.cpu().float() - else: - tensor_to_compare = original_tensor.cpu().float() - - # Training trainer emitted and loaded tensors are of type bfloat16, - # where as a HF model loaded(expected) tensor has type float16. - if not torch.allclose( - loaded_tensor.float(), - tensor_to_compare, - rtol=1e-2, - atol=1e-3, - ): - logger.warning( - f"Loaded tensor {param_name} does not equal original.\n" - f"dtype = {loaded_tensor.dtype} vs {original_tensor.dtype}\n" - f"shape= {loaded_tensor.shape} vs {original_tensor.shape}\n," - f"values = {loaded_tensor} vs {original_tensor}" - ) - raise ValueError( - f"Loaded tensor {param_name} does not equal original " - f"(shapes: loaded={loaded_tensor.shape}, expected={tensor_to_compare.shape})" - ) - else: - print(f"Loaded tensor {param_name} correctly validated") - - print( - f"Successfully validated that all {len(loaded_state_dict)} loaded tensors equal original" - ) +# Run tests: pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync:: def get_configs( @@ -188,15 +60,97 @@ def get_configs( return policy_config, service_config +class MockRLTrainer(RLTrainer): + @endpoint + async def mock_train_step(self): + """Mock train step. This simply sets all model weights to zero.""" + for model_part in self.engine.model_parts: + sd = model_part.state_dict() + for k in sd.keys(): + if not torch.is_floating_point(sd[k]): + logger.info( + f"[MockRLTrainer] mock_train_step(): skipping non-float param {k}" + ) + continue + sd[k] *= 0.0 + + +# exceptions sometimes are not propogated in monarch, do it manually +def validate_fn(prev_params, curr_model, logger) -> Exception | None: + """Validate that current parameters are the same as prev_params.""" + verified = set() + skipped = set() + logger.info( + f"Validating model params, all named_parameters() = {curr_model.named_parameters()}" + ) + errs = [] + for name, param in curr_model.named_parameters(): + if not torch.is_floating_point(param): + logger.info(f"Skipping non-float param {name}") + skipped.add(name) + continue + try: + assert name in prev_params, f"Param {name} not found in prev_params" + assert torch.allclose( + prev_params[name], param.cpu(), atol=1e-3, rtol=1e-2 + ), ( + f"current param {name} does not match expected value; " + f"previous param ({prev_params[name].size()})= {prev_params[name]}; " + f"expected = {prev_params[name]} vs got = {param.cpu().size()} {param.cpu()}" + ) + verified.add(name) + except Exception as e: + # logger.error(f"Validation failed with exception: {e}") + errs.append((name, e)) + logger.info(f"Verified params = {verified}") + logger.info(f"Skipped params = {skipped}") + if errs: + logger.error( + f"Validation failed for the following params: {[e[0] for e in errs]}" + ) + return AssertionError(f"Validation failed: {errs}") + + +# exceptions sometimes are not propogated in monarch, do it manually +def validate_fn_all_zeros(prev_params, curr_model, logger) -> Exception | None: + """Validate all parameters are set to zero. prev_params is actually not used.""" + _ = prev_params + verified = set() + skipped = set() + logger.info( + f"Validating model params, all named_parameters() = {curr_model.named_parameters()}" + ) + errs = [] + for name, param in curr_model.named_parameters(): + if not torch.is_floating_point(param): + logger.info(f"Skipping non-float param {name}") + skipped.add(name) + continue + try: + param = param.cpu() + assert torch.allclose( + torch.zeros_like(param), param, atol=1e-4, rtol=1e-3 + ), "param {name} is not zero." + verified.add(name) + except Exception as e: + # logger.error(f"Validation failed with exception: {e}") + errs.append((name, e)) + logger.info(f"Verified params = {verified}") + logger.info(f"Skipped params = {skipped}") + if errs: + logger.error( + f"Validation failed for the following params: {[e[0] for e in errs]}" + ) + return AssertionError(f"Validation failed: {errs}") + + class TestWeightSync: """Tests for weight sync between trainer and policy. Currently hardcoded to Qwen3-1.7B.""" model = "Qwen/Qwen3-1.7B" - to_vllm_fn: Callable = _qwen3_hf_to_vllm - num_layers = 28 @pytest_asyncio.fixture - def trainer_cfg(self): + async def trainer_cfg(self): cached_dir = snapshot_download(repo_id=self.model) return { "model": { @@ -212,7 +166,7 @@ def trainer_cfg(self): } @pytest_asyncio.fixture - def trainer_cfg_tp(self): + async def trainer_cfg_tp(self): # NB: TP size is set to 2. cached_dir = snapshot_download(repo_id=self.model) return { @@ -229,57 +183,79 @@ def trainer_cfg_tp(self): }, } - @pytest_asyncio.fixture - def expected_sd(self): - model = AutoModelForCausalLM.from_pretrained( - self.model, - dtype=torch.bfloat16, - trust_remote_code=True, - ) - original_state_dict = model.state_dict() - # Hack to access through class without passing in self param - return self.__class__.to_vllm_fn(original_state_dict, self.num_layers) - @pytest.mark.asyncio @requires_cuda - async def test_policy_update_single(self, expected_sd, trainer_cfg): + async def test_policy_update_single(self, trainer_cfg): """ - 1. Loads weights from HF model into in-memory state-dict (source of truth) - 2. Initializes RLTrainer, make the weights available in torchstore. - 3. Initializes Policy, and calls update_weights() to load weights from torchstore. - 4. Validate the policy weights against source of truth. + Test the weight synchronization process between RLTrainer and Policy. + + This test performs the following steps: + - Initialize trainer and push weights v0 (original huggingface ckpt) + - Step the trainer, setting all weights to zero and push weights v1 + - Load weights v0 and check the policy has all zero weights + - Load weights v1 and check the policy has all the weights back """ - worker_size = 1 - # 1. Initialize TS + trainer_worker_size = 1 + policy_worker_size = 1 + tp_size = 1 + await ts.initialize() - # 2. Trainer push - rl_trainer = await RLTrainer.options( - procs=worker_size, with_gpus=True, num_replicas=1 - ).as_service(**trainer_cfg) - await rl_trainer.push_weights.choose(policy_version=0) - # 3. Policy pull weights policy_config, service_config = get_configs( - worker_size=worker_size, tp_size=worker_size, model_name=self.model - ) - policy = await Policy.options(service_config=service_config).as_service( - **policy_config + worker_size=policy_worker_size, tp_size=tp_size, model_name=self.model ) - await policy.update_weights.call() - # 4. Validate weights - loaded_state_dict = await policy._get_model_params.choose() - validate_loaded_tensors_equals_original( - loaded_state_dict, expected_sd, tensor_parallel_size=1, rank=0 + policy, rl_trainer = await asyncio.gather( + *[ + Policy.options(service_config=service_config).as_service( + **policy_config + ), + MockRLTrainer.options( + procs=trainer_worker_size, with_gpus=True, num_replicas=1 + ).as_service(**trainer_cfg), + ] ) + v0 = uuid.uuid4().int + v1 = v0 + 1 + + await rl_trainer.push_weights_hf_nonsharded.call(policy_version=v0) + # Setting everything to zero + await rl_trainer.mock_train_step.call() + await rl_trainer.push_weights_hf_nonsharded.call(policy_version=v1) + await policy._test_save_model_params.call() + + # Sanity check that before update all the tests pass + all_errs = await policy._test_validate_model_params.call(validate_fn) + for errs in all_errs: + for _, e in errs.items(): + assert not e, f"Validation failed with exception: {e}" + + await policy.update_weights.call(policy_version=v1) + all_errs = await policy._test_validate_model_params.call(validate_fn_all_zeros) + for errs in all_errs: + for _, e in errs.items(): + assert not e, f"Validation failed with exception: {e}" + + # Reloading v0, getting back original weights + await policy.update_weights.call(policy_version=v0) + all_errs = await policy._test_validate_model_params.call(validate_fn) + for errs in all_errs: + for _, e in errs.items(): + assert not e, f"Validation failed with exception: {e}" + + await ts.shutdown() + @pytest.mark.asyncio @requires_cuda - async def test_policy_update_tp(self, expected_sd, trainer_cfg_tp): + async def test_policy_update_tp(self, trainer_cfg_tp): """ - 1. Init RLTrainer over multiple workers with TP parallelism strategy. - 2. Push weights in to torchstore. - 3. Initializes Policy over multiple workers, and calls update_weights() to load weights from torchstore. - 4. Validate the policy weights against manually loaded origina HF weights. + Test the weight synchronization process between RLTrainer and Policy. + + This test performs the following steps: + - Initialize trainer and push weights v0 (original huggingface ckpt) + - Step the trainer, setting all weights to zero and push weights v1 + - Load weights v0 and check the policy has all zero weights + - Load weights v1 and check the policy has all the weights back """ # test configs/paralleism trainer_worker_size = 2 @@ -290,31 +266,48 @@ async def test_policy_update_tp(self, expected_sd, trainer_cfg_tp): pytest.skip( f"Only {torch.cuda.device_count()} GPU(s) available, need 2+ for tensor parallel" ) - # 1. Initialize TS + await ts.initialize() - # 2. Trainer push - rl_trainer = await RLTrainer.options( - procs=trainer_worker_size, with_gpus=True, num_replicas=1 - ).as_service(**trainer_cfg_tp) - await rl_trainer.push_weights.call(policy_version=0) - # 3. Policy pull weights policy_config, service_config = get_configs( worker_size=policy_worker_size, tp_size=tp_size, model_name=self.model ) - policy = await Policy.options(service_config=service_config).as_service( - **policy_config + policy, rl_trainer = await asyncio.gather( + *[ + Policy.options(service_config=service_config).as_service( + **policy_config + ), + MockRLTrainer.options( + procs=trainer_worker_size, with_gpus=True, num_replicas=1 + ).as_service(**trainer_cfg_tp), + ] ) - await policy.update_weights.call() - - # validate loaded shard of each worker againt manually calculated shard (expected shard). - # 4. Validate weight shards. We compare vLLM loades shard content with - # Directly loaded HF shard content. - sharded_state_dicts = await policy._get_model_params.call() - validate_loaded_tensors_equals_original( - sharded_state_dicts[0][0], expected_sd, tensor_parallel_size=tp_size, rank=0 - ) - validate_loaded_tensors_equals_original( - sharded_state_dicts[0][1], expected_sd, tensor_parallel_size=tp_size, rank=1 - ) + v0 = uuid.uuid4().int + v1 = v0 + 1 + + await rl_trainer.push_weights_hf_nonsharded.call(policy_version=v0) + # Setting everything to zero + await rl_trainer.mock_train_step.call() + await rl_trainer.push_weights_hf_nonsharded.call(policy_version=v1) + await policy._test_save_model_params.call() + + # Sanity check that before update all the tests pass + all_errs = await policy._test_validate_model_params.call(validate_fn) + for errs in all_errs: + for _, e in errs.items(): + assert not e, f"Validation failed with exception: {e}" + + await policy.update_weights.call(policy_version=v1) + all_errs = await policy._test_validate_model_params.call(validate_fn_all_zeros) + for errs in all_errs: + for _, e in errs.items(): + assert not e, f"Validation failed with exception: {e}" + # Reloading v0, getting back original weights + await policy.update_weights.call(policy_version=v0) + all_errs = await policy._test_validate_model_params.call(validate_fn) + for errs in all_errs: + for _, e in errs.items(): + assert not e, f"Validation failed with exception: {e}" + + await ts.shutdown()