88
99
1010import unittest
11- from typing import Tuple
11+ from typing import Final , List , Tuple
1212
1313import executorch .backends .cadence .aot .ops_registrations # noqa
1414import torch
@@ -436,18 +436,28 @@ def forward(self, x):
436436 )
437437
438438 def test_fuse_mul_into_dequant (self ):
439- class M (torch .nn .Module ):
440- def forward (self , x ):
441- x0 = torch .ops .quantized_decomposed .dequantize_per_tensor (
442- x , 1.5 , 0 , 0 , 255 , torch .uint8
443- )
444- x1 = torch .full ([4 , 32 ], 3 , dtype = torch .float32 )
445- x2 = x0 * x1
446- return x2
439+ INPUT_SHAPE : Final [List [int ]] = [4 , 32 ]
440+ DEQUANT_SCALE : Final [float ] = 1.5
441+ FULL_VALUE : Final [float ] = 3
447442
448- inputs = (torch .randint (0 , 255 , [4 , 32 ], dtype = torch .uint8 ),)
449- graph_module = export_to_edge (M (), inputs ).exported_program ().graph_module
450- graph_module = FuseMulTensorIntoDequantPass ()(graph_module ).graph_module
443+ builder = GraphBuilder ()
444+ x = builder .placeholder ("x" , torch .randn (* INPUT_SHAPE , dtype = torch .float32 ))
445+ dequant = builder .call_operator (
446+ op = exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default ,
447+ args = (x , DEQUANT_SCALE , 0 , 0 , 255 , torch .uint8 ),
448+ )
449+ full = builder .call_operator (
450+ op = exir_ops .edge .aten .full .default ,
451+ args = (INPUT_SHAPE , FULL_VALUE ),
452+ )
453+ mul = builder .call_operator (
454+ op = exir_ops .edge .aten .mul .Tensor ,
455+ args = (dequant , full ),
456+ )
457+ builder .output (mul )
458+ graph_module = FuseMulTensorIntoDequantPass ()(
459+ builder .get_graph_module ()
460+ ).graph_module
451461
452462 # verify that the mul and full ops were removed
453463 self .check_op_counts (
@@ -466,7 +476,7 @@ def forward(self, x):
466476 == exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default
467477 ):
468478 deq_scale = node .args [1 ]
469- self .assertEqual (deq_scale , 4.5 )
479+ self .assertEqual (deq_scale , DEQUANT_SCALE * FULL_VALUE )
470480
471481 def test_fuse_mul_scalar_into_dequant (self ):
472482 dequant_scale = 0.006
0 commit comments