From 79947191b316a3c5740927006180cc00942642fe Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Fri, 28 Feb 2025 08:50:28 +0100 Subject: [PATCH] Make passes preserve and update node metadata When creating or updating nodes in passes, the metadata is not preserved nor updated correctly. This patch adds an ArmPass base class which may update the node metadata if super().call_operator(update=True) is used. It also adds functionality to arm_pass_utils.create_node() to update the node metadata. It will only update the 'stack_trace' field. All the other fields will be preserved from the original node. Signed-off-by: Oscar Andersson Change-Id: I725dd057716ae5a1fac0f97b522df22196f00bdb --- backends/arm/_passes/__init__.py | 1 + backends/arm/_passes/arm_pass.py | 33 ++++++++++++++++ backends/arm/_passes/arm_pass_utils.py | 18 ++++++++- .../arm/_passes/decompose_layernorm_pass.py | 38 +++++++++++++++---- .../arm/_passes/decompose_meandim_pass.py | 12 +++--- .../decompose_softmax_unstable_pass.py | 14 +++---- backends/arm/_passes/decompose_var_pass.py | 19 ++++++---- backends/arm/_passes/mm_to_bmm_pass.py | 8 ++-- backends/arm/tosa_utils.py | 15 ++++++-- 9 files changed, 120 insertions(+), 38 deletions(-) create mode 100644 backends/arm/_passes/arm_pass.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 1142f5565c0..2d7bf722e16 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -7,6 +7,7 @@ from . import arm_pass_utils # noqa from .annotate_channels_last_dim_order_pass import AnnotateChannelsLastDimOrder # noqa from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa +from .arm_pass import ArmPass # noqa from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa from .cast_to_int32_pass import CastToInt32Pass # noqa from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa diff --git a/backends/arm/_passes/arm_pass.py b/backends/arm/_passes/arm_pass.py new file mode 100644 index 00000000000..085267a174e --- /dev/null +++ b/backends/arm/_passes/arm_pass.py @@ -0,0 +1,33 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import traceback +from typing import Optional + +import torch +from executorch.exir.pass_base import ExportPass, NodeMetadata + + +class ArmPass(ExportPass): + """Base class for Arm passes""" + + def __init__(self, exported_program: Optional[torch.export.ExportedProgram] = None): + super(ArmPass, self).__init__() + self.exported_program = exported_program + + def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False): + if not updated: + return super().call_operator(op, args, kwargs, meta) + + # if updated we should update metadata + new_meta = {} + keys = meta.data.keys() + for key in keys: + new_meta[key] = meta[key] + old_stack_trace = new_meta.get("stack_trace", "") + new_meta["stack_trace"] = f"{old_stack_trace}\n{traceback.format_stack()[-2]}" + return super().call_operator(op, args, kwargs, NodeMetadata(new_meta)) diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index dba8f557085..afb2d82a2fc 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -7,12 +7,12 @@ # pyre-unsafe +import traceback from inspect import isclass from typing import Optional, Sequence import torch import torch.fx - from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops @@ -96,6 +96,7 @@ def create_node( kwargs: Optional[dict] = None, quantize: bool = False, q_params: Optional[tuple] = None, + from_node: Optional[torch.fx.Node] = None, ): """ Adds a node to 'graph'. graph.inserting_before/after() should be used before the call to decide where to insert the node. @@ -108,8 +109,18 @@ def create_node( args=args, kwargs=kwargs or {}, ) + + new_meta = {} + if from_node: + keys = from_node.meta.keys() + for key in keys: + new_meta[key] = from_node.meta[key] + old_stack_trace = new_meta.get("stack_trace", "") + new_meta["stack_trace"] = f"{old_stack_trace}\n{traceback.format_stack()[-2]}" + node.meta = new_meta + if quantize and q_params: - return insert_q_dq_pair(graph, node, q_params) + return insert_q_dq_pair(graph, node, q_params, from_node) return node @@ -117,6 +128,7 @@ def insert_q_dq_pair( graph: torch.fx.Graph, anchor: torch.fx.Node, q_params: tuple, + from_node: Optional[torch.fx.Node] = None, ): """ Inserts a q dq node pair after the node 'anchor'. @@ -127,6 +139,7 @@ def insert_q_dq_pair( graph=graph, op_target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, args=(), # We add the argument last + from_node=from_node if from_node else anchor, ) q.meta = anchor.meta with graph.inserting_after(q): @@ -134,6 +147,7 @@ def insert_q_dq_pair( graph=graph, op_target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, args=(q,) + q_params, + from_node=from_node if from_node else anchor, ) dq.meta = q.meta anchor.replace_all_uses_with(dq) diff --git a/backends/arm/_passes/decompose_layernorm_pass.py b/backends/arm/_passes/decompose_layernorm_pass.py index cc4a81caae0..a92434faa7d 100644 --- a/backends/arm/_passes/decompose_layernorm_pass.py +++ b/backends/arm/_passes/decompose_layernorm_pass.py @@ -9,9 +9,10 @@ import operator import torch +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.pass_base import PassResult def get_layer_norm_decomposition(op) -> tuple: @@ -40,7 +41,7 @@ def get_layer_norm_decomposition(op) -> tuple: raise RuntimeError(f"Can't get layer_norm composition for op {op}") -class DecomposeLayerNormPass(ExportPass): +class DecomposeLayerNormPass(ArmPass): """ layernorm is defined as: ((x - E[x]) / sqrt(Var[x] + eps)) * weights + bias Decompose layernorm(x, normalized_shape, weights, bias, eps) to a sequence of: @@ -111,24 +112,39 @@ def call(self, graph_module: torch.fx.GraphModule): var_op, args=(x, dims), kwargs={"correction": 0, "keepdim": keepdim}, + from_node=node, ) full = create_node( graph_module.graph, full_op, args=(epsilon_reshaped_shape, epsilon), kwargs={"dtype": dtype}, + from_node=node, + ) + add0 = create_node( + graph_module.graph, add_op, args=(var, full), from_node=node + ) + rsqrt = create_node( + graph_module.graph, rsqrt_op, args=(add0,), from_node=node + ) + mul0 = create_node( + graph_module.graph, mul_op, args=(sub, rsqrt), from_node=node ) - add0 = create_node(graph_module.graph, add_op, args=(var, full)) - rsqrt = create_node(graph_module.graph, rsqrt_op, args=(add0,)) - mul0 = create_node(graph_module.graph, mul_op, args=(sub, rsqrt)) if weights is not None: weights_reshaped = create_node( graph_module.graph, view_op, args=(weights, weights_reshaped_shape), + from_node=node, ) mul1 = create_node( - graph_module.graph, mul_op, args=(mul0, weights_reshaped) + graph_module.graph, + mul_op, + args=( + mul0, + weights_reshaped, + ), + from_node=node, ) else: mul1 = mul0 @@ -136,10 +152,16 @@ def call(self, graph_module: torch.fx.GraphModule): if bias is not None: bias_reshaped_shape = weights_reshaped_shape bias_reshaped = create_node( - graph_module.graph, view_op, args=(bias, bias_reshaped_shape) + graph_module.graph, + view_op, + args=(bias, bias_reshaped_shape), + from_node=node, ) output = create_node( - graph_module.graph, add_op, args=(mul1, bias_reshaped) + graph_module.graph, + add_op, + args=(mul1, bias_reshaped), + from_node=node, ) users = [user for user in node.users if node != user] diff --git a/backends/arm/_passes/decompose_meandim_pass.py b/backends/arm/_passes/decompose_meandim_pass.py index abf5c8f363d..6af6caf0c3f 100644 --- a/backends/arm/_passes/decompose_meandim_pass.py +++ b/backends/arm/_passes/decompose_meandim_pass.py @@ -1,4 +1,4 @@ -# Copyright 2024 Arm Limited and/or its affiliates. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the @@ -7,9 +7,9 @@ # pyre-unsafe import torch +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import get_node_arg from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass def get_meandim_decomposition(op) -> tuple: @@ -28,7 +28,7 @@ def get_meandim_decomposition(op) -> tuple: raise RuntimeError(f"Can't get meandim decomposition for op {op}") -class DecomposeMeanDimPass(ExportPass): +class DecomposeMeanDimPass(ArmPass): """ This pass decomposes meandim into a sum and mul node. @@ -62,8 +62,8 @@ def call_operator(self, op, args, kwargs, meta): sum_op, full_op, mul_op = get_meandim_decomposition(op) - sum = super().call_operator(sum_op, (x, dim, keepdim), {}, meta) + sum = super().call_operator(sum_op, (x, dim, keepdim), {}, meta, True) full = super().call_operator( - full_op, ([1] * len(shape), 1 / N), {"dtype": dtype}, meta + full_op, ([1] * len(shape), 1 / N), {"dtype": dtype}, meta, True ) - return super().call_operator(mul_op, (sum, full), {}, meta) + return super().call_operator(mul_op, (sum, full), {}, meta, True) diff --git a/backends/arm/_passes/decompose_softmax_unstable_pass.py b/backends/arm/_passes/decompose_softmax_unstable_pass.py index 4a2ce712ab7..b6f5e11b66b 100644 --- a/backends/arm/_passes/decompose_softmax_unstable_pass.py +++ b/backends/arm/_passes/decompose_softmax_unstable_pass.py @@ -6,8 +6,8 @@ # pyre-unsafe import torch +from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass # For BI case torch_softmax = (torch.ops.aten.softmax.int, torch.ops.aten.log_softmax.int) @@ -45,7 +45,7 @@ def get_logsoftmax_ops(op) -> tuple: raise RuntimeError(f"Can't get softmax decomposition ops for op {op}") -class DecomposeSoftmaxUnstablePass(ExportPass): +class DecomposeSoftmaxUnstablePass(ArmPass): """ This pass decomposes log softmax or softmax into more primitive ops. @@ -66,10 +66,10 @@ def call_operator(self, op, args, kwargs, meta): _input = args[0] dim = [args[1]] - op1 = super().call_operator(exp_op, (_input,), {}, meta) - op2 = super().call_operator(sum_op, (op1, dim, True), {}, meta) - op3 = super().call_operator(reciprocal_op, (op2,), {}, meta) - op4 = super().call_operator(mul_op, (op1, op3), {}, meta) + op1 = super().call_operator(exp_op, (_input,), {}, meta, True) + op2 = super().call_operator(sum_op, (op1, dim, True), {}, meta, True) + op3 = super().call_operator(reciprocal_op, (op2,), {}, meta, True) + op4 = super().call_operator(mul_op, (op1, op3), {}, meta, True) if op in log_softmax: - op4 = super().call_operator(log_op, (op4,), {}, meta) + op4 = super().call_operator(log_op, (op4,), {}, meta, True) return op4 diff --git a/backends/arm/_passes/decompose_var_pass.py b/backends/arm/_passes/decompose_var_pass.py index 73747d8313d..15872738f3e 100644 --- a/backends/arm/_passes/decompose_var_pass.py +++ b/backends/arm/_passes/decompose_var_pass.py @@ -1,4 +1,4 @@ -# Copyright 2024 Arm Limited and/or its affiliates. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the @@ -8,9 +8,9 @@ import torch +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import get_node_arg from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass def get_var_decomposition(op) -> tuple: @@ -33,7 +33,7 @@ def get_var_decomposition(op) -> tuple: raise RuntimeError(f"Can't get var decomposition for op {op}") -class DecomposeVarPass(ExportPass): +class DecomposeVarPass(ArmPass): """ This pass decomposes var.correction and var.dim into smaller ops (see https://pytorch.org/docs/stable/generated/torch.var.html) @@ -77,14 +77,17 @@ def call_operator(self, op, args, kwargs, meta): N *= input_shape[d] mean_op, diff_op, mul_op, sum_op, full_op = get_var_decomposition(op) - mean = super().call_operator(mean_op, (x, dim, True), {}, meta) - diff = super().call_operator(diff_op, (x, mean), {}, meta) - squared_diff = super().call_operator(mul_op, (diff, diff), {}, meta) - sum = super().call_operator(sum_op, (squared_diff, dim, keepdim), {}, meta) + mean = super().call_operator(mean_op, (x, dim, True), {}, meta, True) + diff = super().call_operator(diff_op, (x, mean), {}, meta, True) + squared_diff = super().call_operator(mul_op, (diff, diff), {}, meta, True) + sum = super().call_operator( + sum_op, (squared_diff, dim, keepdim), {}, meta, True + ) full = super().call_operator( full_op, ([], 1 / max(0, N - correction)), {"dtype": dtype}, meta, + True, ) - return super().call_operator(mul_op, (sum, full), {}, meta) + return super().call_operator(mul_op, (sum, full), {}, meta, True) diff --git a/backends/arm/_passes/mm_to_bmm_pass.py b/backends/arm/_passes/mm_to_bmm_pass.py index 602d4a007a6..34ac7553212 100644 --- a/backends/arm/_passes/mm_to_bmm_pass.py +++ b/backends/arm/_passes/mm_to_bmm_pass.py @@ -47,7 +47,7 @@ def call(self, graph_module: torch.fx.GraphModule): with graph.inserting_before(node): unsqueeze_before = create_node( - graph, exir_ops.edge.aten.unsqueeze_copy.default + graph, exir_ops.edge.aten.unsqueeze_copy.default, from_node=node ) unsqueeze_before.args = ( input_node, # Input is node's original input @@ -58,13 +58,14 @@ def call(self, graph_module: torch.fx.GraphModule): # If Quantized we must insert unsqueeze --> q --> dq --> node if input_node.target == dq_op: q_params = input_node.args[1:] - insert_q_dq_pair(graph, unsqueeze_before, q_params) + insert_q_dq_pair(graph, unsqueeze_before, q_params, from_node=node) # Replace mm node with bmm with graph.inserting_before(node): bmm_node = create_node( graph, exir_ops.edge.aten.bmm.default, + from_node=node, ) bmm_node.args = node.args node.replace_all_uses_with(bmm_node) @@ -75,6 +76,7 @@ def call(self, graph_module: torch.fx.GraphModule): squeeze_after = create_node( graph, exir_ops.edge.aten.squeeze_copy.dims, + from_node=node, ) squeeze_after.args = ( bmm_node, @@ -89,7 +91,7 @@ def call(self, graph_module: torch.fx.GraphModule): # If quantized, insert mm --> q --> dq --> squeeze if all(original_user.target == q_op for original_user in original_users): q_params = original_users[0].args[1:] - insert_q_dq_pair(graph, bmm_node, q_params) + insert_q_dq_pair(graph, bmm_node, q_params, from_node=node) modified_graph = True diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index 556e30e2b7f..f112d1c1e9d 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -42,10 +42,17 @@ def get_node_debug_info(node: torch.fx.Node) -> str: " Node.meta = \n" ) for k, v in node.meta.items(): - output += f" '{k}' = {v}\n" - if isinstance(v, list): - for i in v: - output += f" {i}\n" + if k == "stack_trace": + matches = v.split("\n") + output += " 'stack_trace =\n" + for m in matches: + output += f" {m}\n" + else: + output += f" '{k}' = {v}\n" + + if isinstance(v, list): + for i in v: + output += f" {i}\n" return output