Skip to content

Commit d86ae25

Browse files
authored
Allow per-group quantizers in QuantOptimizer, fix state_dict (#2743)
* Allow per-group quantizers in QuantOptimizer * Switch PARQ to new QAT API * Relax constraints in QATConfig * Fix parq import in QuantOptimizer * Update Int8DynActOnlyConfig * Move activation-only config to PARQ prototype * Further simplify activation-only
1 parent 6794ef5 commit d86ae25

File tree

7 files changed

+66
-67
lines changed

7 files changed

+66
-67
lines changed

test/prototype/test_parq.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,7 @@
2929
from torchao.prototype.parq.quant.quant_api import StretchedIntxWeightOnlyConfig
3030
from torchao.prototype.parq.quant.uniform_torchao import _BIT_WIDTH_TO_DTYPE
3131
from torchao.quantization.granularity import PerGroup
32-
from torchao.quantization.qat import (
33-
FromIntXQuantizationAwareTrainingConfig,
34-
IntxFakeQuantizeConfig,
35-
IntXQuantizationAwareTrainingConfig,
36-
)
32+
from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig
3733
from torchao.quantization.quant_api import (
3834
Int8DynamicActivationIntxWeightConfig,
3935
IntxWeightOnlyConfig,
@@ -63,9 +59,18 @@ def get_param_groups(model):
6359
return params_quant, params_no_quant
6460

6561

66-
def build_param_groups(model, b: int = 2, group_size: Optional[int] = None):
62+
def build_param_groups(
63+
model,
64+
b: int = 2,
65+
group_size: Optional[int] = None,
66+
quant_cls_name: Optional[str] = None,
67+
):
6768
params_quant, params_no_quant = split_param_groups(model)
68-
quant_kwargs = {"quant_block_size": group_size} if group_size else {}
69+
quant_kwargs = {}
70+
if group_size:
71+
quant_kwargs["quant_block_size"] = group_size
72+
if quant_cls_name is not None:
73+
quant_kwargs["quant_cls"] = quant_cls_name
6974
return [
7075
{"params": params_quant, "quant_bits": b, **quant_kwargs},
7176
{"params": params_no_quant},
@@ -164,15 +169,19 @@ def setUp(self):
164169
@common_utils.parametrize("b", [0, 1, 2, 4])
165170
@common_utils.parametrize("unif_quant", [True, False])
166171
@common_utils.parametrize("hard_prox", [True, False])
167-
def test_parq_train_loop(self, b: int = 2, unif_quant=True, hard_prox=True):
172+
@common_utils.parametrize("per_group_quant_cls", [True, False])
173+
def test_parq_train_loop(
174+
self, b: int = 2, unif_quant=True, hard_prox=True, per_group_quant_cls=False
175+
):
168176
self.model.reset_parameters()
169-
param_groups = build_param_groups(self.model, b)
170-
base_optimizer = torch.optim.AdamW(param_groups)
171-
172177
if unif_quant:
173178
quantizer = TernaryUnifQuantizer() if b == 0 else UnifQuantizer()
174179
else:
175180
quantizer = LSBQuantizer()
181+
quant_cls_name = quantizer.__class__.__name__ if per_group_quant_cls else None
182+
param_groups = build_param_groups(self.model, b, quant_cls_name=quant_cls_name)
183+
base_optimizer = torch.optim.AdamW(param_groups)
184+
176185
prox_map = (
177186
ProxHardQuant() if hard_prox else ProxPARQ(anneal_start=0, anneal_end=2)
178187
)
@@ -283,7 +292,7 @@ def test_intx_weight_only_parq_equivalent(self, b: int = 2, group_size: int = 32
283292
quantizer_ref = UnifQuantizer()
284293
quantizer = StretchedUnifTorchaoQuantizer(b)
285294

286-
for n, module in model.named_children():
295+
for module in model.children():
287296
if not _is_linear(module):
288297
continue
289298

@@ -383,24 +392,18 @@ def test_int8_dynamic_activation_intx_e2e(
383392

384393
# apply torchao quantized activations on top
385394
activation_config = IntxFakeQuantizeConfig(
386-
torch.int8,
387-
granularity="per_token",
388-
mapping_type=config.act_mapping_type,
395+
torch.int8, "per_token", is_symmetric=False
389396
)
397+
qat_config = QATConfig(activation_config=activation_config, step="prepare")
390398
filter_fn = optimizer.get_filter_fn(model)
391-
quantize_(
392-
model,
393-
IntXQuantizationAwareTrainingConfig(activation_config=activation_config),
394-
filter_fn=filter_fn,
395-
)
399+
quantize_(model, qat_config, filter_fn=filter_fn)
396400
out = model(x)
397401
torch.testing.assert_close(out, ref_out, atol=0, rtol=0)
398402

399403
# equivalent to torchao's convert step
400404
model.eval()
401405
optimizer.restore_latent_params()
402-
quantize_(model, FromIntXQuantizationAwareTrainingConfig(), filter_fn=filter_fn)
403-
quantize_(model, config, filter_fn=filter_fn)
406+
quantize_(model, QATConfig(config, step="convert"), filter_fn=filter_fn)
404407
converted_out = model(x)
405408
torch.testing.assert_close(converted_out, ref_out, atol=0, rtol=0)
406409

test/quantization/test_qat.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1250,9 +1250,10 @@ def test_qat_config_init(self):
12501250
):
12511251
QATConfig(base_config, None, None, "prepare")
12521252

1253-
# No configs are provided
1253+
# No configs were provided in prepare step
12541254
with self.assertRaisesRegex(
1255-
ValueError, "One of `base_config` or `weight_config` must be specified"
1255+
ValueError,
1256+
"Must specify `base_config`, `activation_config`, or `weight_config` in the prepare step",
12561257
):
12571258
QATConfig(step="prepare")
12581259

torchao/prototype/parq/optim/quantopt.py

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import json
78
from collections import defaultdict
89
from collections.abc import Callable
910
from functools import partial
@@ -13,8 +14,10 @@
1314
from torch import Tensor
1415
from torch.optim import Optimizer
1516

17+
import torchao.prototype.parq as parq
18+
1619
from ..quant import Quantizer
17-
from ..utils import HAS_DTENSOR, is_dtensor
20+
from ..utils import HAS_DTENSOR, instantiate_module, is_dtensor
1821
from .proxmap import ProxMap
1922

2023
if HAS_DTENSOR:
@@ -133,27 +136,6 @@ def _filter_fn(module: torch.nn.Module, *args) -> bool:
133136

134137
return _filter_fn
135138

136-
@torch._disable_dynamo
137-
def state_dict(self) -> dict[str, Any]:
138-
state_dict = self.base_optimizer.state_dict()
139-
state_dict["qat_state"] = {"num_steps": self.num_steps}
140-
# quantizer and prox_map may also need to save states, can add here
141-
return state_dict
142-
143-
@torch._disable_dynamo
144-
def load_state_dict(
145-
self, state_dict: dict[str, Any], start_step: Optional[int] = None
146-
) -> None:
147-
qat_state = state_dict.get("qat_state")
148-
# resume from check points usually not corresponds to saved num_steps
149-
# so allow explicit start_step computed from epochs * steps_per_epoc
150-
if start_step is not None:
151-
self.num_steps = start_step
152-
elif qat_state is not None:
153-
# hope discrepancy in num_steps does not cause major problem!
154-
self.num_steps = qat_state["num_steps"]
155-
self.base_optimizer.load_state_dict(state_dict)
156-
157139
@torch.no_grad()
158140
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
159141
"""Performs a single optimization step.
@@ -191,6 +173,18 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
191173
quant_update = False
192174

193175
for group in self.regularized_param_groups():
176+
# Override quantizer if specified in the group
177+
if "quant_cls" in group:
178+
quant_cls = instantiate_module(
179+
f"{parq.__name__}.quant", group["quant_cls"]
180+
)
181+
quant_kwargs = (
182+
json.loads(group["quant_kwargs"]) if "quant_kwargs" in group else {}
183+
)
184+
quantizer = quant_cls(**quant_kwargs)
185+
else:
186+
quantizer = self.quantizer
187+
194188
# AProx in practice: ensure shrinkage coefficient >= 1
195189
group["cumu_lr"] += group["lr"]
196190
gamma = max(1.0, group["cumu_lr"])
@@ -224,7 +218,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
224218
# update quantization targets periodically
225219
per_channel = self.quant_per_channel and p.dim() > 1
226220
if quant_update:
227-
quant_size = self.quantizer.get_quant_size(b)
221+
quant_size = quantizer.get_quant_size(b)
228222

229223
if per_channel:
230224
quant_size = (p.size(0), quant_size)
@@ -242,9 +236,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
242236

243237
q = None
244238
if quant_update:
245-
qfunc = partial(
246-
self.quantize_, quantizer=self.quantizer, b=b, dim=dim
247-
)
239+
qfunc = partial(self.quantize_, quantizer=quantizer, b=b, dim=dim)
248240
if is_dtensor(p):
249241
qfunc = local_map(
250242
qfunc,

torchao/prototype/parq/quant/quant_api.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,17 @@
1111
from torch import nn
1212

1313
from torchao.dtypes import AffineQuantizedTensor, Layout, QDQLayout
14-
from torchao.quantization.granularity import PerAxis, PerGroup
14+
from torchao.quantization import (
15+
MappingType,
16+
PerAxis,
17+
PerGroup,
18+
ZeroPointDomain,
19+
dequantize_affine,
20+
)
1521
from torchao.quantization.quant_api import IntxWeightOnlyConfig
1622
from torchao.quantization.quant_primitives import (
1723
_SUB_BYTE_UINT_BOUNDS,
18-
MappingType,
19-
ZeroPointDomain,
2024
_get_reduction_params,
21-
dequantize_affine,
2225
)
2326
from torchao.quantization.transform_module import register_quantize_module_handler
2427

torchao/prototype/parq/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from importlib import import_module
8+
79
import torch
810
from torch import Tensor
911

@@ -15,6 +17,10 @@
1517
HAS_DTENSOR = False
1618

1719

20+
def instantiate_module(module_path, module_suffix):
21+
return getattr(import_module(module_path), module_suffix)
22+
23+
1824
def is_dtensor(x):
1925
return HAS_DTENSOR and isinstance(x, DTensor)
2026

torchao/quantization/qat/api.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ class QATConfig(AOBaseConfig):
114114
Raises:
115115
ValueError: If `base_config` and `activation_config` are both specified
116116
ValueError: If `base_config` and `weight_config` are both specified
117-
ValueError: If neither `base_config` nor `weight_config` is specified
118-
and `step` is "prepare"
117+
ValueError: If none of `base_config`, `activation_config`, or
118+
`weight_config` are specified
119119
ValueError: If either `activation_config` or `weight_config` is specified
120120
and `step` is "convert"
121121
ValueError: If `step` is not one of "prepare" or "convert"
@@ -157,14 +157,13 @@ def __post_init__(self):
157157
)
158158
if self.base_config is not None and self.weight_config is not None:
159159
raise ValueError("Cannot specify both `base_config` and `weight_config`")
160-
if (
161-
self.step == QATStep.PREPARE
162-
and self.base_config is None
163-
and self.weight_config is None
160+
if self.step == QATStep.PREPARE and not any(
161+
(self.base_config, self.activation_config, self.weight_config)
164162
):
165163
raise ValueError(
166-
"One of `base_config` or `weight_config` must be specified in the prepare step"
164+
"Must specify `base_config`, `activation_config`, or `weight_config` in the prepare step"
167165
)
166+
168167
if self.step == QATStep.CONVERT and (
169168
self.activation_config is not None or self.weight_config is not None
170169
):
@@ -207,9 +206,6 @@ def _qat_config_transform(
207206
else:
208207
act_config = config.activation_config
209208
weight_config = config.weight_config
210-
assert config.weight_config is not None, (
211-
"`base_config` and `weight_config` were both None in the prepare step"
212-
)
213209
if isinstance(module, torch.nn.Linear):
214210
return FakeQuantizedLinear.from_linear(module, act_config, weight_config)
215211
elif isinstance(module, torch.nn.Embedding):

torchao/quantization/quant_api.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@
2626
import torch.nn.utils.parametrize as parametrize
2727

2828
import torchao
29-
from torchao.core.config import (
30-
AOBaseConfig,
31-
)
29+
from torchao.core.config import AOBaseConfig
3230
from torchao.dtypes import (
3331
AffineQuantizedTensor,
3432
CutlassInt4PackedLayout,

0 commit comments

Comments
 (0)