Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion modelopt/torch/quantization/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ def _replace_quant_module(model: nn.Module, version=None, registry=QuantModuleRe
# REPLACE on the parent (model), not on child
quantized = registry.convert(child)
setattr(model, name, quantized)
quantized.mopt_ckpt_versn = version

# now recurse into whichever module is now at `model.name`
_replace_quant_module(getattr(model, name), version=version, registry=registry)
Expand Down
18 changes: 0 additions & 18 deletions modelopt/torch/quantization/nn/modules/quant_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,24 +37,6 @@
class QuantModule(DynamicModule):
"""A base class for quantized modules."""

@property
def mopt_ckpt_versn(self):
"""Checkpoint version of the modelopt."""
for module in self.modules():
if isinstance(module, TensorQuantizer):
return module.mopt_ckpt_versn
return None

@mopt_ckpt_versn.setter
def mopt_ckpt_versn(self, version: str):
"""Set the checkpoint version for the TensorQuantizer states."""

def _set_ckpt_version(module):
if isinstance(module, TensorQuantizer):
module.mopt_ckpt_versn = version

self.apply(_set_ckpt_version)

def modelopt_post_restore(self, prefix: str = ""):
"""Post-restore to correctly configure the TensorQuantizer states.

Expand Down
43 changes: 0 additions & 43 deletions modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
DTensor = None

import torch.nn.functional as F
from packaging.version import Version
from torch import nn
from torch.onnx._globals import GLOBALS

Expand Down Expand Up @@ -1023,48 +1022,6 @@ def extra_repr(self):
s += " calib" if (self._if_calib) else ""
return s

@property
def mopt_ckpt_versn(self):
"""Version of the checkpoint if it is restored from a checkpoint."""
return getattr(self, "_mopt_ckpt_versn", None)

@mopt_ckpt_versn.setter
def mopt_ckpt_versn(self, version: str):
self._mopt_ckpt_versn = str(version)

def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
"""Special handling for loading older checkpoints.

This implementation is for backward compatibility and can be deprecated in future versions.

Args:
state_dict: A dict containing the state of the top level module
prefix: A string that prefixes all of this modules state in state_dict, e.g. 'model.conv1.'
"""
if self.mopt_ckpt_versn is None or Version(self.mopt_ckpt_versn) >= Version("0.29"):
# Warnings below are raised if users use partial state dictionary intentionally (eg:- HF ckpts)
# For ModelOpt >= 0.29, the buffers will be correctly created, So lets skip the warnings
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)

_attrs = ["_amax", "_pre_quant_scale", "_svdquant_lora_a", "_svdquant_lora_b"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for attr in _attrs:
has_dst = attr in self._buffers
has_src = prefix + attr in state_dict

if not has_src and has_dst:
warnings.warn(f"{prefix[:-1]}: No {attr} in state_dict.")
elif has_src and not has_dst:
warnings.warn(
f"{prefix[:-1]}: No '{attr}' buffer to load {attr} into."
f" '{attr}` is created as a buffer for now. Please move the model to the correct device and "
"dtype after this by calling `model.to(device, dtype)`."
)
self.register_buffer(attr, state_dict[prefix + attr].clone().detach().to(device))

super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)

def _get_properties_for_modelopt_state(self):
return (
self.__dict__.keys()
Expand Down
18 changes: 0 additions & 18 deletions modelopt/torch/quantization/plugins/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
from types import ModuleType

import torch
from packaging.version import Version

from modelopt import __version__
from modelopt.torch.utils.distributed import ParallelState

from ..nn import QuantModule, SequentialQuantizer, TensorQuantizer
Expand Down Expand Up @@ -174,19 +172,3 @@ def _has_state(quantizer, name):
max_calibrate(self.output_quantizer, lambda oq: oq(dummy_input), distributed_sync=False)
# If there are any other states, lets move them to the correct device
super().modelopt_post_restore(prefix=prefix)

def is_version_less_than(self, version: str) -> bool:
self_version = (
Version(self.mopt_ckpt_versn)
if self.mopt_ckpt_versn is not None
else Version(__version__)
)

# version in NeMo container is 0.0.0 if installed from source without history
if self_version < Version(version) and self_version != Version("0.0.0"):
warnings.warn(
f"Checkpoint version {self_version} is less than {version}. "
"Please re-save model to avoid this warning."
)
return True
return False
12 changes: 0 additions & 12 deletions modelopt/torch/quantization/plugins/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,15 +185,3 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
output = super().forward(hidden_states, router_logits)
self.invoke_fused_moe_quantized = self._invoke_fused_moe_quantized
return output

@property
def mopt_ckpt_versn(self):
"""Checkpoint version of the modelopt."""
return None

@mopt_ckpt_versn.setter
def mopt_ckpt_versn(self, version: str):
"""Set the checkpoint version for the TensorQuantizer states."""
# vLLM defined an apply method that overwrites nn.Module.apply
# To avoid conflicting, disable the apply call here
# self.apply(_set_ckpt_version)
37 changes: 0 additions & 37 deletions tests/_test_utils/torch_quantization/checkpointing.py

This file was deleted.

5 changes: 0 additions & 5 deletions tests/_test_utils/torch_quantization/quantize_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
from modelopt.torch.quantization.utils import is_quantized_linear
from modelopt.torch.utils import torch_to

from .checkpointing import format_modelopt_checkpoint_by_version

INT4_AWQ_FULL_CFG = mtq.INT4_AWQ_CFG.copy()

INT4_AWQ_FULL_CFG["algorithm"] = "awq_full"
Expand Down Expand Up @@ -84,9 +82,6 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N

state_dict = mto.modelopt_state(model_quant)

if version is not None:
state_dict = format_modelopt_checkpoint_by_version(state_dict, version)

mto.restore_from_modelopt_state(model_ref, state_dict)
model_ref.load_state_dict(model_quant.state_dict())
assert torch.allclose(model_quant(calib_data[0]), model_ref(calib_data[0]))
Expand Down
5 changes: 0 additions & 5 deletions tests/gpu/torch/quantization/test_quantize_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,3 @@ def test_quantize(model_cls, config):
)
def test_save_restore(model_cls, quant_config):
save_restore_test(model_cls, "cuda", quant_config)


@pytest.mark.parametrize("version", [None, "0.29", "0.28"])
def test_save_restore_all_versions(version):
save_restore_test(SimpleLinear, "cuda", mtq.INT8_DEFAULT_CFG, version=version)
Loading