diff --git a/test/prototype/test_parq.py b/test/prototype/test_parq.py index 85a6e2b0c2..ca4ca2e44e 100644 --- a/test/prototype/test_parq.py +++ b/test/prototype/test_parq.py @@ -29,11 +29,7 @@ from torchao.prototype.parq.quant.quant_api import StretchedIntxWeightOnlyConfig from torchao.prototype.parq.quant.uniform_torchao import _BIT_WIDTH_TO_DTYPE from torchao.quantization.granularity import PerGroup -from torchao.quantization.qat import ( - FromIntXQuantizationAwareTrainingConfig, - IntxFakeQuantizeConfig, - IntXQuantizationAwareTrainingConfig, -) +from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig from torchao.quantization.quant_api import ( Int8DynamicActivationIntxWeightConfig, IntxWeightOnlyConfig, @@ -63,9 +59,18 @@ def get_param_groups(model): return params_quant, params_no_quant -def build_param_groups(model, b: int = 2, group_size: Optional[int] = None): +def build_param_groups( + model, + b: int = 2, + group_size: Optional[int] = None, + quant_cls_name: Optional[str] = None, +): params_quant, params_no_quant = split_param_groups(model) - quant_kwargs = {"quant_block_size": group_size} if group_size else {} + quant_kwargs = {} + if group_size: + quant_kwargs["quant_block_size"] = group_size + if quant_cls_name is not None: + quant_kwargs["quant_cls"] = quant_cls_name return [ {"params": params_quant, "quant_bits": b, **quant_kwargs}, {"params": params_no_quant}, @@ -164,15 +169,19 @@ def setUp(self): @common_utils.parametrize("b", [0, 1, 2, 4]) @common_utils.parametrize("unif_quant", [True, False]) @common_utils.parametrize("hard_prox", [True, False]) - def test_parq_train_loop(self, b: int = 2, unif_quant=True, hard_prox=True): + @common_utils.parametrize("per_group_quant_cls", [True, False]) + def test_parq_train_loop( + self, b: int = 2, unif_quant=True, hard_prox=True, per_group_quant_cls=False + ): self.model.reset_parameters() - param_groups = build_param_groups(self.model, b) - base_optimizer = torch.optim.AdamW(param_groups) - if unif_quant: quantizer = TernaryUnifQuantizer() if b == 0 else UnifQuantizer() else: quantizer = LSBQuantizer() + quant_cls_name = quantizer.__class__.__name__ if per_group_quant_cls else None + param_groups = build_param_groups(self.model, b, quant_cls_name=quant_cls_name) + base_optimizer = torch.optim.AdamW(param_groups) + prox_map = ( ProxHardQuant() if hard_prox else ProxPARQ(anneal_start=0, anneal_end=2) ) @@ -283,7 +292,7 @@ def test_intx_weight_only_parq_equivalent(self, b: int = 2, group_size: int = 32 quantizer_ref = UnifQuantizer() quantizer = StretchedUnifTorchaoQuantizer(b) - for n, module in model.named_children(): + for module in model.children(): if not _is_linear(module): continue @@ -383,24 +392,18 @@ def test_int8_dynamic_activation_intx_e2e( # apply torchao quantized activations on top activation_config = IntxFakeQuantizeConfig( - torch.int8, - granularity="per_token", - mapping_type=config.act_mapping_type, + torch.int8, "per_token", is_symmetric=False ) + qat_config = QATConfig(activation_config=activation_config, step="prepare") filter_fn = optimizer.get_filter_fn(model) - quantize_( - model, - IntXQuantizationAwareTrainingConfig(activation_config=activation_config), - filter_fn=filter_fn, - ) + quantize_(model, qat_config, filter_fn=filter_fn) out = model(x) torch.testing.assert_close(out, ref_out, atol=0, rtol=0) # equivalent to torchao's convert step model.eval() optimizer.restore_latent_params() - quantize_(model, FromIntXQuantizationAwareTrainingConfig(), filter_fn=filter_fn) - quantize_(model, config, filter_fn=filter_fn) + quantize_(model, QATConfig(config, step="convert"), filter_fn=filter_fn) converted_out = model(x) torch.testing.assert_close(converted_out, ref_out, atol=0, rtol=0) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index bb613d2c99..250ee44cfd 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -1239,9 +1239,10 @@ def test_qat_config_init(self): ): QATConfig(base_config, None, None, "prepare") - # No configs are provided + # No configs were provided in prepare step with self.assertRaisesRegex( - ValueError, "One of `base_config` or `weight_config` must be specified" + ValueError, + "Must specify `base_config`, `activation_config`, or `weight_config` in the prepare step", ): QATConfig(step="prepare") diff --git a/torchao/prototype/parq/optim/quantopt.py b/torchao/prototype/parq/optim/quantopt.py index 2cdd34536d..ce43b86c1c 100644 --- a/torchao/prototype/parq/optim/quantopt.py +++ b/torchao/prototype/parq/optim/quantopt.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import json from collections import defaultdict from collections.abc import Callable from functools import partial @@ -13,8 +14,10 @@ from torch import Tensor from torch.optim import Optimizer +import torchao.prototype.parq as parq + from ..quant import Quantizer -from ..utils import HAS_DTENSOR, is_dtensor +from ..utils import HAS_DTENSOR, instantiate_module, is_dtensor from .proxmap import ProxMap if HAS_DTENSOR: @@ -133,27 +136,6 @@ def _filter_fn(module: torch.nn.Module, *args) -> bool: return _filter_fn - @torch._disable_dynamo - def state_dict(self) -> dict[str, Any]: - state_dict = self.base_optimizer.state_dict() - state_dict["qat_state"] = {"num_steps": self.num_steps} - # quantizer and prox_map may also need to save states, can add here - return state_dict - - @torch._disable_dynamo - def load_state_dict( - self, state_dict: dict[str, Any], start_step: Optional[int] = None - ) -> None: - qat_state = state_dict.get("qat_state") - # resume from check points usually not corresponds to saved num_steps - # so allow explicit start_step computed from epochs * steps_per_epoc - if start_step is not None: - self.num_steps = start_step - elif qat_state is not None: - # hope discrepancy in num_steps does not cause major problem! - self.num_steps = qat_state["num_steps"] - self.base_optimizer.load_state_dict(state_dict) - @torch.no_grad() def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: """Performs a single optimization step. @@ -191,6 +173,18 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] quant_update = False for group in self.regularized_param_groups(): + # Override quantizer if specified in the group + if "quant_cls" in group: + quant_cls = instantiate_module( + f"{parq.__name__}.quant", group["quant_cls"] + ) + quant_kwargs = ( + json.loads(group["quant_kwargs"]) if "quant_kwargs" in group else {} + ) + quantizer = quant_cls(**quant_kwargs) + else: + quantizer = self.quantizer + # AProx in practice: ensure shrinkage coefficient >= 1 group["cumu_lr"] += group["lr"] gamma = max(1.0, group["cumu_lr"]) @@ -224,7 +218,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] # update quantization targets periodically per_channel = self.quant_per_channel and p.dim() > 1 if quant_update: - quant_size = self.quantizer.get_quant_size(b) + quant_size = quantizer.get_quant_size(b) if per_channel: quant_size = (p.size(0), quant_size) @@ -242,9 +236,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] q = None if quant_update: - qfunc = partial( - self.quantize_, quantizer=self.quantizer, b=b, dim=dim - ) + qfunc = partial(self.quantize_, quantizer=quantizer, b=b, dim=dim) if is_dtensor(p): qfunc = local_map( qfunc, diff --git a/torchao/prototype/parq/quant/quant_api.py b/torchao/prototype/parq/quant/quant_api.py index 47dabb73f6..4ea2500ecb 100644 --- a/torchao/prototype/parq/quant/quant_api.py +++ b/torchao/prototype/parq/quant/quant_api.py @@ -11,14 +11,17 @@ from torch import nn from torchao.dtypes import AffineQuantizedTensor, Layout, QDQLayout -from torchao.quantization.granularity import PerAxis, PerGroup +from torchao.quantization import ( + MappingType, + PerAxis, + PerGroup, + ZeroPointDomain, + dequantize_affine, +) from torchao.quantization.quant_api import IntxWeightOnlyConfig from torchao.quantization.quant_primitives import ( _SUB_BYTE_UINT_BOUNDS, - MappingType, - ZeroPointDomain, _get_reduction_params, - dequantize_affine, ) from torchao.quantization.transform_module import register_quantize_module_handler diff --git a/torchao/prototype/parq/utils.py b/torchao/prototype/parq/utils.py index ac5024fb5d..d4c0a603b6 100644 --- a/torchao/prototype/parq/utils.py +++ b/torchao/prototype/parq/utils.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +from importlib import import_module + import torch from torch import Tensor @@ -15,6 +17,10 @@ HAS_DTENSOR = False +def instantiate_module(module_path, module_suffix): + return getattr(import_module(module_path), module_suffix) + + def is_dtensor(x): return HAS_DTENSOR and isinstance(x, DTensor) diff --git a/torchao/quantization/qat/api.py b/torchao/quantization/qat/api.py index 5aa46548a2..b3d801574b 100644 --- a/torchao/quantization/qat/api.py +++ b/torchao/quantization/qat/api.py @@ -114,8 +114,8 @@ class QATConfig(AOBaseConfig): Raises: ValueError: If `base_config` and `activation_config` are both specified ValueError: If `base_config` and `weight_config` are both specified - ValueError: If neither `base_config` nor `weight_config` is specified - and `step` is "prepare" + ValueError: If none of `base_config`, `activation_config`, or + `weight_config` are specified ValueError: If either `activation_config` or `weight_config` is specified and `step` is "convert" ValueError: If `step` is not one of "prepare" or "convert" @@ -156,14 +156,13 @@ def __post_init__(self): ) if self.base_config is not None and self.weight_config is not None: raise ValueError("Cannot specify both `base_config` and `weight_config`") - if ( - self.step == QATStep.PREPARE - and self.base_config is None - and self.weight_config is None + if self.step == QATStep.PREPARE and not any( + (self.base_config, self.activation_config, self.weight_config) ): raise ValueError( - "One of `base_config` or `weight_config` must be specified in the prepare step" + "Must specify `base_config`, `activation_config`, or `weight_config` in the prepare step" ) + if self.step == QATStep.CONVERT and ( self.activation_config is not None or self.weight_config is not None ): @@ -206,9 +205,6 @@ def _qat_config_transform( else: act_config = config.activation_config weight_config = config.weight_config - assert config.weight_config is not None, ( - "`base_config` and `weight_config` were both None in the prepare step" - ) if isinstance(module, torch.nn.Linear): return FakeQuantizedLinear.from_linear(module, act_config, weight_config) elif isinstance(module, torch.nn.Embedding): diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index fea37376a5..71f6721026 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -26,9 +26,7 @@ import torch.nn.utils.parametrize as parametrize import torchao -from torchao.core.config import ( - AOBaseConfig, -) +from torchao.core.config import AOBaseConfig from torchao.dtypes import ( AffineQuantizedTensor, CutlassInt4PackedLayout,