Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@

import torch
import torchstore as ts

from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh

from forge.data.sharding import VLLMSharding
from forge.interfaces import Policy as PolicyInterface
from forge.types import ProcessConfig
from monarch.actor import current_rank, endpoint, ProcMesh
from torchstore.state_dict_utils import DELIM
from vllm.config import VllmConfig
Expand All @@ -40,12 +46,6 @@
from vllm.v1.structured_output import StructuredOutputManager
from vllm.worker.worker_base import WorkerWrapperBase

from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh

from forge.data.sharding import VLLMSharding
from forge.interfaces import Policy as PolicyInterface
from forge.types import ProcessConfig


@dataclass
class SamplingConfig:
Expand Down Expand Up @@ -378,8 +378,11 @@ async def update_weights(self):
@endpoint
async def _get_model_params(self) -> dict[str, torch.Tensor]:
"""Get the current model parameters. Only for testing purposes."""
model_params = await self.policy_worker._get_model_params.choose()
return model_params
val_mesh = await self.policy_worker._get_model_params.call()
sharded_state_dicts = {}
for idx, val in val_mesh.items():
sharded_state_dicts[idx["gpus"]] = val
return sharded_state_dicts

@endpoint
async def get_version(self) -> int:
Expand Down
95 changes: 82 additions & 13 deletions tests/integration_tests/test_policy_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
logger.setLevel(logging.INFO)


# Run tests: pytest tests/integration_tests/test_policy_update.py::TestWeightSync::<test_name>


def convert_state_dict(saved_sd):
"""
Convert transformers state dict to vLLM format.
Expand Down Expand Up @@ -80,7 +83,6 @@ def convert_state_dict(saved_sd):
def calculate_expected_shard(
full_tensor: torch.Tensor,
param_name: str,
expected_shape: torch.Size,
tensor_parallel_size: int,
rank: int,
) -> torch.Tensor:
Expand Down Expand Up @@ -126,19 +128,18 @@ def validate_loaded_tensors_equals_original(
"""
for param_name, loaded_tensor in loaded_state_dict.items():
if param_name in original_state_dict:
expected_tensor = original_state_dict[param_name]
original_tensor = original_state_dict[param_name]

if tensor_parallel_size > 1:
expected_shard = calculate_expected_shard(
expected_tensor,
original_shard = calculate_expected_shard(
original_tensor,
param_name,
loaded_tensor.shape,
tensor_parallel_size,
rank,
)
tensor_to_compare = expected_shard.cpu().float()
tensor_to_compare = original_shard.cpu().float()
else:
tensor_to_compare = expected_tensor.cpu().float()
tensor_to_compare = original_tensor.cpu().float()

# Training trainer emitted and loaded tensors are of type bfloat16,
# where as a HF model loaded(expected) tensor has type float16.
Expand All @@ -150,9 +151,9 @@ def validate_loaded_tensors_equals_original(
):
logger.warning(
f"Loaded tensor {param_name} does not equal original.\n"
f"dtype = {loaded_tensor.dtype} vs {expected_tensor.dtype}\n"
f"shape= {loaded_tensor.shape} vs {expected_tensor.shape}\n,"
f"values = {loaded_tensor} vs {expected_tensor}"
f"dtype = {loaded_tensor.dtype} vs {original_tensor.dtype}\n"
f"shape= {loaded_tensor.shape} vs {original_tensor.shape}\n,"
f"values = {loaded_tensor} vs {original_tensor}"
)
raise ValueError(
f"Loaded tensor {param_name} does not equal original "
Expand All @@ -166,10 +167,12 @@ def validate_loaded_tensors_equals_original(
)


def get_configs(worker_size: int, model_name: str) -> tuple[dict, ServiceConfig]:
def get_configs(
worker_size: int, tp_size: int, model_name: str
) -> tuple[dict, ServiceConfig]:
engine_config = EngineConfig(
model=model_name,
tensor_parallel_size=worker_size,
tensor_parallel_size=tp_size,
pipeline_parallel_size=1,
enforce_eager=True,
)
Expand Down Expand Up @@ -210,6 +213,24 @@ def trainer_cfg(self):
},
}

@pytest_asyncio.fixture
def trainer_cfg_tp(self):
# NB: TP size is set to 2.
cached_dir = snapshot_download(repo_id=self.model)
return {
"model": {
"name": "qwen3",
"flavor": "1.7B",
},
"parallelism": {"tensor_parallel_degree": 2},
"checkpoint": {
"enable": True,
"folder": "/tmp/saved_checkpoints",
"initial_load_path": cached_dir,
"initial_load_in_hf": True,
},
}

@pytest_asyncio.fixture
def expected_sd(self):
model = AutoModelForCausalLM.from_pretrained(
Expand Down Expand Up @@ -244,7 +265,7 @@ async def test_policy_update_single(self, expected_sd, trainer_cfg):
await rl_trainer.push_weights.choose(policy_version=0)
# 3. Policy pull weights
policy_config, service_config = get_configs(
worker_size=worker_size, model_name=self.model
worker_size=worker_size, tp_size=worker_size, model_name=self.model
)
policy = await spawn_service(service_config, Policy, **policy_config)
await policy.update_weights.call()
Expand All @@ -253,3 +274,51 @@ async def test_policy_update_single(self, expected_sd, trainer_cfg):
validate_loaded_tensors_equals_original(
loaded_state_dict, expected_sd, tensor_parallel_size=1, rank=0
)

@pytest.mark.asyncio
@requires_cuda
async def test_policy_update_tp(self, expected_sd, trainer_cfg_tp):
"""
1. Init RLTrainer over multiple workers with TP parallelism strategy.
2. Push weights in to torchstore.
3. Initializes Policy over multiple workers, and calls update_weights() to load weights from torchstore.
4. Validate the policy weights against manually loaded origina HF weights.
"""
# test configs/paralleism
trainer_worker_size = 2
policy_worker_size = 2
tp_size = 2

if torch.cuda.device_count() < 2:
pytest.skip(
f"Only {torch.cuda.device_count()} GPU(s) available, need 2+ for tensor parallel"
)
# 1. Initialize TS
await ts.initialize()
# 2. Trainer push
rl_trainer = await spawn_service(
ServiceConfig(
procs_per_replica=trainer_worker_size, with_gpus=True, num_replicas=1
),
RLTrainer,
**trainer_cfg_tp,
)
await rl_trainer.push_weights.call(policy_version=0)
# 3. Policy pull weights
policy_config, service_config = get_configs(
worker_size=policy_worker_size, tp_size=tp_size, model_name=self.model
)
policy = await spawn_service(service_config, Policy, **policy_config)
await policy.update_weights.call()

# validate loaded shard of each worker againt manually calculated shard (expected shard).

# 4. Validate weight shards. We compare vLLM loades shard content with
# Directly loaded HF shard content.
sharded_state_dicts = await policy._get_model_params.call()
validate_loaded_tensors_equals_original(
sharded_state_dicts[0][0], expected_sd, tensor_parallel_size=tp_size, rank=0
)
validate_loaded_tensors_equals_original(
sharded_state_dicts[0][1], expected_sd, tensor_parallel_size=tp_size, rank=1
)
Loading