diff --git a/backends/cadence/aot/fuse_ops.py b/backends/cadence/aot/fuse_ops.py index 5c7f10729cc..ef42f399943 100644 --- a/backends/cadence/aot/fuse_ops.py +++ b/backends/cadence/aot/fuse_ops.py @@ -856,19 +856,32 @@ class FuseMulTensorIntoQuantPass(ExportPass): def attempt_fusion( self, graph_module: torch.fx.GraphModule, mul_node: torch.fx.Node ) -> None: - full_nodes = [ - arg - for arg in mul_node.args - if isinstance(arg, torch.fx.Node) - and arg.target == exir_ops.edge.aten.full.default - ] + if len(mul_node.args) != 2 or len(mul_node.users) != 1: + return + + first_arg = cast(torch.fx.Node, mul_node.args[0]) + second_arg = cast(torch.fx.Node, mul_node.args[1]) + + input_node = first_arg + full_node = second_arg + if second_arg.target == exir_ops.edge.aten.full.default: + # Most common case, nothing to change. + pass + elif first_arg.target == exir_ops.edge.aten.full.default: + # Input and full nodes are swapped. + full_node = first_arg + input_node = second_arg + else: + # Full node is not found, skip. + return - if len(full_nodes) != 1 or len(mul_node.users) != 1: + # Ensure that the mul op does not do any broadcasting. + if input_node.meta["val"].shape != mul_node.meta["val"].shape: return - full_node = full_nodes[0] mul_user = list(mul_node.users.keys())[0] + # Ensure only the expected quant ops are using the current mul op. if mul_user.target not in { exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, exir_ops.edge.cadence.quantize_per_tensor.default, @@ -878,33 +891,28 @@ def attempt_fusion( quant_node = mul_user # Calculate the new scale value. - prev_scale = quant_node.args[1] - assert isinstance(prev_scale, (int, float)) + old_scale = quant_node.args[1] + assert isinstance(old_scale, (int, float)) mul_scalar = full_node.args[1] assert isinstance(mul_scalar, (int, float)) - new_scale = float(prev_scale) * float(mul_scalar) + """ The reason why we divide old scale by the mul value to get a new scale: + y = x * mul_scalar + q = zp + y / old_scale + q = zp + x * mul_scalar / old_scale + new_scale = old_scale / mul_scalar + q = zp + x / new_scale + """ + new_scale = float(old_scale) / float(mul_scalar) logging.debug( f"Fused {mul_node} and {full_node} into {quant_node}. Updated scale from {quant_node.args[1]} to {new_scale}" ) - # Replace the input first - quant_node.replace_input_with( - cast(torch.fx.Node, quant_node.args[0]), - cast(torch.fx.Node, mul_node.args[0]), - ) - - # Now update the scale in the args - new_quant_args = list(quant_node.args) - new_quant_args[1] = new_scale - quant_node.args = tuple(new_quant_args) - - # Clean up the mul_node - mul_node.args = () - mul_node.users = {} - - graph_module.graph.erase_node(mul_node) - graph_module.graph.erase_node(full_node) + # Update quant node input and scale. + old_quant_input = cast(torch.fx.Node, quant_node.args[0]) + new_quant_input = cast(torch.fx.Node, mul_node.args[0]) + quant_node.replace_input_with(old_quant_input, new_quant_input) + quant_node.update_arg(1, new_scale) def call(self, graph_module: torch.fx.GraphModule) -> PassResult: for node in graph_module.graph.find_nodes( diff --git a/backends/cadence/aot/tests/test_fusion_ops_passes.py b/backends/cadence/aot/tests/test_fusion_ops_passes.py index ead8b46f775..0f52d3a726f 100644 --- a/backends/cadence/aot/tests/test_fusion_ops_passes.py +++ b/backends/cadence/aot/tests/test_fusion_ops_passes.py @@ -598,7 +598,7 @@ def test_fuse_mul_scalar_into_dequant(self) -> None: self.assertEqual(deq_scale, dequant_scale * mul_value) def test_fuse_mul_into_quant(self) -> None: - quant_scale = 1.5 + quant_scale = 5 mul_value = 10 builder = GraphBuilder() @@ -613,7 +613,7 @@ def test_fuse_mul_into_quant(self) -> None: ) quant = builder.call_operator( op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - args=(mul, quant_scale, 0, 0, 255, torch.uint8), + args=(mul, quant_scale, 7, 0, 255, torch.uint8), ) builder.output([quant]) original_graph = builder.get_graph_module() @@ -631,14 +631,18 @@ def test_fuse_mul_into_quant(self) -> None: ) # verify that the quant scale value was updated correctly - deq_scale = -1 - for node in converted_graph.graph.nodes: - if ( - node.target - == exir_ops.edge.quantized_decomposed.quantize_per_tensor.default - ): - deq_scale = node.args[1] - self.assertEqual(deq_scale, quant_scale * mul_value) + for node in converted_graph.graph.find_nodes( + op="call_function", + target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + ): + new_quant_scale = node.args[1] + self.assertEqual(new_quant_scale, quant_scale / mul_value) + + # verify the math is correct + inp = torch.randn(4, 32, dtype=torch.float32) + original_out = original_graph(inp)[0] + new_out = converted_graph(inp)[0] + assert torch.equal(original_out, new_out) def test_fuse_then_transpose_pass(self) -> None: # Create a graph with full -> transpose.