Skip to content

Commit 59a3469

Browse files
eigen-kfacebook-github-bot
authored andcommitted
Make test_fuse_mul_into_dequant use GraphBuilder. (#10925)
Summary: Use GraphBuilder to create the model for unit testing. Differential Revision: D74840805
1 parent 8d53a28 commit 59a3469

File tree

1 file changed

+23
-13
lines changed

1 file changed

+23
-13
lines changed

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
import unittest
11-
from typing import Tuple
11+
from typing import Final, List, Tuple
1212

1313
import executorch.backends.cadence.aot.ops_registrations # noqa
1414
import torch
@@ -436,18 +436,28 @@ def forward(self, x):
436436
)
437437

438438
def test_fuse_mul_into_dequant(self):
439-
class M(torch.nn.Module):
440-
def forward(self, x):
441-
x0 = torch.ops.quantized_decomposed.dequantize_per_tensor(
442-
x, 1.5, 0, 0, 255, torch.uint8
443-
)
444-
x1 = torch.full([4, 32], 3, dtype=torch.float32)
445-
x2 = x0 * x1
446-
return x2
439+
INPUT_SHAPE: Final[List[int]] = [4, 32]
440+
DEQUANT_SCALE: Final[float] = 1.5
441+
FULL_VALUE: Final[float] = 3
447442

448-
inputs = (torch.randint(0, 255, [4, 32], dtype=torch.uint8),)
449-
graph_module = export_to_edge(M(), inputs).exported_program().graph_module
450-
graph_module = FuseMulTensorIntoDequantPass()(graph_module).graph_module
443+
builder = GraphBuilder()
444+
x = builder.placeholder("x", torch.randn(*INPUT_SHAPE, dtype=torch.float32))
445+
dequant = builder.call_operator(
446+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
447+
args=(x, DEQUANT_SCALE, 0, 0, 255, torch.uint8),
448+
)
449+
full = builder.call_operator(
450+
op=exir_ops.edge.aten.full.default,
451+
args=(INPUT_SHAPE, FULL_VALUE),
452+
)
453+
mul = builder.call_operator(
454+
op=exir_ops.edge.aten.mul.Tensor,
455+
args=(dequant, full),
456+
)
457+
builder.output(mul)
458+
graph_module = FuseMulTensorIntoDequantPass()(
459+
builder.get_graph_module()
460+
).graph_module
451461

452462
# verify that the mul and full ops were removed
453463
self.check_op_counts(
@@ -466,7 +476,7 @@ def forward(self, x):
466476
== exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
467477
):
468478
deq_scale = node.args[1]
469-
self.assertEqual(deq_scale, 4.5)
479+
self.assertEqual(deq_scale, DEQUANT_SCALE * FULL_VALUE)
470480

471481
def test_fuse_mul_scalar_into_dequant(self):
472482
dequant_scale = 0.006

0 commit comments

Comments
 (0)