@@ -91,6 +91,10 @@ class QuantizationAwareTraining(Callback):
91
91
92
92
.. warning:: ``QuantizationAwareTraining`` is in beta and subject to change.
93
93
94
+ The ``LightningModule`` is prepared for QAT training in the ``on_fit_start`` hook. Checkpoints saved during training
95
+ include already collected stats to perform the Quantization conversion, but it doesn't contain the quantized or
96
+ fused model/layers. The quantization is performed in the ``on_fit_end`` hook so the model needs to be saved after
97
+ training finishes if quantization is desired.
94
98
95
99
Args:
96
100
@@ -178,7 +182,7 @@ def __init__(
178
182
)
179
183
self ._collect_quantization = collect_quantization
180
184
181
- self .modules_to_fuse = modules_to_fuse
185
+ self ._modules_to_fuse = modules_to_fuse
182
186
self ._input_compatible = input_compatible
183
187
self ._convert_on_fit_end = quantize_on_fit_end
184
188
@@ -193,11 +197,12 @@ def __init__(
193
197
self ._forward_calls = 0
194
198
self ._fake_quant_to_initial_state_dict = {}
195
199
self ._last_fake_quant_to_observer_enabled = {}
200
+ self ._module_prepared = False
196
201
197
202
def _check_feasible_fuse (self , model : "pl.LightningModule" ) -> bool :
198
- if not self .modules_to_fuse :
203
+ if not self ._modules_to_fuse :
199
204
return False
200
- for group in self .modules_to_fuse :
205
+ for group in self ._modules_to_fuse :
201
206
if not all (_recursive_hasattr (model , m ) for m in group ):
202
207
raise MisconfigurationException (
203
208
f"You have requested to fuse { group } but one or more of them is not your model attributes"
@@ -217,44 +222,50 @@ def _restore_last_observer_enabled(self) -> None:
217
222
for fake_quant , observer_enabled in self ._last_fake_quant_to_observer_enabled .items ():
218
223
fake_quant .observer_enabled .copy_ (observer_enabled )
219
224
220
- def on_fit_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
225
+ def _prepare_model (self , model : torch .nn .Module ) -> None :
226
+ if self ._module_prepared :
227
+ return
221
228
# QuantStub converts tensors from floating point to quantized
222
- pl_module .quant = torch .quantization .QuantStub ()
229
+ model .quant = torch .quantization .QuantStub ()
223
230
# DeQuantStub converts tensors from quantized to floating point
224
- pl_module .dequant = torch .quantization .DeQuantStub ()
231
+ model .dequant = torch .quantization .DeQuantStub ()
225
232
# manually specify where tensors will be converted from quantized
226
233
# to floating point in the quantized model
227
- self .__module_forward = pl_module .forward
228
- pl_module .forward = wrap_qat_forward_context (
229
- quant_cb = self , model = pl_module , func = pl_module .forward , trigger_condition = self ._collect_quantization
234
+ self .__module_forward = model .forward
235
+ model .forward = wrap_qat_forward_context (
236
+ quant_cb = self , model = model , func = model .forward , trigger_condition = self ._collect_quantization
230
237
)
231
238
232
239
# attach a global qconfig, which contains information about what kind
233
240
# of observers to attach. Use 'fbgemm' for server inference
234
241
if isinstance (self ._qconfig , str ):
235
242
if self ._observer_type == "histogram" :
236
- pl_module .qconfig = torch .quantization .get_default_qconfig (self ._qconfig )
243
+ model .qconfig = torch .quantization .get_default_qconfig (self ._qconfig )
237
244
elif self ._observer_type == "average" :
238
245
# version=None corresponds to using FakeQuantize rather than
239
246
# FusedMovingAvgObsFakeQuantize which was introduced in PT1.10
240
247
# details in https://github.com/pytorch/pytorch/issues/64564
241
248
extra_kwargs = dict (version = None ) if _TORCH_GREATER_EQUAL_1_10 else {}
242
- pl_module .qconfig = torch .quantization .get_default_qat_qconfig (self ._qconfig , ** extra_kwargs )
249
+ model .qconfig = torch .quantization .get_default_qat_qconfig (self ._qconfig , ** extra_kwargs )
243
250
244
251
elif isinstance (self ._qconfig , QConfig ):
245
- pl_module .qconfig = self ._qconfig
252
+ model .qconfig = self ._qconfig
246
253
247
- if self ._check_feasible_fuse (pl_module ):
248
- torch .quantization .fuse_modules (pl_module , self .modules_to_fuse , inplace = True )
254
+ if self ._check_feasible_fuse (model ):
255
+ torch .quantization .fuse_modules (model , self ._modules_to_fuse , inplace = True )
249
256
250
257
# Prepare the model for QAT. This inserts observers and fake_quants in
251
258
# the model that will observe weight and activation tensors during calibration.
252
- torch .quantization .prepare_qat (pl_module , inplace = True )
259
+ torch .quantization .prepare_qat (model , inplace = True )
253
260
254
- fake_quants = tuple (module for module in pl_module .modules () if isinstance (module , FakeQuantizeBase ))
261
+ fake_quants = tuple (module for module in model .modules () if isinstance (module , FakeQuantizeBase ))
255
262
self ._fake_quant_to_initial_state_dict = {
256
263
fake_quant : copy .deepcopy (fake_quant .state_dict ()) for fake_quant in fake_quants
257
264
}
265
+ self ._module_prepared = True
266
+
267
+ def on_fit_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ):
268
+ self ._prepare_model (pl_module )
258
269
259
270
def on_fit_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
260
271
if not self ._convert_on_fit_end :
@@ -311,3 +322,18 @@ def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule
311
322
def on_predict_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
312
323
if "predict" in self ._observer_disabled_stages :
313
324
self ._restore_last_observer_enabled ()
325
+
326
+ def state_dict (self ) -> Dict [str , Any ]:
327
+ keys = {"_qconfig" , "_observer_type" , "_collect_quantization" , "_modules_to_fuse" , "_input_compatible" }
328
+ return {n : getattr (self , n ) for n in keys }
329
+
330
+ def _load_before_model (self , model : torch .nn .Module , state_dict : Dict [str , Any ]) -> None :
331
+ """Special hook that gets called by the CheckpointConnector *before* the model gets loaded.
332
+
333
+ This hook replaces the :meth:`on_load_checkpoint` and :meth:`load_state_dict` callback methods which get called
334
+ after the model has already loaded the weights. For quantization, we need to convert the model first before that
335
+ happens, assuming the previous training used quantization.
336
+ """
337
+ for k , v in state_dict .items ():
338
+ setattr (self , k , v )
339
+ self ._prepare_model (model )
0 commit comments