44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7- import os
8- from typing import Tuple
7+ from typing import Dict , Tuple
98
109import pytest
1110import pytest_asyncio
1514from forge .actors .policy import Policy , PolicyConfig , SamplingOverrides , WorkerConfig
1615from forge .controller .service import ServiceConfig , spawn_service
1716from forge .data .sharding import VLLMSharding
18- from monarch .actor import proc_mesh
1917from torchstore import MultiProcessStore
2018from torchstore ._state_dict_utils import push_state_dict
2119from transformers import AutoModelForCausalLM
2220
23- from vllm .utils import get_open_port
24-
2521requires_cuda = pytest .mark .skipif (
2622 not torch .cuda .is_available (),
2723 reason = "CUDA not available" ,
@@ -197,41 +193,37 @@ def get_configs(
197193 return policy_config , service_config
198194
199195
200- async def run_policy_integration (store , original_state_dict , num_gpus ):
196+ async def run_policy_integration (
197+ store , original_state_dict , worker_size
198+ ) -> Dict [str , torch .Tensor ]:
201199 """
202200 Common helper function to test Policy integration with different GPU configurations.
203201
204202 Args:
205203 store: TorchStore instance
206204 original_state_dict: Original state dict for validation
207205 num_gpus: Number of GPUs to use (1 for single GPU, 2+ for tensor parallel)
208- test_name: Name for test identification in validation messages
209206 """
210- print (f"=== PHASE 2: Testing Policy Integration (GPUs: { num_gpus } ) ===" )
211-
212- state_dict_key = "llama3_8b_state_dict"
207+ print (f"=== PHASE 2: Testing Policy Integration (Workers: { worker_size } ) ===" )
213208
214- policy_config , service_config = get_configs (1 , "meta-llama/Llama-3.1-8B-Instruct" )
209+ policy_config , service_config = get_configs (
210+ worker_size = 1 , model_name = "meta-llama/Llama-3.1-8B-Instruct"
211+ )
215212 policy = await spawn_service (
216213 service_config , Policy , config = policy_config , store = store
217214 )
215+
216+ # Policy engine start with default version 0 that gets incremented.
218217 print ("Calling Policy.update() to load weights from torchstore..." )
219218 await policy .update_weights .call ()
220219 print (
221220 "Successfully called Policy.update_weights() to load weights from torchstore!"
222221 )
223-
224- model_params = await policy .get_model_params .call ()
225- loaded_state_dict = (
226- model_params ._values [0 ] if hasattr (model_params , "_values" ) else model_params
227- )
222+ # We get the result as a list.
223+ results = await policy .get_model_params .call ()
224+ assert len (results ) == 1
228225 print ("Successfully got model state dict after update" )
229-
230- # validate_loaded_tensors_equals_original(
231- # loaded_state_dict, original_state_dict, tensor_parallel_size=num_gpus, rank=rank
232- # )
233-
234- print ("Test passed! State dict successfully loaded into Policy!" )
226+ return results [0 ]
235227
236228
237229@pytest_asyncio .fixture (scope = "session" )
@@ -261,7 +253,7 @@ async def llama3_torchstore_setup():
261253 converted_state_dict = convert_state_dict (original_state_dict )
262254 print (f"Converted state dict has { len (converted_state_dict )} parameters" )
263255
264- state_dict_key = "llama3_8b_state_dict"
256+ state_dict_key = "model_state_dict/1" # {app_namespace}/{version}
265257 await save_state_dict (store , converted_state_dict , state_dict_key )
266258 print (
267259 f"Successfully wrote converted state dict to torchstore with key: { state_dict_key } "
@@ -277,27 +269,34 @@ async def test_llama3_policy_update_single(llama3_torchstore_setup):
277269
278270 store , original_state_dict = llama3_torchstore_setup
279271
280- await run_policy_integration (store , original_state_dict , num_gpus = 1 )
272+ loaded_state_dict = await run_policy_integration (
273+ store , original_state_dict , worker_size = 1
274+ )
275+
276+ # validating for single resource case.
277+ validate_loaded_tensors_equals_original (
278+ loaded_state_dict , original_state_dict , tensor_parallel_size = 0 , rank = 0
279+ )
281280
282281 print (
283282 "Single GPU test passed! Llama 3.1 8B-Instruct model successfully loaded into Policy via TorchStore!"
284283 )
285284
286285
287- @pytest .mark .asyncio
288- @requires_cuda
289- async def test_llama3_policy_update_tp (llama3_torchstore_setup ):
290- print ("Starting tensor parallel test (load full state dict into sharded model)..." )
291-
292- if torch .cuda .device_count () < 2 :
293- pytest .skip (
294- f"Only { torch .cuda .device_count ()} GPU(s) available, need 2+ for tensor parallel"
295- )
296-
297- store , original_state_dict = llama3_torchstore_setup
298-
299- await run_policy_integration (store , original_state_dict , num_gpus = 2 )
300-
301- print (
302- "Tensor parallel test passed! Full state dict successfully loaded into tensor parallel model!"
303- )
286+ # @pytest.mark.asyncio
287+ # @requires_cuda
288+ # async def test_llama3_policy_update_tp(llama3_torchstore_setup):
289+ # print("Starting tensor parallel test (load full state dict into sharded model)...")
290+ #
291+ # if torch.cuda.device_count() < 2:
292+ # pytest.skip(
293+ # f"Only {torch.cuda.device_count()} GPU(s) available, need 2+ for tensor parallel"
294+ # )
295+ #
296+ # store, original_state_dict = llama3_torchstore_setup
297+ #
298+ # await run_policy_integration(store, original_state_dict, num_gpus=2)
299+ #
300+ # print(
301+ # "Tensor parallel test passed! Full state dict successfully loaded into tensor parallel model!"
302+ # )
0 commit comments