Skip to content

Commit 04005b4

Browse files
eigen-kfacebook-github-bot
authored andcommitted
Make test_replace_quant_view_dequant_with_requantize use GraphBuilder. (pytorch#10928)
Summary: Use GraphBuilder to create the model for unit testing. Reviewed By: zonglinpeng, mcremon-meta Differential Revision: D74843628
1 parent de72d65 commit 04005b4

File tree

1 file changed

+18
-20
lines changed

1 file changed

+18
-20
lines changed

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -310,30 +310,28 @@ def test_no_replace_quant_permute_dequant_with_requantize(self):
310310
)
311311

312312
def test_replace_quant_view_dequant_with_requantize(self):
313-
class M(torch.nn.Module):
314-
def __init__(self):
315-
super().__init__()
316-
317-
def forward(self, x):
318-
x = torch.ops.quantized_decomposed.quantize_per_tensor(
319-
x, 1.2, 3, 0, 127, torch.int8
320-
)
321-
x = x.view(-1)
322-
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
323-
x, 4.5, 6, 0, 127, torch.int8
324-
)
325-
return x
326-
327-
inputs = torch.randn(2, 12, 1, 6)
328-
model = M()
329-
graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module
330-
graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
313+
builder = GraphBuilder()
314+
x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32))
315+
quant = builder.call_operator(
316+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
317+
args=(x, 1.2, 3, 0, 127, torch.int8),
318+
)
319+
view = builder.call_operator(
320+
op=exir_ops.edge.aten.view_copy.default, args=(quant, [-1])
321+
)
322+
dequant = builder.call_operator(
323+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
324+
args=(view, 4.5, 6, 0, 127, torch.int8),
325+
)
326+
builder.output(dequant)
327+
graph_module = FuseQuantDequantToRequantizePass()(
328+
builder.get_graph_module()
329+
).graph_module
331330

332331
self.check_op_counts(
333332
graph_module,
334333
expected_op_counts={
335-
# Verify that no dequant/quant pair was replaced with requantize.
336-
# quantize -> permute -> dequantize should not be replaced with requantize.
334+
# Verify that dequant/quant pair was replaced with requantize.
337335
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
338336
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
339337
exir_ops.edge.cadence.requantize.default: 1,

0 commit comments

Comments
 (0)