Skip to content

Commit 2230d59

Browse files
BordaawaelchliCarlos Mocholí
authored
Support loading a checkpoint with QAT (#11346)
Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent fdcc09c commit 2230d59

File tree

4 files changed

+80
-16
lines changed

4 files changed

+80
-16
lines changed

pytorch_lightning/callbacks/quantization.py

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ class QuantizationAwareTraining(Callback):
9191
9292
.. warning:: ``QuantizationAwareTraining`` is in beta and subject to change.
9393
94+
The ``LightningModule`` is prepared for QAT training in the ``on_fit_start`` hook. Checkpoints saved during training
95+
include already collected stats to perform the Quantization conversion, but it doesn't contain the quantized or
96+
fused model/layers. The quantization is performed in the ``on_fit_end`` hook so the model needs to be saved after
97+
training finishes if quantization is desired.
9498
9599
Args:
96100
@@ -178,7 +182,7 @@ def __init__(
178182
)
179183
self._collect_quantization = collect_quantization
180184

181-
self.modules_to_fuse = modules_to_fuse
185+
self._modules_to_fuse = modules_to_fuse
182186
self._input_compatible = input_compatible
183187
self._convert_on_fit_end = quantize_on_fit_end
184188

@@ -193,11 +197,12 @@ def __init__(
193197
self._forward_calls = 0
194198
self._fake_quant_to_initial_state_dict = {}
195199
self._last_fake_quant_to_observer_enabled = {}
200+
self._module_prepared = False
196201

197202
def _check_feasible_fuse(self, model: "pl.LightningModule") -> bool:
198-
if not self.modules_to_fuse:
203+
if not self._modules_to_fuse:
199204
return False
200-
for group in self.modules_to_fuse:
205+
for group in self._modules_to_fuse:
201206
if not all(_recursive_hasattr(model, m) for m in group):
202207
raise MisconfigurationException(
203208
f"You have requested to fuse {group} but one or more of them is not your model attributes"
@@ -217,44 +222,50 @@ def _restore_last_observer_enabled(self) -> None:
217222
for fake_quant, observer_enabled in self._last_fake_quant_to_observer_enabled.items():
218223
fake_quant.observer_enabled.copy_(observer_enabled)
219224

220-
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
225+
def _prepare_model(self, model: torch.nn.Module) -> None:
226+
if self._module_prepared:
227+
return
221228
# QuantStub converts tensors from floating point to quantized
222-
pl_module.quant = torch.quantization.QuantStub()
229+
model.quant = torch.quantization.QuantStub()
223230
# DeQuantStub converts tensors from quantized to floating point
224-
pl_module.dequant = torch.quantization.DeQuantStub()
231+
model.dequant = torch.quantization.DeQuantStub()
225232
# manually specify where tensors will be converted from quantized
226233
# to floating point in the quantized model
227-
self.__module_forward = pl_module.forward
228-
pl_module.forward = wrap_qat_forward_context(
229-
quant_cb=self, model=pl_module, func=pl_module.forward, trigger_condition=self._collect_quantization
234+
self.__module_forward = model.forward
235+
model.forward = wrap_qat_forward_context(
236+
quant_cb=self, model=model, func=model.forward, trigger_condition=self._collect_quantization
230237
)
231238

232239
# attach a global qconfig, which contains information about what kind
233240
# of observers to attach. Use 'fbgemm' for server inference
234241
if isinstance(self._qconfig, str):
235242
if self._observer_type == "histogram":
236-
pl_module.qconfig = torch.quantization.get_default_qconfig(self._qconfig)
243+
model.qconfig = torch.quantization.get_default_qconfig(self._qconfig)
237244
elif self._observer_type == "average":
238245
# version=None corresponds to using FakeQuantize rather than
239246
# FusedMovingAvgObsFakeQuantize which was introduced in PT1.10
240247
# details in https://github.com/pytorch/pytorch/issues/64564
241248
extra_kwargs = dict(version=None) if _TORCH_GREATER_EQUAL_1_10 else {}
242-
pl_module.qconfig = torch.quantization.get_default_qat_qconfig(self._qconfig, **extra_kwargs)
249+
model.qconfig = torch.quantization.get_default_qat_qconfig(self._qconfig, **extra_kwargs)
243250

244251
elif isinstance(self._qconfig, QConfig):
245-
pl_module.qconfig = self._qconfig
252+
model.qconfig = self._qconfig
246253

247-
if self._check_feasible_fuse(pl_module):
248-
torch.quantization.fuse_modules(pl_module, self.modules_to_fuse, inplace=True)
254+
if self._check_feasible_fuse(model):
255+
torch.quantization.fuse_modules(model, self._modules_to_fuse, inplace=True)
249256

250257
# Prepare the model for QAT. This inserts observers and fake_quants in
251258
# the model that will observe weight and activation tensors during calibration.
252-
torch.quantization.prepare_qat(pl_module, inplace=True)
259+
torch.quantization.prepare_qat(model, inplace=True)
253260

254-
fake_quants = tuple(module for module in pl_module.modules() if isinstance(module, FakeQuantizeBase))
261+
fake_quants = tuple(module for module in model.modules() if isinstance(module, FakeQuantizeBase))
255262
self._fake_quant_to_initial_state_dict = {
256263
fake_quant: copy.deepcopy(fake_quant.state_dict()) for fake_quant in fake_quants
257264
}
265+
self._module_prepared = True
266+
267+
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
268+
self._prepare_model(pl_module)
258269

259270
def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
260271
if not self._convert_on_fit_end:
@@ -311,3 +322,18 @@ def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule
311322
def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
312323
if "predict" in self._observer_disabled_stages:
313324
self._restore_last_observer_enabled()
325+
326+
def state_dict(self) -> Dict[str, Any]:
327+
keys = {"_qconfig", "_observer_type", "_collect_quantization", "_modules_to_fuse", "_input_compatible"}
328+
return {n: getattr(self, n) for n in keys}
329+
330+
def _load_before_model(self, model: torch.nn.Module, state_dict: Dict[str, Any]) -> None:
331+
"""Special hook that gets called by the CheckpointConnector *before* the model gets loaded.
332+
333+
This hook replaces the :meth:`on_load_checkpoint` and :meth:`load_state_dict` callback methods which get called
334+
after the model has already loaded the weights. For quantization, we need to convert the model first before that
335+
happens, assuming the previous training used quantization.
336+
"""
337+
for k, v in state_dict.items():
338+
setattr(self, k, v)
339+
self._prepare_model(model)

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import logging
1616
import os
1717
import re
18+
from copy import deepcopy
1819
from typing import Any, Dict, Optional
1920

2021
import torch
@@ -217,6 +218,32 @@ def restore_precision_plugin_state(self) -> None:
217218
):
218219
prec_plugin.load_state_dict(self._loaded_checkpoint["native_amp_scaling_state"])
219220

221+
def _restore_quantization_callbacks(self) -> None:
222+
"""Restores all the ``QuantizationAwareTraining`` callbacks from the pre-loaded checkpoint.
223+
224+
The implementation is similar to :meth:`restore_callbacks` but calls the QAT callback with a special hook
225+
`load_before_model` instead of `load_state_dict`.
226+
"""
227+
if not self._loaded_checkpoint:
228+
return
229+
230+
callback_states = self._loaded_checkpoint.get("callbacks")
231+
232+
if callback_states is None:
233+
return
234+
235+
from pytorch_lightning.callbacks.quantization import QuantizationAwareTraining # avoid circular import
236+
237+
for callback in self.trainer.callbacks:
238+
if not isinstance(callback, QuantizationAwareTraining):
239+
continue
240+
241+
state = callback_states.get(callback.state_key, callback_states.get(callback._legacy_state_key))
242+
if state:
243+
# The Quantization callbacks have a special method that must be called before restoring the weights
244+
# of the model
245+
callback._load_before_model(self.trainer.model, deepcopy(state))
246+
220247
def restore_callbacks(self) -> None:
221248
"""Restores all callbacks from the pre-loaded checkpoint."""
222249
if not self._loaded_checkpoint:

pytorch_lightning/trainer/trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,6 +1140,7 @@ def tune(
11401140
def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None) -> None:
11411141
# restore modules after setup
11421142
self._checkpoint_connector.resume_start(checkpoint_path)
1143+
self._checkpoint_connector._restore_quantization_callbacks()
11431144
self._checkpoint_connector.restore_model()
11441145
self._checkpoint_connector.restore_datamodule()
11451146
if self.state.fn == TrainerFn.FITTING:

tests/callbacks/test_quantization.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool):
6363
# test that the test score is almost the same as with pure training
6464
assert torch.allclose(org_score, quant_score, atol=0.45)
6565
model_path = trainer.checkpoint_callback.best_model_path
66+
curr_epoch = trainer.current_epoch
6667

6768
trainer_args.update(dict(max_epochs=1, enable_checkpointing=False))
6869
if not convert:
@@ -81,6 +82,15 @@ def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool):
8182
quant2_score = torch.mean(torch.tensor([mape(qmodel2(x), y) for x, y in dm.test_dataloader()]))
8283
assert torch.allclose(org_score, quant2_score, atol=0.45)
8384

85+
# test without and with QAT callback
86+
trainer_args.update(max_epochs=curr_epoch + 1)
87+
qmodel2 = RegressionModel()
88+
trainer = Trainer(callbacks=[QuantizationAwareTraining()], **trainer_args)
89+
trainer.fit(qmodel2, datamodule=dm, ckpt_path=model_path)
90+
quant2_score = torch.mean(torch.tensor([mape(qmodel2(x), y) for x, y in dm.test_dataloader()]))
91+
# test that the test score is almost the same as with pure training
92+
assert torch.allclose(org_score, quant2_score, atol=0.45)
93+
8494

8595
@RunIf(quantization=True)
8696
def test_quantize_torchscript(tmpdir):

0 commit comments

Comments
 (0)