Skip to content

Commit f16d27b

Browse files
committed
fix example after policy service APIs.
1 parent 3755490 commit f16d27b

File tree

2 files changed

+47
-42
lines changed

2 files changed

+47
-42
lines changed

src/forge/actors/policy.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,12 @@ async def update_weights(self) -> int:
331331
self.weights_version = new_version
332332
return self.weights_version
333333

334+
@endpoint
335+
async def get_model_params(self) -> Dict[str, torch.Tensor]:
336+
"""Get the current model parameters. Only for testing purposes."""
337+
model_params = await self.policy_worker.get_model_params.choose()
338+
return model_params
339+
334340
@endpoint
335341
async def get_version(self) -> int:
336342
"""Get the current policy version."""
@@ -480,7 +486,7 @@ async def get_vllm_args(self):
480486
return self.vllm_args
481487

482488
@endpoint
483-
async def get_model_params(self):
489+
async def get_model_params(self) -> Dict[str, torch.Tensor]:
484490
model = self.worker.model_runner.model
485491
state_dict = {}
486492

tests/integration_tests/test_policy_update.py

Lines changed: 40 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import os
8-
from typing import Tuple
7+
from typing import Dict, Tuple
98

109
import pytest
1110
import pytest_asyncio
@@ -15,13 +14,10 @@
1514
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
1615
from forge.controller.service import ServiceConfig, spawn_service
1716
from forge.data.sharding import VLLMSharding
18-
from monarch.actor import proc_mesh
1917
from torchstore import MultiProcessStore
2018
from torchstore._state_dict_utils import push_state_dict
2119
from transformers import AutoModelForCausalLM
2220

23-
from vllm.utils import get_open_port
24-
2521
requires_cuda = pytest.mark.skipif(
2622
not torch.cuda.is_available(),
2723
reason="CUDA not available",
@@ -197,41 +193,37 @@ def get_configs(
197193
return policy_config, service_config
198194

199195

200-
async def run_policy_integration(store, original_state_dict, num_gpus):
196+
async def run_policy_integration(
197+
store, original_state_dict, worker_size
198+
) -> Dict[str, torch.Tensor]:
201199
"""
202200
Common helper function to test Policy integration with different GPU configurations.
203201
204202
Args:
205203
store: TorchStore instance
206204
original_state_dict: Original state dict for validation
207205
num_gpus: Number of GPUs to use (1 for single GPU, 2+ for tensor parallel)
208-
test_name: Name for test identification in validation messages
209206
"""
210-
print(f"=== PHASE 2: Testing Policy Integration (GPUs: {num_gpus}) ===")
211-
212-
state_dict_key = "llama3_8b_state_dict"
207+
print(f"=== PHASE 2: Testing Policy Integration (Workers: {worker_size}) ===")
213208

214-
policy_config, service_config = get_configs(1, "meta-llama/Llama-3.1-8B-Instruct")
209+
policy_config, service_config = get_configs(
210+
worker_size=1, model_name="meta-llama/Llama-3.1-8B-Instruct"
211+
)
215212
policy = await spawn_service(
216213
service_config, Policy, config=policy_config, store=store
217214
)
215+
216+
# Policy engine start with default version 0 that gets incremented.
218217
print("Calling Policy.update() to load weights from torchstore...")
219218
await policy.update_weights.call()
220219
print(
221220
"Successfully called Policy.update_weights() to load weights from torchstore!"
222221
)
223-
224-
model_params = await policy.get_model_params.call()
225-
loaded_state_dict = (
226-
model_params._values[0] if hasattr(model_params, "_values") else model_params
227-
)
222+
# We get the result as a list.
223+
results = await policy.get_model_params.call()
224+
assert len(results) == 1
228225
print("Successfully got model state dict after update")
229-
230-
# validate_loaded_tensors_equals_original(
231-
# loaded_state_dict, original_state_dict, tensor_parallel_size=num_gpus, rank=rank
232-
# )
233-
234-
print("Test passed! State dict successfully loaded into Policy!")
226+
return results[0]
235227

236228

237229
@pytest_asyncio.fixture(scope="session")
@@ -261,7 +253,7 @@ async def llama3_torchstore_setup():
261253
converted_state_dict = convert_state_dict(original_state_dict)
262254
print(f"Converted state dict has {len(converted_state_dict)} parameters")
263255

264-
state_dict_key = "llama3_8b_state_dict"
256+
state_dict_key = "model_state_dict/1" # {app_namespace}/{version}
265257
await save_state_dict(store, converted_state_dict, state_dict_key)
266258
print(
267259
f"Successfully wrote converted state dict to torchstore with key: {state_dict_key}"
@@ -277,27 +269,34 @@ async def test_llama3_policy_update_single(llama3_torchstore_setup):
277269

278270
store, original_state_dict = llama3_torchstore_setup
279271

280-
await run_policy_integration(store, original_state_dict, num_gpus=1)
272+
loaded_state_dict = await run_policy_integration(
273+
store, original_state_dict, worker_size=1
274+
)
275+
276+
# validating for single resource case.
277+
validate_loaded_tensors_equals_original(
278+
loaded_state_dict, original_state_dict, tensor_parallel_size=0, rank=0
279+
)
281280

282281
print(
283282
"Single GPU test passed! Llama 3.1 8B-Instruct model successfully loaded into Policy via TorchStore!"
284283
)
285284

286285

287-
@pytest.mark.asyncio
288-
@requires_cuda
289-
async def test_llama3_policy_update_tp(llama3_torchstore_setup):
290-
print("Starting tensor parallel test (load full state dict into sharded model)...")
291-
292-
if torch.cuda.device_count() < 2:
293-
pytest.skip(
294-
f"Only {torch.cuda.device_count()} GPU(s) available, need 2+ for tensor parallel"
295-
)
296-
297-
store, original_state_dict = llama3_torchstore_setup
298-
299-
await run_policy_integration(store, original_state_dict, num_gpus=2)
300-
301-
print(
302-
"Tensor parallel test passed! Full state dict successfully loaded into tensor parallel model!"
303-
)
286+
# @pytest.mark.asyncio
287+
# @requires_cuda
288+
# async def test_llama3_policy_update_tp(llama3_torchstore_setup):
289+
# print("Starting tensor parallel test (load full state dict into sharded model)...")
290+
#
291+
# if torch.cuda.device_count() < 2:
292+
# pytest.skip(
293+
# f"Only {torch.cuda.device_count()} GPU(s) available, need 2+ for tensor parallel"
294+
# )
295+
#
296+
# store, original_state_dict = llama3_torchstore_setup
297+
#
298+
# await run_policy_integration(store, original_state_dict, num_gpus=2)
299+
#
300+
# print(
301+
# "Tensor parallel test passed! Full state dict successfully loaded into tensor parallel model!"
302+
# )

0 commit comments

Comments
 (0)