Skip to content

Commit eac6c3d

Browse files
committed
Allow per-group quantizers in QuantOptimizer
1 parent 46ba24c commit eac6c3d

File tree

1 file changed

+17
-28
lines changed

1 file changed

+17
-28
lines changed

torchao/prototype/parq/optim/quantopt.py

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -133,27 +133,6 @@ def _filter_fn(module: torch.nn.Module, *args) -> bool:
133133

134134
return _filter_fn
135135

136-
@torch._disable_dynamo
137-
def state_dict(self) -> dict[str, Any]:
138-
state_dict = self.base_optimizer.state_dict()
139-
state_dict["qat_state"] = {"num_steps": self.num_steps}
140-
# quantizer and prox_map may also need to save states, can add here
141-
return state_dict
142-
143-
@torch._disable_dynamo
144-
def load_state_dict(
145-
self, state_dict: dict[str, Any], start_step: Optional[int] = None
146-
) -> None:
147-
qat_state = state_dict.get("qat_state")
148-
# resume from check points usually not corresponds to saved num_steps
149-
# so allow explicit start_step computed from epochs * steps_per_epoc
150-
if start_step is not None:
151-
self.num_steps = start_step
152-
elif qat_state is not None:
153-
# hope discrepancy in num_steps does not cause major problem!
154-
self.num_steps = qat_state["num_steps"]
155-
self.base_optimizer.load_state_dict(state_dict)
156-
157136
@torch.no_grad()
158137
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
159138
"""Performs a single optimization step.
@@ -191,6 +170,18 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
191170
quant_update = False
192171

193172
for group in self.regularized_param_groups():
173+
# Override quantizer if specified in the group
174+
if "quant_cls" in group:
175+
quant_cls = instantiate_module(
176+
f"{parq.__name__}.quant", group["quant_cls"]
177+
)
178+
quant_kwargs = (
179+
json.loads(group["quant_kwargs"]) if "quant_kwargs" in group else {}
180+
)
181+
quantizer = quant_cls(**quant_kwargs)
182+
else:
183+
quantizer = self.quantizer
184+
194185
# AProx in practice: ensure shrinkage coefficient >= 1
195186
group["cumu_lr"] += group["lr"]
196187
gamma = max(1.0, group["cumu_lr"])
@@ -210,9 +201,9 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
210201

211202
# reshape p according to block size if specified
212203
if block_size is not None:
213-
assert p.size(-1) % block_size == 0, (
214-
f"{p.size(-1)=} is not divisible by {block_size=}"
215-
)
204+
assert (
205+
p.size(-1) % block_size == 0
206+
), f"{p.size(-1)=} is not divisible by {block_size=}"
216207
assert p.dim() <= 2, f"Invalid {p.dim()=} for {block_size=}"
217208
if p.dim() == 1:
218209
p = p.unsqueeze(0)
@@ -224,7 +215,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
224215
# update quantization targets periodically
225216
per_channel = self.quant_per_channel and p.dim() > 1
226217
if quant_update:
227-
quant_size = self.quantizer.get_quant_size(b)
218+
quant_size = quantizer.get_quant_size(b)
228219

229220
if per_channel:
230221
quant_size = (p.size(0), quant_size)
@@ -242,9 +233,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
242233

243234
q = None
244235
if quant_update:
245-
qfunc = partial(
246-
self.quantize_, quantizer=self.quantizer, b=b, dim=dim
247-
)
236+
qfunc = partial(self.quantize_, quantizer=quantizer, b=b, dim=dim)
248237
if is_dtensor(p):
249238
qfunc = local_map(
250239
qfunc,

0 commit comments

Comments
 (0)