diff --git a/modelopt/torch/quantization/conversion.py b/modelopt/torch/quantization/conversion.py index 3c5208778..6a457f172 100644 --- a/modelopt/torch/quantization/conversion.py +++ b/modelopt/torch/quantization/conversion.py @@ -198,7 +198,6 @@ def _replace_quant_module(model: nn.Module, version=None, registry=QuantModuleRe # REPLACE on the parent (model), not on child quantized = registry.convert(child) setattr(model, name, quantized) - quantized.mopt_ckpt_versn = version # now recurse into whichever module is now at `model.name` _replace_quant_module(getattr(model, name), version=version, registry=registry) diff --git a/modelopt/torch/quantization/nn/modules/quant_module.py b/modelopt/torch/quantization/nn/modules/quant_module.py index 710307c07..93df3651b 100644 --- a/modelopt/torch/quantization/nn/modules/quant_module.py +++ b/modelopt/torch/quantization/nn/modules/quant_module.py @@ -37,24 +37,6 @@ class QuantModule(DynamicModule): """A base class for quantized modules.""" - @property - def mopt_ckpt_versn(self): - """Checkpoint version of the modelopt.""" - for module in self.modules(): - if isinstance(module, TensorQuantizer): - return module.mopt_ckpt_versn - return None - - @mopt_ckpt_versn.setter - def mopt_ckpt_versn(self, version: str): - """Set the checkpoint version for the TensorQuantizer states.""" - - def _set_ckpt_version(module): - if isinstance(module, TensorQuantizer): - module.mopt_ckpt_versn = version - - self.apply(_set_ckpt_version) - def modelopt_post_restore(self, prefix: str = ""): """Post-restore to correctly configure the TensorQuantizer states. diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index b1b2543ac..9846f3554 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -29,7 +29,6 @@ DTensor = None import torch.nn.functional as F -from packaging.version import Version from torch import nn from torch.onnx._globals import GLOBALS @@ -1023,48 +1022,6 @@ def extra_repr(self): s += " calib" if (self._if_calib) else "" return s - @property - def mopt_ckpt_versn(self): - """Version of the checkpoint if it is restored from a checkpoint.""" - return getattr(self, "_mopt_ckpt_versn", None) - - @mopt_ckpt_versn.setter - def mopt_ckpt_versn(self, version: str): - self._mopt_ckpt_versn = str(version) - - def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): - """Special handling for loading older checkpoints. - - This implementation is for backward compatibility and can be deprecated in future versions. - - Args: - state_dict: A dict containing the state of the top level module - prefix: A string that prefixes all of this modules state in state_dict, e.g. 'model.conv1.' - """ - if self.mopt_ckpt_versn is None or Version(self.mopt_ckpt_versn) >= Version("0.29"): - # Warnings below are raised if users use partial state dictionary intentionally (eg:- HF ckpts) - # For ModelOpt >= 0.29, the buffers will be correctly created, So lets skip the warnings - return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) - - _attrs = ["_amax", "_pre_quant_scale", "_svdquant_lora_a", "_svdquant_lora_b"] - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - for attr in _attrs: - has_dst = attr in self._buffers - has_src = prefix + attr in state_dict - - if not has_src and has_dst: - warnings.warn(f"{prefix[:-1]}: No {attr} in state_dict.") - elif has_src and not has_dst: - warnings.warn( - f"{prefix[:-1]}: No '{attr}' buffer to load {attr} into." - f" '{attr}` is created as a buffer for now. Please move the model to the correct device and " - "dtype after this by calling `model.to(device, dtype)`." - ) - self.register_buffer(attr, state_dict[prefix + attr].clone().detach().to(device)) - - super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) - def _get_properties_for_modelopt_state(self): return ( self.__dict__.keys() diff --git a/modelopt/torch/quantization/plugins/custom.py b/modelopt/torch/quantization/plugins/custom.py index 6e89b6668..4227f3c49 100644 --- a/modelopt/torch/quantization/plugins/custom.py +++ b/modelopt/torch/quantization/plugins/custom.py @@ -21,9 +21,7 @@ from types import ModuleType import torch -from packaging.version import Version -from modelopt import __version__ from modelopt.torch.utils.distributed import ParallelState from ..nn import QuantModule, SequentialQuantizer, TensorQuantizer @@ -174,19 +172,3 @@ def _has_state(quantizer, name): max_calibrate(self.output_quantizer, lambda oq: oq(dummy_input), distributed_sync=False) # If there are any other states, lets move them to the correct device super().modelopt_post_restore(prefix=prefix) - - def is_version_less_than(self, version: str) -> bool: - self_version = ( - Version(self.mopt_ckpt_versn) - if self.mopt_ckpt_versn is not None - else Version(__version__) - ) - - # version in NeMo container is 0.0.0 if installed from source without history - if self_version < Version(version) and self_version != Version("0.0.0"): - warnings.warn( - f"Checkpoint version {self_version} is less than {version}. " - "Please re-save model to avoid this warning." - ) - return True - return False diff --git a/modelopt/torch/quantization/plugins/vllm.py b/modelopt/torch/quantization/plugins/vllm.py index fb606b0d5..f5c10c87e 100644 --- a/modelopt/torch/quantization/plugins/vllm.py +++ b/modelopt/torch/quantization/plugins/vllm.py @@ -185,15 +185,3 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): output = super().forward(hidden_states, router_logits) self.invoke_fused_moe_quantized = self._invoke_fused_moe_quantized return output - - @property - def mopt_ckpt_versn(self): - """Checkpoint version of the modelopt.""" - return None - - @mopt_ckpt_versn.setter - def mopt_ckpt_versn(self, version: str): - """Set the checkpoint version for the TensorQuantizer states.""" - # vLLM defined an apply method that overwrites nn.Module.apply - # To avoid conflicting, disable the apply call here - # self.apply(_set_ckpt_version) diff --git a/tests/_test_utils/torch_quantization/checkpointing.py b/tests/_test_utils/torch_quantization/checkpointing.py deleted file mode 100644 index ec29320e3..000000000 --- a/tests/_test_utils/torch_quantization/checkpointing.py +++ /dev/null @@ -1,37 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy - -from packaging.version import Version - - -# Deprecation > 0.29 -def format_modelopt_checkpoint_by_version(modelopt_state: dict, version: str): - if Version(version) >= Version("0.29"): - return modelopt_state - modelopt_state = copy.deepcopy(modelopt_state) - modelopt_state["modelopt_version"] = version - for mode, state in modelopt_state["modelopt_state_dict"]: - if "quantizer_state" not in state["metadata"]: - continue - for quantizer_state in state["metadata"]["quantizer_state"].values(): - quantizer_state["_mopt_ckpt_versn"] = version - pyt_states = quantizer_state.pop("_pytorch_state_metadata", None) - if pyt_states is None: - continue - for k in pyt_states["buffers"]: - quantizer_state["_has" + k] = True - return modelopt_state diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index 02795099d..505eac2b6 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -26,8 +26,6 @@ from modelopt.torch.quantization.utils import is_quantized_linear from modelopt.torch.utils import torch_to -from .checkpointing import format_modelopt_checkpoint_by_version - INT4_AWQ_FULL_CFG = mtq.INT4_AWQ_CFG.copy() INT4_AWQ_FULL_CFG["algorithm"] = "awq_full" @@ -84,9 +82,6 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N state_dict = mto.modelopt_state(model_quant) - if version is not None: - state_dict = format_modelopt_checkpoint_by_version(state_dict, version) - mto.restore_from_modelopt_state(model_ref, state_dict) model_ref.load_state_dict(model_quant.state_dict()) assert torch.allclose(model_quant(calib_data[0]), model_ref(calib_data[0])) diff --git a/tests/gpu/torch/quantization/test_quantize_cuda.py b/tests/gpu/torch/quantization/test_quantize_cuda.py index 95cc02d37..176e6e10b 100644 --- a/tests/gpu/torch/quantization/test_quantize_cuda.py +++ b/tests/gpu/torch/quantization/test_quantize_cuda.py @@ -92,8 +92,3 @@ def test_quantize(model_cls, config): ) def test_save_restore(model_cls, quant_config): save_restore_test(model_cls, "cuda", quant_config) - - -@pytest.mark.parametrize("version", [None, "0.29", "0.28"]) -def test_save_restore_all_versions(version): - save_restore_test(SimpleLinear, "cuda", mtq.INT8_DEFAULT_CFG, version=version)