Skip to content

Commit 851b373

Browse files
authored
Make test_fuse_mul_into_dequant use GraphBuilder.
Differential Revision: D74840805 Pull Request resolved: #10925
1 parent 309faf8 commit 851b373

File tree

1 file changed

+40
-32
lines changed

1 file changed

+40
-32
lines changed

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
import unittest
11-
from typing import Tuple
11+
from typing import Final, List, Tuple
1212

1313
import executorch.backends.cadence.aot.ops_registrations # noqa
1414
import torch
@@ -281,25 +281,23 @@ def forward(self, x):
281281
)
282282

283283
def test_no_replace_quant_permute_dequant_with_requantize(self):
284-
class M(torch.nn.Module):
285-
def __init__(self):
286-
super().__init__()
287-
288-
def forward(self, x):
289-
x = torch.ops.quantized_decomposed.quantize_per_tensor(
290-
x, 1.2, 3, 0, 127, torch.int8
291-
)
292-
x = torch.permute(x, [2, 0, 1, 3])
293-
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
294-
x, 4.5, 6, 0, 127, torch.int8
295-
)
296-
return x
297-
298-
inputs = torch.randn(2, 12, 1, 6)
299-
model = M()
300-
graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module
301-
302-
graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
284+
builder = GraphBuilder()
285+
x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32))
286+
quant = builder.call_operator(
287+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
288+
args=(x, 1.2, 3, 0, 127, torch.int8),
289+
)
290+
permute = builder.call_operator(
291+
op=exir_ops.edge.aten.permute_copy.default, args=(quant, [2, 0, 1, 3])
292+
)
293+
dequant = builder.call_operator(
294+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
295+
args=(permute, 4.5, 6, 0, 127, torch.int8),
296+
)
297+
builder.output(dequant)
298+
graph_module = FuseQuantDequantToRequantizePass(
299+
force_quant_dequant_fusion=False
300+
)(builder.get_graph_module()).graph_module
303301
self.check_op_counts(
304302
graph_module,
305303
expected_op_counts={
@@ -436,18 +434,28 @@ def forward(self, x):
436434
)
437435

438436
def test_fuse_mul_into_dequant(self):
439-
class M(torch.nn.Module):
440-
def forward(self, x):
441-
x0 = torch.ops.quantized_decomposed.dequantize_per_tensor(
442-
x, 1.5, 0, 0, 255, torch.uint8
443-
)
444-
x1 = torch.full([4, 32], 3, dtype=torch.float32)
445-
x2 = x0 * x1
446-
return x2
437+
INPUT_SHAPE: Final[List[int]] = [4, 32]
438+
DEQUANT_SCALE: Final[float] = 1.5
439+
FULL_VALUE: Final[float] = 3
447440

448-
inputs = (torch.randint(0, 255, [4, 32], dtype=torch.uint8),)
449-
graph_module = export_to_edge(M(), inputs).exported_program().graph_module
450-
graph_module = FuseMulTensorIntoDequantPass()(graph_module).graph_module
441+
builder = GraphBuilder()
442+
x = builder.placeholder("x", torch.randn(*INPUT_SHAPE, dtype=torch.float32))
443+
dequant = builder.call_operator(
444+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
445+
args=(x, DEQUANT_SCALE, 0, 0, 255, torch.uint8),
446+
)
447+
full = builder.call_operator(
448+
op=exir_ops.edge.aten.full.default,
449+
args=(INPUT_SHAPE, FULL_VALUE),
450+
)
451+
mul = builder.call_operator(
452+
op=exir_ops.edge.aten.mul.Tensor,
453+
args=(dequant, full),
454+
)
455+
builder.output(mul)
456+
graph_module = FuseMulTensorIntoDequantPass()(
457+
builder.get_graph_module()
458+
).graph_module
451459

452460
# verify that the mul and full ops were removed
453461
self.check_op_counts(
@@ -466,7 +474,7 @@ def forward(self, x):
466474
== exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
467475
):
468476
deq_scale = node.args[1]
469-
self.assertEqual(deq_scale, 4.5)
477+
self.assertEqual(deq_scale, DEQUANT_SCALE * FULL_VALUE)
470478

471479
def test_fuse_mul_scalar_into_dequant(self):
472480
dequant_scale = 0.006

0 commit comments

Comments
 (0)