Skip to content

Commit 8cd6708

Browse files
committed
Switch PARQ to new QAT API
1 parent eac6c3d commit 8cd6708

File tree

4 files changed

+69
-23
lines changed

4 files changed

+69
-23
lines changed

test/prototype/test_parq.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,13 @@
2626
UnifQuantizer,
2727
UnifTorchaoQuantizer,
2828
)
29-
from torchao.prototype.parq.quant.quant_api import StretchedIntxWeightOnlyConfig
29+
from torchao.prototype.parq.quant.quant_api import (
30+
PARQATConfig,
31+
StretchedIntxWeightOnlyConfig,
32+
)
3033
from torchao.prototype.parq.quant.uniform_torchao import _BIT_WIDTH_TO_DTYPE
3134
from torchao.quantization.granularity import PerGroup
32-
from torchao.quantization.qat import (
33-
FromIntXQuantizationAwareTrainingConfig,
34-
IntxFakeQuantizeConfig,
35-
IntXQuantizationAwareTrainingConfig,
36-
)
35+
from torchao.quantization.qat import IntxFakeQuantizeConfig
3736
from torchao.quantization.quant_api import (
3837
Int8DynamicActivationIntxWeightConfig,
3938
IntxWeightOnlyConfig,
@@ -283,7 +282,7 @@ def test_intx_weight_only_parq_equivalent(self, b: int = 2, group_size: int = 32
283282
quantizer_ref = UnifQuantizer()
284283
quantizer = StretchedUnifTorchaoQuantizer(b)
285284

286-
for n, module in model.named_children():
285+
for module in model.children():
287286
if not _is_linear(module):
288287
continue
289288

@@ -382,24 +381,31 @@ def test_int8_dynamic_activation_intx_e2e(
382381
optimizer.step()
383382

384383
# apply torchao quantized activations on top
384+
base_config = None
385385
activation_config = IntxFakeQuantizeConfig(
386-
torch.int8,
387-
granularity="per_token",
388-
mapping_type=config.act_mapping_type,
386+
torch.int8, "per_token", is_symmetric=False
389387
)
390-
filter_fn = optimizer.get_filter_fn(model)
391-
quantize_(
392-
model,
393-
IntXQuantizationAwareTrainingConfig(activation_config=activation_config),
394-
filter_fn=filter_fn,
388+
qat_config = PARQATConfig(
389+
base_config,
390+
activation_config=activation_config,
391+
weight_config=None,
392+
step="prepare",
395393
)
394+
filter_fn = optimizer.get_filter_fn(model)
395+
quantize_(model, qat_config, filter_fn=filter_fn)
396396
out = model(x)
397397
torch.testing.assert_close(out, ref_out, atol=0, rtol=0)
398398

399399
# equivalent to torchao's convert step
400400
model.eval()
401401
optimizer.restore_latent_params()
402-
quantize_(model, FromIntXQuantizationAwareTrainingConfig(), filter_fn=filter_fn)
402+
qat_config = PARQATConfig(
403+
base_config,
404+
activation_config=activation_config,
405+
weight_config=None,
406+
step="convert",
407+
)
408+
quantize_(model, qat_config, filter_fn=filter_fn)
403409
quantize_(model, config, filter_fn=filter_fn)
404410
converted_out = model(x)
405411
torch.testing.assert_close(converted_out, ref_out, atol=0, rtol=0)

torchao/prototype/parq/optim/quantopt.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import json
78
from collections import defaultdict
89
from collections.abc import Callable
910
from functools import partial
@@ -14,7 +15,7 @@
1415
from torch.optim import Optimizer
1516

1617
from ..quant import Quantizer
17-
from ..utils import HAS_DTENSOR, is_dtensor
18+
from ..utils import HAS_DTENSOR, instantiate_module, is_dtensor
1819
from .proxmap import ProxMap
1920

2021
if HAS_DTENSOR:
@@ -172,9 +173,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
172173
for group in self.regularized_param_groups():
173174
# Override quantizer if specified in the group
174175
if "quant_cls" in group:
175-
quant_cls = instantiate_module(
176-
f"{parq.__name__}.quant", group["quant_cls"]
177-
)
176+
quant_cls = instantiate_module("..quant", group["quant_cls"])
178177
quant_kwargs = (
179178
json.loads(group["quant_kwargs"]) if "quant_kwargs" in group else {}
180179
)
@@ -201,9 +200,9 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
201200

202201
# reshape p according to block size if specified
203202
if block_size is not None:
204-
assert (
205-
p.size(-1) % block_size == 0
206-
), f"{p.size(-1)=} is not divisible by {block_size=}"
203+
assert p.size(-1) % block_size == 0, (
204+
f"{p.size(-1)=} is not divisible by {block_size=}"
205+
)
207206
assert p.dim() <= 2, f"Invalid {p.dim()=} for {block_size=}"
208207
if p.dim() == 1:
209208
p = p.unsqueeze(0)

torchao/prototype/parq/quant/quant_api.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@
1212

1313
from torchao.dtypes import AffineQuantizedTensor, Layout, QDQLayout
1414
from torchao.quantization.granularity import PerAxis, PerGroup
15+
from torchao.quantization.qat import (
16+
FakeQuantizedEmbedding,
17+
FakeQuantizedLinear,
18+
QATConfig,
19+
QATStep,
20+
)
21+
from torchao.quantization.qat.api import _qat_config_transform
1522
from torchao.quantization.quant_api import IntxWeightOnlyConfig
1623
from torchao.quantization.quant_primitives import (
1724
_SUB_BYTE_UINT_BOUNDS,
@@ -219,3 +226,31 @@ def _stretched_intx_weight_only_transform(
219226
)
220227
module.weight = torch.nn.Parameter(weight, requires_grad=False)
221228
return module
229+
230+
231+
@dataclass
232+
class PARQATConfig(QATConfig):
233+
def __post_init__(self):
234+
try:
235+
super().__post_init__()
236+
except ValueError as e:
237+
msg = str(e)
238+
if msg == "One of `base_config` or `weight_config` must be specified":
239+
pass
240+
else:
241+
raise e
242+
243+
244+
@register_quantize_module_handler(PARQATConfig)
245+
def _parq_config_transform(module: nn.Module, config: PARQATConfig) -> nn.Module:
246+
step = config.step
247+
if step == QATStep.PREPARE:
248+
return _qat_config_transform(module, config)
249+
elif step == QATStep.CONVERT:
250+
if isinstance(module, FakeQuantizedLinear):
251+
module = module.to_linear()
252+
elif isinstance(module, FakeQuantizedEmbedding):
253+
module = module.to_embedding()
254+
return module
255+
else:
256+
raise ValueError("unexpected {step=} in QATConfig")

torchao/prototype/parq/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from importlib import import_module
8+
79
import torch
810
from torch import Tensor
911

@@ -15,6 +17,10 @@
1517
HAS_DTENSOR = False
1618

1719

20+
def instantiate_module(module_path, module_suffix):
21+
return getattr(import_module(module_path), module_suffix)
22+
23+
1824
def is_dtensor(x):
1925
return HAS_DTENSOR and isinstance(x, DTensor)
2026

0 commit comments

Comments
 (0)