diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 48a9f780b6..aa3358d73a 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -1278,6 +1278,7 @@ def test_qat_config_init(self): QATConfig(base_config, step=QATStep.CONVERT) QATConfig(activation_config=fq_config, weight_config=fq_config, step="prepare") QATConfig(weight_config=fq_config, step="prepare") + QATConfig(step="convert") # OK: good step values self.assertEqual(QATConfig(base_config).step, "prepare") @@ -1306,7 +1307,7 @@ def test_qat_config_init(self): with self.assertRaisesRegex(ValueError, "Cannot specify both"): QATConfig(base_config, activation_config=fq_config, step="prepare") with self.assertRaisesRegex( - ValueError, "must be specified in the convert step" + ValueError, "Cannot specify .* in the convert step" ): QATConfig(weight_config=fq_config, step="convert") @@ -1884,6 +1885,37 @@ def test_qat_api_deprecation(self): str(w.message), ) + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) + def test_qat_api_convert_no_quantization(self): + """ + Test that `QATConfig(step="convert")` swaps back to nn modules without quantization. + """ + torch.manual_seed(self.SEED) + m = M() + baseline_model = copy.deepcopy(m) + + # Prepare swaps to FakeQuantizedLinear + quantize_(m, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare")) + self.assertEqual(type(m.linear1), FakeQuantizedLinear) + self.assertEqual(type(m.sub.linear), FakeQuantizedLinear) + self.assertEqual(type(m.linear2), FakeQuantizedLinear) + + # Convert without a `base_config` swaps back to nn.Linear + quantize_(m, QATConfig(step="convert")) + self.assertEqual(type(m.linear1), torch.nn.Linear) + self.assertEqual(type(m.sub.linear), torch.nn.Linear) + self.assertEqual(type(m.linear2), torch.nn.Linear) + + # Model weights should be identical to before + torch.manual_seed(self.SEED) + x = m.example_inputs() + x2 = copy.deepcopy(x) + out = m(*x) + baseline_out = baseline_model(*x2) + torch.testing.assert_close(out, baseline_out, atol=0, rtol=0) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/qat/api.py b/torchao/quantization/qat/api.py index 0d69f44bd9..6be58868e4 100644 --- a/torchao/quantization/qat/api.py +++ b/torchao/quantization/qat/api.py @@ -115,8 +115,10 @@ class QATConfig(AOBaseConfig): ValueError: If `base_config` and `activation_config` are both specified ValueError: If `base_config` and `weight_config` are both specified ValueError: If neither `base_config` nor `weight_config` is specified + and `step` is "prepare" + ValueError: If either `activation_config` or `weight_config` is specified + and `step` is "convert" ValueError: If `step` is not one of "prepare" or "convert" - ValueError: If `base_config` is None but `step` is "convert" ValueError: If the config is applied on a module that is not a `torch.nn.Linear` or `torch.nn.Embedding`, or it is applied on `torch.nn.Embedding` with an activation config @@ -148,18 +150,26 @@ def __post_init__(self): all_step_values = [s.value for s in QATStep] if self.step not in all_step_values: raise ValueError(f"`step` must be one of {all_step_values}") - if self.base_config is None and self.weight_config is None: - raise ValueError( - "One of `base_config` or `weight_config` must be specified" - ) if self.base_config is not None and self.activation_config is not None: raise ValueError( "Cannot specify both `base_config` and `activation_config`" ) if self.base_config is not None and self.weight_config is not None: raise ValueError("Cannot specify both `base_config` and `weight_config`") - if self.base_config is None and self.step == "convert": - raise ValueError("`base_config` must be specified in the convert step") + if ( + self.step == QATStep.PREPARE + and self.base_config is None + and self.weight_config is None + ): + raise ValueError( + "One of `base_config` or `weight_config` must be specified in the prepare step" + ) + if self.step == QATStep.CONVERT and ( + self.activation_config is not None or self.weight_config is not None + ): + raise ValueError( + "Cannot specify `weight_config` or `activation_config` in the convert step" + ) if isinstance(self.base_config, FakeQuantizeConfigBase): config_type = self.base_config.__class__.__name__ raise ValueError( @@ -196,6 +206,9 @@ def _qat_config_transform( else: act_config = config.activation_config weight_config = config.weight_config + assert config.weight_config is not None, ( + "`base_config` and `weight_config` were both None in the prepare step" + ) if isinstance(module, torch.nn.Linear): return FakeQuantizedLinear.from_linear(module, act_config, weight_config) elif isinstance(module, torch.nn.Embedding): @@ -213,8 +226,10 @@ def _qat_config_transform( # Swap FakeQuantizedLinear -> nn.Linear # Swap FakeQuantizedEmbedding -> nn.Embedding # Then apply the base config's transform function to quantize the model + # If there is no base config, then simply perform the module swap assert step == QATStep.CONVERT, "unexpected step '%s' in QATConfig" % step - assert base_config is not None, "expected `base_config` in convert step" + assert config.activation_config is None, "unexpected `activation_config`" + assert config.weight_config is None, "unexpected `weight_config`" if isinstance(module, FakeQuantizedLinear): module = module.to_linear() elif isinstance(module, FakeQuantizedEmbedding): @@ -222,7 +237,10 @@ def _qat_config_transform( else: # Unrelated module, ignore return module - return _QUANTIZE_CONFIG_HANDLER[type(base_config)](module, base_config) + if base_config is not None: + return _QUANTIZE_CONFIG_HANDLER[type(base_config)](module, base_config) + else: + return module @dataclass