Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ services:
procs: ${policy.engine_args.tensor_parallel_size}
num_replicas: 1
with_gpus: true

actors:
trainer:
procs: 1
num_replicas: 1
Expand Down
2 changes: 2 additions & 0 deletions tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ services:
procs: ${policy.engine_args.tensor_parallel_size}
num_replicas: 1
with_gpus: true

actors:
trainer:
procs: 2
num_replicas: 1
Expand Down
253 changes: 147 additions & 106 deletions tests/integration_tests/test_policy_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,22 @@

import asyncio
import logging
from tempfile import TemporaryDirectory
import shutil
from pathlib import Path

import pytest
import pytest_asyncio

import torch
import torchstore as ts
from forge.actors.generator import Generator

from forge.actors.trainer import RLTrainer
from forge.cli.config import resolve_hf_hub_paths
from forge.controller.provisioner import init_provisioner

from forge.controller.service.service import uuid
from forge.types import LauncherConfig, ProvisionerConfig
from monarch.actor import endpoint

from omegaconf import DictConfig, OmegaConf
Expand All @@ -35,13 +39,16 @@
"""
Run tests:

pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync::test_sanity_check \
--config tests/integration_tests/artifacts/qwen3_1_7b_tp.yaml --use_dcp=false
PYTHONPATH=. pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync::test_sanity_check \
--config tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml --use_dcp=false

pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync::test_sanity_check \
PYTHONPATH=. pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync::test_sanity_check \
--config apps/grpo/qwen3_8b.yaml
"""

# Temp directory won't work for multi-node because NFS does not cover the tmp path
TEST_DCP_DIR = "test_dcp_tmp"


class MockRLTrainer(RLTrainer):
@endpoint
Expand All @@ -58,13 +65,27 @@ async def zero_out_model_states(self):
sd[k] *= 0.0


# exceptions sometimes are not propogated in monarch, do it manually
def validate_fn(prev_params, curr_model, logger) -> Exception | None:
def _load_config(config_path: str) -> DictConfig:
cfg = None
try:
cfg = OmegaConf.load(config_path)
except Exception as e:
pytest.fail(f"Failed to load config file {config_path}: {e}")

assert isinstance(cfg, DictConfig)

cfg = resolve_hf_hub_paths(cfg)
return cfg


def _test_validate_params_unchanged(
prev_params, curr_model, logger
) -> Exception | None:
"""Validate that current parameters are the same as prev_params."""
verified = set()
skipped = set()
logger.info(
f"Validating model params, all named_parameters() = {curr_model.named_parameters()}"
f"Validating model params, all named_parameters() = {curr_model.named_parameters()}"
)
errs = []
for name, param in curr_model.named_parameters():
Expand All @@ -83,7 +104,6 @@ def validate_fn(prev_params, curr_model, logger) -> Exception | None:
)
verified.add(name)
except Exception as e:
# logger.error(f"Validation failed with exception: {e}")
errs.append((name, e))
logger.info(f"Verified params = {verified}")
logger.info(f"Skipped params = {skipped}")
Expand All @@ -94,14 +114,15 @@ def validate_fn(prev_params, curr_model, logger) -> Exception | None:
return AssertionError(f"Validation failed: {errs}")


# exceptions sometimes are not propogated in monarch, do it manually
def validate_fn_all_zeros(prev_params, curr_model, logger) -> Exception | None:
def _test_validate_params_all_zeros(
prev_params, curr_model, logger
) -> Exception | None:
"""Validate all parameters are set to zero. prev_params is actually not used."""
_ = prev_params
verified = set()
skipped = set()
logger.info(
f"Validating model params, all named_parameters() = {curr_model.named_parameters()}"
f"Validating model params, all named_parameters() = {curr_model.named_parameters()}"
)
errs = []
for name, param in curr_model.named_parameters():
Expand All @@ -113,10 +134,9 @@ def validate_fn_all_zeros(prev_params, curr_model, logger) -> Exception | None:
param = param.cpu()
assert torch.allclose(
torch.zeros_like(param), param, atol=1e-4, rtol=1e-3
), "param {name} is not zero."
), f"param {name} is not zero."
verified.add(name)
except Exception as e:
# logger.error(f"Validation failed with exception: {e}")
errs.append((name, e))
logger.info(f"Verified params = {verified}")
logger.info(f"Skipped params = {skipped}")
Expand All @@ -127,24 +147,93 @@ def validate_fn_all_zeros(prev_params, curr_model, logger) -> Exception | None:
return AssertionError(f"Validation failed: {errs}")


class TestWeightSync:
"""Tests for weight sync between trainer and policy."""
@pytest_asyncio.fixture(autouse=True)
async def _setup_and_teardown(request):
# ---- setup ---- #
config_path = request.config.getoption("--config", default=None)
if not config_path:
pytest.skip(
"No config file provided. Use --config <path> to specify a YAML config file"
)

def _load_config(self, config_path: str) -> DictConfig:
cfg = None
try:
cfg = OmegaConf.load(config_path)
except Exception as e:
pytest.fail(f"Failed to load config file {config_path}: {e}")
use_dcp_override = request.config.getoption("--use_dcp")
cfg = _load_config(config_path=config_path)

trainer_proc_size = cfg.actors.trainer.procs
policy_tp_size = cfg.policy.engine_args.tensor_parallel_size

if policy_tp_size != cfg.services.policy.procs:
pytest.fail(
f"Expect policy proc = {cfg.services.policy.procs} to be equal to tensor parallel size = {policy_tp_size}"
)

model_card = cfg.model
logger.info(f"Running sanity check with config: {config_path}")
logger.info(f"Model name: {model_card}")
logger.info(f"Trainer proc size: {trainer_proc_size}")
logger.info(f"Policy tensor parallel size: {policy_tp_size}")

logger.info("Downloading model checkpoint from HuggingFace Hub")
cached_dir = snapshot_download(repo_id=model_card)
logger.info("Finished downloading model checkpoint from HuggingFace Hub")

services_policy_cfg = cfg.services.policy
services_policy_cfg.num_replicas = 1

trainer_cfg = cfg.trainer
trainer_cfg.dcp_path = TEST_DCP_DIR
trainer_cfg.checkpoint = {
"enable": True,
"folder": "/tmp/saved_checkpoints",
"initial_load_path": cached_dir,
"initial_load_in_hf": True,
}

if use_dcp_override is not None:
trainer_cfg["use_dcp"] = use_dcp_override
logger.info(f"`trainer.use_dcp` is overriden to {use_dcp_override}")

if cfg.get("provisioner", None) is not None:
await init_provisioner(
ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner))
)
await ts.initialize(strategy=ts.ControllerStorageVolumes())

policy, rl_trainer = await asyncio.gather(
*[
Generator.options(**services_policy_cfg).as_service(**cfg.policy),
MockRLTrainer.options(**cfg.actors.trainer).as_actor(**trainer_cfg),
]
)

yield policy, rl_trainer

# ---- teardown ---- #
logger.info("Shutting down services and cleaning up DCP directory..")

await asyncio.gather(
policy.shutdown(),
ts.shutdown(),
RLTrainer.shutdown(rl_trainer),
)

# Cleanup DCP directory
path = Path(TEST_DCP_DIR)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use TemporaryDirectory which handles this automatically

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to change this because NFS does not work for tmp directory

if not path.exists() or not path.is_dir():
return
try:
shutil.rmtree(path)
logger.info(f"Successfully removed {TEST_DCP_DIR}")
except Exception as e:
logger.error(f"Failed to remove {TEST_DCP_DIR}: {e}")

assert isinstance(cfg, DictConfig)

cfg = resolve_hf_hub_paths(cfg)
return cfg
class TestWeightSync:
"""Tests for weight sync between trainer and policy."""

@pytest.mark.asyncio
@requires_cuda
async def test_sanity_check(self, request):
async def test_sanity_check(self, _setup_and_teardown):
"""
Sanity check for weight sync sharding between RLTrainer and Policy for a given model config.

Expand All @@ -155,89 +244,41 @@ async def test_sanity_check(self, request):
- Load weights v1 and check the policy has all the weights back

"""
# Test setup
config_path = request.config.getoption("--config", default=None)
if not config_path:
pytest.skip(
"No config file provided. Use --config <path> to specify a YAML config file"
)

use_dcp_override = request.config.getoption("--use_dcp")
cfg = self._load_config(config_path=config_path)
policy, rl_trainer = _setup_and_teardown

trainer_proc_size = cfg.actors.trainer.procs
policy_tp_size = cfg.policy.engine_args.tensor_parallel_size
v0 = uuid.uuid4().int
v1 = v0 + 1

if policy_tp_size != cfg.services.policy.procs:
pytest.fail(
f"Expect policy proc = {cfg.services.policy.procs} to be equal to tensor parallel size = {policy_tp_size}"
)
await rl_trainer.push_weights.call(policy_version=v0)
# Setting everything to zero
await rl_trainer.zero_out_model_states.call()
await rl_trainer.push_weights.call(policy_version=v1)
await policy._test_save_model_params.fanout()

model_card = cfg.model

logger.info(f"Running sanity check with config: {config_path}")
logger.info(f"Model name: {model_card}")
logger.info(f"Trainer proc size: {trainer_proc_size}")
logger.info(f"Policy tensor parallel size: {policy_tp_size}")

logger.info("Downloading model checkpoint from HuggingFace Hub")
cached_dir = snapshot_download(repo_id=model_card)
logger.info("Finished downloading model checkpoint from HuggingFace Hub")

await ts.initialize()
services_policy_cfg = cfg.services.policy
services_policy_cfg.num_replicas = 1

trainer_cfg = cfg.trainer
trainer_cfg.checkpoint = {
"enable": True,
"folder": "/tmp/saved_checkpoints",
"initial_load_path": cached_dir,
"initial_load_in_hf": True,
}
if use_dcp_override is not None:
trainer_cfg["use_dcp"] = use_dcp_override
logger.info(f"`trainer.use_dcp` is overriden to {use_dcp_override}")

with TemporaryDirectory(dir="/dev/shm/") as tmpdir:
trainer_cfg["dcp_path"] = tmpdir
policy, rl_trainer = await asyncio.gather(
*[
Generator.options(**services_policy_cfg).as_service(**cfg.policy),
MockRLTrainer.options(**cfg.actors.trainer).as_actor(**trainer_cfg),
]
)
# Sanity check that before update all the tests pass
all_errs = await policy._test_validate_model_params.fanout(
_test_validate_params_unchanged
)
for errs in all_errs:
for _, e in errs.items():
assert not e, f"Validation failed with exception: {e}"

# Main logic begins here
v0 = uuid.uuid4().int
v1 = v0 + 1

await rl_trainer.push_weights.call(policy_version=v0)
# Setting everything to zero
await rl_trainer.zero_out_model_states.call()
await rl_trainer.push_weights.call(policy_version=v1)
await policy._test_save_model_params.fanout()

# Sanity check that before update all the tests pass
all_errs = await policy._test_validate_model_params.fanout(validate_fn)
for errs in all_errs:
for _, e in errs.items():
assert not e, f"Validation failed with exception: {e}"

await policy.update_weights.fanout(version=v1)
all_errs = await policy._test_validate_model_params.fanout(
validate_fn_all_zeros
)
for errs in all_errs:
for _, e in errs.items():
assert not e, f"Validation failed with exception: {e}"

# Reloading v0, getting back original weights
await policy.update_weights.fanout(version=v0)
all_errs = await policy._test_validate_model_params.fanout(validate_fn)
for errs in all_errs:
for _, e in errs.items():
assert not e, f"Validation failed with exception: {e}"

logger.info("✅ Weight sharding sanity check passed!")
await ts.shutdown()
await policy.update_weights.fanout(version=v1)
all_errs = await policy._test_validate_model_params.fanout(
_test_validate_params_all_zeros
)
for errs in all_errs:
for _, e in errs.items():
assert not e, f"Validation failed with exception: {e}"

# Reloading v0, getting back original weights
await policy.update_weights.fanout(version=v0)
all_errs = await policy._test_validate_model_params.fanout(
_test_validate_params_unchanged
)
for errs in all_errs:
for _, e in errs.items():
assert not e, f"Validation failed with exception: {e}"

logger.info("✅ Weight sharding sanity check passed!")
Loading