Skip to content

Commit 9fbcffe

Browse files
committed
Allow no quantization during QATConfig convert
**Summary:** This commit allows users to call the following, which swaps `FakeQuantized*` modules back to the corresponding `torch.nn.*` without performing post-training quantization. ``` QATConfig(base_config=None, step="convert") ``` This has the exact same functionality as this deprecated config: ``` FromIntXQuantizationAwareTrainingConfig() ``` This functionality is added back since it may be useful to users who wish to save QAT trained checkpoints from models containing only `torch.nn.*` modules (not `FakeQuanitzed*` modules), e.g. when training and inference need to happen on different machines: ``` quantize_(model, QATConfig(base_ptq_config, step="prepare")) train(model) quantize_(model, QATConfig(step="convert")) torch.save(model.state_dict(), "my_checkpoint.pt") \# On a different machine model.load_state_dict(torch.load("my_checkpoint.pt")) quantize_(model, base_ptq_config) ``` **Test Plan:** ``` python test/quantization/test_qat.py -k qat_config_init python test/quantization/test_qat.py -k qat_api_convert_no_quantization ```
1 parent 418593c commit 9fbcffe

File tree

2 files changed

+60
-10
lines changed

2 files changed

+60
-10
lines changed

test/quantization/test_qat.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1278,6 +1278,7 @@ def test_qat_config_init(self):
12781278
QATConfig(base_config, step=QATStep.CONVERT)
12791279
QATConfig(activation_config=fq_config, weight_config=fq_config, step="prepare")
12801280
QATConfig(weight_config=fq_config, step="prepare")
1281+
QATConfig(step="convert")
12811282

12821283
# OK: good step values
12831284
self.assertEqual(QATConfig(base_config).step, "prepare")
@@ -1306,7 +1307,7 @@ def test_qat_config_init(self):
13061307
with self.assertRaisesRegex(ValueError, "Cannot specify both"):
13071308
QATConfig(base_config, activation_config=fq_config, step="prepare")
13081309
with self.assertRaisesRegex(
1309-
ValueError, "must be specified in the convert step"
1310+
ValueError, "Cannot specify .* in the convert step"
13101311
):
13111312
QATConfig(weight_config=fq_config, step="convert")
13121313

@@ -1884,6 +1885,37 @@ def test_qat_api_deprecation(self):
18841885
str(w.message),
18851886
)
18861887

1888+
@unittest.skipIf(
1889+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1890+
)
1891+
def test_qat_api_convert_no_quantization(self):
1892+
"""
1893+
Test that `QATConfig(step="convert")` swaps back to nn modules without quantization.
1894+
"""
1895+
torch.manual_seed(self.SEED)
1896+
m = M()
1897+
baseline_model = copy.deepcopy(m)
1898+
1899+
# Prepare swaps to FakeQuantizedLinear
1900+
quantize_(m, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare"))
1901+
self.assertEqual(type(m.linear1), FakeQuantizedLinear)
1902+
self.assertEqual(type(m.sub.linear), FakeQuantizedLinear)
1903+
self.assertEqual(type(m.linear2), FakeQuantizedLinear)
1904+
1905+
# Convert without a `base_config` swaps back to nn.Linear
1906+
quantize_(m, QATConfig(step="convert"))
1907+
self.assertEqual(type(m.linear1), torch.nn.Linear)
1908+
self.assertEqual(type(m.sub.linear), torch.nn.Linear)
1909+
self.assertEqual(type(m.linear2), torch.nn.Linear)
1910+
1911+
# Model weights should be identical to before
1912+
torch.manual_seed(self.SEED)
1913+
x = m.example_inputs()
1914+
x2 = copy.deepcopy(x)
1915+
out = m(*x)
1916+
baseline_out = baseline_model(*x2)
1917+
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)
1918+
18871919

18881920
if __name__ == "__main__":
18891921
unittest.main()

torchao/quantization/qat/api.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,10 @@ class QATConfig(AOBaseConfig):
115115
ValueError: If `base_config` and `activation_config` are both specified
116116
ValueError: If `base_config` and `weight_config` are both specified
117117
ValueError: If neither `base_config` nor `weight_config` is specified
118+
and `step` is "prepare"
119+
ValueError: If either `activation_config` or `weight_config` is specified
120+
and `step` is "convert"
118121
ValueError: If `step` is not one of "prepare" or "convert"
119-
ValueError: If `base_config` is None but `step` is "convert"
120122
ValueError: If the config is applied on a module that is not a
121123
`torch.nn.Linear` or `torch.nn.Embedding`, or it is applied on
122124
`torch.nn.Embedding` with an activation config
@@ -148,18 +150,26 @@ def __post_init__(self):
148150
all_step_values = [s.value for s in QATStep]
149151
if self.step not in all_step_values:
150152
raise ValueError(f"`step` must be one of {all_step_values}")
151-
if self.base_config is None and self.weight_config is None:
152-
raise ValueError(
153-
"One of `base_config` or `weight_config` must be specified"
154-
)
155153
if self.base_config is not None and self.activation_config is not None:
156154
raise ValueError(
157155
"Cannot specify both `base_config` and `activation_config`"
158156
)
159157
if self.base_config is not None and self.weight_config is not None:
160158
raise ValueError("Cannot specify both `base_config` and `weight_config`")
161-
if self.base_config is None and self.step == "convert":
162-
raise ValueError("`base_config` must be specified in the convert step")
159+
if (
160+
self.step == QATStep.PREPARE
161+
and self.base_config is None
162+
and self.weight_config is None
163+
):
164+
raise ValueError(
165+
"One of `base_config` or `weight_config` must be specified in the prepare step"
166+
)
167+
if self.step == QATStep.CONVERT and (
168+
self.activation_config is not None or self.weight_config is not None
169+
):
170+
raise ValueError(
171+
"Cannot specify `weight_config` or `activation_config` in the convert step"
172+
)
163173
if isinstance(self.base_config, FakeQuantizeConfigBase):
164174
config_type = self.base_config.__class__.__name__
165175
raise ValueError(
@@ -196,6 +206,9 @@ def _qat_config_transform(
196206
else:
197207
act_config = config.activation_config
198208
weight_config = config.weight_config
209+
assert config.weight_config is not None, (
210+
"`base_config` and `weight_config` were both None in the prepare step"
211+
)
199212
if isinstance(module, torch.nn.Linear):
200213
return FakeQuantizedLinear.from_linear(module, act_config, weight_config)
201214
elif isinstance(module, torch.nn.Embedding):
@@ -213,16 +226,21 @@ def _qat_config_transform(
213226
# Swap FakeQuantizedLinear -> nn.Linear
214227
# Swap FakeQuantizedEmbedding -> nn.Embedding
215228
# Then apply the base config's transform function to quantize the model
229+
# If there is no base config, then simply perform the module swap
216230
assert step == QATStep.CONVERT, "unexpected step '%s' in QATConfig" % step
217-
assert base_config is not None, "expected `base_config` in convert step"
231+
assert config.activation_config is None, "unexpected `activation_config`"
232+
assert config.weight_config is None, "unexpected `weight_config`"
218233
if isinstance(module, FakeQuantizedLinear):
219234
module = module.to_linear()
220235
elif isinstance(module, FakeQuantizedEmbedding):
221236
module = module.to_embedding()
222237
else:
223238
# Unrelated module, ignore
224239
return module
225-
return _QUANTIZE_CONFIG_HANDLER[type(base_config)](module, base_config)
240+
if base_config is not None:
241+
return _QUANTIZE_CONFIG_HANDLER[type(base_config)](module, base_config)
242+
else:
243+
return module
226244

227245

228246
@dataclass

0 commit comments

Comments
 (0)