1919)
2020from executorch .backends .cadence .aot .fuse_ops import (
2121 FuseFullThenReshapePass ,
22- FuseMulIntoDequantPass ,
22+ FuseMulScalarIntoDequantPass ,
23+ FuseMulTensorIntoDequantPass ,
2324 FuseQuantDequantToRequantizePass ,
2425 FuseTransposeOrPermuteOpPairsPass ,
2526)
@@ -446,7 +447,7 @@ def forward(self, x):
446447
447448 inputs = (torch .randint (0 , 255 , [4 , 32 ], dtype = torch .uint8 ),)
448449 graph_module = export_to_edge (M (), inputs ).exported_program ().graph_module
449- graph_module = FuseMulIntoDequantPass ()(graph_module ).graph_module
450+ graph_module = FuseMulTensorIntoDequantPass ()(graph_module ).graph_module
450451
451452 # verify that the mul and full ops were removed
452453 self .check_op_counts (
@@ -467,6 +468,47 @@ def forward(self, x):
467468 deq_scale = node .args [1 ]
468469 self .assertEqual (deq_scale , 4.5 )
469470
471+ def test_fuse_mul_scalar_into_dequant (self ):
472+ dequant_scale = 0.006
473+ mul_value = 0.3
474+
475+ builder = GraphBuilder ()
476+ x = builder .placeholder ("x" , torch .randn (2 , 3 , 4 , dtype = torch .float32 ))
477+ quant = builder .call_operator (
478+ op = exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
479+ args = (x , 1 , 0 , - 128 , 127 , torch .int8 ),
480+ )
481+ dequant = builder .call_operator (
482+ op = exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default ,
483+ args = (quant , dequant_scale , 5 , - 128 , 127 , torch .int8 ),
484+ )
485+ mul_scalar = builder .call_operator (
486+ op = exir_ops .edge .aten .mul .Scalar ,
487+ args = (dequant , mul_value ),
488+ )
489+ builder .output (mul_scalar )
490+ graph_module = builder .get_graph_module ()
491+
492+ graph_module = FuseMulScalarIntoDequantPass ()(graph_module ).graph_module
493+
494+ # verify that the mul and full ops were removed
495+ self .check_op_counts (
496+ graph_module ,
497+ expected_op_counts = {
498+ exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default : 1 ,
499+ exir_ops .edge .aten .mul .Scalar : 0 ,
500+ },
501+ )
502+
503+ # verify that the dequant scale value was updated correctly
504+ for node in graph_module .graph .nodes :
505+ if (
506+ node .target
507+ == exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default
508+ ):
509+ deq_scale = node .args [1 ]
510+ self .assertEqual (deq_scale , dequant_scale * mul_value )
511+
470512 def test_fuse_then_transpose_pass (self ):
471513 # Create a graph with full -> transpose.
472514 builder = GraphBuilder ()
0 commit comments