Skip to content
Merged
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ cython_debug/
slogs/
slurm-*

# DCP checkpoints
model_state_dict/

# Celery stuff
celerybeat-schedule
celerybeat.pid
Expand Down
53 changes: 36 additions & 17 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from __future__ import annotations

import asyncio
import logging
import os
import sys
import time
Expand Down Expand Up @@ -50,6 +51,8 @@
from forge.interfaces import Policy as PolicyInterface
from forge.types import ProcessConfig

logger: logging.Logger = logging.getLogger(__name__)


@dataclass
class SamplingConfig:
Expand Down Expand Up @@ -382,15 +385,6 @@ async def update_weights(self, policy_version: int):
self.policy_version = policy_version
self.logger.info(f"Weight update completed (now v{self.policy_version})")

@endpoint
async def _get_model_params(self) -> dict[str, torch.Tensor]:
"""Get the current model parameters. Only for testing purposes."""
val_mesh = await self.policy_worker._get_model_params.call()
sharded_state_dicts = {}
for idx, val in val_mesh.items():
sharded_state_dicts[idx["gpus"]] = val
return sharded_state_dicts

@endpoint
async def get_version(self) -> int:
"""Get the current policy version."""
Expand All @@ -400,6 +394,18 @@ async def get_version(self) -> int:
async def stop(self):
self.running = False

@endpoint
async def _test_save_model_params(self):
"""Save model parameters before weight update, used for tesing purposes only."""
logger.info("[Policy] start saving model parameters before update for testing")
await self.policy_worker._test_save_model_params.call()

@endpoint
async def _test_validate_model_params(self, validate_fn):
"""Validate updated model params using validate_fn."""
logger.info("[Policy] start validating model parameters post update")
return await self.policy_worker._test_validate_model_params.call(validate_fn)

Comment on lines 410 to 414
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really smart! Having the validate_fn as an argument greatly simplified the logic of validating the weights on each worker

def _to_completions(self, request_output: RequestOutput) -> list[Completion]:
"""Convert a RequestOutput to a list of Completion objects."""
completions = []
Expand Down Expand Up @@ -443,6 +449,9 @@ class PolicyWorker(ForgeActor):
state_dict_key: str = "model_state_dict"
use_dcp: bool = True

# used for tesing purposes only
_test_prev_params = {}

@endpoint
async def setup(self):
# TODO: remove ["gpus"] when monarch implements a flat rank
Expand Down Expand Up @@ -532,15 +541,25 @@ async def setup_kv_cache(self):
return kv_cache_config

@endpoint
async def _get_model_params(self) -> dict[str, torch.Tensor]:
model = self.worker.model_runner.model
state_dict = {}
async def _test_save_model_params(self):
"""Save model parameters before weight update, used for tesing purposes only."""
logger.info(
"[PolicyWorker] start saving model parameters before update for testing"
)
for name, param in self.worker.model_runner.model.named_parameters():
self._test_prev_params[name] = param.detach().cpu()
logger.info(
"[PolicyWorker] finished saving model parameters, len = %d",
len(self._test_prev_params),
)

for name, param in model.named_parameters():
if "layers.0" not in name:
continue
state_dict[name] = param.cpu().detach()
return state_dict
@endpoint
async def _test_validate_model_params(self, validate_fn):
"""Validate updated model params using validate_fn."""
logger.info("[PolicyWorker] start validating model parameters post update")
return validate_fn(
self._test_prev_params, self.worker.model_runner.model, logger
)

def setup_worker(self):
"""Build and Instantiate vLLM worker"""
Expand Down
5 changes: 5 additions & 0 deletions tests/integration_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Loading