|
14 | 14 | FullyShardedDataParallel as FSDP, |
15 | 15 | MixedPrecisionPolicy, |
16 | 16 | ) |
| 17 | +from torchtnt.framework._test_utils import DummyAutoUnit, generate_random_dataloader |
| 18 | +from torchtnt.framework.train import train |
| 19 | +from torchtnt.utils.distributed import get_global_rank |
| 20 | +from torchtnt.utils.prepare_module import get_module_state_dict |
17 | 21 |
|
18 | 22 | try: |
19 | 23 | from torch.distributed.fsdp import fully_shard |
@@ -404,6 +408,62 @@ def _test_prepare_fsdp2_meta_device() -> None: |
404 | 408 | # linear and SimpleModule are fsdp modules |
405 | 409 | tc.assertTrue(_is_fsdp_module(submodule)) |
406 | 410 |
|
| 411 | + def test_get_module_state_dict(self) -> None: |
| 412 | + spawn_multi_process( |
| 413 | + 2, |
| 414 | + "nccl", |
| 415 | + self._test_get_module_state_dict, |
| 416 | + ) |
| 417 | + |
| 418 | + @staticmethod |
| 419 | + def _test_get_module_state_dict() -> None: |
| 420 | + rank = get_global_rank() |
| 421 | + |
| 422 | + fsdp_strategy = FSDPStrategy( |
| 423 | + sharding_strategy="FULL_SHARD", |
| 424 | + auto_wrap_policy=lambda module, recurse, nonwrapped_numel: True, |
| 425 | + ) |
| 426 | + ddp_strategy = DDPStrategy() |
| 427 | + |
| 428 | + for strategy, rank0_only in ( |
| 429 | + (fsdp_strategy, True), |
| 430 | + (fsdp_strategy, False), |
| 431 | + (ddp_strategy, True), |
| 432 | + (ddp_strategy, False), |
| 433 | + (None, True), |
| 434 | + (None, False), |
| 435 | + ): |
| 436 | + module = torch.nn.Sequential( |
| 437 | + torch.nn.Linear(2, 100), |
| 438 | + torch.nn.Linear(100, 2), |
| 439 | + ) |
| 440 | + |
| 441 | + unit = DummyAutoUnit( |
| 442 | + module=module, |
| 443 | + strategy=strategy, |
| 444 | + ) |
| 445 | + |
| 446 | + dataloader = generate_random_dataloader(10, 2, 10) |
| 447 | + train(unit, dataloader, max_epochs=1) |
| 448 | + |
| 449 | + module_sd = get_module_state_dict(unit.module, rank0_only=rank0_only) |
| 450 | + |
| 451 | + tc = unittest.TestCase() |
| 452 | + |
| 453 | + # For FSDP, if the user passed rank0_only=True, we should get an empty state dict |
| 454 | + # on all ranks except rank 0 |
| 455 | + if rank0_only and isinstance(strategy, FSDPStrategy) and rank != 0: |
| 456 | + tc.assertEqual(module_sd, {}) |
| 457 | + |
| 458 | + else: |
| 459 | + # Make sure that the generated state dict has the actual model keys, |
| 460 | + # and the values are actual tensors as opposed to ShardedTensor. |
| 461 | + tc.assertCountEqual( |
| 462 | + ["0.weight", "0.bias", "1.weight", "1.bias"], |
| 463 | + list(module_sd.keys()), |
| 464 | + ) |
| 465 | + tc.assertIsInstance(module_sd["0.weight"], torch.Tensor) |
| 466 | + |
407 | 467 |
|
408 | 468 | class SimpleModule(torch.nn.Module): |
409 | 469 | def __init__(self, meta_device: bool = False) -> None: |
|
0 commit comments