Skip to content

Commit 6e07dff

Browse files
committed
code rabbit suggestions + minor fix
Signed-off-by: Suguna Velury <[email protected]>
1 parent 067f02d commit 6e07dff

File tree

3 files changed

+24
-8
lines changed

3 files changed

+24
-8
lines changed

examples/llm_ptq/multinode-ptq.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,9 @@ def export_model(
274274
export_dir = Path(export_path)
275275
export_dir.mkdir(parents=True, exist_ok=True)
276276

277-
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, torch.bfloat16)
277+
post_state_dict, hf_quant_config = _export_hf_checkpoint(
278+
model, torch.bfloat16, is_fsdp2=True, accelerator=accelerator
279+
)
278280

279281
if accelerator.is_main_process:
280282
# Save hf_quant_config.json for backward compatibility
@@ -389,4 +391,6 @@ def main(args):
389391

390392
if __name__ == "__main__":
391393
args = parse_args()
392-
main(args)
394+
# This context manager can be removed once the update to FSDP2 function is reflected in torch
395+
with patch_fsdp_mp_dtypes():
396+
main(args)

modelopt/torch/quantization/utils.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
2929
from torch.distributed.tensor import Replicate
3030

31-
from modelopt.torch.quantization.qtensor.base_qtensor import QFSDPParam, QTensorWrapper
3231
from modelopt.torch.utils import get_unwrapped_name, print_rank_0
3332

3433
if TYPE_CHECKING:
@@ -479,6 +478,7 @@ def set_quantizer_state_dict(model: nn.Module, quantizer_state_dict: dict):
479478
module.load_state_dict(quantizer_state_dict[key])
480479

481480

481+
@contextmanager
482482
def patch_fsdp_mp_dtypes():
483483
"""Patch FSDP2 to handle mixed dtypes properly during quantization."""
484484

@@ -509,10 +509,15 @@ def _init_mp_dtypes(self) -> None:
509509
original_init_mp_dtypes = (
510510
torch.distributed.fsdp._fully_shard._fsdp_param_group.FSDPParamGroup._init_mp_dtypes
511511
)
512-
torch.distributed.fsdp._fully_shard._fsdp_param_group.FSDPParamGroup._init_mp_dtypes = (
513-
_init_mp_dtypes
514-
)
515-
return original_init_mp_dtypes
512+
try:
513+
torch.distributed.fsdp._fully_shard._fsdp_param_group.FSDPParamGroup._init_mp_dtypes = (
514+
_init_mp_dtypes
515+
)
516+
yield
517+
finally:
518+
torch.distributed.fsdp._fully_shard._fsdp_param_group.FSDPParamGroup._init_mp_dtypes = (
519+
original_init_mp_dtypes
520+
)
516521

517522

518523
def get_prefixed_param_names(parent_model, target_module):
@@ -623,6 +628,8 @@ def fsdp2_aware_weight_update(root_model, modules_to_update):
623628
# Yields for necessary weight updates/processing
624629
yield
625630
finally:
631+
from modelopt.torch.quantization.qtensor.base_qtensor import QFSDPParam, QTensorWrapper
632+
626633
if isinstance(root_model, FSDPModule):
627634
# Update FSDPParam list
628635
for module in modules_to_update:

tests/gpu/torch/export/test_fsdp2_export.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,12 @@
2929
)
3030
from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, patch_fsdp_mp_dtypes
3131

32-
orig_init_mp_dtypes = patch_fsdp_mp_dtypes()
32+
33+
@pytest.fixture(autouse=True)
34+
def patch_fsdp_dtypes():
35+
"""Automatically patch FSDP mixed precision dtypes for all tests in this module."""
36+
with patch_fsdp_mp_dtypes():
37+
yield
3338

3439

3540
def _update_weight_test(rank, size):

0 commit comments

Comments
 (0)