7070
7171
7272class QuantizationTest (INCTestMixin ):
73- SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = (
73+ SUPPORTED_ARCHITECTURES_STATIC = (
74+ ("text-generation" , "gpt_neo" , 17 ),
7475 ("text-classification" , "bert" , 21 ),
75- # ("text-generation", "bloom", 21),
76+ ("text-generation" , "bloom" , 21 ),
7677 )
7778
78- SUPPORTED_ARCHITECTURES_DYNAMIC = SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS + (
79+ SUPPORTED_ARCHITECTURES_DYNAMIC = SUPPORTED_ARCHITECTURES_STATIC + (
7980 ("fill-mask" , "bert" , 22 ),
8081 ("token-classification" , "albert" , 26 ),
8182 )
@@ -88,12 +89,14 @@ class QuantizationTest(INCTestMixin):
8889 @parameterized .expand (SUPPORTED_ARCHITECTURES_DYNAMIC )
8990 def test_dynamic_quantization (self , task , model_arch , expected_quantized_matmuls ):
9091 model_name = MODEL_NAMES [model_arch ]
91- quantization_config = PostTrainingQuantConfig (approach = "dynamic" )
9292 model_class = ORT_SUPPORTED_TASKS [task ]["class" ][0 ]
9393 tokenizer = AutoTokenizer .from_pretrained (model_name )
94- save_onnx_model = False
94+
9595 quantized_model = None
96+ save_onnx_model = False
9697 model_kwargs = {"use_cache" : False , "use_io_binding" : False } if task == "text-generation" else {}
98+ quantization_config = PostTrainingQuantConfig (approach = "dynamic" )
99+
97100 with tempfile .TemporaryDirectory () as tmp_dir :
98101 for backend in ["torch" , "ort" ]:
99102 if backend == "torch" :
@@ -104,8 +107,8 @@ def test_dynamic_quantization(self, task, model_arch, expected_quantized_matmuls
104107 quantizer = INCQuantizer .from_pretrained (model , task = task )
105108 quantizer .quantize (
106109 quantization_config = quantization_config ,
107- save_directory = tmp_dir ,
108110 save_onnx_model = save_onnx_model ,
111+ save_directory = tmp_dir ,
109112 )
110113 if backend == "torch" :
111114 quantized_model = quantizer ._quantized_model
@@ -121,7 +124,7 @@ def test_dynamic_quantization(self, task, model_arch, expected_quantized_matmuls
121124 load_inc_model = True ,
122125 )
123126
124- @parameterized .expand (SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS )
127+ @parameterized .expand (SUPPORTED_ARCHITECTURES_STATIC )
125128 def test_static_quantization (self , task , model_arch , expected_quantized_matmuls ):
126129 num_samples = 10
127130 model_name = MODEL_NAMES [model_arch ]
@@ -130,28 +133,26 @@ def test_static_quantization(self, task, model_arch, expected_quantized_matmuls)
130133 if tokenizer .pad_token is None :
131134 tokenizer .pad_token = tokenizer .eos_token
132135
133- save_onnx_model = False
134- op_type_dict = (
135- {"Embedding" : {"weight" : {"dtype" : ["fp32" ]}, "activation" : {"dtype" : ["fp32" ]}}}
136- if save_onnx_model
137- else None
138- )
139- quantization_config = PostTrainingQuantConfig (approach = "static" , op_type_dict = op_type_dict )
140136 quantized_model = None
137+ save_onnx_model = False
138+ quantization_config = PostTrainingQuantConfig (approach = "static" )
139+ model_kwargs = {"use_cache" : False , "use_io_binding" : False } if task == "text-generation" else {}
141140
142141 with tempfile .TemporaryDirectory () as tmp_dir :
143142 for backend in ["torch" , "ort" ]:
144143 if backend == "torch" :
145144 model = model_class .auto_model_class .from_pretrained (model_name )
146145 else :
147- model = model_class .from_pretrained (model_name , export = True )
146+ model = model_class .from_pretrained (model_name , export = True , ** model_kwargs )
147+
148148 quantizer = INCQuantizer .from_pretrained (model , task = task )
149149 calibration_dataset = _generate_dataset (quantizer , tokenizer , num_samples = num_samples )
150+
150151 quantizer .quantize (
151152 quantization_config = quantization_config ,
152153 calibration_dataset = calibration_dataset ,
153- save_directory = tmp_dir ,
154154 save_onnx_model = save_onnx_model ,
155+ save_directory = tmp_dir ,
155156 )
156157 if backend == "torch" :
157158 quantized_model = quantizer ._quantized_model
0 commit comments