diff --git a/backends/cadence/aot/tests/test_fusion_ops_passes.py b/backends/cadence/aot/tests/test_fusion_ops_passes.py index 4e7c37a1635..8f888c4c8bf 100644 --- a/backends/cadence/aot/tests/test_fusion_ops_passes.py +++ b/backends/cadence/aot/tests/test_fusion_ops_passes.py @@ -248,32 +248,28 @@ def forward(self, x): ) def test_force_quant_dequant_fusion(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - x = torch.ops.quantized_decomposed.quantize_per_tensor( - x, 1.2, 3, 0, 127, torch.int8 - ) - x = torch.permute(x, [2, 0, 1, 3]) - x = torch.ops.quantized_decomposed.dequantize_per_tensor( - x, 4.5, 6, 0, 127, torch.int8 - ) - return x - - inputs = torch.randn(2, 12, 1, 6) - model = M() - graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module - - graph_module = FuseQuantDequantToRequantizePass( + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32)) + quant = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + args=(x, 1.2, 3, 0, 127, torch.int8), + ) + permute = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(quant, [2, 0, 1, 3]) + ) + dequant = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + args=(permute, 4.5, 6, 0, 127, torch.int8), + ) + builder.output(dequant) + original_graph = builder.get_graph_module() + converted_graph = FuseQuantDequantToRequantizePass( force_quant_dequant_fusion=True - )(graph_module).graph_module + )(original_graph).graph_module self.check_op_counts( - graph_module, + converted_graph, expected_op_counts={ - # Verify that no dequant/quant pair was replaced with requantize. - # quantize -> permute -> dequantize should not be replaced with requantize. + # Verify that dequant/quant pair was replaced with requantize. exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0, exir_ops.edge.cadence.requantize.default: 1,