diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 3d45f6a0d..c0b1f59a3 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -378,8 +378,11 @@ async def update_weights(self): @endpoint async def _get_model_params(self) -> dict[str, torch.Tensor]: """Get the current model parameters. Only for testing purposes.""" - model_params = await self.policy_worker._get_model_params.choose() - return model_params + val_mesh = await self.policy_worker._get_model_params.call() + sharded_state_dicts = {} + for idx, val in val_mesh.items(): + sharded_state_dicts[idx["gpus"]] = val + return sharded_state_dicts @endpoint async def get_version(self) -> int: diff --git a/tests/integration_tests/test_policy_update.py b/tests/integration_tests/test_policy_update.py index e86d62f8a..c83d53fae 100644 --- a/tests/integration_tests/test_policy_update.py +++ b/tests/integration_tests/test_policy_update.py @@ -31,6 +31,9 @@ logger.setLevel(logging.INFO) +# Run tests: pytest tests/integration_tests/test_policy_update.py::TestWeightSync:: + + def convert_state_dict(saved_sd): """ Convert transformers state dict to vLLM format. @@ -80,7 +83,6 @@ def convert_state_dict(saved_sd): def calculate_expected_shard( full_tensor: torch.Tensor, param_name: str, - expected_shape: torch.Size, tensor_parallel_size: int, rank: int, ) -> torch.Tensor: @@ -126,19 +128,18 @@ def validate_loaded_tensors_equals_original( """ for param_name, loaded_tensor in loaded_state_dict.items(): if param_name in original_state_dict: - expected_tensor = original_state_dict[param_name] + original_tensor = original_state_dict[param_name] if tensor_parallel_size > 1: - expected_shard = calculate_expected_shard( - expected_tensor, + original_shard = calculate_expected_shard( + original_tensor, param_name, - loaded_tensor.shape, tensor_parallel_size, rank, ) - tensor_to_compare = expected_shard.cpu().float() + tensor_to_compare = original_shard.cpu().float() else: - tensor_to_compare = expected_tensor.cpu().float() + tensor_to_compare = original_tensor.cpu().float() # Training trainer emitted and loaded tensors are of type bfloat16, # where as a HF model loaded(expected) tensor has type float16. @@ -150,9 +151,9 @@ def validate_loaded_tensors_equals_original( ): logger.warning( f"Loaded tensor {param_name} does not equal original.\n" - f"dtype = {loaded_tensor.dtype} vs {expected_tensor.dtype}\n" - f"shape= {loaded_tensor.shape} vs {expected_tensor.shape}\n," - f"values = {loaded_tensor} vs {expected_tensor}" + f"dtype = {loaded_tensor.dtype} vs {original_tensor.dtype}\n" + f"shape= {loaded_tensor.shape} vs {original_tensor.shape}\n," + f"values = {loaded_tensor} vs {original_tensor}" ) raise ValueError( f"Loaded tensor {param_name} does not equal original " @@ -166,10 +167,12 @@ def validate_loaded_tensors_equals_original( ) -def get_configs(worker_size: int, model_name: str) -> tuple[dict, ServiceConfig]: +def get_configs( + worker_size: int, tp_size: int, model_name: str +) -> tuple[dict, ServiceConfig]: engine_config = EngineConfig( model=model_name, - tensor_parallel_size=worker_size, + tensor_parallel_size=tp_size, pipeline_parallel_size=1, enforce_eager=True, ) @@ -210,6 +213,24 @@ def trainer_cfg(self): }, } + @pytest_asyncio.fixture + def trainer_cfg_tp(self): + # NB: TP size is set to 2. + cached_dir = snapshot_download(repo_id=self.model) + return { + "model": { + "name": "qwen3", + "flavor": "1.7B", + }, + "parallelism": {"tensor_parallel_degree": 2}, + "checkpoint": { + "enable": True, + "folder": "/tmp/saved_checkpoints", + "initial_load_path": cached_dir, + "initial_load_in_hf": True, + }, + } + @pytest_asyncio.fixture def expected_sd(self): model = AutoModelForCausalLM.from_pretrained( @@ -241,7 +262,7 @@ async def test_policy_update_single(self, expected_sd, trainer_cfg): await rl_trainer.push_weights.choose(policy_version=0) # 3. Policy pull weights policy_config, service_config = get_configs( - worker_size=worker_size, model_name=self.model + worker_size=worker_size, tp_size=worker_size, model_name=self.model ) policy = await Policy.options(service_config=service_config).as_service( **policy_config @@ -252,3 +273,50 @@ async def test_policy_update_single(self, expected_sd, trainer_cfg): validate_loaded_tensors_equals_original( loaded_state_dict, expected_sd, tensor_parallel_size=1, rank=0 ) + + @pytest.mark.asyncio + @requires_cuda + async def test_policy_update_tp(self, expected_sd, trainer_cfg_tp): + """ + 1. Init RLTrainer over multiple workers with TP parallelism strategy. + 2. Push weights in to torchstore. + 3. Initializes Policy over multiple workers, and calls update_weights() to load weights from torchstore. + 4. Validate the policy weights against manually loaded origina HF weights. + """ + # test configs/paralleism + trainer_worker_size = 2 + policy_worker_size = 2 + tp_size = 2 + + if torch.cuda.device_count() < 2: + pytest.skip( + f"Only {torch.cuda.device_count()} GPU(s) available, need 2+ for tensor parallel" + ) + # 1. Initialize TS + await ts.initialize() + # 2. Trainer push + rl_trainer = await RLTrainer.options( + procs_per_replica=trainer_worker_size, with_gpus=True, num_replicas=1 + ).as_service(**trainer_cfg_tp) + + await rl_trainer.push_weights.call(policy_version=0) + # 3. Policy pull weights + policy_config, service_config = get_configs( + worker_size=policy_worker_size, tp_size=tp_size, model_name=self.model + ) + policy = await Policy.options(service_config=service_config).as_service( + **policy_config + ) + await policy.update_weights.call() + + # validate loaded shard of each worker againt manually calculated shard (expected shard). + + # 4. Validate weight shards. We compare vLLM loades shard content with + # Directly loaded HF shard content. + sharded_state_dicts = await policy._get_model_params.call() + validate_loaded_tensors_equals_original( + sharded_state_dicts[0][0], expected_sd, tensor_parallel_size=tp_size, rank=0 + ) + validate_loaded_tensors_equals_original( + sharded_state_dicts[0][1], expected_sd, tensor_parallel_size=tp_size, rank=1 + )