Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 26 additions & 23 deletions test/prototype/test_parq.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,9 @@
from torchao.prototype.parq.quant.quant_api import StretchedIntxWeightOnlyConfig
from torchao.prototype.parq.quant.uniform_torchao import _BIT_WIDTH_TO_DTYPE
from torchao.quantization.granularity import PerGroup
from torchao.quantization.qat import (
FromIntXQuantizationAwareTrainingConfig,
IntxFakeQuantizeConfig,
IntXQuantizationAwareTrainingConfig,
)
from torchao.quantization.qat import QATConfig
from torchao.quantization.quant_api import (
Int8DynActOnlyConfig,
Int8DynamicActivationIntxWeightConfig,
IntxWeightOnlyConfig,
_is_linear,
Expand Down Expand Up @@ -63,9 +60,18 @@ def get_param_groups(model):
return params_quant, params_no_quant


def build_param_groups(model, b: int = 2, group_size: Optional[int] = None):
def build_param_groups(
model,
b: int = 2,
group_size: Optional[int] = None,
quant_cls_name: Optional[str] = None,
):
params_quant, params_no_quant = split_param_groups(model)
quant_kwargs = {"quant_block_size": group_size} if group_size else {}
quant_kwargs = {}
if group_size:
quant_kwargs["quant_block_size"] = group_size
if quant_cls_name is not None:
quant_kwargs["quant_cls"] = quant_cls_name
return [
{"params": params_quant, "quant_bits": b, **quant_kwargs},
{"params": params_no_quant},
Expand Down Expand Up @@ -164,15 +170,19 @@ def setUp(self):
@common_utils.parametrize("b", [0, 1, 2, 4])
@common_utils.parametrize("unif_quant", [True, False])
@common_utils.parametrize("hard_prox", [True, False])
def test_parq_train_loop(self, b: int = 2, unif_quant=True, hard_prox=True):
@common_utils.parametrize("per_group_quant_cls", [True, False])
def test_parq_train_loop(
self, b: int = 2, unif_quant=True, hard_prox=True, per_group_quant_cls=False
):
self.model.reset_parameters()
param_groups = build_param_groups(self.model, b)
base_optimizer = torch.optim.AdamW(param_groups)

if unif_quant:
quantizer = TernaryUnifQuantizer() if b == 0 else UnifQuantizer()
else:
quantizer = LSBQuantizer()
quant_cls_name = quantizer.__class__.__name__ if per_group_quant_cls else None
param_groups = build_param_groups(self.model, b, quant_cls_name=quant_cls_name)
base_optimizer = torch.optim.AdamW(param_groups)

prox_map = (
ProxHardQuant() if hard_prox else ProxPARQ(anneal_start=0, anneal_end=2)
)
Expand Down Expand Up @@ -283,7 +293,7 @@ def test_intx_weight_only_parq_equivalent(self, b: int = 2, group_size: int = 32
quantizer_ref = UnifQuantizer()
quantizer = StretchedUnifTorchaoQuantizer(b)

for n, module in model.named_children():
for module in model.children():
if not _is_linear(module):
continue

Expand Down Expand Up @@ -382,24 +392,17 @@ def test_int8_dynamic_activation_intx_e2e(
optimizer.step()

# apply torchao quantized activations on top
activation_config = IntxFakeQuantizeConfig(
torch.int8,
granularity="per_token",
mapping_type=config.act_mapping_type,
)
qat_config = QATConfig(Int8DynActOnlyConfig(), step="prepare")
filter_fn = optimizer.get_filter_fn(model)
quantize_(
model,
IntXQuantizationAwareTrainingConfig(activation_config=activation_config),
filter_fn=filter_fn,
)
quantize_(model, qat_config, filter_fn=filter_fn)
out = model(x)
torch.testing.assert_close(out, ref_out, atol=0, rtol=0)

# equivalent to torchao's convert step
model.eval()
optimizer.restore_latent_params()
quantize_(model, FromIntXQuantizationAwareTrainingConfig(), filter_fn=filter_fn)
qat_config = QATConfig(step="convert")
quantize_(model, qat_config, filter_fn=filter_fn)
quantize_(model, config, filter_fn=filter_fn)
converted_out = model(x)
torch.testing.assert_close(converted_out, ref_out, atol=0, rtol=0)
Expand Down
44 changes: 18 additions & 26 deletions torchao/prototype/parq/optim/quantopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import json
from collections import defaultdict
from collections.abc import Callable
from functools import partial
Expand All @@ -13,8 +14,10 @@
from torch import Tensor
from torch.optim import Optimizer

import torchao.prototype.parq as parq

from ..quant import Quantizer
from ..utils import HAS_DTENSOR, is_dtensor
from ..utils import HAS_DTENSOR, instantiate_module, is_dtensor
from .proxmap import ProxMap

if HAS_DTENSOR:
Expand Down Expand Up @@ -133,27 +136,6 @@ def _filter_fn(module: torch.nn.Module, *args) -> bool:

return _filter_fn

@torch._disable_dynamo
def state_dict(self) -> dict[str, Any]:
state_dict = self.base_optimizer.state_dict()
state_dict["qat_state"] = {"num_steps": self.num_steps}
# quantizer and prox_map may also need to save states, can add here
return state_dict

@torch._disable_dynamo
def load_state_dict(
self, state_dict: dict[str, Any], start_step: Optional[int] = None
) -> None:
qat_state = state_dict.get("qat_state")
# resume from check points usually not corresponds to saved num_steps
# so allow explicit start_step computed from epochs * steps_per_epoc
if start_step is not None:
self.num_steps = start_step
elif qat_state is not None:
# hope discrepancy in num_steps does not cause major problem!
self.num_steps = qat_state["num_steps"]
self.base_optimizer.load_state_dict(state_dict)

@torch.no_grad()
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
"""Performs a single optimization step.
Expand Down Expand Up @@ -191,6 +173,18 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
quant_update = False

for group in self.regularized_param_groups():
# Override quantizer if specified in the group
if "quant_cls" in group:
quant_cls = instantiate_module(
f"{parq.__name__}.quant", group["quant_cls"]
)
quant_kwargs = (
json.loads(group["quant_kwargs"]) if "quant_kwargs" in group else {}
)
quantizer = quant_cls(**quant_kwargs)
else:
quantizer = self.quantizer

# AProx in practice: ensure shrinkage coefficient >= 1
group["cumu_lr"] += group["lr"]
gamma = max(1.0, group["cumu_lr"])
Expand Down Expand Up @@ -224,7 +218,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
# update quantization targets periodically
per_channel = self.quant_per_channel and p.dim() > 1
if quant_update:
quant_size = self.quantizer.get_quant_size(b)
quant_size = quantizer.get_quant_size(b)

if per_channel:
quant_size = (p.size(0), quant_size)
Expand All @@ -242,9 +236,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]

q = None
if quant_update:
qfunc = partial(
self.quantize_, quantizer=self.quantizer, b=b, dim=dim
)
qfunc = partial(self.quantize_, quantizer=quantizer, b=b, dim=dim)
if is_dtensor(p):
qfunc = local_map(
qfunc,
Expand Down
6 changes: 6 additions & 0 deletions torchao/prototype/parq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from importlib import import_module

import torch
from torch import Tensor

Expand All @@ -15,6 +17,10 @@
HAS_DTENSOR = False


def instantiate_module(module_path, module_suffix):
return getattr(import_module(module_path), module_suffix)


def is_dtensor(x):
return HAS_DTENSOR and isinstance(x, DTensor)

Expand Down
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
GemliteUIntXWeightOnlyConfig,
Int4DynamicActivationInt4WeightConfig,
Int4WeightOnlyConfig,
Int8DynActOnlyConfig,
Int8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt8WeightConfig,
Int8DynamicActivationIntxWeightConfig,
Expand Down Expand Up @@ -144,6 +145,7 @@
"Int8DynamicActivationIntxWeightConfig",
"Int4WeightOnlyConfig",
"Float8DynamicActivationInt4WeightConfig",
"Int8DynActOnlyConfig",
"Int8WeightOnlyConfig",
"Float8WeightOnlyConfig",
"Float8DynamicActivationFloat8WeightConfig",
Expand Down
8 changes: 8 additions & 0 deletions torchao/quantization/qat/fake_quantize_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ def _infer_fake_quantize_configs(
# avoid circular imports
from torchao.quantization import (
Int4WeightOnlyConfig,
Int8DynActOnlyConfig,
Int8DynamicActivationInt4WeightConfig,
)

Expand All @@ -315,5 +316,12 @@ def _infer_fake_quantize_configs(
zero_point_domain=base_config.zero_point_domain,
)
return (None, weight_config)
elif isinstance(base_config, Int8DynActOnlyConfig):
act_config = IntxFakeQuantizeConfig(
dtype=torch.int8,
granularity="per_token",
is_symmetric=base_config.is_symmetric,
)
return (act_config, None)
else:
raise ValueError("Unexpected base config: %s" % base_config)
26 changes: 26 additions & 0 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@
"gemlite_uintx_weight_only",
"float8_dynamic_activation_float8_weight",
"float8_static_activation_float8_weight",
"Int8DynActOnlyConfig",
"Int8DynActInt4WeightQuantizer",
"Float8DynamicActivationFloat8SemiSparseWeightConfig",
"ModuleFqnToConfig",
Expand Down Expand Up @@ -1312,6 +1313,31 @@ def _float8_cutlass_quant_sparse(
)


@dataclass
class Int8DynActOnlyConfig(AOBaseConfig):
"""
Configuration for applying int8 dynamic symmetric per-token activation quantization to linear layers.
Args:
is_symmetric: bool = False - Whether to use symmetric quantization for activations.
"""

is_symmetric: bool = False


@register_quantize_module_handler(Int8DynActOnlyConfig)
def _int8_dynamic_activation_transform(
module: torch.nn.Module, config: Int8DynActOnlyConfig
) -> torch.nn.Module:
weight = module.weight
if config.is_symmetric:
input_quant_func = _int8_symm_per_token_reduced_range_quant
else:
input_quant_func = _int8_asymm_per_token_quant
weight = to_linear_activation_quantized(weight, input_quant_func)
module.weight = torch.nn.Parameter(weight, requires_grad=False)
return module


@dataclass
class Int8DynamicActivationInt8WeightConfig(AOBaseConfig):
"""
Expand Down
Loading