From 0bde1e12e49c0d5a56eb52769969244fab583433 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 10 Oct 2025 08:21:12 -0700 Subject: [PATCH 1/2] Remove config functions like `int4_weight_only` **Summary:** As a follow-up to https://github.com/pytorch/ao/pull/2994, this commit removes all quantization functions that were used as configs. These functions were deprecated in 0.14.0 and will be removed in the next release, 0.15.0. **Test Plan:** CI [ghstack-poisoned] --- README.md | 2 +- test/quantization/test_quant_api.py | 51 -------------------- torchao/quantization/__init__.py | 24 --------- torchao/quantization/quant_api.py | 75 +---------------------------- torchao/utils.py | 21 +------- 5 files changed, 3 insertions(+), 170 deletions(-) diff --git a/README.md b/README.md index 9330900300..a1d474ae02 100644 --- a/README.md +++ b/README.md @@ -258,7 +258,7 @@ Our framework makes it straightforward to add tensor parallel support to your cu We've added support for authoring and releasing [custom ops](./torchao/csrc/) that do not graph break with `torch.compile()`. We have a few examples you can follow -1. [fp6](torchao/dtypes/floatx/README.md) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fpx_weight_only(3, 2))` +1. [fp6](torchao/dtypes/floatx/README.md) for 2x faster inference over fp16 with an easy to use API `quantize_(model, FPXWeightOnlyConfig(3, 2))` 2. [2:4 Sparse Marlin GEMM](https://github.com/pytorch/ao/pull/733) 2x speedups for FP16xINT4 kernels even at batch sizes up to 256 3. [int4 tinygemm unpacker](https://github.com/pytorch/ao/pull/415) which makes it easier to switch quantized backends for inference diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index d7d2b4a5b4..2b3538195e 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -10,7 +10,6 @@ import gc import tempfile import unittest -import warnings from pathlib import Path import torch @@ -847,56 +846,6 @@ def test_int4wo_cuda_serialization(self): # load state_dict in cuda model.load_state_dict(sd, assign=True) - def test_config_deprecation(self): - """ - Test that old config functions like `int4_weight_only` trigger deprecation warnings. - """ - from torchao.quantization import ( - float8_dynamic_activation_float8_weight, - float8_static_activation_float8_weight, - float8_weight_only, - fpx_weight_only, - gemlite_uintx_weight_only, - int4_dynamic_activation_int4_weight, - int4_weight_only, - int8_dynamic_activation_int4_weight, - int8_dynamic_activation_int8_weight, - int8_weight_only, - uintx_weight_only, - ) - - # Reset deprecation warning state, otherwise we won't log warnings here - warnings.resetwarnings() - - # Map from deprecated API to the args needed to instantiate it - deprecated_apis_to_args = { - float8_dynamic_activation_float8_weight: (), - float8_static_activation_float8_weight: (torch.randn(3)), - float8_weight_only: (), - fpx_weight_only: (3, 2), - gemlite_uintx_weight_only: (), - int4_dynamic_activation_int4_weight: (), - int4_weight_only: (), - int8_dynamic_activation_int4_weight: (), - int8_dynamic_activation_int8_weight: (), - int8_weight_only: (), - uintx_weight_only: (torch.uint4,), - } - - with warnings.catch_warnings(record=True) as _warnings: - # Call each deprecated API twice - for cls, args in deprecated_apis_to_args.items(): - cls(*args) - cls(*args) - - # Each call should trigger the warning only once - self.assertEqual(len(_warnings), len(deprecated_apis_to_args)) - for w in _warnings: - self.assertIn( - "is deprecated and will be removed in a future release", - str(w.message), - ) - common_utils.instantiate_parametrized_tests(TestQuantFlow) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index c8774e9426..aa19aa1890 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -64,21 +64,9 @@ PlainLayout, TensorCoreTiledLayout, UIntXWeightOnlyConfig, - float8_dynamic_activation_float8_weight, - float8_static_activation_float8_weight, - float8_weight_only, - fpx_weight_only, - gemlite_uintx_weight_only, - int4_dynamic_activation_int4_weight, - int4_weight_only, - int8_dynamic_activation_int4_weight, - int8_dynamic_activation_int8_semi_sparse_weight, - int8_dynamic_activation_int8_weight, - int8_weight_only, intx_quantization_aware_training, quantize_, swap_conv2d_1x1_to_linear, - uintx_weight_only, ) from .quant_primitives import ( MappingType, @@ -131,19 +119,7 @@ "ALL_AUTOQUANT_CLASS_LIST", # top level API - manual "quantize_", - "int4_dynamic_activation_int4_weight", - "int8_dynamic_activation_int4_weight", - "int8_dynamic_activation_int8_weight", - "int8_dynamic_activation_int8_semi_sparse_weight", - "int4_weight_only", - "int8_weight_only", "intx_quantization_aware_training", - "float8_weight_only", - "float8_dynamic_activation_float8_weight", - "float8_static_activation_float8_weight", - "uintx_weight_only", - "fpx_weight_only", - "gemlite_uintx_weight_only", "swap_conv2d_1x1_to_linear", "Int4DynamicActivationInt4WeightConfig", "Int8DynamicActivationInt4WeightConfig", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 3bda8f91ab..4a2fe99c1d 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -96,7 +96,6 @@ to_weight_tensor_with_linear_activation_quantization_metadata, ) from torchao.utils import ( - _ConfigDeprecationWrapper, is_MI300, is_sm_at_least_89, is_sm_at_least_90, @@ -148,18 +147,7 @@ "autoquant", "_get_subclass_inserter", "quantize_", - "int8_dynamic_activation_int4_weight", - "int8_dynamic_activation_int8_weight", - "int8_dynamic_activation_int8_semi_sparse_weight", - "int4_weight_only", - "int8_weight_only", "intx_quantization_aware_training", - "float8_weight_only", - "uintx_weight_only", - "fpx_weight_only", - "gemlite_uintx_weight_only", - "float8_dynamic_activation_float8_weight", - "float8_static_activation_float8_weight", "Int8DynActInt4WeightQuantizer", "Float8DynamicActivationFloat8SemiSparseWeightConfig", "ModuleFqnToConfig", @@ -519,7 +507,7 @@ def quantize_( # Int8DynamicActivationInt8WeightConfig (optimized with int8 mm op and torch.compile) # Int4WeightOnlyConfig (optimized with int4 tinygemm kernel and torch.compile) # Int8WeightOnlyConfig (optimized with int8 mm op and torch.compile - from torchao.quantization.quant_api import int4_weight_only + from torchao.quantization.quant_api import Int4WeightOnlyConfig m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) quantize_(m, Int4WeightOnlyConfig(group_size=32, version=1)) @@ -641,12 +629,6 @@ def __post_init__(self): ) -# for BC -int8_dynamic_activation_int4_weight = _ConfigDeprecationWrapper( - "int8_dynamic_activation_int4_weight", Int8DynamicActivationInt4WeightConfig -) - - @register_quantize_module_handler(Int8DynamicActivationInt4WeightConfig) def _int8_dynamic_activation_int4_weight_transform( module: torch.nn.Module, @@ -1012,12 +994,6 @@ def __post_init__(self): ) -# for bc -int4_dynamic_activation_int4_weight = _ConfigDeprecationWrapper( - "int4_dynamic_activation_int4_weight", Int4DynamicActivationInt4WeightConfig -) - - @register_quantize_module_handler(Int4DynamicActivationInt4WeightConfig) def _int4_dynamic_activation_int4_weight_transform( module: torch.nn.Module, config: Int4DynamicActivationInt4WeightConfig @@ -1075,12 +1051,6 @@ def __post_init__(self): ) -# for BC -gemlite_uintx_weight_only = _ConfigDeprecationWrapper( - "gemlite_uintx_weight_only", GemliteUIntXWeightOnlyConfig -) - - @register_quantize_module_handler(GemliteUIntXWeightOnlyConfig) def _gemlite_uintx_weight_only_transform( module: torch.nn.Module, config: GemliteUIntXWeightOnlyConfig @@ -1158,11 +1128,6 @@ def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.Int4WeightOnlyConfig") -# for BC -# TODO maybe change other callsites -int4_weight_only = _ConfigDeprecationWrapper("int4_weight_only", Int4WeightOnlyConfig) - - def _int4_weight_only_quantize_tensor(weight, config): # TODO(future PR): perhaps move this logic to a different file, to keep the API # file clean of implementation details @@ -1374,10 +1339,6 @@ def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.Int8WeightOnlyConfig") -# for BC -int8_weight_only = _ConfigDeprecationWrapper("int8_weight_only", Int8WeightOnlyConfig) - - def _int8_weight_only_quantize_tensor(weight, config): mapping_type = MappingType.SYMMETRIC target_dtype = torch.int8 @@ -1535,12 +1496,6 @@ def __post_init__(self): ) -# for BC -int8_dynamic_activation_int8_weight = _ConfigDeprecationWrapper( - "int8_dynamic_activation_int8_weight", Int8DynamicActivationInt8WeightConfig -) - - def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config): layout = config.layout act_mapping_type = config.act_mapping_type @@ -1646,12 +1601,6 @@ def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.Float8WeightOnlyConfig") -# for BC -float8_weight_only = _ConfigDeprecationWrapper( - "float8_weight_only", Float8WeightOnlyConfig -) - - def _float8_weight_only_quant_tensor(weight, config): if config.version == 1: warnings.warn( @@ -1806,12 +1755,6 @@ def __post_init__(self): self.granularity = [activation_granularity, weight_granularity] -# for bc -float8_dynamic_activation_float8_weight = _ConfigDeprecationWrapper( - "float8_dynamic_activation_float8_weight", Float8DynamicActivationFloat8WeightConfig -) - - def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): activation_dtype = config.activation_dtype weight_dtype = config.weight_dtype @@ -1981,12 +1924,6 @@ def __post_init__(self): ) -# for bc -float8_static_activation_float8_weight = _ConfigDeprecationWrapper( - "float8_static_activation_float8_weight", Float8StaticActivationFloat8WeightConfig -) - - @register_quantize_module_handler(Float8StaticActivationFloat8WeightConfig) def _float8_static_activation_float8_weight_transform( module: torch.nn.Module, config: Float8StaticActivationFloat8WeightConfig @@ -2066,12 +2003,6 @@ def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.UIntXWeightOnlyConfig") -# for BC -uintx_weight_only = _ConfigDeprecationWrapper( - "uintx_weight_only", UIntXWeightOnlyConfig -) - - @register_quantize_module_handler(UIntXWeightOnlyConfig) def _uintx_weight_only_transform( module: torch.nn.Module, config: UIntXWeightOnlyConfig @@ -2350,10 +2281,6 @@ def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.FPXWeightOnlyConfig") -# for BC -fpx_weight_only = _ConfigDeprecationWrapper("fpx_weight_only", FPXWeightOnlyConfig) - - @register_quantize_module_handler(FPXWeightOnlyConfig) def _fpx_weight_only_transform( module: torch.nn.Module, config: FPXWeightOnlyConfig diff --git a/torchao/utils.py b/torchao/utils.py index 5af3e00cfa..9dfebfb6fb 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -12,7 +12,7 @@ from functools import reduce from importlib.metadata import version from math import gcd -from typing import Any, Callable, Optional, Type +from typing import Any, Callable, Optional import torch import torch.nn.utils.parametrize as parametrize @@ -433,25 +433,6 @@ def __eq__(self, other): TORCH_VERSION_AFTER_2_2 = _deprecated_torch_version_after("2.2.0.dev") -class _ConfigDeprecationWrapper: - """ - A deprecation wrapper that directs users from a deprecated "config function" - (e.g. `int4_weight_only`) to the replacement config class. - """ - - def __init__(self, deprecated_name: str, config_cls: Type): - self.deprecated_name = deprecated_name - self.config_cls = config_cls - - def __call__(self, *args, **kwargs): - warnings.warn( - f"`{self.deprecated_name}` is deprecated and will be removed in a future release. " - f"Please use `{self.config_cls.__name__}` instead. Example usage:\n" - f" quantize_(model, {self.config_cls.__name__}(...))" - ) - return self.config_cls(*args, **kwargs) - - """ Helper function for implementing aten op or torch function dispatch and dispatching to these implementations. From b0a4f3990ba4d47f22c8fee414ee061474f7831f Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 10 Oct 2025 08:21:15 -0700 Subject: [PATCH 2/2] Remove old TORCH_VERSION variables **Summary:** As a follow-up to https://github.com/pytorch/ao/pull/2719, which deprecated these variables in 0.13.0, we remove them now in the next release 0.15.0. **Test Plan:** CI [ghstack-poisoned] --- test/test_utils.py | 50 ---------------------------------- torchao/utils.py | 67 ---------------------------------------------- 2 files changed, 117 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index b46d600053..0e77388f13 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. import unittest -import warnings from unittest.mock import patch import torch @@ -37,55 +36,6 @@ def test_torch_version_at_least(self): f"Failed for torch.__version__={torch_version}, comparing with {compare_version}", ) - def test_torch_version_deprecation(self): - """ - Test that TORCH_VERSION_AT_LEAST* and TORCH_VERSION_AFTER* - trigger deprecation warnings on use, not on import. - """ - # Reset deprecation warning state, otherwise we won't log warnings here - warnings.resetwarnings() - - # Importing and referencing should not trigger deprecation warning - with warnings.catch_warnings(record=True) as _warnings: - from torchao.utils import ( - TORCH_VERSION_AFTER_2_2, - TORCH_VERSION_AFTER_2_3, - TORCH_VERSION_AFTER_2_4, - TORCH_VERSION_AFTER_2_5, - TORCH_VERSION_AT_LEAST_2_2, - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, - TORCH_VERSION_AT_LEAST_2_7, - TORCH_VERSION_AT_LEAST_2_8, - ) - - deprecated_api_to_name = [ - (TORCH_VERSION_AT_LEAST_2_8, "TORCH_VERSION_AT_LEAST_2_8"), - (TORCH_VERSION_AT_LEAST_2_7, "TORCH_VERSION_AT_LEAST_2_7"), - (TORCH_VERSION_AT_LEAST_2_6, "TORCH_VERSION_AT_LEAST_2_6"), - (TORCH_VERSION_AT_LEAST_2_5, "TORCH_VERSION_AT_LEAST_2_5"), - (TORCH_VERSION_AT_LEAST_2_4, "TORCH_VERSION_AT_LEAST_2_4"), - (TORCH_VERSION_AT_LEAST_2_3, "TORCH_VERSION_AT_LEAST_2_3"), - (TORCH_VERSION_AT_LEAST_2_2, "TORCH_VERSION_AT_LEAST_2_2"), - (TORCH_VERSION_AFTER_2_5, "TORCH_VERSION_AFTER_2_5"), - (TORCH_VERSION_AFTER_2_4, "TORCH_VERSION_AFTER_2_4"), - (TORCH_VERSION_AFTER_2_3, "TORCH_VERSION_AFTER_2_3"), - (TORCH_VERSION_AFTER_2_2, "TORCH_VERSION_AFTER_2_2"), - ] - self.assertEqual(len(_warnings), 0) - - # Accessing the boolean value should trigger deprecation warning - with warnings.catch_warnings(record=True) as _warnings: - for api, name in deprecated_api_to_name: - num_warnings_before = len(_warnings) - if api: - pass - regex = f"{name} is deprecated and will be removed" - self.assertEqual(len(_warnings), num_warnings_before + 1) - self.assertIn(regex, str(_warnings[-1].message)) - class TestTorchAOBaseTensor(unittest.TestCase): def test_print_arg_types(self): diff --git a/torchao/utils.py b/torchao/utils.py index 9dfebfb6fb..b010c4f9b8 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -8,7 +8,6 @@ import itertools import re import time -import warnings from functools import reduce from importlib.metadata import version from math import gcd @@ -34,17 +33,6 @@ "is_sm_at_least_90", "is_package_at_least", "DummyModule", - # Deprecated - "TORCH_VERSION_AT_LEAST_2_2", - "TORCH_VERSION_AT_LEAST_2_3", - "TORCH_VERSION_AT_LEAST_2_4", - "TORCH_VERSION_AT_LEAST_2_5", - "TORCH_VERSION_AT_LEAST_2_6", - "TORCH_VERSION_AT_LEAST_2_7", - "TORCH_VERSION_AFTER_2_2", - "TORCH_VERSION_AFTER_2_3", - "TORCH_VERSION_AFTER_2_4", - "TORCH_VERSION_AFTER_2_5", ] @@ -378,61 +366,6 @@ def torch_version_at_least(min_version): return parse_version(torch.__version__) >= parse_version(min_version) -def _deprecated_torch_version_at_least(version_str: str) -> str: - """ - Wrapper for existing TORCH_VERSION_AT_LEAST* variables that will log - a deprecation warning if the variable is used. - """ - version_str_var_name = "_".join(version_str.split(".")[:2]) - deprecation_msg = f"TORCH_VERSION_AT_LEAST_{version_str_var_name} is deprecated and will be removed in torchao 0.14.0" - return _BoolDeprecationWrapper( - torch_version_at_least(version_str), - deprecation_msg, - ) - - -def _deprecated_torch_version_after(version_str: str) -> str: - """ - Wrapper for existing TORCH_VERSION_AFTER* variables that will log - a deprecation warning if the variable is used. - """ - bool_value = is_fbcode() or version("torch") >= version_str - version_str_var_name = "_".join(version_str.split(".")[:2]) - deprecation_msg = f"TORCH_VERSION_AFTER_{version_str_var_name} is deprecated and will be removed in torchao 0.14.0" - return _BoolDeprecationWrapper(bool_value, deprecation_msg) - - -class _BoolDeprecationWrapper: - """ - A deprecation wrapper that logs a warning when the given bool value is accessed. - """ - - def __init__(self, bool_value: bool, msg: str): - self.bool_value = bool_value - self.msg = msg - - def __bool__(self): - warnings.warn(self.msg) - return self.bool_value - - def __eq__(self, other): - return bool(self) == bool(other) - - -# Deprecated, use `torch_version_at_least` directly instead -TORCH_VERSION_AT_LEAST_2_8 = _deprecated_torch_version_at_least("2.8.0") -TORCH_VERSION_AT_LEAST_2_7 = _deprecated_torch_version_at_least("2.7.0") -TORCH_VERSION_AT_LEAST_2_6 = _deprecated_torch_version_at_least("2.6.0") -TORCH_VERSION_AT_LEAST_2_5 = _deprecated_torch_version_at_least("2.5.0") -TORCH_VERSION_AT_LEAST_2_4 = _deprecated_torch_version_at_least("2.4.0") -TORCH_VERSION_AT_LEAST_2_3 = _deprecated_torch_version_at_least("2.3.0") -TORCH_VERSION_AT_LEAST_2_2 = _deprecated_torch_version_at_least("2.2.0") -TORCH_VERSION_AFTER_2_5 = _deprecated_torch_version_after("2.5.0.dev") -TORCH_VERSION_AFTER_2_4 = _deprecated_torch_version_after("2.4.0.dev") -TORCH_VERSION_AFTER_2_3 = _deprecated_torch_version_after("2.3.0.dev") -TORCH_VERSION_AFTER_2_2 = _deprecated_torch_version_after("2.2.0.dev") - - """ Helper function for implementing aten op or torch function dispatch and dispatching to these implementations.