@@ -133,27 +133,6 @@ def _filter_fn(module: torch.nn.Module, *args) -> bool:
133
133
134
134
return _filter_fn
135
135
136
- @torch ._disable_dynamo
137
- def state_dict (self ) -> dict [str , Any ]:
138
- state_dict = self .base_optimizer .state_dict ()
139
- state_dict ["qat_state" ] = {"num_steps" : self .num_steps }
140
- # quantizer and prox_map may also need to save states, can add here
141
- return state_dict
142
-
143
- @torch ._disable_dynamo
144
- def load_state_dict (
145
- self , state_dict : dict [str , Any ], start_step : Optional [int ] = None
146
- ) -> None :
147
- qat_state = state_dict .get ("qat_state" )
148
- # resume from check points usually not corresponds to saved num_steps
149
- # so allow explicit start_step computed from epochs * steps_per_epoc
150
- if start_step is not None :
151
- self .num_steps = start_step
152
- elif qat_state is not None :
153
- # hope discrepancy in num_steps does not cause major problem!
154
- self .num_steps = qat_state ["num_steps" ]
155
- self .base_optimizer .load_state_dict (state_dict )
156
-
157
136
@torch .no_grad ()
158
137
def step (self , closure : Optional [Callable [[], float ]] = None ) -> Optional [float ]:
159
138
"""Performs a single optimization step.
@@ -191,6 +170,18 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
191
170
quant_update = False
192
171
193
172
for group in self .regularized_param_groups ():
173
+ # Override quantizer if specified in the group
174
+ if "quant_cls" in group :
175
+ quant_cls = instantiate_module (
176
+ f"{ parq .__name__ } .quant" , group ["quant_cls" ]
177
+ )
178
+ quant_kwargs = (
179
+ json .loads (group ["quant_kwargs" ]) if "quant_kwargs" in group else {}
180
+ )
181
+ quantizer = quant_cls (** quant_kwargs )
182
+ else :
183
+ quantizer = self .quantizer
184
+
194
185
# AProx in practice: ensure shrinkage coefficient >= 1
195
186
group ["cumu_lr" ] += group ["lr" ]
196
187
gamma = max (1.0 , group ["cumu_lr" ])
@@ -210,9 +201,9 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
210
201
211
202
# reshape p according to block size if specified
212
203
if block_size is not None :
213
- assert p . size ( - 1 ) % block_size == 0 , (
214
- f" { p .size (- 1 )= } is not divisible by { block_size = } "
215
- )
204
+ assert (
205
+ p .size (- 1 ) % block_size == 0
206
+ ), f" { p . size ( - 1 ) = } is not divisible by { block_size = } "
216
207
assert p .dim () <= 2 , f"Invalid { p .dim ()= } for { block_size = } "
217
208
if p .dim () == 1 :
218
209
p = p .unsqueeze (0 )
@@ -224,7 +215,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
224
215
# update quantization targets periodically
225
216
per_channel = self .quant_per_channel and p .dim () > 1
226
217
if quant_update :
227
- quant_size = self . quantizer .get_quant_size (b )
218
+ quant_size = quantizer .get_quant_size (b )
228
219
229
220
if per_channel :
230
221
quant_size = (p .size (0 ), quant_size )
@@ -242,9 +233,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
242
233
243
234
q = None
244
235
if quant_update :
245
- qfunc = partial (
246
- self .quantize_ , quantizer = self .quantizer , b = b , dim = dim
247
- )
236
+ qfunc = partial (self .quantize_ , quantizer = quantizer , b = b , dim = dim )
248
237
if is_dtensor (p ):
249
238
qfunc = local_map (
250
239
qfunc ,
0 commit comments