Skip to content

Commit 62c1c99

Browse files
realAsmayeyu-nvidia
authored andcommitted
Remove unused utilities for ModelOpt <0.29 MCore checkpoints (#322)
Signed-off-by: realAsma <[email protected]> Signed-off-by: Ye Yu <[email protected]>
1 parent 742429d commit 62c1c99

File tree

8 files changed

+0
-139
lines changed

8 files changed

+0
-139
lines changed

modelopt/torch/quantization/conversion.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,6 @@ def _replace_quant_module(model: nn.Module, version=None, registry=QuantModuleRe
198198
# REPLACE on the parent (model), not on child
199199
quantized = registry.convert(child)
200200
setattr(model, name, quantized)
201-
quantized.mopt_ckpt_versn = version
202201

203202
# now recurse into whichever module is now at `model.name`
204203
_replace_quant_module(getattr(model, name), version=version, registry=registry)

modelopt/torch/quantization/nn/modules/quant_module.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -37,24 +37,6 @@
3737
class QuantModule(DynamicModule):
3838
"""A base class for quantized modules."""
3939

40-
@property
41-
def mopt_ckpt_versn(self):
42-
"""Checkpoint version of the modelopt."""
43-
for module in self.modules():
44-
if isinstance(module, TensorQuantizer):
45-
return module.mopt_ckpt_versn
46-
return None
47-
48-
@mopt_ckpt_versn.setter
49-
def mopt_ckpt_versn(self, version: str):
50-
"""Set the checkpoint version for the TensorQuantizer states."""
51-
52-
def _set_ckpt_version(module):
53-
if isinstance(module, TensorQuantizer):
54-
module.mopt_ckpt_versn = version
55-
56-
self.apply(_set_ckpt_version)
57-
5840
def modelopt_post_restore(self, prefix: str = ""):
5941
"""Post-restore to correctly configure the TensorQuantizer states.
6042

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
DTensor = None
3030

3131
import torch.nn.functional as F
32-
from packaging.version import Version
3332
from torch import nn
3433
from torch.onnx._globals import GLOBALS
3534

@@ -1023,48 +1022,6 @@ def extra_repr(self):
10231022
s += " calib" if (self._if_calib) else ""
10241023
return s
10251024

1026-
@property
1027-
def mopt_ckpt_versn(self):
1028-
"""Version of the checkpoint if it is restored from a checkpoint."""
1029-
return getattr(self, "_mopt_ckpt_versn", None)
1030-
1031-
@mopt_ckpt_versn.setter
1032-
def mopt_ckpt_versn(self, version: str):
1033-
self._mopt_ckpt_versn = str(version)
1034-
1035-
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
1036-
"""Special handling for loading older checkpoints.
1037-
1038-
This implementation is for backward compatibility and can be deprecated in future versions.
1039-
1040-
Args:
1041-
state_dict: A dict containing the state of the top level module
1042-
prefix: A string that prefixes all of this modules state in state_dict, e.g. 'model.conv1.'
1043-
"""
1044-
if self.mopt_ckpt_versn is None or Version(self.mopt_ckpt_versn) >= Version("0.29"):
1045-
# Warnings below are raised if users use partial state dictionary intentionally (eg:- HF ckpts)
1046-
# For ModelOpt >= 0.29, the buffers will be correctly created, So lets skip the warnings
1047-
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
1048-
1049-
_attrs = ["_amax", "_pre_quant_scale", "_svdquant_lora_a", "_svdquant_lora_b"]
1050-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1051-
1052-
for attr in _attrs:
1053-
has_dst = attr in self._buffers
1054-
has_src = prefix + attr in state_dict
1055-
1056-
if not has_src and has_dst:
1057-
warnings.warn(f"{prefix[:-1]}: No {attr} in state_dict.")
1058-
elif has_src and not has_dst:
1059-
warnings.warn(
1060-
f"{prefix[:-1]}: No '{attr}' buffer to load {attr} into."
1061-
f" '{attr}` is created as a buffer for now. Please move the model to the correct device and "
1062-
"dtype after this by calling `model.to(device, dtype)`."
1063-
)
1064-
self.register_buffer(attr, state_dict[prefix + attr].clone().detach().to(device))
1065-
1066-
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
1067-
10681025
def _get_properties_for_modelopt_state(self):
10691026
return (
10701027
self.__dict__.keys()

modelopt/torch/quantization/plugins/custom.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@
2121
from types import ModuleType
2222

2323
import torch
24-
from packaging.version import Version
2524

26-
from modelopt import __version__
2725
from modelopt.torch.utils.distributed import ParallelState
2826

2927
from ..nn import QuantModule, SequentialQuantizer, TensorQuantizer
@@ -174,19 +172,3 @@ def _has_state(quantizer, name):
174172
max_calibrate(self.output_quantizer, lambda oq: oq(dummy_input), distributed_sync=False)
175173
# If there are any other states, lets move them to the correct device
176174
super().modelopt_post_restore(prefix=prefix)
177-
178-
def is_version_less_than(self, version: str) -> bool:
179-
self_version = (
180-
Version(self.mopt_ckpt_versn)
181-
if self.mopt_ckpt_versn is not None
182-
else Version(__version__)
183-
)
184-
185-
# version in NeMo container is 0.0.0 if installed from source without history
186-
if self_version < Version(version) and self_version != Version("0.0.0"):
187-
warnings.warn(
188-
f"Checkpoint version {self_version} is less than {version}. "
189-
"Please re-save model to avoid this warning."
190-
)
191-
return True
192-
return False

modelopt/torch/quantization/plugins/vllm.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -185,15 +185,3 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
185185
output = super().forward(hidden_states, router_logits)
186186
self.invoke_fused_moe_quantized = self._invoke_fused_moe_quantized
187187
return output
188-
189-
@property
190-
def mopt_ckpt_versn(self):
191-
"""Checkpoint version of the modelopt."""
192-
return None
193-
194-
@mopt_ckpt_versn.setter
195-
def mopt_ckpt_versn(self, version: str):
196-
"""Set the checkpoint version for the TensorQuantizer states."""
197-
# vLLM defined an apply method that overwrites nn.Module.apply
198-
# To avoid conflicting, disable the apply call here
199-
# self.apply(_set_ckpt_version)

tests/_test_utils/torch_quantization/checkpointing.py

Lines changed: 0 additions & 37 deletions
This file was deleted.

tests/_test_utils/torch_quantization/quantize_common.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@
2626
from modelopt.torch.quantization.utils import is_quantized_linear
2727
from modelopt.torch.utils import torch_to
2828

29-
from .checkpointing import format_modelopt_checkpoint_by_version
30-
3129
INT4_AWQ_FULL_CFG = mtq.INT4_AWQ_CFG.copy()
3230

3331
INT4_AWQ_FULL_CFG["algorithm"] = "awq_full"
@@ -84,9 +82,6 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N
8482

8583
state_dict = mto.modelopt_state(model_quant)
8684

87-
if version is not None:
88-
state_dict = format_modelopt_checkpoint_by_version(state_dict, version)
89-
9085
mto.restore_from_modelopt_state(model_ref, state_dict)
9186
model_ref.load_state_dict(model_quant.state_dict())
9287
assert torch.allclose(model_quant(calib_data[0]), model_ref(calib_data[0]))

tests/gpu/torch/quantization/test_quantize_cuda.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,3 @@ def test_quantize(model_cls, config):
9292
)
9393
def test_save_restore(model_cls, quant_config):
9494
save_restore_test(model_cls, "cuda", quant_config)
95-
96-
97-
@pytest.mark.parametrize("version", [None, "0.29", "0.28"])
98-
def test_save_restore_all_versions(version):
99-
save_restore_test(SimpleLinear, "cuda", mtq.INT8_DEFAULT_CFG, version=version)

0 commit comments

Comments
 (0)