|
28 | 28 | from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam |
29 | 29 | from torch.distributed.tensor import Replicate |
30 | 30 |
|
31 | | -from modelopt.torch.quantization.qtensor.base_qtensor import QFSDPParam, QTensorWrapper |
32 | 31 | from modelopt.torch.utils import get_unwrapped_name, print_rank_0 |
33 | 32 |
|
34 | 33 | if TYPE_CHECKING: |
@@ -479,6 +478,7 @@ def set_quantizer_state_dict(model: nn.Module, quantizer_state_dict: dict): |
479 | 478 | module.load_state_dict(quantizer_state_dict[key]) |
480 | 479 |
|
481 | 480 |
|
| 481 | +@contextmanager |
482 | 482 | def patch_fsdp_mp_dtypes(): |
483 | 483 | """Patch FSDP2 to handle mixed dtypes properly during quantization.""" |
484 | 484 |
|
@@ -509,10 +509,15 @@ def _init_mp_dtypes(self) -> None: |
509 | 509 | original_init_mp_dtypes = ( |
510 | 510 | torch.distributed.fsdp._fully_shard._fsdp_param_group.FSDPParamGroup._init_mp_dtypes |
511 | 511 | ) |
512 | | - torch.distributed.fsdp._fully_shard._fsdp_param_group.FSDPParamGroup._init_mp_dtypes = ( |
513 | | - _init_mp_dtypes |
514 | | - ) |
515 | | - return original_init_mp_dtypes |
| 512 | + try: |
| 513 | + torch.distributed.fsdp._fully_shard._fsdp_param_group.FSDPParamGroup._init_mp_dtypes = ( |
| 514 | + _init_mp_dtypes |
| 515 | + ) |
| 516 | + yield |
| 517 | + finally: |
| 518 | + torch.distributed.fsdp._fully_shard._fsdp_param_group.FSDPParamGroup._init_mp_dtypes = ( |
| 519 | + original_init_mp_dtypes |
| 520 | + ) |
516 | 521 |
|
517 | 522 |
|
518 | 523 | def get_prefixed_param_names(parent_model, target_module): |
@@ -623,6 +628,8 @@ def fsdp2_aware_weight_update(root_model, modules_to_update): |
623 | 628 | # Yields for necessary weight updates/processing |
624 | 629 | yield |
625 | 630 | finally: |
| 631 | + from modelopt.torch.quantization.qtensor.base_qtensor import QFSDPParam, QTensorWrapper |
| 632 | + |
626 | 633 | if isinstance(root_model, FSDPModule): |
627 | 634 | # Update FSDPParam list |
628 | 635 | for module in modules_to_update: |
|
0 commit comments