@@ -1133,62 +1133,6 @@ def embedding_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
11331133 baseline_out = embedding_forward_4w (x2 , fq_embedding .weight )
11341134 torch .testing .assert_close (baseline_out , fq_out , atol = 0 , rtol = 0 )
11351135
1136- @unittest .skipIf (
1137- not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower"
1138- )
1139- def test_qat_prototype_bc (self ):
1140- """
1141- Just to make sure we can import all the old prototype paths.
1142- We will remove this test in the near future when we actually break BC.
1143- """
1144- from torchao .quantization .prototype .qat import ( # noqa: F401, F811, I001
1145- disable_4w_fake_quant ,
1146- disable_8da4w_fake_quant ,
1147- enable_4w_fake_quant ,
1148- enable_8da4w_fake_quant ,
1149- ComposableQATQuantizer ,
1150- Int8DynActInt4WeightQATLinear ,
1151- Int4WeightOnlyEmbeddingQATQuantizer ,
1152- Int4WeightOnlyQATQuantizer ,
1153- Int8DynActInt4WeightQATQuantizer ,
1154- )
1155- from torchao .quantization .prototype .qat ._module_swap_api import ( # noqa: F401, F811
1156- disable_4w_fake_quant_module_swap ,
1157- enable_4w_fake_quant_module_swap ,
1158- disable_8da4w_fake_quant_module_swap ,
1159- enable_8da4w_fake_quant_module_swap ,
1160- Int4WeightOnlyQATQuantizerModuleSwap ,
1161- Int8DynActInt4WeightQATQuantizerModuleSwap ,
1162- )
1163- from torchao .quantization .prototype .qat .affine_fake_quantized_tensor import ( # noqa: F401, F811
1164- AffineFakeQuantizedTensor ,
1165- to_affine_fake_quantized ,
1166- )
1167- from torchao .quantization .prototype .qat .api import ( # noqa: F401, F811
1168- ComposableQATQuantizer ,
1169- FakeQuantizeConfig ,
1170- )
1171- from torchao .quantization .prototype .qat .embedding import ( # noqa: F401, F811
1172- FakeQuantizedEmbedding ,
1173- Int4WeightOnlyEmbeddingQATQuantizer ,
1174- Int4WeightOnlyEmbedding ,
1175- Int4WeightOnlyQATEmbedding ,
1176- )
1177- from torchao .quantization .prototype .qat .fake_quantizer import ( # noqa: F401, F811
1178- FakeQuantizer ,
1179- )
1180- from torchao .quantization .prototype .qat .linear import ( # noqa: F401, F811
1181- disable_4w_fake_quant ,
1182- disable_8da4w_fake_quant ,
1183- enable_4w_fake_quant ,
1184- enable_8da4w_fake_quant ,
1185- FakeQuantizedLinear ,
1186- Int4WeightOnlyQATLinear ,
1187- Int4WeightOnlyQATQuantizer ,
1188- Int8DynActInt4WeightQATLinear ,
1189- Int8DynActInt4WeightQATQuantizer ,
1190- )
1191-
11921136 @unittest .skipIf (
11931137 not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower"
11941138 )
0 commit comments