Skip to content

Commit f580316

Browse files
JenniferWangallenwang28
authored andcommitted
Fix policy update test (#365)
1 parent 5c07503 commit f580316

File tree

3 files changed

+151
-106
lines changed

3 files changed

+151
-106
lines changed

tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ services:
6666
procs: ${policy.engine_args.tensor_parallel_size}
6767
num_replicas: 1
6868
with_gpus: true
69+
70+
actors:
6971
trainer:
7072
procs: 1
7173
num_replicas: 1

tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ services:
6868
procs: ${policy.engine_args.tensor_parallel_size}
6969
num_replicas: 1
7070
with_gpus: true
71+
72+
actors:
7173
trainer:
7274
procs: 2
7375
num_replicas: 1

tests/integration_tests/test_policy_update.py

Lines changed: 147 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,22 @@
66

77
import asyncio
88
import logging
9-
from tempfile import TemporaryDirectory
9+
import shutil
10+
from pathlib import Path
1011

1112
import pytest
13+
import pytest_asyncio
1214

1315
import torch
1416
import torchstore as ts
1517
from forge.actors.generator import Generator
1618

1719
from forge.actors.trainer import RLTrainer
1820
from forge.cli.config import resolve_hf_hub_paths
21+
from forge.controller.provisioner import init_provisioner
1922

2023
from forge.controller.service.service import uuid
24+
from forge.types import LauncherConfig, ProvisionerConfig
2125
from monarch.actor import endpoint
2226

2327
from omegaconf import DictConfig, OmegaConf
@@ -35,13 +39,16 @@
3539
"""
3640
Run tests:
3741
38-
pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync::test_sanity_check \
39-
--config tests/integration_tests/artifacts/qwen3_1_7b_tp.yaml --use_dcp=false
42+
PYTHONPATH=. pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync::test_sanity_check \
43+
--config tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml --use_dcp=false
4044
41-
pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync::test_sanity_check \
45+
PYTHONPATH=. pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync::test_sanity_check \
4246
--config apps/grpo/qwen3_8b.yaml
4347
"""
4448

49+
# Temp directory won't work for multi-node because NFS does not cover the tmp path
50+
TEST_DCP_DIR = "test_dcp_tmp"
51+
4552

4653
class MockRLTrainer(RLTrainer):
4754
@endpoint
@@ -58,13 +65,27 @@ async def zero_out_model_states(self):
5865
sd[k] *= 0.0
5966

6067

61-
# exceptions sometimes are not propogated in monarch, do it manually
62-
def validate_fn(prev_params, curr_model, logger) -> Exception | None:
68+
def _load_config(config_path: str) -> DictConfig:
69+
cfg = None
70+
try:
71+
cfg = OmegaConf.load(config_path)
72+
except Exception as e:
73+
pytest.fail(f"Failed to load config file {config_path}: {e}")
74+
75+
assert isinstance(cfg, DictConfig)
76+
77+
cfg = resolve_hf_hub_paths(cfg)
78+
return cfg
79+
80+
81+
def _test_validate_params_unchanged(
82+
prev_params, curr_model, logger
83+
) -> Exception | None:
6384
"""Validate that current parameters are the same as prev_params."""
6485
verified = set()
6586
skipped = set()
6687
logger.info(
67-
f"Validating model params, all named_parameters() = {curr_model.named_parameters()}"
88+
f"Validating model params, all named_parameters() = {curr_model.named_parameters()}"
6889
)
6990
errs = []
7091
for name, param in curr_model.named_parameters():
@@ -83,7 +104,6 @@ def validate_fn(prev_params, curr_model, logger) -> Exception | None:
83104
)
84105
verified.add(name)
85106
except Exception as e:
86-
# logger.error(f"Validation failed with exception: {e}")
87107
errs.append((name, e))
88108
logger.info(f"Verified params = {verified}")
89109
logger.info(f"Skipped params = {skipped}")
@@ -94,14 +114,15 @@ def validate_fn(prev_params, curr_model, logger) -> Exception | None:
94114
return AssertionError(f"Validation failed: {errs}")
95115

96116

97-
# exceptions sometimes are not propogated in monarch, do it manually
98-
def validate_fn_all_zeros(prev_params, curr_model, logger) -> Exception | None:
117+
def _test_validate_params_all_zeros(
118+
prev_params, curr_model, logger
119+
) -> Exception | None:
99120
"""Validate all parameters are set to zero. prev_params is actually not used."""
100121
_ = prev_params
101122
verified = set()
102123
skipped = set()
103124
logger.info(
104-
f"Validating model params, all named_parameters() = {curr_model.named_parameters()}"
125+
f"Validating model params, all named_parameters() = {curr_model.named_parameters()}"
105126
)
106127
errs = []
107128
for name, param in curr_model.named_parameters():
@@ -113,10 +134,9 @@ def validate_fn_all_zeros(prev_params, curr_model, logger) -> Exception | None:
113134
param = param.cpu()
114135
assert torch.allclose(
115136
torch.zeros_like(param), param, atol=1e-4, rtol=1e-3
116-
), "param {name} is not zero."
137+
), f"param {name} is not zero."
117138
verified.add(name)
118139
except Exception as e:
119-
# logger.error(f"Validation failed with exception: {e}")
120140
errs.append((name, e))
121141
logger.info(f"Verified params = {verified}")
122142
logger.info(f"Skipped params = {skipped}")
@@ -127,24 +147,93 @@ def validate_fn_all_zeros(prev_params, curr_model, logger) -> Exception | None:
127147
return AssertionError(f"Validation failed: {errs}")
128148

129149

130-
class TestWeightSync:
131-
"""Tests for weight sync between trainer and policy."""
150+
@pytest_asyncio.fixture(autouse=True)
151+
async def _setup_and_teardown(request):
152+
# ---- setup ---- #
153+
config_path = request.config.getoption("--config", default=None)
154+
if not config_path:
155+
pytest.skip(
156+
"No config file provided. Use --config <path> to specify a YAML config file"
157+
)
132158

133-
def _load_config(self, config_path: str) -> DictConfig:
134-
cfg = None
135-
try:
136-
cfg = OmegaConf.load(config_path)
137-
except Exception as e:
138-
pytest.fail(f"Failed to load config file {config_path}: {e}")
159+
use_dcp_override = request.config.getoption("--use_dcp")
160+
cfg = _load_config(config_path=config_path)
161+
162+
trainer_proc_size = cfg.actors.trainer.procs
163+
policy_tp_size = cfg.policy.engine_args.tensor_parallel_size
164+
165+
if policy_tp_size != cfg.services.policy.procs:
166+
pytest.fail(
167+
f"Expect policy proc = {cfg.services.policy.procs} to be equal to tensor parallel size = {policy_tp_size}"
168+
)
169+
170+
model_card = cfg.model
171+
logger.info(f"Running sanity check with config: {config_path}")
172+
logger.info(f"Model name: {model_card}")
173+
logger.info(f"Trainer proc size: {trainer_proc_size}")
174+
logger.info(f"Policy tensor parallel size: {policy_tp_size}")
175+
176+
logger.info("Downloading model checkpoint from HuggingFace Hub")
177+
cached_dir = snapshot_download(repo_id=model_card)
178+
logger.info("Finished downloading model checkpoint from HuggingFace Hub")
179+
180+
services_policy_cfg = cfg.services.policy
181+
services_policy_cfg.num_replicas = 1
182+
183+
trainer_cfg = cfg.trainer
184+
trainer_cfg.dcp_path = TEST_DCP_DIR
185+
trainer_cfg.checkpoint = {
186+
"enable": True,
187+
"folder": "/tmp/saved_checkpoints",
188+
"initial_load_path": cached_dir,
189+
"initial_load_in_hf": True,
190+
}
191+
192+
if use_dcp_override is not None:
193+
trainer_cfg["use_dcp"] = use_dcp_override
194+
logger.info(f"`trainer.use_dcp` is overriden to {use_dcp_override}")
195+
196+
if cfg.get("provisioner", None) is not None:
197+
await init_provisioner(
198+
ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner))
199+
)
200+
await ts.initialize(strategy=ts.ControllerStorageVolumes())
201+
202+
policy, rl_trainer = await asyncio.gather(
203+
*[
204+
Generator.options(**services_policy_cfg).as_service(**cfg.policy),
205+
MockRLTrainer.options(**cfg.actors.trainer).as_actor(**trainer_cfg),
206+
]
207+
)
208+
209+
yield policy, rl_trainer
210+
211+
# ---- teardown ---- #
212+
logger.info("Shutting down services and cleaning up DCP directory..")
213+
214+
await asyncio.gather(
215+
policy.shutdown(),
216+
ts.shutdown(),
217+
RLTrainer.shutdown(rl_trainer),
218+
)
219+
220+
# Cleanup DCP directory
221+
path = Path(TEST_DCP_DIR)
222+
if not path.exists() or not path.is_dir():
223+
return
224+
try:
225+
shutil.rmtree(path)
226+
logger.info(f"Successfully removed {TEST_DCP_DIR}")
227+
except Exception as e:
228+
logger.error(f"Failed to remove {TEST_DCP_DIR}: {e}")
139229

140-
assert isinstance(cfg, DictConfig)
141230

142-
cfg = resolve_hf_hub_paths(cfg)
143-
return cfg
231+
class TestWeightSync:
232+
"""Tests for weight sync between trainer and policy."""
144233

145234
@pytest.mark.asyncio
146235
@requires_cuda
147-
async def test_sanity_check(self, request):
236+
async def test_sanity_check(self, _setup_and_teardown):
148237
"""
149238
Sanity check for weight sync sharding between RLTrainer and Policy for a given model config.
150239
@@ -155,89 +244,41 @@ async def test_sanity_check(self, request):
155244
- Load weights v1 and check the policy has all the weights back
156245
157246
"""
158-
# Test setup
159-
config_path = request.config.getoption("--config", default=None)
160-
if not config_path:
161-
pytest.skip(
162-
"No config file provided. Use --config <path> to specify a YAML config file"
163-
)
164247

165-
use_dcp_override = request.config.getoption("--use_dcp")
166-
cfg = self._load_config(config_path=config_path)
248+
policy, rl_trainer = _setup_and_teardown
167249

168-
trainer_proc_size = cfg.actors.trainer.procs
169-
policy_tp_size = cfg.policy.engine_args.tensor_parallel_size
250+
v0 = uuid.uuid4().int
251+
v1 = v0 + 1
170252

171-
if policy_tp_size != cfg.services.policy.procs:
172-
pytest.fail(
173-
f"Expect policy proc = {cfg.services.policy.procs} to be equal to tensor parallel size = {policy_tp_size}"
174-
)
253+
await rl_trainer.push_weights.call(policy_version=v0)
254+
# Setting everything to zero
255+
await rl_trainer.zero_out_model_states.call()
256+
await rl_trainer.push_weights.call(policy_version=v1)
257+
await policy._test_save_model_params.fanout()
175258

176-
model_card = cfg.model
177-
178-
logger.info(f"Running sanity check with config: {config_path}")
179-
logger.info(f"Model name: {model_card}")
180-
logger.info(f"Trainer proc size: {trainer_proc_size}")
181-
logger.info(f"Policy tensor parallel size: {policy_tp_size}")
182-
183-
logger.info("Downloading model checkpoint from HuggingFace Hub")
184-
cached_dir = snapshot_download(repo_id=model_card)
185-
logger.info("Finished downloading model checkpoint from HuggingFace Hub")
186-
187-
await ts.initialize()
188-
services_policy_cfg = cfg.services.policy
189-
services_policy_cfg.num_replicas = 1
190-
191-
trainer_cfg = cfg.trainer
192-
trainer_cfg.checkpoint = {
193-
"enable": True,
194-
"folder": "/tmp/saved_checkpoints",
195-
"initial_load_path": cached_dir,
196-
"initial_load_in_hf": True,
197-
}
198-
if use_dcp_override is not None:
199-
trainer_cfg["use_dcp"] = use_dcp_override
200-
logger.info(f"`trainer.use_dcp` is overriden to {use_dcp_override}")
201-
202-
with TemporaryDirectory(dir="/dev/shm/") as tmpdir:
203-
trainer_cfg["dcp_path"] = tmpdir
204-
policy, rl_trainer = await asyncio.gather(
205-
*[
206-
Generator.options(**services_policy_cfg).as_service(**cfg.policy),
207-
MockRLTrainer.options(**cfg.actors.trainer).as_actor(**trainer_cfg),
208-
]
209-
)
259+
# Sanity check that before update all the tests pass
260+
all_errs = await policy._test_validate_model_params.fanout(
261+
_test_validate_params_unchanged
262+
)
263+
for errs in all_errs:
264+
for _, e in errs.items():
265+
assert not e, f"Validation failed with exception: {e}"
210266

211-
# Main logic begins here
212-
v0 = uuid.uuid4().int
213-
v1 = v0 + 1
214-
215-
await rl_trainer.push_weights.call(policy_version=v0)
216-
# Setting everything to zero
217-
await rl_trainer.zero_out_model_states.call()
218-
await rl_trainer.push_weights.call(policy_version=v1)
219-
await policy._test_save_model_params.fanout()
220-
221-
# Sanity check that before update all the tests pass
222-
all_errs = await policy._test_validate_model_params.fanout(validate_fn)
223-
for errs in all_errs:
224-
for _, e in errs.items():
225-
assert not e, f"Validation failed with exception: {e}"
226-
227-
await policy.update_weights.fanout(version=v1)
228-
all_errs = await policy._test_validate_model_params.fanout(
229-
validate_fn_all_zeros
230-
)
231-
for errs in all_errs:
232-
for _, e in errs.items():
233-
assert not e, f"Validation failed with exception: {e}"
234-
235-
# Reloading v0, getting back original weights
236-
await policy.update_weights.fanout(version=v0)
237-
all_errs = await policy._test_validate_model_params.fanout(validate_fn)
238-
for errs in all_errs:
239-
for _, e in errs.items():
240-
assert not e, f"Validation failed with exception: {e}"
241-
242-
logger.info("✅ Weight sharding sanity check passed!")
243-
await ts.shutdown()
267+
await policy.update_weights.fanout(version=v1)
268+
all_errs = await policy._test_validate_model_params.fanout(
269+
_test_validate_params_all_zeros
270+
)
271+
for errs in all_errs:
272+
for _, e in errs.items():
273+
assert not e, f"Validation failed with exception: {e}"
274+
275+
# Reloading v0, getting back original weights
276+
await policy.update_weights.fanout(version=v0)
277+
all_errs = await policy._test_validate_model_params.fanout(
278+
_test_validate_params_unchanged
279+
)
280+
for errs in all_errs:
281+
for _, e in errs.items():
282+
assert not e, f"Validation failed with exception: {e}"
283+
284+
logger.info("✅ Weight sharding sanity check passed!")

0 commit comments

Comments
 (0)