@@ -281,25 +281,24 @@ def forward(self, x):
281281 )
282282
283283 def test_no_replace_quant_permute_dequant_with_requantize (self ):
284- class M (torch .nn .Module ):
285- def __init__ (self ):
286- super ().__init__ ()
287-
288- def forward (self , x ):
289- x = torch .ops .quantized_decomposed .quantize_per_tensor (
290- x , 1.2 , 3 , 0 , 127 , torch .int8
291- )
292- x = torch .permute (x , [2 , 0 , 1 , 3 ])
293- x = torch .ops .quantized_decomposed .dequantize_per_tensor (
294- x , 4.5 , 6 , 0 , 127 , torch .int8
295- )
296- return x
297-
298- inputs = torch .randn (2 , 12 , 1 , 6 )
299- model = M ()
300- graph_module = export_to_edge (model , (inputs ,)).exported_program ().graph_module
301-
302- graph_module = FuseQuantDequantToRequantizePass ()(graph_module ).graph_module
284+ builder = GraphBuilder ()
285+ x = builder .placeholder ("x" , torch .randn (2 , 12 , 1 , 6 , dtype = torch .float32 ))
286+ quant = builder .call_operator (
287+ op = exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
288+ args = (x , 1.2 , 3 , 0 , 127 , torch .int8 ),
289+ )
290+ permute = builder .call_operator (
291+ op = exir_ops .edge .aten .permute_copy .default ,
292+ args = (quant , [2 , 0 , 1 , 3 ])
293+ )
294+ dequant = builder .call_operator (
295+ op = exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default ,
296+ args = (permute , 4.5 , 6 , 0 , 127 , torch .int8 ),
297+ )
298+ builder .output (dequant )
299+ graph_module = FuseQuantDequantToRequantizePass (
300+ force_quant_dequant_fusion = False
301+ )(builder .get_graph_module ()).graph_module
303302 self .check_op_counts (
304303 graph_module ,
305304 expected_op_counts = {
0 commit comments