Skip to content

Allow per-group quantizers in QuantOptimizer, fix state_dict #2743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 25 additions & 22 deletions test/prototype/test_parq.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,7 @@
from torchao.prototype.parq.quant.quant_api import StretchedIntxWeightOnlyConfig
from torchao.prototype.parq.quant.uniform_torchao import _BIT_WIDTH_TO_DTYPE
from torchao.quantization.granularity import PerGroup
from torchao.quantization.qat import (
FromIntXQuantizationAwareTrainingConfig,
IntxFakeQuantizeConfig,
IntXQuantizationAwareTrainingConfig,
)
from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig
from torchao.quantization.quant_api import (
Int8DynamicActivationIntxWeightConfig,
IntxWeightOnlyConfig,
Expand Down Expand Up @@ -63,9 +59,18 @@ def get_param_groups(model):
return params_quant, params_no_quant


def build_param_groups(model, b: int = 2, group_size: Optional[int] = None):
def build_param_groups(
model,
b: int = 2,
group_size: Optional[int] = None,
quant_cls_name: Optional[str] = None,
):
params_quant, params_no_quant = split_param_groups(model)
quant_kwargs = {"quant_block_size": group_size} if group_size else {}
quant_kwargs = {}
if group_size:
quant_kwargs["quant_block_size"] = group_size
if quant_cls_name is not None:
quant_kwargs["quant_cls"] = quant_cls_name
return [
{"params": params_quant, "quant_bits": b, **quant_kwargs},
{"params": params_no_quant},
Expand Down Expand Up @@ -164,15 +169,19 @@ def setUp(self):
@common_utils.parametrize("b", [0, 1, 2, 4])
@common_utils.parametrize("unif_quant", [True, False])
@common_utils.parametrize("hard_prox", [True, False])
def test_parq_train_loop(self, b: int = 2, unif_quant=True, hard_prox=True):
@common_utils.parametrize("per_group_quant_cls", [True, False])
def test_parq_train_loop(
self, b: int = 2, unif_quant=True, hard_prox=True, per_group_quant_cls=False
):
self.model.reset_parameters()
param_groups = build_param_groups(self.model, b)
base_optimizer = torch.optim.AdamW(param_groups)

if unif_quant:
quantizer = TernaryUnifQuantizer() if b == 0 else UnifQuantizer()
else:
quantizer = LSBQuantizer()
quant_cls_name = quantizer.__class__.__name__ if per_group_quant_cls else None
param_groups = build_param_groups(self.model, b, quant_cls_name=quant_cls_name)
base_optimizer = torch.optim.AdamW(param_groups)

prox_map = (
ProxHardQuant() if hard_prox else ProxPARQ(anneal_start=0, anneal_end=2)
)
Expand Down Expand Up @@ -283,7 +292,7 @@ def test_intx_weight_only_parq_equivalent(self, b: int = 2, group_size: int = 32
quantizer_ref = UnifQuantizer()
quantizer = StretchedUnifTorchaoQuantizer(b)

for n, module in model.named_children():
for module in model.children():
if not _is_linear(module):
continue

Expand Down Expand Up @@ -383,24 +392,18 @@ def test_int8_dynamic_activation_intx_e2e(

# apply torchao quantized activations on top
activation_config = IntxFakeQuantizeConfig(
torch.int8,
granularity="per_token",
mapping_type=config.act_mapping_type,
torch.int8, "per_token", is_symmetric=False
)
qat_config = QATConfig(activation_config=activation_config, step="prepare")
filter_fn = optimizer.get_filter_fn(model)
quantize_(
model,
IntXQuantizationAwareTrainingConfig(activation_config=activation_config),
filter_fn=filter_fn,
)
quantize_(model, qat_config, filter_fn=filter_fn)
out = model(x)
torch.testing.assert_close(out, ref_out, atol=0, rtol=0)

# equivalent to torchao's convert step
model.eval()
optimizer.restore_latent_params()
quantize_(model, FromIntXQuantizationAwareTrainingConfig(), filter_fn=filter_fn)
quantize_(model, config, filter_fn=filter_fn)
quantize_(model, QATConfig(config, step="convert"), filter_fn=filter_fn)
converted_out = model(x)
torch.testing.assert_close(converted_out, ref_out, atol=0, rtol=0)

Expand Down
5 changes: 3 additions & 2 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,9 +1239,10 @@ def test_qat_config_init(self):
):
QATConfig(base_config, None, None, "prepare")

# No configs are provided
# No configs were provided in prepare step
with self.assertRaisesRegex(
ValueError, "One of `base_config` or `weight_config` must be specified"
ValueError,
"Must specify `base_config`, `activation_config`, or `weight_config` in the prepare step",
):
QATConfig(step="prepare")

Expand Down
44 changes: 18 additions & 26 deletions torchao/prototype/parq/optim/quantopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import json
from collections import defaultdict
from collections.abc import Callable
from functools import partial
Expand All @@ -13,8 +14,10 @@
from torch import Tensor
from torch.optim import Optimizer

import torchao.prototype.parq as parq

from ..quant import Quantizer
from ..utils import HAS_DTENSOR, is_dtensor
from ..utils import HAS_DTENSOR, instantiate_module, is_dtensor
from .proxmap import ProxMap

if HAS_DTENSOR:
Expand Down Expand Up @@ -133,27 +136,6 @@ def _filter_fn(module: torch.nn.Module, *args) -> bool:

return _filter_fn

@torch._disable_dynamo
def state_dict(self) -> dict[str, Any]:
state_dict = self.base_optimizer.state_dict()
state_dict["qat_state"] = {"num_steps": self.num_steps}
# quantizer and prox_map may also need to save states, can add here
return state_dict

@torch._disable_dynamo
def load_state_dict(
self, state_dict: dict[str, Any], start_step: Optional[int] = None
) -> None:
qat_state = state_dict.get("qat_state")
# resume from check points usually not corresponds to saved num_steps
# so allow explicit start_step computed from epochs * steps_per_epoc
if start_step is not None:
self.num_steps = start_step
elif qat_state is not None:
# hope discrepancy in num_steps does not cause major problem!
self.num_steps = qat_state["num_steps"]
self.base_optimizer.load_state_dict(state_dict)

@torch.no_grad()
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
"""Performs a single optimization step.
Expand Down Expand Up @@ -191,6 +173,18 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
quant_update = False

for group in self.regularized_param_groups():
# Override quantizer if specified in the group
if "quant_cls" in group:
quant_cls = instantiate_module(
f"{parq.__name__}.quant", group["quant_cls"]
)
quant_kwargs = (
json.loads(group["quant_kwargs"]) if "quant_kwargs" in group else {}
)
quantizer = quant_cls(**quant_kwargs)
else:
quantizer = self.quantizer

# AProx in practice: ensure shrinkage coefficient >= 1
group["cumu_lr"] += group["lr"]
gamma = max(1.0, group["cumu_lr"])
Expand Down Expand Up @@ -224,7 +218,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
# update quantization targets periodically
per_channel = self.quant_per_channel and p.dim() > 1
if quant_update:
quant_size = self.quantizer.get_quant_size(b)
quant_size = quantizer.get_quant_size(b)

if per_channel:
quant_size = (p.size(0), quant_size)
Expand All @@ -242,9 +236,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]

q = None
if quant_update:
qfunc = partial(
self.quantize_, quantizer=self.quantizer, b=b, dim=dim
)
qfunc = partial(self.quantize_, quantizer=quantizer, b=b, dim=dim)
if is_dtensor(p):
qfunc = local_map(
qfunc,
Expand Down
11 changes: 7 additions & 4 deletions torchao/prototype/parq/quant/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@
from torch import nn

from torchao.dtypes import AffineQuantizedTensor, Layout, QDQLayout
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization import (
MappingType,
PerAxis,
PerGroup,
ZeroPointDomain,
dequantize_affine,
)
from torchao.quantization.quant_api import IntxWeightOnlyConfig
from torchao.quantization.quant_primitives import (
_SUB_BYTE_UINT_BOUNDS,
MappingType,
ZeroPointDomain,
_get_reduction_params,
dequantize_affine,
)
from torchao.quantization.transform_module import register_quantize_module_handler

Expand Down
6 changes: 6 additions & 0 deletions torchao/prototype/parq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from importlib import import_module

import torch
from torch import Tensor

Expand All @@ -15,6 +17,10 @@
HAS_DTENSOR = False


def instantiate_module(module_path, module_suffix):
return getattr(import_module(module_path), module_suffix)


def is_dtensor(x):
return HAS_DTENSOR and isinstance(x, DTensor)

Expand Down
16 changes: 6 additions & 10 deletions torchao/quantization/qat/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ class QATConfig(AOBaseConfig):
Raises:
ValueError: If `base_config` and `activation_config` are both specified
ValueError: If `base_config` and `weight_config` are both specified
ValueError: If neither `base_config` nor `weight_config` is specified
and `step` is "prepare"
ValueError: If none of `base_config`, `activation_config`, or
`weight_config` are specified
ValueError: If either `activation_config` or `weight_config` is specified
and `step` is "convert"
ValueError: If `step` is not one of "prepare" or "convert"
Expand Down Expand Up @@ -156,14 +156,13 @@ def __post_init__(self):
)
if self.base_config is not None and self.weight_config is not None:
raise ValueError("Cannot specify both `base_config` and `weight_config`")
if (
self.step == QATStep.PREPARE
and self.base_config is None
and self.weight_config is None
if self.step == QATStep.PREPARE and not any(
(self.base_config, self.activation_config, self.weight_config)
):
raise ValueError(
"One of `base_config` or `weight_config` must be specified in the prepare step"
"Must specify `base_config`, `activation_config`, or `weight_config` in the prepare step"
)

if self.step == QATStep.CONVERT and (
self.activation_config is not None or self.weight_config is not None
):
Expand Down Expand Up @@ -206,9 +205,6 @@ def _qat_config_transform(
else:
act_config = config.activation_config
weight_config = config.weight_config
assert config.weight_config is not None, (
"`base_config` and `weight_config` were both None in the prepare step"
)
if isinstance(module, torch.nn.Linear):
return FakeQuantizedLinear.from_linear(module, act_config, weight_config)
elif isinstance(module, torch.nn.Embedding):
Expand Down
4 changes: 1 addition & 3 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@
import torch.nn.utils.parametrize as parametrize

import torchao
from torchao.core.config import (
AOBaseConfig,
)
from torchao.core.config import AOBaseConfig
from torchao.dtypes import (
AffineQuantizedTensor,
CutlassInt4PackedLayout,
Expand Down
Loading