diff --git a/torchao/quantization/pt2e/__init__.py b/torchao/quantization/pt2e/__init__.py index 3e4352dabd..b6b8a728a3 100644 --- a/torchao/quantization/pt2e/__init__.py +++ b/torchao/quantization/pt2e/__init__.py @@ -39,6 +39,8 @@ FusedMovingAvgObsFakeQuantize, default_dynamic_fake_quant, default_fake_quant, + disable_fake_quant, + disable_observer, enable_fake_quant, enable_observer, ) @@ -114,6 +116,8 @@ # utils "enable_fake_quant", "enable_observer", + "disable_fake_quant", + "disable_observer", # export_utils "move_exported_model_to_eval", "move_exported_model_to_train", diff --git a/torchao/testing/pt2e/utils.py b/torchao/testing/pt2e/utils.py index 41be460a40..ad49fec014 100644 --- a/torchao/testing/pt2e/utils.py +++ b/torchao/testing/pt2e/utils.py @@ -78,6 +78,7 @@ def _test_quantizer( m, example_inputs, dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None, + strict=True, ).module() if is_qat: