Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,11 @@ async def update_weights(self):
@endpoint
async def _get_model_params(self) -> dict[str, torch.Tensor]:
"""Get the current model parameters. Only for testing purposes."""
model_params = await self.policy_worker._get_model_params.choose()
return model_params
val_mesh = await self.policy_worker._get_model_params.call()
sharded_state_dicts = {}
for idx, val in val_mesh.items():
sharded_state_dicts[idx["gpus"]] = val
return sharded_state_dicts

@endpoint
async def get_version(self) -> int:
Expand Down
94 changes: 81 additions & 13 deletions tests/integration_tests/test_policy_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
logger.setLevel(logging.INFO)


# Run tests: pytest tests/integration_tests/test_policy_update.py::TestWeightSync::<test_name>


def convert_state_dict(saved_sd):
"""
Convert transformers state dict to vLLM format.
Expand Down Expand Up @@ -80,7 +83,6 @@ def convert_state_dict(saved_sd):
def calculate_expected_shard(
full_tensor: torch.Tensor,
param_name: str,
expected_shape: torch.Size,
tensor_parallel_size: int,
rank: int,
) -> torch.Tensor:
Expand Down Expand Up @@ -126,19 +128,18 @@ def validate_loaded_tensors_equals_original(
"""
for param_name, loaded_tensor in loaded_state_dict.items():
if param_name in original_state_dict:
expected_tensor = original_state_dict[param_name]
original_tensor = original_state_dict[param_name]

if tensor_parallel_size > 1:
expected_shard = calculate_expected_shard(
expected_tensor,
original_shard = calculate_expected_shard(
original_tensor,
param_name,
loaded_tensor.shape,
tensor_parallel_size,
rank,
)
tensor_to_compare = expected_shard.cpu().float()
tensor_to_compare = original_shard.cpu().float()
else:
tensor_to_compare = expected_tensor.cpu().float()
tensor_to_compare = original_tensor.cpu().float()

# Training trainer emitted and loaded tensors are of type bfloat16,
# where as a HF model loaded(expected) tensor has type float16.
Expand All @@ -150,9 +151,9 @@ def validate_loaded_tensors_equals_original(
):
logger.warning(
f"Loaded tensor {param_name} does not equal original.\n"
f"dtype = {loaded_tensor.dtype} vs {expected_tensor.dtype}\n"
f"shape= {loaded_tensor.shape} vs {expected_tensor.shape}\n,"
f"values = {loaded_tensor} vs {expected_tensor}"
f"dtype = {loaded_tensor.dtype} vs {original_tensor.dtype}\n"
f"shape= {loaded_tensor.shape} vs {original_tensor.shape}\n,"
f"values = {loaded_tensor} vs {original_tensor}"
)
raise ValueError(
f"Loaded tensor {param_name} does not equal original "
Expand All @@ -166,10 +167,12 @@ def validate_loaded_tensors_equals_original(
)


def get_configs(worker_size: int, model_name: str) -> tuple[dict, ServiceConfig]:
def get_configs(
worker_size: int, tp_size: int, model_name: str
) -> tuple[dict, ServiceConfig]:
engine_config = EngineConfig(
model=model_name,
tensor_parallel_size=worker_size,
tensor_parallel_size=tp_size,
pipeline_parallel_size=1,
enforce_eager=True,
)
Expand Down Expand Up @@ -210,6 +213,24 @@ def trainer_cfg(self):
},
}

@pytest_asyncio.fixture
def trainer_cfg_tp(self):
# NB: TP size is set to 2.
cached_dir = snapshot_download(repo_id=self.model)
return {
"model": {
"name": "qwen3",
"flavor": "1.7B",
},
"parallelism": {"tensor_parallel_degree": 2},
"checkpoint": {
"enable": True,
"folder": "/tmp/saved_checkpoints",
"initial_load_path": cached_dir,
"initial_load_in_hf": True,
},
}

@pytest_asyncio.fixture
def expected_sd(self):
model = AutoModelForCausalLM.from_pretrained(
Expand Down Expand Up @@ -241,7 +262,7 @@ async def test_policy_update_single(self, expected_sd, trainer_cfg):
await rl_trainer.push_weights.choose(policy_version=0)
# 3. Policy pull weights
policy_config, service_config = get_configs(
worker_size=worker_size, model_name=self.model
worker_size=worker_size, tp_size=worker_size, model_name=self.model
)
policy = await Policy.options(service_config=service_config).as_service(
**policy_config
Expand All @@ -252,3 +273,50 @@ async def test_policy_update_single(self, expected_sd, trainer_cfg):
validate_loaded_tensors_equals_original(
loaded_state_dict, expected_sd, tensor_parallel_size=1, rank=0
)

@pytest.mark.asyncio
@requires_cuda
async def test_policy_update_tp(self, expected_sd, trainer_cfg_tp):
"""
1. Init RLTrainer over multiple workers with TP parallelism strategy.
2. Push weights in to torchstore.
3. Initializes Policy over multiple workers, and calls update_weights() to load weights from torchstore.
4. Validate the policy weights against manually loaded origina HF weights.
"""
# test configs/paralleism
trainer_worker_size = 2
policy_worker_size = 2
tp_size = 2

if torch.cuda.device_count() < 2:
pytest.skip(
f"Only {torch.cuda.device_count()} GPU(s) available, need 2+ for tensor parallel"
)
# 1. Initialize TS
await ts.initialize()
# 2. Trainer push
rl_trainer = await RLTrainer.options(
procs_per_replica=trainer_worker_size, with_gpus=True, num_replicas=1
).as_service(**trainer_cfg_tp)

await rl_trainer.push_weights.call(policy_version=0)
# 3. Policy pull weights
policy_config, service_config = get_configs(
worker_size=policy_worker_size, tp_size=tp_size, model_name=self.model
)
policy = await Policy.options(service_config=service_config).as_service(
**policy_config
)
await policy.update_weights.call()

# validate loaded shard of each worker againt manually calculated shard (expected shard).

# 4. Validate weight shards. We compare vLLM loades shard content with
# Directly loaded HF shard content.
sharded_state_dicts = await policy._get_model_params.call()
validate_loaded_tensors_equals_original(
sharded_state_dicts[0][0], expected_sd, tensor_parallel_size=tp_size, rank=0
)
validate_loaded_tensors_equals_original(
sharded_state_dicts[0][1], expected_sd, tensor_parallel_size=tp_size, rank=1
)
Loading