diff --git a/benchmarks/benchmark_aq.py b/benchmarks/benchmark_aq.py index 7dd732debc..5106eb5494 100644 --- a/benchmarks/benchmark_aq.py +++ b/benchmarks/benchmark_aq.py @@ -75,11 +75,13 @@ def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs """ from torchao.quantization.quant_api import ( _get_subclass_inserter, - _in_features_greater_than_16, _is_linear, ) from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight + def _in_features_greater_than_16(mod, *args): + return hasattr(mod, "in_features") and mod.in_features > 16 + if filter_fn is None: filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16( *args diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 5c29f0b8ad..455a51061b 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -40,7 +40,6 @@ from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, _replace_with_custom_fn_if_matches_filter, - change_linear_weights_to_int8_dqtensors, int4_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, @@ -1852,11 +1851,6 @@ class TestAOTI(unittest.TestCase): list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)), ) def test_aoti(self, api, test_device, test_dtype): - if api is change_linear_weights_to_int8_dqtensors and test_device == "cuda": - self.skipTest( - f"{api} in {test_device} is not support for aoti compilation yet" - ) - if ( test_device == "cuda" and torch.cuda.is_available() diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 3b26cd25d6..f979c9a588 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -146,32 +146,6 @@ def forward(self, x): return x -def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs): - """ - The deprecated implementation for int8 dynamic quant API, used as a reference for - numerics and performance - """ - from torchao.quantization.quant_api import ( - _get_subclass_inserter, - _in_features_greater_than_16, - _is_linear, - ) - from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight - - if filter_fn is None: - filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16( - *args - ) - - _replace_with_custom_fn_if_matches_filter( - model, - _get_subclass_inserter( - Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs - ), - filter_fn, - ) - - def _get_ref_change_linear_weights_to_woqtensors(deprecated_tenosr_subclass): def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs): """ diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index fa0293bf82..ab3a27f05a 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -125,7 +125,6 @@ be applied individually. While there are a large variety of quantization apis, t #### A16W4 WeightOnly Quantization ```python -# for torch 2.4+ from torchao.quantization import quantize_, Int4WeightOnlyConfig group_size = 32 @@ -133,10 +132,6 @@ group_size = 32 # use_hqq flag for `Int4WeightOnlyConfig` quantization use_hqq = False quantize_(model, Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq)) - -# for torch 2.2.2 and 2.3 -from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors -change_linear_weights_to_int4_woqtensors(model) ``` Note: The quantization error incurred by applying int4 quantization to your model can be fairly significant, so using external techniques like GPTQ may be necessary to obtain a usable model. @@ -144,25 +139,15 @@ Note: The quantization error incurred by applying int4 quantization to your mode #### A16W8 Int8 WeightOnly Quantization ```python -# for torch 2.4+ from torchao.quantization import quantize_, Int8WeightOnlyConfig quantize_(model, Int8WeightOnlyConfig()) - -# for torch 2.2.2 and 2.3 -from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors -change_linear_weights_to_int8_woqtensors(model) ``` #### A8W8 Int8 Dynamic Quantization ```python -# for torch 2.4+ from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig quantize_(model, Int8DynamicActivationInt8WeightConfig()) - -# for torch 2.2.2 and 2.3 -from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors -change_linear_weights_to_int8_dqtensors(model) ``` ### A16W8 Float8 WeightOnly Quantization diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 4e6bf7fa41..25492d00b8 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -120,9 +120,6 @@ ZeroPointDomain, ) from .subclass import ( - Int4WeightOnlyQuantizedLinearWeight, - Int8DynamicallyQuantizedLinearWeight, - Int8WeightOnlyQuantizedLinearWeight, QuantizedLinearWeightBase, ) from .unified import Quantizer, TwoStepQuantizer @@ -172,109 +169,6 @@ } -###### -# TO BE DEPRECATED START -###### -def _in_features_greater_than_16(mod, *args): - return hasattr(mod, "in_features") and mod.in_features > 16 - - -# TODO: delete -def change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs): - """ - Converts all linear weight tensors to the `Int8DynamicallyQuantizedLinearWeight` - Tensor subclass, effectively applying the same form of quantization - as apply_dynamic_quant while not modifying the linear modules. - """ - raise ImportError( - "This API is deprecated for pytorch 2.4+, please checkout quantization/README.md for most up to date APIs" - ) - - if filter_fn is None: - filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16( - *args - ) - - _replace_with_custom_fn_if_matches_filter( - model, - _get_subclass_inserter( - Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs - ), - filter_fn, - ) - - -# TODO: delete -def change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs): - """ - Converts all linear weight tensors to the - `Int8WeightOnlyQuantizedLinearWeight` tensor subclass, - effectively applying the same form of quantization - as apply_weight_only_int8_quant while not modifying the linear modules. - """ - raise ImportError( - "This API is deprecated for pytorch 2.4+, please checkout quantization/README.md for most up to date APIs" - ) - - _replace_with_custom_fn_if_matches_filter( - model, - _get_subclass_inserter( - Int8WeightOnlyQuantizedLinearWeight, enable_parametrization=False, **kwargs - ), - _is_linear if filter_fn is None else filter_fn, - ) - - -# TODO: delete -def change_linear_weights_to_int4_woqtensors( - model, - groupsize=128, - inner_k_tiles=8, - filter_fn=None, - zero_point_domain=ZeroPointDomain.FLOAT, - preserve_zero=False, -): - """ - Converts all linear weight tensors to the - `Int4WeightOnlyQuantizedLinearWeight` tensor subclass, - effectively applying the same form of quantization - as apply_dynamic_quant while not modifying the linear modules. - Args: - `groupsize`: parameter for quantization, controls the granularity of quantization, smaller - size is more fine grained, choices are [256, 128, 64, 32] - `inner_k_tiles`: parameter for int4 mm kernel, choices are [8, 4, 2] - `filter_fn`: function that takes a nn.Module instance and fully qualified name of the module, \ - returns True if we want to run `config` on - `zero_point_domain`: data type of zeros points, choices are [ZeroPointDomain.FLOAT, \ - ZeroPointDomain.INT, ZeroPointDomain.NONE] - `preserve_zero`: whether to preserve zero, default is False - """ - raise ImportError( - "This API is deprecated for pytorch 2.4+, please checkout quantization/README.md for most up to date APIs" - ) - - if filter_fn is None: - filter_fn = _is_linear - - _replace_with_custom_fn_if_matches_filter( - model, - _get_subclass_inserter( - Int4WeightOnlyQuantizedLinearWeight, - enable_parametrization=False, - groupsize=groupsize, - inner_k_tiles=inner_k_tiles, - zero_point_domain=zero_point_domain, - preserve_zero=preserve_zero, - ), - filter_fn, - ) - - -######## -# TO BE DEPRECATED END -######## - - def _replace_with_custom_fn_if_matches_filter( model, replacement_fn,