diff --git a/README.md b/README.md index 72fd2d7403..1f711fb0dc 100644 --- a/README.md +++ b/README.md @@ -180,10 +180,10 @@ Post-training quantization can result in a fast and compact model, but may also ```python from torchao.quantization import quantize_ -from torchao.quantization.qat import FakeQuantizeConfig, IntXQuantizationAwareTrainingConfig -activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) -weight_config = FakeQuantizeConfig(torch.int4, group_size=32) -qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config), +from torchao.quantization.qat import IntxFakeQuantizeConfig, QuantizationAwareTrainingConfig +activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) +weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) +qat_config = QuantizationAwareTrainingConfig(activation_config, weight_config), quantize_(my_model, qat_config) ``` diff --git a/docs/source/api_ref_qat.rst b/docs/source/api_ref_qat.rst index 046a1b74a4..9c20cc1f12 100644 --- a/docs/source/api_ref_qat.rst +++ b/docs/source/api_ref_qat.rst @@ -15,8 +15,8 @@ please refer to the `QAT README torch.Tensor: @@ -1059,7 +1064,7 @@ def test_fake_quantized_linear_4w(self): Test that we can express int4 weight only (tinygemm) with `FakeQuantizedLinear`. """ group_size = 128 - weight_config = FakeQuantizeConfig( + weight_config = IntxFakeQuantizeConfig( dtype=torch.uint4, group_size=group_size, is_symmetric=False, @@ -1172,7 +1177,9 @@ def test_fake_quantized_embedding_4w(self): fq_embedding = FakeQuantizedEmbedding( num_embeddings, embedding_dim, - weight_config=FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size), + weight_config=IntxFakeQuantizeConfig( + TorchAODType.INT4, group_size=group_size + ), ) def embedding_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: @@ -1258,7 +1265,7 @@ def test_quantize_api_standalone(self): """ Test that the following: - quantize_(model, intx_quantization_aware_training(...)) + quantize_(model, QuantizationAwareTrainingConfig(...)) can produce the same results as `ComposableQATQuantizer`. """ @@ -1283,19 +1290,19 @@ def test_quantize_api_standalone(self): baseline_model = baseline_quantizer.prepare(baseline_model) # quantize_ API - activation_config = FakeQuantizeConfig( + activation_config = IntxFakeQuantizeConfig( torch.int8, "per_token", is_symmetric=False, ) - weight_config = FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size) + weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size) quantize_( m, - intx_quantization_aware_training(activation_config, weight_config), + QuantizationAwareTrainingConfig(activation_config, weight_config), ) quantize_( m, - intx_quantization_aware_training(weight_config=weight_config), + QuantizationAwareTrainingConfig(weight_config=weight_config), filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding), ) @@ -1315,7 +1322,7 @@ def test_quantize_api_errors(self): Test that we throw exceptions with helpful error messages if `quantize_` runs into unexpected configurations. """ - my_config = FakeQuantizeConfig(torch.int8, group_size=32) + my_config = IntxFakeQuantizeConfig(torch.int8, group_size=32) m = M3() # Embedding currently only supports weight-only quantization @@ -1324,7 +1331,7 @@ def test_quantize_api_errors(self): ): quantize_( m, - intx_quantization_aware_training(my_config, my_config), + QuantizationAwareTrainingConfig(my_config, my_config), lambda m, _: isinstance(m, torch.nn.Embedding), ) @@ -1332,7 +1339,7 @@ def test_quantize_api_errors(self): with self.assertRaisesRegex(ValueError, "does not have QAT support"): quantize_( m, - intx_quantization_aware_training(my_config, my_config), + QuantizationAwareTrainingConfig(my_config, my_config), lambda m, _: isinstance(m, torch.nn.ReLU), ) @@ -1343,8 +1350,8 @@ def test_quantize_api_convert_path(self): """ Test that the following: - quantize_(model, intx_quantization_aware_training(...)) - quantize_(model, from_intx_quantization_aware_training(...)) + quantize_(model, QuantizationAwareTrainingConfig(...)) + quantize_(model, FromQuantizationAwareTrainingConfig(...)) quantize_(model, int8_dynamic_activation_int4_weight()) can produce the same results as `Int8DynActInt4WeightQATQuantizer` prepare + convert. @@ -1363,15 +1370,15 @@ def test_quantize_api_convert_path(self): baseline_model = baseline_quantizer.prepare(baseline_model) # quantize_ prepare - activation_config = FakeQuantizeConfig( + activation_config = IntxFakeQuantizeConfig( torch.int8, "per_token", is_symmetric=False, ) - weight_config = FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size) + weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size) quantize_( m, - intx_quantization_aware_training(activation_config, weight_config), + QuantizationAwareTrainingConfig(activation_config, weight_config), ) # Compare prepared values @@ -1386,7 +1393,7 @@ def test_quantize_api_convert_path(self): baseline_model = baseline_quantizer.convert(baseline_model) # quantize_ convert - quantize_(m, from_intx_quantization_aware_training()) + quantize_(m, FromQuantizationAwareTrainingConfig()) quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size)) # Compare converted values @@ -1402,11 +1409,11 @@ def test_quantize_api_convert_path(self): ) def test_fake_quantize_config_torch_intx(self): """ - Test that `FakeQuantizeConfig` works with torch.intx. + Test that `IntxFakeQuantizeConfig` works with torch.intx. """ group_size = 16 - config1 = FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size) - config2 = FakeQuantizeConfig(torch.int4, group_size=group_size) + config1 = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size) + config2 = IntxFakeQuantizeConfig(torch.int4, group_size=group_size) linear1 = FakeQuantizedLinear(32, 64, weight_config=config1) linear2 = FakeQuantizedLinear(32, 64, weight_config=config2) linear2.weight = linear1.weight @@ -1424,7 +1431,7 @@ def test_fake_quantizer_repr(self): """ Test that `repr(FakeQuantizer(config))` exposes useful config details. """ - config = FakeQuantizeConfig(torch.int4, group_size=128) + config = IntxFakeQuantizeConfig(torch.int4, group_size=128) fake_quantizer = FakeQuantizer(config) fake_quantizer_repr = repr(fake_quantizer) self.assertTrue("dtype=torch.int4" in fake_quantizer_repr) @@ -1440,13 +1447,13 @@ def test_qat_linear_bias(self): Test that QAT supports linear bias. """ m = ModelWithLinearBias() - activation_config = FakeQuantizeConfig( + activation_config = IntxFakeQuantizeConfig( torch.int8, "per_token", is_symmetric=False ) - weight_config = FakeQuantizeConfig(TorchAODType.INT4, group_size=32) + weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=32) quantize_( m, - intx_quantization_aware_training(activation_config, weight_config), + QuantizationAwareTrainingConfig(activation_config, weight_config), ) example_inputs = m.example_inputs() m(*example_inputs) @@ -1465,7 +1472,7 @@ def test_fake_quantize_per_token_vs_convert(self, dtype: torch.dtype): torch.manual_seed(self.SEED) x = torch.randn(1, 235, 2048).to(dtype) - config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) + config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) fake_quantizer = FakeQuantizer(config) fake_quantizer_out = fake_quantizer(x) baseline_out = per_token_dynamic_quant(x) @@ -1518,7 +1525,7 @@ def test_qat_8da4w_prepare_vs_convert(self, dtype: torch.dtype): ) def test_fake_quantize_config_eps(self): """ - Test that users can set arbitrary eps value in `FakeQuantizeConfig`. + Test that users can set arbitrary eps value in `IntxFakeQuantizeConfig`. """ eps = 0.00123 x = torch.randn(2, 3).to(torch.float32) @@ -1532,7 +1539,7 @@ def test_fake_quantize_config_eps(self): eps=eps, ) expected_out = _fake_quantize_per_token(x, scale, zp, -128, 127) - config = FakeQuantizeConfig( + config = IntxFakeQuantizeConfig( torch.int8, "per_token", is_symmetric=False, @@ -1598,7 +1605,7 @@ def test_fake_quantizer_range_learning(self): """ Test that range learning requires `FakeQuantizer`s to be initialized correctly. """ - config = FakeQuantizeConfig( + config = IntxFakeQuantizeConfig( torch.int8, "per_channel", is_dynamic=False, @@ -1636,7 +1643,7 @@ def test_qat_range_learning(self): """ Test end-to-end QAT flow with range learning. """ - config = FakeQuantizeConfig( + config = IntxFakeQuantizeConfig( torch.int8, "per_channel", is_dynamic=False, @@ -1646,7 +1653,7 @@ def test_qat_range_learning(self): ) m = M() example_inputs = m.example_inputs() - quantize_(m, IntXQuantizationAwareTrainingConfig(weight_config=config)) + quantize_(m, QuantizationAwareTrainingConfig(weight_config=config)) # Not initialized, should fail for t in m._get_all_weight_qparams(): diff --git a/torchao/experimental/tests/test_embedding_xbit_quantizer.py b/torchao/experimental/tests/test_embedding_xbit_quantizer.py index 442612410e..65cedf2682 100644 --- a/torchao/experimental/tests/test_embedding_xbit_quantizer.py +++ b/torchao/experimental/tests/test_embedding_xbit_quantizer.py @@ -21,10 +21,10 @@ ) from torchao.quantization.granularity import PerAxis, PerGroup from torchao.quantization.qat import ( - FakeQuantizeConfig, - FromIntXQuantizationAwareTrainingConfig, + FromQuantizationAwareTrainingConfig, Int4WeightOnlyEmbeddingQATQuantizer, - IntXQuantizationAwareTrainingConfig, + IntxFakeQuantizeConfig, + QuantizationAwareTrainingConfig, ) from torchao.quantization.quant_api import ( Int8DynamicActivationIntxWeightConfig, @@ -259,7 +259,7 @@ def test_identical_to_IntxWeightOnlyConfig( ], name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}", ) - def test_identical_to_IntXQuantizationAwareTrainingConfig( + def test_identical_to_QuantizationAwareTrainingConfig( self, weight_dtype, granularity, mapping_type, scale_dtype, model_dtype ): # ASYMMETRIC in QAT is very different that PTQ configs @@ -282,7 +282,7 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig( ) embedding_filter = lambda m, fqn: isinstance(m, torch.nn.Embedding) - weight_config = FakeQuantizeConfig( + weight_config = IntxFakeQuantizeConfig( weight_dtype, group_size=group_size, is_symmetric=is_symmetric, @@ -290,12 +290,12 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig( ) quantize_( model, - IntXQuantizationAwareTrainingConfig(weight_config=weight_config), + QuantizationAwareTrainingConfig(weight_config=weight_config), embedding_filter, ) prepared_out = model(indices) - quantize_(model, FromIntXQuantizationAwareTrainingConfig(), embedding_filter) + quantize_(model, FromQuantizationAwareTrainingConfig(), embedding_filter) quantize_( model, IntxWeightOnlyConfig( @@ -357,7 +357,7 @@ def test_identical_to_Int4WeightOnlyEmbeddingQATQuantizer( prepared_out = model(indices) # Convert model method 1 - quantize_(model, FromIntXQuantizationAwareTrainingConfig(), embedding_filter) + quantize_(model, FromQuantizationAwareTrainingConfig(), embedding_filter) quantize_( model, IntxWeightOnlyConfig( diff --git a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py index 08548b9e9e..d792734186 100644 --- a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py +++ b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py @@ -15,10 +15,10 @@ from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout, QDQLayout from torchao.quantization.granularity import PerAxis, PerGroup from torchao.quantization.qat import ( - FakeQuantizeConfig, - FromIntXQuantizationAwareTrainingConfig, + FromQuantizationAwareTrainingConfig, Int8DynActInt4WeightQATQuantizer, - IntXQuantizationAwareTrainingConfig, + IntxFakeQuantizeConfig, + QuantizationAwareTrainingConfig, ) from torchao.quantization.quant_api import ( Int8DynamicActivationInt4WeightConfig, @@ -498,7 +498,7 @@ def test_identical_to_Int8DynamicActivationInt4WeightConfig( ], name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}", ) - def test_identical_to_IntXQuantizationAwareTrainingConfig( + def test_identical_to_QuantizationAwareTrainingConfig( self, weight_dtype, group_size, @@ -530,12 +530,12 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig( model = model.to(model_dtype) activations = activations.to(model_dtype) - activation_config = FakeQuantizeConfig( + activation_config = IntxFakeQuantizeConfig( torch.int8, "per_token", is_symmetric=is_act_symmetric, ) - weight_config = FakeQuantizeConfig( + weight_config = IntxFakeQuantizeConfig( weight_dtype, group_size=group_size, is_symmetric=is_symmetric, @@ -544,7 +544,7 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig( quantize_( model, - IntXQuantizationAwareTrainingConfig(activation_config, weight_config), + QuantizationAwareTrainingConfig(activation_config, weight_config), ) try: prepared_out = model(activations) @@ -554,7 +554,7 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig( return raise e - quantize_(model, FromIntXQuantizationAwareTrainingConfig()) + quantize_(model, FromQuantizationAwareTrainingConfig()) quantize_( model, Int8DynamicActivationIntxWeightConfig( @@ -608,7 +608,7 @@ def test_identical_to_Int8DynActInt4WeightQATQuantizer( prepared_out = model(activations) # Convert model method 1 - quantize_(model, FromIntXQuantizationAwareTrainingConfig()) + quantize_(model, FromQuantizationAwareTrainingConfig()) quantize_( model, Int8DynamicActivationIntxWeightConfig( diff --git a/torchao/quantization/qat/README.md b/torchao/quantization/qat/README.md index 6395952ab5..785dae1868 100644 --- a/torchao/quantization/qat/README.md +++ b/torchao/quantization/qat/README.md @@ -71,7 +71,7 @@ def train_loop(m: torch.nn.Module): The recommended way to run QAT in torchao is through the `quantize_` API: 1. **Prepare:** specify how weights and/or activations are to be quantized through -[`FakeQuantizeConfig`](https://docs.pytorch.org/ao/main/generated/torchao.quantization.qat.FakeQuantizeConfig.html#torchao.quantization.qat.FakeQuantizeConfig) and passing these to [`IntXQuantizationAwareTrainingConfig`](https://docs.pytorch.org/ao/main/generated/torchao.quantization.qat.IntXQuantizationAwareTrainingConfig.html#torchao.quantization.qat.IntXQuantizationAwareTrainingConfig) +[`IntxFakeQuantizeConfig`](https://docs.pytorch.org/ao/main/generated/torchao.quantization.qat.IntxFakeQuantizeConfig.html#torchao.quantization.qat.IntxFakeQuantizeConfig) and passing these to [`QuantizationAwareTrainingConfig`](https://docs.pytorch.org/ao/main/generated/torchao.quantization.qat.QuantizationAwareTrainingConfig.html#torchao.quantization.qat.QuantizationAwareTrainingConfig) 2. **Convert:** quantize the model using the standard post-training quantization (PTQ) functions such as [`Int8DynamicActivationInt4WeightConfig`](https://docs.pytorch.org/ao/main/generated/torchao.quantization.Int8DynamicActivationInt4WeightConfig.html#torchao.quantization.Int8DynamicActivationInt4WeightConfig) @@ -84,19 +84,19 @@ from torchao.quantization import ( Int8DynamicActivationInt4WeightConfig, ) from torchao.quantization.qat import ( - FakeQuantizeConfig, - FromIntXQuantizationAwareTrainingConfig, - IntXQuantizationAwareTrainingConfig, + IntxFakeQuantizeConfig, + FromQuantizationAwareTrainingConfig, + QuantizationAwareTrainingConfig, ) model = get_model() # prepare: insert fake quantization ops # swaps `torch.nn.Linear` with `FakeQuantizedLinear` -activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) -weight_config = FakeQuantizeConfig(torch.int4, group_size=32) +activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) +weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) quantize_( model, - IntXQuantizationAwareTrainingConfig(activation_config, weight_config), + QuantizationAwareTrainingConfig(activation_config, weight_config), ) # train @@ -105,7 +105,7 @@ train_loop(model) # convert: transform fake quantization ops into actual quantized ops # swap `FakeQuantizedLinear` back to `torch.nn.Linear` and inserts # quantized activation and weight tensor subclasses -quantize_(model, FromIntXQuantizationAwareTrainingConfig()) +quantize_(model, FromQuantizationAwareTrainingConfig()) quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32)) # inference or generate @@ -116,18 +116,18 @@ the following with a filter function during the prepare step: ``` # first apply linear transformation to the model as above -activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) -weight_config = FakeQuantizeConfig(torch.int4, group_size=32) +activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) +weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) quantize_( model, - IntXQuantizationAwareTrainingConfig(activation_config, weight_config), + QuantizationAwareTrainingConfig(activation_config, weight_config), ) # then apply weight-only transformation to embedding layers # activation fake quantization is not supported for embedding layers quantize_( m, - IntXQuantizationAwareTrainingConfig(weight_config=weight_config), + QuantizationAwareTrainingConfig(weight_config=weight_config), filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding) ) ``` diff --git a/torchao/quantization/qat/__init__.py b/torchao/quantization/qat/__init__.py index 72cecfd254..6bd409248e 100644 --- a/torchao/quantization/qat/__init__.py +++ b/torchao/quantization/qat/__init__.py @@ -1,8 +1,9 @@ from .api import ( ComposableQATQuantizer, - FakeQuantizeConfig, FromIntXQuantizationAwareTrainingConfig, + FromQuantizationAwareTrainingConfig, IntXQuantizationAwareTrainingConfig, + QuantizationAwareTrainingConfig, from_intx_quantization_aware_training, initialize_fake_quantizers, intx_quantization_aware_training, @@ -11,6 +12,11 @@ FakeQuantizedEmbedding, Int4WeightOnlyEmbeddingQATQuantizer, ) +from .fake_quantize_config import ( + FakeQuantizeConfig, + FakeQuantizeConfigBase, + IntxFakeQuantizeConfig, +) from .fake_quantizer import FakeQuantizer from .linear import ( FakeQuantizedLinear, @@ -21,17 +27,22 @@ __all__ = [ "ComposableQATQuantizer", - "FakeQuantizeConfig", + "FakeQuantizeConfigBase", "FakeQuantizedLinear", "FakeQuantizedEmbedding", "FakeQuantizer", "Float8ActInt4WeightQATQuantizer", - "FromIntXQuantizationAwareTrainingConfig", + "FromQuantizationAwareTrainingConfig", "Int4WeightOnlyEmbeddingQATQuantizer", "Int4WeightOnlyQATQuantizer", "Int8DynActInt4WeightQATQuantizer", - "IntXQuantizationAwareTrainingConfig", + "IntxFakeQuantizeConfig", "initialize_fake_quantizers", - "intx_quantization_aware_training", + "QuantizationAwareTrainingConfig", + # for BC + "FakeQuantizeConfig", "from_intx_quantization_aware_training", + "FromIntXQuantizationAwareTrainingConfig", + "intx_quantization_aware_training", + "IntXQuantizationAwareTrainingConfig", ] diff --git a/torchao/quantization/qat/api.py b/torchao/quantization/qat/api.py index b7df56409f..2e63b1e504 100644 --- a/torchao/quantization/qat/api.py +++ b/torchao/quantization/qat/api.py @@ -5,256 +5,24 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple import torch from torchao.core.config import AOBaseConfig -from torchao.quantization.granularity import ( - Granularity, - PerAxis, - PerGroup, - PerToken, -) -from torchao.quantization.quant_primitives import ( - _SUB_BYTE_INT_BOUNDS, - _SUB_BYTE_UINT_BOUNDS, - MappingType, - TorchAODType, - ZeroPointDomain, -) from torchao.quantization.transform_module import ( register_quantize_module_handler, ) from torchao.quantization.unified import TwoStepQuantizer - -@dataclass -class FakeQuantizeConfig: - """ - Config for how to fake quantize weights or activations. - - Args: - dtype: dtype to simulate during fake quantization, e.g. torch.int8. - For PyTorch versions older than 2.6, you may use `TorchAODType` to represent - torch.int1 to torch.int7 instead, e.g. TorchAODType.INT4. - granularity: granularity of scales and zero points, e.g. PerGroup(32). - We also support the following strings: - 1) 'per_token': equivalent to PerToken() - 2) 'per_channel': equivalent to PerAxis(0) - 3) 'per_group': equivalent to PerGroup(group_size), must be combined - with separate `group_size` kwarg, Alternatively, just set the - `group_size` kwarg and leave this field empty. - mapping_type: whether to use symmetric (default) or asymmetric quantization - Alternatively, set `is_symmetric` (bool) and leave this field empty. - scale_precision: scale dtype (default torch.fp32) - zero_point_precision: zero point dtype (default torch.int32) - zero_point_domain: whether zero point is in integer (default) or float domain - is_dynamic: whether to use dynamic (default) or static scale and zero points - range_learning (prototype): whether to learn scale and zero points during training - (default false), not compatible with `is_dynamic`. - - Keyword args: - group_size: size of each group in per group fake quantization, - can be set instead of `granularity` - is_symmetric: whether to use symmetric or asymmetric quantization, - can be set instead of `mapping_type` - - Example usage:: - - # Per token asymmetric quantization - FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) - FakeQuantizeConfig(torch.int8, PerToken(), MappingType.ASYMMETRIC) - - # Per channel symmetric quantization - FakeQuantizeConfig(torch.int4, "per_channel") - FakeQuantizeConfig(torch.int4, "per_channel", is_symmetric=True) - FakeQuantizeConfig(torch.int4, PerAxis(0), MappingType.SYMMETRIC) - - # Per group symmetric quantization - FakeQuantizeConfig(torch.int4, group_size=32) - FakeQuantizeConfig(torch.int4, group_size=32, is_symmetric=True) - FakeQuantizeConfig(torch.int4, "per_group", group_size=32, is_symmetric=True) - FakeQuantizeConfig(torch.int4, PerGroup(32), MappingType.SYMMETRIC) - """ - - dtype: Union[torch.dtype, TorchAODType] - granularity: Granularity - mapping_type: MappingType - scale_precision: torch.dtype - zero_point_precision: torch.dtype - zero_point_domain: ZeroPointDomain - is_dynamic: bool = True - range_learning: bool = False - eps: Optional[float] = None - - def __init__( - self, - dtype: Union[torch.dtype, TorchAODType], - granularity: Union[Granularity, str, None] = None, - mapping_type: Optional[MappingType] = None, - scale_precision: torch.dtype = torch.float32, - zero_point_precision: torch.dtype = torch.int32, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, - is_dynamic: bool = True, - range_learning: bool = False, - eps: Optional[float] = None, - *, - group_size: Optional[int] = None, - is_symmetric: Optional[bool] = None, - ): - if zero_point_domain is None: - raise ValueError("Please use ZeroPointDomain.NONE instead of None") - self.dtype = dtype - self.granularity = self._get_granularity(granularity, group_size) - self.mapping_type = self._get_mapping_type(mapping_type, is_symmetric) - self.scale_precision = scale_precision - self.zero_point_precision = zero_point_precision - self.zero_point_domain = zero_point_domain - self.is_dynamic = is_dynamic - self.range_learning = range_learning - self.eps = eps - - # Validate dtype - all_dtypes = [torch.int8, torch.uint8] - all_dtypes.extend(list(_SUB_BYTE_INT_BOUNDS.keys())) - all_dtypes.extend(list(_SUB_BYTE_UINT_BOUNDS.keys())) - if dtype not in all_dtypes: - raise ValueError( - "Unsupported dtype '%s', choose from %s" % (dtype, all_dtypes) - ) - - # Dynamic is not compatible with range learning - if is_dynamic and range_learning: - raise ValueError("`is_dynamic` is not compatible with `range_learning`") - - def _get_granularity( - self, - granularity: Union[Granularity, str, None], - group_size: Optional[int], - ) -> Granularity: - """ - Parse the `Granularity` represented in the args. - - Granularity can be specified in one of three ways: - 1) `Granularity` object: one of PerToken(), PerAxis(), and PerGroup(group_size) - 2) str: one of 'per_token', 'per_channel', and 'per_group' - 3) None: `group_size` must be set instead, represents per group granularity - """ - # If group_size is set, then granularity must be either "per_group" or None - if ( - group_size is not None - and granularity != "per_group" - and granularity is not None - ): - raise ValueError( - "`group_size` conflicts with granularity '%s'" % granularity - ) - - # Case 1: Granularity object - if isinstance(granularity, Granularity): - if not isinstance(granularity, (PerToken, PerAxis, PerGroup)): - raise ValueError("Granularity '%s' is not supported" % granularity) - if isinstance(granularity, PerAxis) and granularity.axis != 0: - raise ValueError("Only axis=0 is supported for PerAxis granularity") - return granularity - - # Case 2: str granularity - if granularity == "per_token": - return PerToken() - elif granularity == "per_channel": - return PerAxis(axis=0) - elif granularity == "per_group": - if group_size is None: - raise ValueError( - "Granularity was 'per_group' but no `group_size` was set" - ) - return PerGroup(group_size) - elif isinstance(granularity, str): - raise ValueError( - "Unexpected granularity: '%s', must be one of %s" - % (granularity, ["per_token", "per_channel", "per_group"]) - ) - - # Case 3: None granularity + group_size was specified - if granularity is not None: - raise ValueError( - "Granularity '%s' has unexpected type %s" - % (granularity, type(granularity)) - ) - if group_size is None: - raise ValueError( - "At least one of `granularity` or `group_size` must be set" - ) - return PerGroup(group_size) - - def _get_mapping_type( - self, - mapping_type: Optional[MappingType], - is_symmetric: Optional[bool], - ) -> MappingType: - """ - Parse the `MappingType` represented in the args. - - Mapping type can be specified in one of two ways: - 1): `MappingType` object: one of SYMMETRIC or ASYMMETRIC - 2): is_symmetric bool - """ - if mapping_type is not None and is_symmetric is not None: - raise ValueError("Cannot set both `mapping_type` and `is_symmetric`") - - # Case 0: Default to symmetric - if mapping_type is None and is_symmetric is None: - return MappingType.SYMMETRIC - - # Case 1: MappingType object - if mapping_type is not None: - if mapping_type not in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC]: - raise ValueError("MappingType '%s' is not supported" % mapping_type) - return mapping_type - - # Case 2: is_symmetric flag - assert is_symmetric is not None - if is_symmetric: - return MappingType.SYMMETRIC - else: - return MappingType.ASYMMETRIC - - @property - def group_size(self) -> int: - """ - If this is per group granularity, return the group size. - Otherwise, throw an error. - """ - if isinstance(self.granularity, PerGroup): - return self.granularity.group_size - else: - raise ValueError( - "`group_size` is undefined for %s granularity" % self.granularity - ) - - @property - def is_symmetric(self) -> bool: - """ - Return True if mapping type is symmetric, else False (asymmetric). - """ - return self.mapping_type == MappingType.SYMMETRIC - - def __setattr__(self, name: str, value: Any): - """ - Support setting `group_size` and `is_symmetric`. - """ - if name == "group_size": - super().__setattr__("granularity", PerGroup(value)) - elif name == "is_symmetric": - mapping_type = MappingType.SYMMETRIC if value else MappingType.ASYMMETRIC - super().__setattr__("mapping_type", mapping_type) - else: - super().__setattr__(name, value) +from .fake_quantize_config import ( + FakeQuantizeConfig, # noqa: F401, for BC + FakeQuantizeConfigBase, +) @dataclass -class IntXQuantizationAwareTrainingConfig(AOBaseConfig): +class QuantizationAwareTrainingConfig(AOBaseConfig): """ Config for applying fake quantization to a `torch.nn.Module`. to be used with :func:`~torchao.quantization.quant_api.quantize_`. @@ -262,16 +30,16 @@ class IntXQuantizationAwareTrainingConfig(AOBaseConfig): Example usage:: from torchao.quantization import quantize_ - from torchao.quantization.qat import FakeQuantizeConfig - activation_config = FakeQuantizeConfig( + from torchao.quantization.qat import IntxFakeQuantizeConfig + activation_config = IntxFakeQuantizeConfig( torch.int8, "per_token", is_symmetric=False, ) - weight_config = FakeQuantizeConfig( + weight_config = IntxFakeQuantizeConfig( torch.int4, group_size=32, is_symmetric=True, ) quantize_( model, - IntXQuantizationAwareTrainingConfig(activation_config, weight_config), + QuantizationAwareTrainingConfig(activation_config, weight_config), ) Note: If the config is applied on a module that is not @@ -280,18 +48,19 @@ class IntXQuantizationAwareTrainingConfig(AOBaseConfig): ValueError as these are not supported. """ - activation_config: Optional[FakeQuantizeConfig] = None - weight_config: Optional[FakeQuantizeConfig] = None + activation_config: Optional[FakeQuantizeConfigBase] = None + weight_config: Optional[FakeQuantizeConfigBase] = None # for BC -intx_quantization_aware_training = IntXQuantizationAwareTrainingConfig +IntXQuantizationAwareTrainingConfig = QuantizationAwareTrainingConfig +intx_quantization_aware_training = QuantizationAwareTrainingConfig -@register_quantize_module_handler(IntXQuantizationAwareTrainingConfig) +@register_quantize_module_handler(QuantizationAwareTrainingConfig) def _intx_quantization_aware_training_transform( module: torch.nn.Module, - config: IntXQuantizationAwareTrainingConfig, + config: QuantizationAwareTrainingConfig, ) -> torch.nn.Module: from .embedding import FakeQuantizedEmbedding from .linear import FakeQuantizedLinear @@ -316,7 +85,7 @@ def _intx_quantization_aware_training_transform( raise ValueError("Module of type '%s' does not have QAT support" % type(mod)) -class FromIntXQuantizationAwareTrainingConfig(AOBaseConfig): +class FromQuantizationAwareTrainingConfig(AOBaseConfig): """ Config for converting a model with fake quantized modules, such as :func:`~torchao.quantization.qat.linear.FakeQuantizedLinear` @@ -330,7 +99,7 @@ class FromIntXQuantizationAwareTrainingConfig(AOBaseConfig): from torchao.quantization import quantize_ quantize_( model_with_fake_quantized_linears, - FromIntXQuantizationAwareTrainingConfig(), + FromQuantizationAwareTrainingConfig(), ) """ @@ -338,13 +107,14 @@ class FromIntXQuantizationAwareTrainingConfig(AOBaseConfig): # for BC -from_intx_quantization_aware_training = FromIntXQuantizationAwareTrainingConfig +FromIntXQuantizationAwareTrainingConfig = FromQuantizationAwareTrainingConfig +from_intx_quantization_aware_training = FromQuantizationAwareTrainingConfig -@register_quantize_module_handler(FromIntXQuantizationAwareTrainingConfig) +@register_quantize_module_handler(FromQuantizationAwareTrainingConfig) def _from_intx_quantization_aware_training_transform( mod: torch.nn.Module, - config: FromIntXQuantizationAwareTrainingConfig, + config: FromQuantizationAwareTrainingConfig, ) -> torch.nn.Module: """ If the given module is a fake quantized module, return the original diff --git a/torchao/quantization/qat/embedding.py b/torchao/quantization/qat/embedding.py index aec23712ed..778ba2b83c 100644 --- a/torchao/quantization/qat/embedding.py +++ b/torchao/quantization/qat/embedding.py @@ -13,7 +13,10 @@ from torchao.quantization.unified import TwoStepQuantizer from torchao.quantization.utils import get_group_qparams_symmetric -from .api import FakeQuantizeConfig +from .fake_quantize_config import ( + FakeQuantizeConfigBase, + IntxFakeQuantizeConfig, +) from .fake_quantizer import FakeQuantizer from .utils import ( _get_qmin_qmax, @@ -29,7 +32,7 @@ class FakeQuantizedEmbedding(torch.nn.Embedding): Example usage:: - weight_config = FakeQuantizeConfig( + weight_config = IntxFakeQuantizeConfig( dtype=torch.int4, group_size=8, symmetric=True, @@ -47,7 +50,7 @@ def __init__( norm_type: float = 2.0, scale_grad_by_freq: bool = False, sparse: bool = False, - weight_config: Optional[FakeQuantizeConfig] = None, + weight_config: Optional[FakeQuantizeConfigBase] = None, *args, **kwargs, ) -> None: @@ -105,7 +108,7 @@ def to_embedding(self) -> torch.nn.Embedding: def from_embedding( cls, mod: torch.nn.Embedding, - weight_config: Optional[FakeQuantizeConfig] = None, + weight_config: Optional[FakeQuantizeConfigBase] = None, ): new_embedding = FakeQuantizedEmbedding( mod.num_embeddings, @@ -285,7 +288,7 @@ def __init__( *args, **kwargs, ): - weight_config = FakeQuantizeConfig( + weight_config = IntxFakeQuantizeConfig( dtype=TorchAODType.INT4, group_size=group_size, is_symmetric=True, diff --git a/torchao/quantization/qat/fake_quantize_config.py b/torchao/quantization/qat/fake_quantize_config.py new file mode 100644 index 0000000000..041c4e5395 --- /dev/null +++ b/torchao/quantization/qat/fake_quantize_config.py @@ -0,0 +1,262 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import abc +from dataclasses import dataclass +from typing import Any, Optional, Union + +import torch + +from torchao.quantization.granularity import ( + Granularity, + PerAxis, + PerGroup, + PerToken, +) +from torchao.quantization.quant_primitives import ( + _SUB_BYTE_INT_BOUNDS, + _SUB_BYTE_UINT_BOUNDS, + MappingType, + TorchAODType, + ZeroPointDomain, +) + + +@dataclass +class FakeQuantizeConfigBase(abc.ABC): + """ + Base class for representing fake quantization config. + """ + + pass + + +@dataclass +class IntxFakeQuantizeConfig(FakeQuantizeConfigBase): + """ + Config for how to fake quantize weights or activations. + + Args: + dtype: dtype to simulate during fake quantization, e.g. torch.int8. + For PyTorch versions older than 2.6, you may use `TorchAODType` to represent + torch.int1 to torch.int7 instead, e.g. TorchAODType.INT4. + granularity: granularity of scales and zero points, e.g. PerGroup(32). + We also support the following strings: + 1) 'per_token': equivalent to PerToken() + 2) 'per_channel': equivalent to PerAxis(0) + 3) 'per_group': equivalent to PerGroup(group_size), must be combined + with separate `group_size` kwarg, Alternatively, just set the + `group_size` kwarg and leave this field empty. + mapping_type: whether to use symmetric (default) or asymmetric quantization + Alternatively, set `is_symmetric` (bool) and leave this field empty. + scale_precision: scale dtype (default torch.fp32) + zero_point_precision: zero point dtype (default torch.int32) + zero_point_domain: whether zero point is in integer (default) or float domain + is_dynamic: whether to use dynamic (default) or static scale and zero points + range_learning (prototype): whether to learn scale and zero points during training + (default false), not compatible with `is_dynamic`. + + Keyword args: + group_size: size of each group in per group fake quantization, + can be set instead of `granularity` + is_symmetric: whether to use symmetric or asymmetric quantization, + can be set instead of `mapping_type` + + Example usage:: + + # Per token asymmetric quantization + FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) + FakeQuantizeConfig(torch.int8, PerToken(), MappingType.ASYMMETRIC) + + # Per channel symmetric quantization + FakeQuantizeConfig(torch.int4, "per_channel") + FakeQuantizeConfig(torch.int4, "per_channel", is_symmetric=True) + FakeQuantizeConfig(torch.int4, PerAxis(0), MappingType.SYMMETRIC) + + # Per group symmetric quantization + FakeQuantizeConfig(torch.int4, group_size=32) + FakeQuantizeConfig(torch.int4, group_size=32, is_symmetric=True) + FakeQuantizeConfig(torch.int4, "per_group", group_size=32, is_symmetric=True) + FakeQuantizeConfig(torch.int4, PerGroup(32), MappingType.SYMMETRIC) + """ + + dtype: Union[torch.dtype, TorchAODType] + granularity: Granularity + mapping_type: MappingType + scale_precision: torch.dtype + zero_point_precision: torch.dtype + zero_point_domain: ZeroPointDomain + is_dynamic: bool = True + range_learning: bool = False + eps: Optional[float] = None + + def __init__( + self, + dtype: Union[torch.dtype, TorchAODType], + granularity: Union[Granularity, str, None] = None, + mapping_type: Optional[MappingType] = None, + scale_precision: torch.dtype = torch.float32, + zero_point_precision: torch.dtype = torch.int32, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + is_dynamic: bool = True, + range_learning: bool = False, + eps: Optional[float] = None, + *, + group_size: Optional[int] = None, + is_symmetric: Optional[bool] = None, + ): + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") + self.dtype = dtype + self.granularity = self._get_granularity(granularity, group_size) + self.mapping_type = self._get_mapping_type(mapping_type, is_symmetric) + self.scale_precision = scale_precision + self.zero_point_precision = zero_point_precision + self.zero_point_domain = zero_point_domain + self.is_dynamic = is_dynamic + self.range_learning = range_learning + self.eps = eps + + # Validate dtype + all_dtypes = [torch.int8, torch.uint8] + all_dtypes.extend(list(_SUB_BYTE_INT_BOUNDS.keys())) + all_dtypes.extend(list(_SUB_BYTE_UINT_BOUNDS.keys())) + if dtype not in all_dtypes: + raise ValueError( + "Unsupported dtype '%s', choose from %s" % (dtype, all_dtypes) + ) + + # Dynamic is not compatible with range learning + if is_dynamic and range_learning: + raise ValueError("`is_dynamic` is not compatible with `range_learning`") + + def _get_granularity( + self, + granularity: Union[Granularity, str, None], + group_size: Optional[int], + ) -> Granularity: + """ + Parse the `Granularity` represented in the args. + + Granularity can be specified in one of three ways: + 1) `Granularity` object: one of PerToken(), PerAxis(), and PerGroup(group_size) + 2) str: one of 'per_token', 'per_channel', and 'per_group' + 3) None: `group_size` must be set instead, represents per group granularity + """ + # If group_size is set, then granularity must be either "per_group" or None + if ( + group_size is not None + and granularity != "per_group" + and granularity is not None + ): + raise ValueError( + "`group_size` conflicts with granularity '%s'" % granularity + ) + + # Case 1: Granularity object + if isinstance(granularity, Granularity): + if not isinstance(granularity, (PerToken, PerAxis, PerGroup)): + raise ValueError("Granularity '%s' is not supported" % granularity) + if isinstance(granularity, PerAxis) and granularity.axis != 0: + raise ValueError("Only axis=0 is supported for PerAxis granularity") + return granularity + + # Case 2: str granularity + if granularity == "per_token": + return PerToken() + elif granularity == "per_channel": + return PerAxis(axis=0) + elif granularity == "per_group": + if group_size is None: + raise ValueError( + "Granularity was 'per_group' but no `group_size` was set" + ) + return PerGroup(group_size) + elif isinstance(granularity, str): + raise ValueError( + "Unexpected granularity: '%s', must be one of %s" + % (granularity, ["per_token", "per_channel", "per_group"]) + ) + + # Case 3: None granularity + group_size was specified + if granularity is not None: + raise ValueError( + "Granularity '%s' has unexpected type %s" + % (granularity, type(granularity)) + ) + if group_size is None: + raise ValueError( + "At least one of `granularity` or `group_size` must be set" + ) + return PerGroup(group_size) + + def _get_mapping_type( + self, + mapping_type: Optional[MappingType], + is_symmetric: Optional[bool], + ) -> MappingType: + """ + Parse the `MappingType` represented in the args. + + Mapping type can be specified in one of two ways: + 1): `MappingType` object: one of SYMMETRIC or ASYMMETRIC + 2): is_symmetric bool + """ + if mapping_type is not None and is_symmetric is not None: + raise ValueError("Cannot set both `mapping_type` and `is_symmetric`") + + # Case 0: Default to symmetric + if mapping_type is None and is_symmetric is None: + return MappingType.SYMMETRIC + + # Case 1: MappingType object + if mapping_type is not None: + if mapping_type not in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC]: + raise ValueError("MappingType '%s' is not supported" % mapping_type) + return mapping_type + + # Case 2: is_symmetric flag + assert is_symmetric is not None + if is_symmetric: + return MappingType.SYMMETRIC + else: + return MappingType.ASYMMETRIC + + @property + def group_size(self) -> int: + """ + If this is per group granularity, return the group size. + Otherwise, throw an error. + """ + if isinstance(self.granularity, PerGroup): + return self.granularity.group_size + else: + raise ValueError( + "`group_size` is undefined for %s granularity" % self.granularity + ) + + @property + def is_symmetric(self) -> bool: + """ + Return True if mapping type is symmetric, else False (asymmetric). + """ + return self.mapping_type == MappingType.SYMMETRIC + + def __setattr__(self, name: str, value: Any): + """ + Support setting `group_size` and `is_symmetric`. + """ + if name == "group_size": + super().__setattr__("granularity", PerGroup(value)) + elif name == "is_symmetric": + mapping_type = MappingType.SYMMETRIC if value else MappingType.ASYMMETRIC + super().__setattr__("mapping_type", mapping_type) + else: + super().__setattr__(name, value) + + +# for BC +FakeQuantizeConfig = IntxFakeQuantizeConfig diff --git a/torchao/quantization/qat/fake_quantizer.py b/torchao/quantization/qat/fake_quantizer.py index b7ad792dc1..3cb873f3ff 100644 --- a/torchao/quantization/qat/fake_quantizer.py +++ b/torchao/quantization/qat/fake_quantizer.py @@ -26,8 +26,9 @@ get_groupwise_affine_qparams, ) -from .api import ( - FakeQuantizeConfig, +from .fake_quantize_config import ( + FakeQuantizeConfigBase, + IntxFakeQuantizeConfig, ) from .utils import ( _fake_quantize_per_channel_group, @@ -41,7 +42,7 @@ class FakeQuantizer(torch.nn.Module): Generic module for applying fake quantization to a tensor, as specified in the config. """ - def __init__(self, config: FakeQuantizeConfig): + def __init__(self, config: FakeQuantizeConfigBase): super().__init__() self.config = config self.enabled = True @@ -61,6 +62,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if not self.enabled: return x + if not isinstance(self.config, IntxFakeQuantizeConfig): + raise ValueError("Only IntxFakeQuantizeConfig is supported currently") + if ( self.config.range_learning and not self._initialized diff --git a/torchao/quantization/qat/linear.py b/torchao/quantization/qat/linear.py index 02b48fc5e3..c9c8f8ea5d 100644 --- a/torchao/quantization/qat/linear.py +++ b/torchao/quantization/qat/linear.py @@ -27,7 +27,10 @@ from torchao.quantization.utils import get_group_qparams_symmetric from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 -from .api import FakeQuantizeConfig +from .fake_quantize_config import ( + FakeQuantizeConfigBase, + IntxFakeQuantizeConfig, +) from .fake_quantizer import ( FakeQuantizer, _Float8RowwiseActivationFakeQuantizer, @@ -46,12 +49,12 @@ class FakeQuantizedLinear(torch.nn.Linear): Example usage:: - activation_config = FakeQuantizeConfig( + activation_config = IntxFakeQuantizeConfig( dtype=torch.int8, granularity="per_token", is_symmetric=False, ) - weight_config = FakeQuantizeConfig( + weight_config = IntxFakeQuantizeConfig( dtype=torch.int4, group_size=8, is_symmetric=True, @@ -67,8 +70,8 @@ def __init__( in_features: int, out_features: int, bias: bool = False, - activation_config: Optional[FakeQuantizeConfig] = None, - weight_config: Optional[FakeQuantizeConfig] = None, + activation_config: Optional[FakeQuantizeConfigBase] = None, + weight_config: Optional[FakeQuantizeConfigBase] = None, *args, **kwargs, ) -> None: @@ -127,8 +130,8 @@ def to_linear(self) -> torch.nn.Linear: def from_linear( cls, mod: torch.nn.Linear, - activation_config: Optional[FakeQuantizeConfig] = None, - weight_config: Optional[FakeQuantizeConfig] = None, + activation_config: Optional[FakeQuantizeConfigBase] = None, + weight_config: Optional[FakeQuantizeConfigBase] = None, ): new_linear = FakeQuantizedLinear( mod.in_features, @@ -179,10 +182,10 @@ class _LegacyQATQuantizer(TwoStepQuantizer): Base class for sharing common methods across legacy QAT quantizers. """ - def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: + def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]: return None - def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: + def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]: return None @@ -281,10 +284,10 @@ def _convert_qat_linear_8da4w(self, module: torch.nn.Module): else: self._convert_qat_linear_8da4w(child) - def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: + def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]: return _get_8da4w_activation_config(self.activation_scales_precision) - def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: + def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]: return _get_8da4w_weight_config(self.groupsize, self.scales_precision) @@ -354,13 +357,15 @@ def disable_8da4w_fake_quant(mod: torch.nn.Module): mod.disable_fake_quant() -def _get_8da4w_activation_config(qparams_precision: torch.dtype) -> FakeQuantizeConfig: +def _get_8da4w_activation_config( + qparams_precision: torch.dtype, +) -> IntxFakeQuantizeConfig: """ - Return the activation `FakeQuantizeConfig` for `Int8DynActInt4WeightQATQuantizer`. + Return the activation `IntxFakeQuantizeConfig` for `Int8DynActInt4WeightQATQuantizer`. """ # TODO: generalize this assert qparams_precision == torch.float32 - return FakeQuantizeConfig( + return IntxFakeQuantizeConfig( dtype=torch.int8, granularity="per_token", is_symmetric=False, @@ -374,11 +379,11 @@ def _get_8da4w_activation_config(qparams_precision: torch.dtype) -> FakeQuantize def _get_8da4w_weight_config( group_size: int, qparams_precision: torch.dtype, -) -> FakeQuantizeConfig: +) -> IntxFakeQuantizeConfig: """ - Return the weight `FakeQuantizeConfig` for `Int8DynActInt4WeightQATQuantizer`. + Return the weight `IntxFakeQuantizeConfig` for `Int8DynActInt4WeightQATQuantizer`. """ - return FakeQuantizeConfig( + return IntxFakeQuantizeConfig( dtype=TorchAODType.INT4, group_size=group_size, is_symmetric=True, @@ -482,7 +487,7 @@ def _convert_qat_linear_4w(self, module: torch.nn.Module): else: self._convert_qat_linear_4w(child) - def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: + def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]: return _get_4w_weight_config(self.groupsize, self.scales_precision) @@ -553,11 +558,11 @@ def disable_4w_fake_quant(mod: torch.nn.Module): def _get_4w_weight_config( group_size: int, qparams_precision: torch.dtype, -) -> FakeQuantizeConfig: +) -> IntxFakeQuantizeConfig: """ - Return the weight `FakeQuantizeConfig` for `Int4WeightOnlyQATQuantizer`. + Return the weight `IntxFakeQuantizeConfig` for `Int4WeightOnlyQATQuantizer`. """ - return FakeQuantizeConfig( + return IntxFakeQuantizeConfig( dtype=torch.uint4, group_size=group_size, is_symmetric=False, @@ -595,7 +600,7 @@ def __init__( weight_granularity = "per_group" else: weight_granularity = "per_channel" - self._weight_config = FakeQuantizeConfig( + self._weight_config = IntxFakeQuantizeConfig( dtype=torch.int4, granularity=weight_granularity, group_size=group_size, @@ -632,8 +637,8 @@ def convert( ) -> torch.nn.Module: raise NotImplementedError - def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: + def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]: raise NotImplementedError("Float8 FakeQuantizeConfig does not exist yet") - def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: + def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]: return self.weight_config