88
99
1010import unittest
11- from typing import Tuple
11+ from typing import Final , List , Tuple
1212
1313import executorch .backends .cadence .aot .ops_registrations # noqa
1414import torch
@@ -281,25 +281,23 @@ def forward(self, x):
281281 )
282282
283283 def test_no_replace_quant_permute_dequant_with_requantize (self ):
284- class M (torch .nn .Module ):
285- def __init__ (self ):
286- super ().__init__ ()
287-
288- def forward (self , x ):
289- x = torch .ops .quantized_decomposed .quantize_per_tensor (
290- x , 1.2 , 3 , 0 , 127 , torch .int8
291- )
292- x = torch .permute (x , [2 , 0 , 1 , 3 ])
293- x = torch .ops .quantized_decomposed .dequantize_per_tensor (
294- x , 4.5 , 6 , 0 , 127 , torch .int8
295- )
296- return x
297-
298- inputs = torch .randn (2 , 12 , 1 , 6 )
299- model = M ()
300- graph_module = export_to_edge (model , (inputs ,)).exported_program ().graph_module
301-
302- graph_module = FuseQuantDequantToRequantizePass ()(graph_module ).graph_module
284+ builder = GraphBuilder ()
285+ x = builder .placeholder ("x" , torch .randn (2 , 12 , 1 , 6 , dtype = torch .float32 ))
286+ quant = builder .call_operator (
287+ op = exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
288+ args = (x , 1.2 , 3 , 0 , 127 , torch .int8 ),
289+ )
290+ permute = builder .call_operator (
291+ op = exir_ops .edge .aten .permute_copy .default , args = (quant , [2 , 0 , 1 , 3 ])
292+ )
293+ dequant = builder .call_operator (
294+ op = exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default ,
295+ args = (permute , 4.5 , 6 , 0 , 127 , torch .int8 ),
296+ )
297+ builder .output (dequant )
298+ graph_module = FuseQuantDequantToRequantizePass (
299+ force_quant_dequant_fusion = False
300+ )(builder .get_graph_module ()).graph_module
303301 self .check_op_counts (
304302 graph_module ,
305303 expected_op_counts = {
@@ -436,18 +434,28 @@ def forward(self, x):
436434 )
437435
438436 def test_fuse_mul_into_dequant (self ):
439- class M (torch .nn .Module ):
440- def forward (self , x ):
441- x0 = torch .ops .quantized_decomposed .dequantize_per_tensor (
442- x , 1.5 , 0 , 0 , 255 , torch .uint8
443- )
444- x1 = torch .full ([4 , 32 ], 3 , dtype = torch .float32 )
445- x2 = x0 * x1
446- return x2
437+ INPUT_SHAPE : Final [List [int ]] = [4 , 32 ]
438+ DEQUANT_SCALE : Final [float ] = 1.5
439+ FULL_VALUE : Final [float ] = 3
447440
448- inputs = (torch .randint (0 , 255 , [4 , 32 ], dtype = torch .uint8 ),)
449- graph_module = export_to_edge (M (), inputs ).exported_program ().graph_module
450- graph_module = FuseMulTensorIntoDequantPass ()(graph_module ).graph_module
441+ builder = GraphBuilder ()
442+ x = builder .placeholder ("x" , torch .randn (* INPUT_SHAPE , dtype = torch .float32 ))
443+ dequant = builder .call_operator (
444+ op = exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default ,
445+ args = (x , DEQUANT_SCALE , 0 , 0 , 255 , torch .uint8 ),
446+ )
447+ full = builder .call_operator (
448+ op = exir_ops .edge .aten .full .default ,
449+ args = (INPUT_SHAPE , FULL_VALUE ),
450+ )
451+ mul = builder .call_operator (
452+ op = exir_ops .edge .aten .mul .Tensor ,
453+ args = (dequant , full ),
454+ )
455+ builder .output (mul )
456+ graph_module = FuseMulTensorIntoDequantPass ()(
457+ builder .get_graph_module ()
458+ ).graph_module
451459
452460 # verify that the mul and full ops were removed
453461 self .check_op_counts (
@@ -466,7 +474,7 @@ def forward(self, x):
466474 == exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default
467475 ):
468476 deq_scale = node .args [1 ]
469- self .assertEqual (deq_scale , 4.5 )
477+ self .assertEqual (deq_scale , DEQUANT_SCALE * FULL_VALUE )
470478
471479 def test_fuse_mul_scalar_into_dequant (self ):
472480 dequant_scale = 0.006
0 commit comments