Skip to content
50 changes: 34 additions & 16 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from __future__ import annotations

import asyncio

import logging
import os
import sys
Expand Down Expand Up @@ -393,18 +392,26 @@ async def update_weights(self, policy_version: int):
logger.info(f"Weight update completed (now v{self.policy_version})")

@endpoint
async def _get_model_params(self) -> dict[str, torch.Tensor]:
"""Get the current model parameters. Only for testing purposes."""
val_mesh = await self.policy_worker._get_model_params.call()
sharded_state_dicts = {}
for idx, val in val_mesh.items():
sharded_state_dicts[idx["gpus"]] = val
return sharded_state_dicts
async def get_version(self) -> int:
"""Get the current policy version."""
return self.policy_version

@endpoint
async def stop(self):
self.running = False

@endpoint
async def _test_save_model_params(self):
"""Save model parameters before weight update, used for tesing purposes only."""
logger.info("[Policy] save model parameters for testing.")
await self.policy_worker._test_save_model_params.call()

@endpoint
async def _test_validate_model_params(self, validate_fn):
"""Validate updated model params using validate_fn."""
logger.info("[Policy] start validating model parameters.")
return await self.policy_worker._test_validate_model_params.call(validate_fn)

def _to_completions(self, request_output: RequestOutput) -> list[Completion]:
"""Convert a RequestOutput to a list of Completion objects."""
completions = []
Expand Down Expand Up @@ -449,6 +456,9 @@ class PolicyWorker(ForgeActor):
state_dict_key: str = "model_state_dict"
use_dcp: bool = True

# used for tesing purposes only
_test_prev_params = {}

def __post_init__(self):
super().__init__()

Expand Down Expand Up @@ -541,15 +551,23 @@ async def setup_kv_cache(self):
return kv_cache_config

@endpoint
async def _get_model_params(self) -> dict[str, torch.Tensor]:
model = self.worker.model_runner.model
state_dict = {}
async def _test_save_model_params(self):
"""Save model parameters before weight update, used for tesing purposes only."""
logger.info("[PolicyWorker] save model parameters for testing.")
for name, param in self.worker.model_runner.model.named_parameters():
self._test_prev_params[name] = param.detach().cpu()
logger.info(
"[PolicyWorker] finished saving model parameters, len = %d",
len(self._test_prev_params),
)

for name, param in model.named_parameters():
if "layers.0" not in name:
continue
state_dict[name] = param.cpu().detach()
return state_dict
@endpoint
async def _test_validate_model_params(self, validate_fn):
"""Validate updated model params using validate_fn."""
logger.info("[PolicyWorker] start validating model parameters.")
return validate_fn(
self._test_prev_params, self.worker.model_runner.model, logger
)

def setup_worker(self):
"""Build and Instantiate vLLM worker"""
Expand Down
5 changes: 5 additions & 0 deletions tests/integration_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Loading
Loading