8
8
9
9
10
10
import unittest
11
- from typing import Tuple
11
+ from typing import Final , List , Tuple
12
12
13
13
import executorch .backends .cadence .aot .ops_registrations # noqa
14
14
import torch
@@ -281,25 +281,23 @@ def forward(self, x):
281
281
)
282
282
283
283
def test_no_replace_quant_permute_dequant_with_requantize (self ):
284
- class M (torch .nn .Module ):
285
- def __init__ (self ):
286
- super ().__init__ ()
287
-
288
- def forward (self , x ):
289
- x = torch .ops .quantized_decomposed .quantize_per_tensor (
290
- x , 1.2 , 3 , 0 , 127 , torch .int8
291
- )
292
- x = torch .permute (x , [2 , 0 , 1 , 3 ])
293
- x = torch .ops .quantized_decomposed .dequantize_per_tensor (
294
- x , 4.5 , 6 , 0 , 127 , torch .int8
295
- )
296
- return x
297
-
298
- inputs = torch .randn (2 , 12 , 1 , 6 )
299
- model = M ()
300
- graph_module = export_to_edge (model , (inputs ,)).exported_program ().graph_module
301
-
302
- graph_module = FuseQuantDequantToRequantizePass ()(graph_module ).graph_module
284
+ builder = GraphBuilder ()
285
+ x = builder .placeholder ("x" , torch .randn (2 , 12 , 1 , 6 , dtype = torch .float32 ))
286
+ quant = builder .call_operator (
287
+ op = exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
288
+ args = (x , 1.2 , 3 , 0 , 127 , torch .int8 ),
289
+ )
290
+ permute = builder .call_operator (
291
+ op = exir_ops .edge .aten .permute_copy .default , args = (quant , [2 , 0 , 1 , 3 ])
292
+ )
293
+ dequant = builder .call_operator (
294
+ op = exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default ,
295
+ args = (permute , 4.5 , 6 , 0 , 127 , torch .int8 ),
296
+ )
297
+ builder .output (dequant )
298
+ graph_module = FuseQuantDequantToRequantizePass (
299
+ force_quant_dequant_fusion = False
300
+ )(builder .get_graph_module ()).graph_module
303
301
self .check_op_counts (
304
302
graph_module ,
305
303
expected_op_counts = {
@@ -436,18 +434,28 @@ def forward(self, x):
436
434
)
437
435
438
436
def test_fuse_mul_into_dequant (self ):
439
- class M (torch .nn .Module ):
440
- def forward (self , x ):
441
- x0 = torch .ops .quantized_decomposed .dequantize_per_tensor (
442
- x , 1.5 , 0 , 0 , 255 , torch .uint8
443
- )
444
- x1 = torch .full ([4 , 32 ], 3 , dtype = torch .float32 )
445
- x2 = x0 * x1
446
- return x2
437
+ INPUT_SHAPE : Final [List [int ]] = [4 , 32 ]
438
+ DEQUANT_SCALE : Final [float ] = 1.5
439
+ FULL_VALUE : Final [float ] = 3
447
440
448
- inputs = (torch .randint (0 , 255 , [4 , 32 ], dtype = torch .uint8 ),)
449
- graph_module = export_to_edge (M (), inputs ).exported_program ().graph_module
450
- graph_module = FuseMulTensorIntoDequantPass ()(graph_module ).graph_module
441
+ builder = GraphBuilder ()
442
+ x = builder .placeholder ("x" , torch .randn (* INPUT_SHAPE , dtype = torch .float32 ))
443
+ dequant = builder .call_operator (
444
+ op = exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default ,
445
+ args = (x , DEQUANT_SCALE , 0 , 0 , 255 , torch .uint8 ),
446
+ )
447
+ full = builder .call_operator (
448
+ op = exir_ops .edge .aten .full .default ,
449
+ args = (INPUT_SHAPE , FULL_VALUE ),
450
+ )
451
+ mul = builder .call_operator (
452
+ op = exir_ops .edge .aten .mul .Tensor ,
453
+ args = (dequant , full ),
454
+ )
455
+ builder .output (mul )
456
+ graph_module = FuseMulTensorIntoDequantPass ()(
457
+ builder .get_graph_module ()
458
+ ).graph_module
451
459
452
460
# verify that the mul and full ops were removed
453
461
self .check_op_counts (
@@ -466,7 +474,7 @@ def forward(self, x):
466
474
== exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default
467
475
):
468
476
deq_scale = node .args [1 ]
469
- self .assertEqual (deq_scale , 4.5 )
477
+ self .assertEqual (deq_scale , DEQUANT_SCALE * FULL_VALUE )
470
478
471
479
def test_fuse_mul_scalar_into_dequant (self ):
472
480
dequant_scale = 0.006
0 commit comments