66
77import asyncio
88import logging
9- from tempfile import TemporaryDirectory
9+ import shutil
10+ from pathlib import Path
1011
1112import pytest
13+ import pytest_asyncio
1214
1315import torch
1416import torchstore as ts
1517from forge .actors .generator import Generator
1618
1719from forge .actors .trainer import RLTrainer
1820from forge .cli .config import resolve_hf_hub_paths
21+ from forge .controller .provisioner import init_provisioner
1922
2023from forge .controller .service .service import uuid
24+ from forge .types import LauncherConfig , ProvisionerConfig
2125from monarch .actor import endpoint
2226
2327from omegaconf import DictConfig , OmegaConf
3539"""
3640Run tests:
3741
38- pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync::test_sanity_check \
39- --config tests/integration_tests/artifacts /qwen3_1_7b_tp.yaml --use_dcp=false
42+ PYTHONPATH=. pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync::test_sanity_check \
43+ --config tests/integration_tests/fixtures /qwen3_1_7b_tp.yaml --use_dcp=false
4044
41- pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync::test_sanity_check \
45+ PYTHONPATH=. pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync::test_sanity_check \
4246 --config apps/grpo/qwen3_8b.yaml
4347"""
4448
49+ # Temp directory won't work for multi-node because NFS does not cover the tmp path
50+ TEST_DCP_DIR = "test_dcp_tmp"
51+
4552
4653class MockRLTrainer (RLTrainer ):
4754 @endpoint
@@ -58,13 +65,27 @@ async def zero_out_model_states(self):
5865 sd [k ] *= 0.0
5966
6067
61- # exceptions sometimes are not propogated in monarch, do it manually
62- def validate_fn (prev_params , curr_model , logger ) -> Exception | None :
68+ def _load_config (config_path : str ) -> DictConfig :
69+ cfg = None
70+ try :
71+ cfg = OmegaConf .load (config_path )
72+ except Exception as e :
73+ pytest .fail (f"Failed to load config file { config_path } : { e } " )
74+
75+ assert isinstance (cfg , DictConfig )
76+
77+ cfg = resolve_hf_hub_paths (cfg )
78+ return cfg
79+
80+
81+ def _test_validate_params_unchanged (
82+ prev_params , curr_model , logger
83+ ) -> Exception | None :
6384 """Validate that current parameters are the same as prev_params."""
6485 verified = set ()
6586 skipped = set ()
6687 logger .info (
67- f"Validating model params, all named_parameters() = { curr_model .named_parameters ()} "
88+ f"Validating model params, all named_parameters() = { curr_model .named_parameters ()} "
6889 )
6990 errs = []
7091 for name , param in curr_model .named_parameters ():
@@ -83,7 +104,6 @@ def validate_fn(prev_params, curr_model, logger) -> Exception | None:
83104 )
84105 verified .add (name )
85106 except Exception as e :
86- # logger.error(f"Validation failed with exception: {e}")
87107 errs .append ((name , e ))
88108 logger .info (f"Verified params = { verified } " )
89109 logger .info (f"Skipped params = { skipped } " )
@@ -94,14 +114,15 @@ def validate_fn(prev_params, curr_model, logger) -> Exception | None:
94114 return AssertionError (f"Validation failed: { errs } " )
95115
96116
97- # exceptions sometimes are not propogated in monarch, do it manually
98- def validate_fn_all_zeros (prev_params , curr_model , logger ) -> Exception | None :
117+ def _test_validate_params_all_zeros (
118+ prev_params , curr_model , logger
119+ ) -> Exception | None :
99120 """Validate all parameters are set to zero. prev_params is actually not used."""
100121 _ = prev_params
101122 verified = set ()
102123 skipped = set ()
103124 logger .info (
104- f"Validating model params, all named_parameters() = { curr_model .named_parameters ()} "
125+ f"Validating model params, all named_parameters() = { curr_model .named_parameters ()} "
105126 )
106127 errs = []
107128 for name , param in curr_model .named_parameters ():
@@ -113,10 +134,9 @@ def validate_fn_all_zeros(prev_params, curr_model, logger) -> Exception | None:
113134 param = param .cpu ()
114135 assert torch .allclose (
115136 torch .zeros_like (param ), param , atol = 1e-4 , rtol = 1e-3
116- ), "param {name} is not zero."
137+ ), f "param { name } is not zero."
117138 verified .add (name )
118139 except Exception as e :
119- # logger.error(f"Validation failed with exception: {e}")
120140 errs .append ((name , e ))
121141 logger .info (f"Verified params = { verified } " )
122142 logger .info (f"Skipped params = { skipped } " )
@@ -127,24 +147,93 @@ def validate_fn_all_zeros(prev_params, curr_model, logger) -> Exception | None:
127147 return AssertionError (f"Validation failed: { errs } " )
128148
129149
130- class TestWeightSync :
131- """Tests for weight sync between trainer and policy."""
150+ @pytest_asyncio .fixture (autouse = True )
151+ async def _setup_and_teardown (request ):
152+ # ---- setup ---- #
153+ config_path = request .config .getoption ("--config" , default = None )
154+ if not config_path :
155+ pytest .skip (
156+ "No config file provided. Use --config <path> to specify a YAML config file"
157+ )
132158
133- def _load_config (self , config_path : str ) -> DictConfig :
134- cfg = None
135- try :
136- cfg = OmegaConf .load (config_path )
137- except Exception as e :
138- pytest .fail (f"Failed to load config file { config_path } : { e } " )
159+ use_dcp_override = request .config .getoption ("--use_dcp" )
160+ cfg = _load_config (config_path = config_path )
161+
162+ trainer_proc_size = cfg .actors .trainer .procs
163+ policy_tp_size = cfg .policy .engine_args .tensor_parallel_size
164+
165+ if policy_tp_size != cfg .services .policy .procs :
166+ pytest .fail (
167+ f"Expect policy proc = { cfg .services .policy .procs } to be equal to tensor parallel size = { policy_tp_size } "
168+ )
169+
170+ model_card = cfg .model
171+ logger .info (f"Running sanity check with config: { config_path } " )
172+ logger .info (f"Model name: { model_card } " )
173+ logger .info (f"Trainer proc size: { trainer_proc_size } " )
174+ logger .info (f"Policy tensor parallel size: { policy_tp_size } " )
175+
176+ logger .info ("Downloading model checkpoint from HuggingFace Hub" )
177+ cached_dir = snapshot_download (repo_id = model_card )
178+ logger .info ("Finished downloading model checkpoint from HuggingFace Hub" )
179+
180+ services_policy_cfg = cfg .services .policy
181+ services_policy_cfg .num_replicas = 1
182+
183+ trainer_cfg = cfg .trainer
184+ trainer_cfg .dcp_path = TEST_DCP_DIR
185+ trainer_cfg .checkpoint = {
186+ "enable" : True ,
187+ "folder" : "/tmp/saved_checkpoints" ,
188+ "initial_load_path" : cached_dir ,
189+ "initial_load_in_hf" : True ,
190+ }
191+
192+ if use_dcp_override is not None :
193+ trainer_cfg ["use_dcp" ] = use_dcp_override
194+ logger .info (f"`trainer.use_dcp` is overriden to { use_dcp_override } " )
195+
196+ if cfg .get ("provisioner" , None ) is not None :
197+ await init_provisioner (
198+ ProvisionerConfig (launcher_config = LauncherConfig (** cfg .provisioner ))
199+ )
200+ await ts .initialize (strategy = ts .ControllerStorageVolumes ())
201+
202+ policy , rl_trainer = await asyncio .gather (
203+ * [
204+ Generator .options (** services_policy_cfg ).as_service (** cfg .policy ),
205+ MockRLTrainer .options (** cfg .actors .trainer ).as_actor (** trainer_cfg ),
206+ ]
207+ )
208+
209+ yield policy , rl_trainer
210+
211+ # ---- teardown ---- #
212+ logger .info ("Shutting down services and cleaning up DCP directory.." )
213+
214+ await asyncio .gather (
215+ policy .shutdown (),
216+ ts .shutdown (),
217+ RLTrainer .shutdown (rl_trainer ),
218+ )
219+
220+ # Cleanup DCP directory
221+ path = Path (TEST_DCP_DIR )
222+ if not path .exists () or not path .is_dir ():
223+ return
224+ try :
225+ shutil .rmtree (path )
226+ logger .info (f"Successfully removed { TEST_DCP_DIR } " )
227+ except Exception as e :
228+ logger .error (f"Failed to remove { TEST_DCP_DIR } : { e } " )
139229
140- assert isinstance (cfg , DictConfig )
141230
142- cfg = resolve_hf_hub_paths ( cfg )
143- return cfg
231+ class TestWeightSync :
232+ """Tests for weight sync between trainer and policy."""
144233
145234 @pytest .mark .asyncio
146235 @requires_cuda
147- async def test_sanity_check (self , request ):
236+ async def test_sanity_check (self , _setup_and_teardown ):
148237 """
149238 Sanity check for weight sync sharding between RLTrainer and Policy for a given model config.
150239
@@ -155,89 +244,41 @@ async def test_sanity_check(self, request):
155244 - Load weights v1 and check the policy has all the weights back
156245
157246 """
158- # Test setup
159- config_path = request .config .getoption ("--config" , default = None )
160- if not config_path :
161- pytest .skip (
162- "No config file provided. Use --config <path> to specify a YAML config file"
163- )
164247
165- use_dcp_override = request .config .getoption ("--use_dcp" )
166- cfg = self ._load_config (config_path = config_path )
248+ policy , rl_trainer = _setup_and_teardown
167249
168- trainer_proc_size = cfg . actors . trainer . procs
169- policy_tp_size = cfg . policy . engine_args . tensor_parallel_size
250+ v0 = uuid . uuid4 (). int
251+ v1 = v0 + 1
170252
171- if policy_tp_size != cfg .services .policy .procs :
172- pytest .fail (
173- f"Expect policy proc = { cfg .services .policy .procs } to be equal to tensor parallel size = { policy_tp_size } "
174- )
253+ await rl_trainer .push_weights .call (policy_version = v0 )
254+ # Setting everything to zero
255+ await rl_trainer .zero_out_model_states .call ()
256+ await rl_trainer .push_weights .call (policy_version = v1 )
257+ await policy ._test_save_model_params .fanout ()
175258
176- model_card = cfg .model
177-
178- logger .info (f"Running sanity check with config: { config_path } " )
179- logger .info (f"Model name: { model_card } " )
180- logger .info (f"Trainer proc size: { trainer_proc_size } " )
181- logger .info (f"Policy tensor parallel size: { policy_tp_size } " )
182-
183- logger .info ("Downloading model checkpoint from HuggingFace Hub" )
184- cached_dir = snapshot_download (repo_id = model_card )
185- logger .info ("Finished downloading model checkpoint from HuggingFace Hub" )
186-
187- await ts .initialize ()
188- services_policy_cfg = cfg .services .policy
189- services_policy_cfg .num_replicas = 1
190-
191- trainer_cfg = cfg .trainer
192- trainer_cfg .checkpoint = {
193- "enable" : True ,
194- "folder" : "/tmp/saved_checkpoints" ,
195- "initial_load_path" : cached_dir ,
196- "initial_load_in_hf" : True ,
197- }
198- if use_dcp_override is not None :
199- trainer_cfg ["use_dcp" ] = use_dcp_override
200- logger .info (f"`trainer.use_dcp` is overriden to { use_dcp_override } " )
201-
202- with TemporaryDirectory (dir = "/dev/shm/" ) as tmpdir :
203- trainer_cfg ["dcp_path" ] = tmpdir
204- policy , rl_trainer = await asyncio .gather (
205- * [
206- Generator .options (** services_policy_cfg ).as_service (** cfg .policy ),
207- MockRLTrainer .options (** cfg .actors .trainer ).as_actor (** trainer_cfg ),
208- ]
209- )
259+ # Sanity check that before update all the tests pass
260+ all_errs = await policy ._test_validate_model_params .fanout (
261+ _test_validate_params_unchanged
262+ )
263+ for errs in all_errs :
264+ for _ , e in errs .items ():
265+ assert not e , f"Validation failed with exception: { e } "
210266
211- # Main logic begins here
212- v0 = uuid .uuid4 ().int
213- v1 = v0 + 1
214-
215- await rl_trainer .push_weights .call (policy_version = v0 )
216- # Setting everything to zero
217- await rl_trainer .zero_out_model_states .call ()
218- await rl_trainer .push_weights .call (policy_version = v1 )
219- await policy ._test_save_model_params .fanout ()
220-
221- # Sanity check that before update all the tests pass
222- all_errs = await policy ._test_validate_model_params .fanout (validate_fn )
223- for errs in all_errs :
224- for _ , e in errs .items ():
225- assert not e , f"Validation failed with exception: { e } "
226-
227- await policy .update_weights .fanout (version = v1 )
228- all_errs = await policy ._test_validate_model_params .fanout (
229- validate_fn_all_zeros
230- )
231- for errs in all_errs :
232- for _ , e in errs .items ():
233- assert not e , f"Validation failed with exception: { e } "
234-
235- # Reloading v0, getting back original weights
236- await policy .update_weights .fanout (version = v0 )
237- all_errs = await policy ._test_validate_model_params .fanout (validate_fn )
238- for errs in all_errs :
239- for _ , e in errs .items ():
240- assert not e , f"Validation failed with exception: { e } "
241-
242- logger .info ("✅ Weight sharding sanity check passed!" )
243- await ts .shutdown ()
267+ await policy .update_weights .fanout (version = v1 )
268+ all_errs = await policy ._test_validate_model_params .fanout (
269+ _test_validate_params_all_zeros
270+ )
271+ for errs in all_errs :
272+ for _ , e in errs .items ():
273+ assert not e , f"Validation failed with exception: { e } "
274+
275+ # Reloading v0, getting back original weights
276+ await policy .update_weights .fanout (version = v0 )
277+ all_errs = await policy ._test_validate_model_params .fanout (
278+ _test_validate_params_unchanged
279+ )
280+ for errs in all_errs :
281+ for _ , e in errs .items ():
282+ assert not e , f"Validation failed with exception: { e } "
283+
284+ logger .info ("✅ Weight sharding sanity check passed!" )
0 commit comments