diff --git a/test/prototype/test_parq.py b/test/prototype/test_parq.py index ca4ca2e44e..a25ce2301d 100644 --- a/test/prototype/test_parq.py +++ b/test/prototype/test_parq.py @@ -21,6 +21,7 @@ from torchao.prototype.parq.quant import ( Int4UnifTorchaoQuantizer, LSBQuantizer, + Quantizer, StretchedUnifTorchaoQuantizer, TernaryUnifQuantizer, UnifQuantizer, @@ -63,14 +64,14 @@ def build_param_groups( model, b: int = 2, group_size: Optional[int] = None, - quant_cls_name: Optional[str] = None, + quantizer: Optional[Quantizer] = None, ): params_quant, params_no_quant = split_param_groups(model) quant_kwargs = {} if group_size: quant_kwargs["quant_block_size"] = group_size - if quant_cls_name is not None: - quant_kwargs["quant_cls"] = quant_cls_name + if quantizer is not None: + quant_kwargs["quantizer"] = quantizer return [ {"params": params_quant, "quant_bits": b, **quant_kwargs}, {"params": params_no_quant}, @@ -169,17 +170,18 @@ def setUp(self): @common_utils.parametrize("b", [0, 1, 2, 4]) @common_utils.parametrize("unif_quant", [True, False]) @common_utils.parametrize("hard_prox", [True, False]) - @common_utils.parametrize("per_group_quant_cls", [True, False]) + @common_utils.parametrize("per_group_quantizer", [True, False]) def test_parq_train_loop( - self, b: int = 2, unif_quant=True, hard_prox=True, per_group_quant_cls=False + self, b: int = 2, unif_quant=True, hard_prox=True, per_group_quantizer=False ): self.model.reset_parameters() if unif_quant: quantizer = TernaryUnifQuantizer() if b == 0 else UnifQuantizer() else: quantizer = LSBQuantizer() - quant_cls_name = quantizer.__class__.__name__ if per_group_quant_cls else None - param_groups = build_param_groups(self.model, b, quant_cls_name=quant_cls_name) + param_groups = build_param_groups( + self.model, b, quantizer=quantizer if per_group_quantizer else None + ) base_optimizer = torch.optim.AdamW(param_groups) prox_map = ( diff --git a/torchao/prototype/parq/optim/quantopt.py b/torchao/prototype/parq/optim/quantopt.py index ce43b86c1c..194c0b0c67 100644 --- a/torchao/prototype/parq/optim/quantopt.py +++ b/torchao/prototype/parq/optim/quantopt.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -import json from collections import defaultdict from collections.abc import Callable from functools import partial @@ -14,10 +13,8 @@ from torch import Tensor from torch.optim import Optimizer -import torchao.prototype.parq as parq - from ..quant import Quantizer -from ..utils import HAS_DTENSOR, instantiate_module, is_dtensor +from ..utils import HAS_DTENSOR, is_dtensor from .proxmap import ProxMap if HAS_DTENSOR: @@ -136,6 +133,14 @@ def _filter_fn(module: torch.nn.Module, *args) -> bool: return _filter_fn + @torch._disable_dynamo + def state_dict(self) -> dict[str, Any]: + return self.base_optimizer.state_dict() + + @torch._disable_dynamo + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + self.base_optimizer.load_state_dict(state_dict) + @torch.no_grad() def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: """Performs a single optimization step. @@ -174,16 +179,8 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] for group in self.regularized_param_groups(): # Override quantizer if specified in the group - if "quant_cls" in group: - quant_cls = instantiate_module( - f"{parq.__name__}.quant", group["quant_cls"] - ) - quant_kwargs = ( - json.loads(group["quant_kwargs"]) if "quant_kwargs" in group else {} - ) - quantizer = quant_cls(**quant_kwargs) - else: - quantizer = self.quantizer + quantizer = group.get("quantizer", self.quantizer) + assert isinstance(quantizer, Quantizer), f"Invalid {quantizer=}" # AProx in practice: ensure shrinkage coefficient >= 1 group["cumu_lr"] += group["lr"] diff --git a/torchao/prototype/parq/quant/uniform_torchao.py b/torchao/prototype/parq/quant/uniform_torchao.py index 6d895452e8..56c4ad268d 100644 --- a/torchao/prototype/parq/quant/uniform_torchao.py +++ b/torchao/prototype/parq/quant/uniform_torchao.py @@ -142,7 +142,7 @@ def quantize( class StretchedUnifTorchaoQuantizer(UnifTorchaoQuantizer): - def __init__(self, b: int, int_shift: float = 0.5) -> None: + def __init__(self, b: int, int_shift: float = 0.5, **kwargs) -> None: quant_absmax = 2 ** (b - 1) - int_shift self.quant_min = -quant_absmax self.quant_max = quant_absmax @@ -152,6 +152,7 @@ def __init__(self, b: int, int_shift: float = 0.5) -> None: mapping_type=MappingType.ASYMMETRIC, quant_min=self.quant_min, quant_max=self.quant_max, + **kwargs, ) self._choose_qparams = partial(choose_qparams_stretched_affine, b=b)