Skip to content

Commit df3c697

Browse files
eigen-kfacebook-github-bot
authored andcommitted
Make test_no_replace_quant_permute_dequant_with_requantize use GraphBuilder. (pytorch#10927)
Summary: Use GraphBuilder to create the model for unit testing. Differential Revision: D74842294
1 parent 8953279 commit df3c697

File tree

1 file changed

+18
-19
lines changed

1 file changed

+18
-19
lines changed

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -281,25 +281,24 @@ def forward(self, x):
281281
)
282282

283283
def test_no_replace_quant_permute_dequant_with_requantize(self):
284-
class M(torch.nn.Module):
285-
def __init__(self):
286-
super().__init__()
287-
288-
def forward(self, x):
289-
x = torch.ops.quantized_decomposed.quantize_per_tensor(
290-
x, 1.2, 3, 0, 127, torch.int8
291-
)
292-
x = torch.permute(x, [2, 0, 1, 3])
293-
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
294-
x, 4.5, 6, 0, 127, torch.int8
295-
)
296-
return x
297-
298-
inputs = torch.randn(2, 12, 1, 6)
299-
model = M()
300-
graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module
301-
302-
graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
284+
builder = GraphBuilder()
285+
x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32))
286+
quant = builder.call_operator(
287+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
288+
args=(x, 1.2, 3, 0, 127, torch.int8),
289+
)
290+
permute = builder.call_operator(
291+
op=exir_ops.edge.aten.permute_copy.default,
292+
args=(quant, [2, 0, 1, 3])
293+
)
294+
dequant = builder.call_operator(
295+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
296+
args=(permute, 4.5, 6, 0, 127, torch.int8),
297+
)
298+
builder.output(dequant)
299+
graph_module = FuseQuantDequantToRequantizePass(
300+
force_quant_dequant_fusion=False
301+
)(builder.get_graph_module()).graph_module
303302
self.check_op_counts(
304303
graph_module,
305304
expected_op_counts={

0 commit comments

Comments
 (0)