19
19
)
20
20
from executorch .backends .cadence .aot .fuse_ops import (
21
21
FuseFullThenReshapePass ,
22
- FuseMulIntoDequantPass ,
22
+ FuseMulScalarIntoDequantPass ,
23
+ FuseMulTensorIntoDequantPass ,
23
24
FuseQuantDequantToRequantizePass ,
24
25
FuseTransposeOrPermuteOpPairsPass ,
25
26
)
@@ -446,7 +447,7 @@ def forward(self, x):
446
447
447
448
inputs = (torch .randint (0 , 255 , [4 , 32 ], dtype = torch .uint8 ),)
448
449
graph_module = export_to_edge (M (), inputs ).exported_program ().graph_module
449
- graph_module = FuseMulIntoDequantPass ()(graph_module ).graph_module
450
+ graph_module = FuseMulTensorIntoDequantPass ()(graph_module ).graph_module
450
451
451
452
# verify that the mul and full ops were removed
452
453
self .check_op_counts (
@@ -467,6 +468,47 @@ def forward(self, x):
467
468
deq_scale = node .args [1 ]
468
469
self .assertEqual (deq_scale , 4.5 )
469
470
471
+ def test_fuse_mul_scalar_into_dequant (self ):
472
+ dequant_scale = 0.006
473
+ mul_value = 0.3
474
+
475
+ builder = GraphBuilder ()
476
+ x = builder .placeholder ("x" , torch .randn (2 , 3 , 4 , dtype = torch .float32 ))
477
+ quant = builder .call_operator (
478
+ op = exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
479
+ args = (x , 1 , 0 , - 128 , 127 , torch .int8 ),
480
+ )
481
+ dequant = builder .call_operator (
482
+ op = exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default ,
483
+ args = (quant , dequant_scale , 5 , - 128 , 127 , torch .int8 ),
484
+ )
485
+ mul_scalar = builder .call_operator (
486
+ op = exir_ops .edge .aten .mul .Scalar ,
487
+ args = (dequant , mul_value ),
488
+ )
489
+ builder .output (mul_scalar )
490
+ graph_module = builder .get_graph_module ()
491
+
492
+ graph_module = FuseMulScalarIntoDequantPass ()(graph_module ).graph_module
493
+
494
+ # verify that the mul and full ops were removed
495
+ self .check_op_counts (
496
+ graph_module ,
497
+ expected_op_counts = {
498
+ exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default : 1 ,
499
+ exir_ops .edge .aten .mul .Scalar : 0 ,
500
+ },
501
+ )
502
+
503
+ # verify that the dequant scale value was updated correctly
504
+ for node in graph_module .graph .nodes :
505
+ if (
506
+ node .target
507
+ == exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default
508
+ ):
509
+ deq_scale = node .args [1 ]
510
+ self .assertEqual (deq_scale , dequant_scale * mul_value )
511
+
470
512
def test_fuse_then_transpose_pass (self ):
471
513
# Create a graph with full -> transpose.
472
514
builder = GraphBuilder ()
0 commit comments