Skip to content

Commit 45c42f1

Browse files
committed
Fix missing QuantOptimizer methods
1 parent 6a2d975 commit 45c42f1

File tree

3 files changed

+21
-22
lines changed

3 files changed

+21
-22
lines changed

test/prototype/test_parq.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,14 @@ def build_param_groups(
6363
model,
6464
b: int = 2,
6565
group_size: Optional[int] = None,
66-
quant_cls_name: Optional[str] = None,
66+
quantizer: Optional[callable] = None,
6767
):
6868
params_quant, params_no_quant = split_param_groups(model)
6969
quant_kwargs = {}
7070
if group_size:
7171
quant_kwargs["quant_block_size"] = group_size
72-
if quant_cls_name is not None:
73-
quant_kwargs["quant_cls"] = quant_cls_name
72+
if quantizer is not None:
73+
quant_kwargs["quantizer"] = quantizer
7474
return [
7575
{"params": params_quant, "quant_bits": b, **quant_kwargs},
7676
{"params": params_no_quant},
@@ -169,17 +169,18 @@ def setUp(self):
169169
@common_utils.parametrize("b", [0, 1, 2, 4])
170170
@common_utils.parametrize("unif_quant", [True, False])
171171
@common_utils.parametrize("hard_prox", [True, False])
172-
@common_utils.parametrize("per_group_quant_cls", [True, False])
172+
@common_utils.parametrize("per_group_quantizer", [True, False])
173173
def test_parq_train_loop(
174-
self, b: int = 2, unif_quant=True, hard_prox=True, per_group_quant_cls=False
174+
self, b: int = 2, unif_quant=True, hard_prox=True, per_group_quantizer=False
175175
):
176176
self.model.reset_parameters()
177177
if unif_quant:
178178
quantizer = TernaryUnifQuantizer() if b == 0 else UnifQuantizer()
179179
else:
180180
quantizer = LSBQuantizer()
181-
quant_cls_name = quantizer.__class__.__name__ if per_group_quant_cls else None
182-
param_groups = build_param_groups(self.model, b, quant_cls_name=quant_cls_name)
181+
param_groups = build_param_groups(
182+
self.model, b, quantizer=quantizer if per_group_quantizer else None
183+
)
183184
base_optimizer = torch.optim.AdamW(param_groups)
184185

185186
prox_map = (

torchao/prototype/parq/optim/quantopt.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import json
87
from collections import defaultdict
98
from collections.abc import Callable
109
from functools import partial
@@ -14,10 +13,8 @@
1413
from torch import Tensor
1514
from torch.optim import Optimizer
1615

17-
import torchao.prototype.parq as parq
18-
1916
from ..quant import Quantizer
20-
from ..utils import HAS_DTENSOR, instantiate_module, is_dtensor
17+
from ..utils import HAS_DTENSOR, is_dtensor
2118
from .proxmap import ProxMap
2219

2320
if HAS_DTENSOR:
@@ -136,6 +133,14 @@ def _filter_fn(module: torch.nn.Module, *args) -> bool:
136133

137134
return _filter_fn
138135

136+
@torch._disable_dynamo
137+
def state_dict(self) -> dict[str, Any]:
138+
return self.base_optimizer.state_dict()
139+
140+
@torch._disable_dynamo
141+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
142+
self.base_optimizer.load_state_dict(state_dict)
143+
139144
@torch.no_grad()
140145
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
141146
"""Performs a single optimization step.
@@ -174,16 +179,8 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
174179

175180
for group in self.regularized_param_groups():
176181
# Override quantizer if specified in the group
177-
if "quant_cls" in group:
178-
quant_cls = instantiate_module(
179-
f"{parq.__name__}.quant", group["quant_cls"]
180-
)
181-
quant_kwargs = (
182-
json.loads(group["quant_kwargs"]) if "quant_kwargs" in group else {}
183-
)
184-
quantizer = quant_cls(**quant_kwargs)
185-
else:
186-
quantizer = self.quantizer
182+
quantizer = group.get("quantizer", self.quantizer)
183+
assert isinstance(quantizer, Quantizer), f"Invalid {quantizer=}"
187184

188185
# AProx in practice: ensure shrinkage coefficient >= 1
189186
group["cumu_lr"] += group["lr"]

torchao/prototype/parq/quant/uniform_torchao.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def quantize(
142142

143143

144144
class StretchedUnifTorchaoQuantizer(UnifTorchaoQuantizer):
145-
def __init__(self, b: int, int_shift: float = 0.5) -> None:
145+
def __init__(self, b: int, int_shift: float = 0.5, **kwargs) -> None:
146146
quant_absmax = 2 ** (b - 1) - int_shift
147147
self.quant_min = -quant_absmax
148148
self.quant_max = quant_absmax
@@ -152,6 +152,7 @@ def __init__(self, b: int, int_shift: float = 0.5) -> None:
152152
mapping_type=MappingType.ASYMMETRIC,
153153
quant_min=self.quant_min,
154154
quant_max=self.quant_max,
155+
**kwargs,
155156
)
156157

157158
self._choose_qparams = partial(choose_qparams_stretched_affine, b=b)

0 commit comments

Comments
 (0)