@@ -248,32 +248,28 @@ def forward(self, x):
248248 )
249249
250250 def test_force_quant_dequant_fusion (self ):
251- class M (torch .nn .Module ):
252- def __init__ (self ):
253- super ().__init__ ()
254-
255- def forward (self , x ):
256- x = torch .ops .quantized_decomposed .quantize_per_tensor (
257- x , 1.2 , 3 , 0 , 127 , torch .int8
258- )
259- x = torch .permute (x , [2 , 0 , 1 , 3 ])
260- x = torch .ops .quantized_decomposed .dequantize_per_tensor (
261- x , 4.5 , 6 , 0 , 127 , torch .int8
262- )
263- return x
264-
265- inputs = torch .randn (2 , 12 , 1 , 6 )
266- model = M ()
267- graph_module = export_to_edge (model , (inputs ,)).exported_program ().graph_module
268-
269- graph_module = FuseQuantDequantToRequantizePass (
251+ builder = GraphBuilder ()
252+ x = builder .placeholder ("x" , torch .randn (2 , 12 , 1 , 6 , dtype = torch .float32 ))
253+ quant = builder .call_operator (
254+ op = exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
255+ args = (x , 1.2 , 3 , 0 , 127 , torch .int8 ),
256+ )
257+ permute = builder .call_operator (
258+ op = exir_ops .edge .aten .permute_copy .default , args = (quant , [2 , 0 , 1 , 3 ])
259+ )
260+ dequant = builder .call_operator (
261+ op = exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default ,
262+ args = (permute , 4.5 , 6 , 0 , 127 , torch .int8 ),
263+ )
264+ builder .output (dequant )
265+ original_graph = builder .get_graph_module ()
266+ converted_graph = FuseQuantDequantToRequantizePass (
270267 force_quant_dequant_fusion = True
271- )(graph_module ).graph_module
268+ )(original_graph ).graph_module
272269 self .check_op_counts (
273- graph_module ,
270+ converted_graph ,
274271 expected_op_counts = {
275- # Verify that no dequant/quant pair was replaced with requantize.
276- # quantize -> permute -> dequantize should not be replaced with requantize.
272+ # Verify that dequant/quant pair was replaced with requantize.
277273 exir_ops .edge .quantized_decomposed .quantize_per_tensor .default : 0 ,
278274 exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default : 0 ,
279275 exir_ops .edge .cadence .requantize .default : 1 ,
@@ -341,24 +337,24 @@ def forward(self, x):
341337 )
342338
343339 def test_replace_dequant_quant_with_requantize (self ):
344- class M (torch .nn .Module ):
345- def __init__ (self ):
346- super ().__init__ ()
340+ builder = GraphBuilder ()
341+ x = builder .placeholder ("x" , torch .randn (2 , 12 , 1 , 6 , dtype = torch .float32 ))
347342
348- def forward (self , x ):
349- x = torch .ops .quantized_decomposed .dequantize_per_tensor (
350- x , 1.2 , 3 , 0 , 127 , torch .int8
351- )
352- x = torch .permute (x , [2 , 0 , 1 , 3 ])
353- x = torch .ops .quantized_decomposed .quantize_per_tensor (
354- x , 4.5 , 6 , 0 , 127 , torch .int8
355- )
356- return x
343+ dequant = builder .call_operator (
344+ op = exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default ,
345+ args = (x , 1.2 , 3 , 0 , 127 , torch .int8 ),
346+ )
347+ permute = builder .call_operator (
348+ op = exir_ops .edge .aten .permute_copy .default ,
349+ args = (dequant , [2 , 0 , 1 , 3 ])
350+ )
357351
358- inputs = torch .randn (2 , 12 , 1 , 6 ).to (torch .int8 )
359- model = M ()
360- graph_module = export_to_edge (model , (inputs ,)).exported_program ().graph_module
361- graph_module = FuseQuantDequantToRequantizePass ()(graph_module ).graph_module
352+ quant = builder .call_operator (
353+ op = exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
354+ args = (permute , 4.5 , 6 , 0 , 127 , torch .int8 ),
355+ )
356+ builder .output (quant )
357+ graph_module = FuseQuantDequantToRequantizePass ()(builder .get_graph_module ()).graph_module
362358
363359 self .check_op_counts (
364360 graph_module ,
0 commit comments