@@ -186,15 +186,12 @@ def get_configs(
186186
187187
188188@pytest_asyncio .fixture (scope = "session" )
189- async def llama3_torchstore_setup ():
189+ async def setup_test ():
190190 """
191- Pytest fixture to load Llama 3.1 8B-Instruct. We use the loaded state dict as SOT for validation.
192- Uses session scope so it's only called once when both tests are run.
191+ Pytest fixture to load Llama 3.1 8B-Instruct. We use the loaded state dict
192+ as the SOT for validation. Uses session scope so it's only called once
193+ across UT.
193194 """
194- print ("=== PHASE 1: Writing Llama 3.1 8B-Instruct to TorchStore ===" )
195-
196- store = await ts .initialize ()
197-
198195 model_path = "meta-llama/Meta-Llama-3.1-8B-Instruct"
199196
200197 # Load the model from local path - using device_map="auto" for efficient loading
@@ -207,76 +204,70 @@ async def llama3_torchstore_setup():
207204
208205 original_state_dict = model .state_dict ()
209206 print (f"Original state dict has { len (original_state_dict )} parameters" )
210- converted_state_dict = convert_state_dict (original_state_dict )
211- print (f"Converted state dict has { len (converted_state_dict )} parameters" )
207+ hf_state_dict = convert_state_dict (original_state_dict )
208+ print (f"Converted state dict has { len (hf_state_dict )} parameters" )
212209
213- return store , converted_state_dict
210+ return hf_state_dict
214211
215212
216213async def run_rl_trainer (worker_size ) -> None :
217214 """
218- 1. Spawn the trainer.
219- 2. Inject torchstore references via setup call.
220- 2. Call push weights .
215+ Spawn the RL trainer
216+ Args:
217+ worker_size: Number of workers/procs .
221218 """
222219 cfg : DictConfig = OmegaConf .load ("apps/rl/llama3_8b.yaml" )
223220 rl_trainer = await spawn_service (
224- ServiceConfig (procs_per_replica = 1 , with_gpus = True , num_replicas = 1 ),
221+ ServiceConfig (procs_per_replica = worker_size , with_gpus = True , num_replicas = 1 ),
225222 RLTrainer ,
226223 ** cfg .trainer ,
227224 )
228225 # Push the weights to torchstore
229226 await rl_trainer .push_weights .choose ()
230227
231228
232- async def run_policy_integration (store , worker_size ) -> Dict [str , torch .Tensor ]:
229+ async def run_policy_integration (worker_size ) -> Dict [str , torch .Tensor ]:
233230 """
234- Common helper function to test Policy integration with different GPU configurations .
231+ Launch the policy service .
235232
236233 Args:
237234 store: TorchStore instance
238- original_state_dict: Original state dict for validation
239- num_gpus: Number of GPUs to use (1 for single GPU, 2+ for tensor parallel)
235+ worker_size: Number of workers/procs (2+ for tensor parallel)
240236 """
241- print (f"=== PHASE 2: Testing Policy Integration (Workers: { worker_size } ) ===" )
237+ print (f"=== PHASE 2: Launching Policy Engine (Workers: { worker_size } ) ===" )
242238
243239 policy_config , service_config = get_configs (
244240 worker_size = worker_size , model_name = "meta-llama/Llama-3.1-8B-Instruct"
245241 )
246- policy = await spawn_service (
247- service_config , Policy , config = policy_config , store = store
248- )
242+ policy = await spawn_service (service_config , Policy , config = policy_config )
249243
250244 # Policy engine start with default version 0 that gets incremented.
251245 print ("Calling Policy.update() to load weights from torchstore..." )
252246 await policy .update_weights .call ()
253247 print (
254248 "Successfully called Policy.update_weights() to load weights from torchstore!"
255249 )
256- # We get the result as a list.
257- #results = await policy._get_model_params.call()
258- #assert len(results) == 1
259- #print("Successfully got model state dict after update")
260- #return results[0]
261- return {}
250+ results = await policy ._get_model_params .call ()
251+ assert len (results ) == 1
252+ print ("Successfully got model state dict after update" )
253+ return results [0 ]
262254
263255
264256@pytest .mark .asyncio
265257@requires_cuda
266- async def test_llama3_policy_update_single ():
258+ async def test_llama3_policy_update_single (setup_test ):
267259 print ("Starting Llama 3 8B torchstore test (single GPU)..." )
268260
269- # store, original_state_dict = llama3_torchstore_setup
270261 await ts .initialize ()
262+ expected_state_dict = setup_test
271263 await run_rl_trainer (worker_size = 1 )
272- loaded_state_dict = await run_policy_integration (None , worker_size = 1 )
273- assert False , "Planned failure"
264+ loaded_state_dict = await run_policy_integration (worker_size = 1 )
274265
275266 # validating for single resource case.
276- # validate_loaded_tensors_equals_original(
277- # loaded_state_dict, original_state_dict, tensor_parallel_size=0, rank=0
278- # )
279-
267+ validate_loaded_tensors_equals_original (
268+ loaded_state_dict , expected_state_dict , tensor_parallel_size = 0 , rank = 0
269+ )
280270 print (
281271 "Single GPU test passed! Llama 3.1 8B-Instruct model successfully loaded into Policy via TorchStore!"
282272 )
273+ assert False , "Planned failure"
0 commit comments