88
99
1010import  unittest 
11- from  typing  import  Tuple 
11+ from  typing  import  Final ,  List ,  Tuple 
1212
1313import  executorch .backends .cadence .aot .ops_registrations   # noqa 
1414import  torch 
@@ -281,25 +281,23 @@ def forward(self, x):
281281        )
282282
283283    def  test_no_replace_quant_permute_dequant_with_requantize (self ):
284-         class  M (torch .nn .Module ):
285-             def  __init__ (self ):
286-                 super ().__init__ ()
287- 
288-             def  forward (self , x ):
289-                 x  =  torch .ops .quantized_decomposed .quantize_per_tensor (
290-                     x , 1.2 , 3 , 0 , 127 , torch .int8 
291-                 )
292-                 x  =  torch .permute (x , [2 , 0 , 1 , 3 ])
293-                 x  =  torch .ops .quantized_decomposed .dequantize_per_tensor (
294-                     x , 4.5 , 6 , 0 , 127 , torch .int8 
295-                 )
296-                 return  x 
297- 
298-         inputs  =  torch .randn (2 , 12 , 1 , 6 )
299-         model  =  M ()
300-         graph_module  =  export_to_edge (model , (inputs ,)).exported_program ().graph_module 
301- 
302-         graph_module  =  FuseQuantDequantToRequantizePass ()(graph_module ).graph_module 
284+         builder  =  GraphBuilder ()
285+         x  =  builder .placeholder ("x" , torch .randn (2 , 12 , 1 , 6 , dtype = torch .float32 ))
286+         quant  =  builder .call_operator (
287+             op = exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
288+             args = (x , 1.2 , 3 , 0 , 127 , torch .int8 ),
289+         )
290+         permute  =  builder .call_operator (
291+             op = exir_ops .edge .aten .permute_copy .default , args = (quant , [2 , 0 , 1 , 3 ])
292+         )
293+         dequant  =  builder .call_operator (
294+             op = exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default ,
295+             args = (permute , 4.5 , 6 , 0 , 127 , torch .int8 ),
296+         )
297+         builder .output (dequant )
298+         graph_module  =  FuseQuantDequantToRequantizePass (
299+             force_quant_dequant_fusion = False 
300+         )(builder .get_graph_module ()).graph_module 
303301        self .check_op_counts (
304302            graph_module ,
305303            expected_op_counts = {
@@ -436,18 +434,28 @@ def forward(self, x):
436434        )
437435
438436    def  test_fuse_mul_into_dequant (self ):
439-         class  M (torch .nn .Module ):
440-             def  forward (self , x ):
441-                 x0  =  torch .ops .quantized_decomposed .dequantize_per_tensor (
442-                     x , 1.5 , 0 , 0 , 255 , torch .uint8 
443-                 )
444-                 x1  =  torch .full ([4 , 32 ], 3 , dtype = torch .float32 )
445-                 x2  =  x0  *  x1 
446-                 return  x2 
437+         INPUT_SHAPE : Final [List [int ]] =  [4 , 32 ]
438+         DEQUANT_SCALE : Final [float ] =  1.5 
439+         FULL_VALUE : Final [float ] =  3 
447440
448-         inputs  =  (torch .randint (0 , 255 , [4 , 32 ], dtype = torch .uint8 ),)
449-         graph_module  =  export_to_edge (M (), inputs ).exported_program ().graph_module 
450-         graph_module  =  FuseMulTensorIntoDequantPass ()(graph_module ).graph_module 
441+         builder  =  GraphBuilder ()
442+         x  =  builder .placeholder ("x" , torch .randn (* INPUT_SHAPE , dtype = torch .float32 ))
443+         dequant  =  builder .call_operator (
444+             op = exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default ,
445+             args = (x , DEQUANT_SCALE , 0 , 0 , 255 , torch .uint8 ),
446+         )
447+         full  =  builder .call_operator (
448+             op = exir_ops .edge .aten .full .default ,
449+             args = (INPUT_SHAPE , FULL_VALUE ),
450+         )
451+         mul  =  builder .call_operator (
452+             op = exir_ops .edge .aten .mul .Tensor ,
453+             args = (dequant , full ),
454+         )
455+         builder .output (mul )
456+         graph_module  =  FuseMulTensorIntoDequantPass ()(
457+             builder .get_graph_module ()
458+         ).graph_module 
451459
452460        # verify that the mul and full ops were removed 
453461        self .check_op_counts (
@@ -466,7 +474,7 @@ def forward(self, x):
466474                ==  exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default 
467475            ):
468476                deq_scale  =  node .args [1 ]
469-         self .assertEqual (deq_scale , 4.5 )
477+         self .assertEqual (deq_scale , DEQUANT_SCALE   *   FULL_VALUE )
470478
471479    def  test_fuse_mul_scalar_into_dequant (self ):
472480        dequant_scale  =  0.006 
0 commit comments