From f5e888bd03d405872c7abb0cb6822ae63c130c18 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Wed, 22 Jan 2025 11:40:00 +0100 Subject: [PATCH 1/2] Use default overloads when calling custom ops If a node is created without specifying an overload, A OpOverloadPacket is created, rather than an OpOverload. This works in a GraphModule, but the OpOverloadPacket is not a valid operator type in the _EXIREdgeDialectVerifier, which means that Edge ExportedPrograms can't contain a GraphModule with such ops. In short, specifying using the default overload seems to be the more correct way of creating a custom operator. Signed-off-by: Erik Lundell Change-Id: I3a1733c0ae88826d88b1e820eaacff765df7fbd2 --- .../_passes/annotate_channels_last_dim_order_pass.py | 4 ++-- backends/arm/_passes/insert_table_ops.py | 10 +++++++--- backends/arm/operators/op_table.py | 4 ++-- backends/arm/operators/op_transpose.py | 4 ++-- 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py index a3d168fb870..42c9f9e492d 100644 --- a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py +++ b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py @@ -116,7 +116,7 @@ def insert_input_transpose(node, input_node, graph_module): with graph_module.graph.inserting_before(node): permute_node = create_node( graph_module.graph, - torch.ops.passthrough_to_tosa._transpose, + torch.ops.passthrough_to_tosa._transpose.default, args=( input_node, list(AnnotateChannelsLastDimOrder.NHWC_inverse_order), @@ -135,7 +135,7 @@ def insert_output_transpose(node, graph_module): with graph_module.graph.inserting_after(node): permute_node = create_node( graph_module.graph, - torch.ops.passthrough_to_tosa._transpose, + torch.ops.passthrough_to_tosa._transpose.default, args=(node, list(AnnotateChannelsLastDimOrder.NHWC_order)), ) permute_node.meta["tosa_dim_order"] = ( diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index 57a8376d40f..314bda4ddcb 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -1,4 +1,4 @@ -# Copyright 2024 Arm Limited and/or its affiliates. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the @@ -92,7 +92,7 @@ def call(self, graph_module: GraphModule) -> PassResult: with graph_module.graph.inserting_before(node): table_node = create_node( graph=graph_module.graph, - op_target=torch.ops.tosa._table, + op_target=torch.ops.tosa._table.default, args=(node.args[0],), ) assert len(input_qparams) == 1 @@ -104,7 +104,11 @@ def call(self, graph_module: GraphModule) -> PassResult: out_quantargs=output_qparams[0], ) # Register buffer in self.exported_program.state_dict - self.register_buffer(buffer_name=table_node.name, buffer=buffer) + # When the graph is retraced, the implementation _table is used and the suffix _default disappears from the node name + # Remove it here to make it possible to find in the node_visitor + self.register_buffer( + buffer_name=table_node.name.replace("_default", ""), buffer=buffer + ) node.replace_all_uses_with(table_node) graph_module.graph.erase_node(node) table_node.meta["input_qparams"] = input_qparams diff --git a/backends/arm/operators/op_table.py b/backends/arm/operators/op_table.py index bfaaf4578ed..c5f8de609e6 100644 --- a/backends/arm/operators/op_table.py +++ b/backends/arm/operators/op_table.py @@ -1,4 +1,4 @@ -# Copyright 2024 Arm Limited and/or its affiliates. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -21,7 +21,7 @@ @register_node_visitor class TableVisitor(NodeVisitor): - target = "_table" + target = "_table.default" def define_node( self, diff --git a/backends/arm/operators/op_transpose.py b/backends/arm/operators/op_transpose.py index 42675be34b5..fea8d64f9c9 100644 --- a/backends/arm/operators/op_transpose.py +++ b/backends/arm/operators/op_transpose.py @@ -1,4 +1,4 @@ -# Copyright 2024 Arm Limited and/or its affiliates. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -25,7 +25,7 @@ class TransposeVisitor(NodeVisitor): Inserts a TOSA TRANSPOSE. """ - target = "_transpose" + target = "_transpose.default" def define_node( self, From 727bf6459aa6e3a9bc962a2fd4f71a89ac6dfac6 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Wed, 15 Jan 2025 11:23:21 +0100 Subject: [PATCH 2/2] [devtools/visualization] Add visualize_graph When working with passes, you might have access to a modified graph_module rather than an exported_program. visualize_graph allows visualization of this graph_module by combining the modified graph_module with an exported_program. Note that the graph_module can't be set directly, a new exported_program needs to be constructed. Additionally, we disable the operator validation for the newly constructed ExportedProgram. This is ok since it is only used for visualization. Signed-off-by: Erik Lundell Change-Id: I4fad809bf094a1ec70e25534cc0858f9d8d3d225 --- devtools/visualization/__init__.py | 1 + devtools/visualization/visualization_utils.py | 32 ++++++++++++++++- .../visualization/visualization_utils_test.py | 19 ++++++++++- exir/program/_program.py | 34 ++++++++++++------- 4 files changed, 72 insertions(+), 14 deletions(-) diff --git a/devtools/visualization/__init__.py b/devtools/visualization/__init__.py index 645cc5d5378..df1d74c7fae 100644 --- a/devtools/visualization/__init__.py +++ b/devtools/visualization/__init__.py @@ -8,4 +8,5 @@ ModelExplorerServer, SingletonModelExplorerServer, visualize, + visualize_graph, ) diff --git a/devtools/visualization/visualization_utils.py b/devtools/visualization/visualization_utils.py index 4d520a66366..d21d11082a3 100644 --- a/devtools/visualization/visualization_utils.py +++ b/devtools/visualization/visualization_utils.py @@ -6,9 +6,13 @@ import subprocess import time +from typing import Any, Callable, Type from executorch.exir import EdgeProgramManager, ExecutorchProgramManager +from executorch.exir.program._program import _update_exported_program_graph_module +from torch._export.verifier import Verifier from torch.export.exported_program import ExportedProgram +from torch.fx import GraphModule try: from model_explorer import config, consts, visualize_from_config # type: ignore @@ -27,7 +31,7 @@ class SingletonModelExplorerServer: server: None | subprocess.Popen = None num_open: int = 0 - wait_after_start = 2.0 + wait_after_start = 3.0 def __init__(self, open_in_browser: bool = True, port: int | None = None): if SingletonModelExplorerServer.server is None: @@ -124,3 +128,29 @@ def visualize( no_open_in_browser=no_open_in_browser, **kwargs, ) + + +def visualize_graph( + graph_module: GraphModule, + exported_program: ExportedProgram | EdgeProgramManager | ExecutorchProgramManager, + reuse_server: bool = True, + no_open_in_browser: bool = False, + **kwargs, +): + """Overrides the graph_module of the supplied exported_program with 'graph_module' before visualizing. + Also disables validating operators to allow visualizing graphs containing custom ops. + + A typical example is after running passes, which returns a graph_module rather than an ExportedProgram. + """ + + class _any_op(Verifier): + dialect = "ANY_OP" + + def allowed_op_types(self) -> tuple[Type[Any], ...]: + return (Callable,) # type: ignore + + exported_program = _get_exported_program(exported_program) + exported_program = _update_exported_program_graph_module( + exported_program, graph_module, override_verifiers=[_any_op] + ) + visualize(exported_program, reuse_server, no_open_in_browser, **kwargs) diff --git a/devtools/visualization/visualization_utils_test.py b/devtools/visualization/visualization_utils_test.py index dafefa7dfdd..d49c6d2f72d 100644 --- a/devtools/visualization/visualization_utils_test.py +++ b/devtools/visualization/visualization_utils_test.py @@ -8,6 +8,7 @@ import pytest import torch +from executorch.backends.arm._passes.decompose_linear_pass import DecomposeLinearPass from executorch.backends.xnnpack.test.tester import Tester from executorch.devtools.visualization import ( @@ -15,8 +16,9 @@ SingletonModelExplorerServer, visualization_utils, visualize, + visualize_graph, ) -from executorch.exir import ExportedProgram +from executorch.exir import ExportedProgram, to_edge_transform_and_lower try: from model_explorer.config import ModelExplorerConfig # type: ignore @@ -145,6 +147,17 @@ def test_visualize_to_executorch(server): ) +def test_visualize_graph(server): + with server(): + model = Linear(20, 30) + exported_program = torch.export.export(model, model.get_inputs()) + exported_program = to_edge_transform_and_lower( + exported_program + ).exported_program() + modified_gm = DecomposeLinearPass()(exported_program.graph_module).graph_module + visualize_graph(modified_gm, exported_program) + + if __name__ == "__main__": """A test to run locally to make sure that the web browser opens up automatically as intended. @@ -158,3 +171,7 @@ def test_visualize_to_executorch(server): test_visualize_to_edge(SingletonModelExplorerServer) test_visualize_partition(SingletonModelExplorerServer) test_visualize_to_executorch(SingletonModelExplorerServer) + test_visualize_graph(SingletonModelExplorerServer) + + # Sleep to give the server time to load the last graph before killing it. + time.sleep(3.0) diff --git a/exir/program/_program.py b/exir/program/_program.py index 86f111f2f98..fdf4b93e19c 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -10,7 +11,7 @@ import io import logging import os -from typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Tuple, Type, Union import torch import torch._export @@ -66,6 +67,7 @@ ) from executorch.extension.flat_tensor.serialize.serialize import FlatTensorSerializer from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass +from torch._export.verifier import Verifier from torch.export import ExportedProgram from torch.export._remove_auto_functionalized_pass import ( unsafe_remove_auto_functionalized_pass, @@ -213,21 +215,29 @@ def _transform(self, *passes: PassType) -> "ExportedProgram": if transformed_gm is self.graph_module and not res.modified: return self + return _update_exported_program_graph_module(self, transformed_gm) + + +def _update_exported_program_graph_module( + exported_program: ExportedProgram, + gm: torch.fx.GraphModule, + override_verifiers: None | list[Type[Verifier]] = None, +) -> "ExportedProgram": transformed_ep = ExportedProgram( - root=transformed_gm, - graph=transformed_gm.graph, + root=gm, + graph=gm.graph, graph_signature=_get_updated_graph_signature( - self.graph_signature, transformed_gm + exported_program.graph_signature, gm ), - state_dict=self.state_dict, - range_constraints=_get_updated_range_constraints(transformed_gm), - module_call_graph=copy.deepcopy(self._module_call_graph), - example_inputs=self.example_inputs, - constants=self.constants, - verifiers=[self.verifier], + state_dict=exported_program.state_dict, + range_constraints=_get_updated_range_constraints(gm), + module_call_graph=copy.deepcopy(exported_program._module_call_graph), + example_inputs=exported_program.example_inputs, + constants=exported_program.constants, + verifiers=override_verifiers or [exported_program.verifier], ) - transformed_ep.graph_module.meta.update(self.graph_module.meta) - transformed_ep.graph_module.meta.update(res.graph_module.meta) + transformed_ep.graph_module.meta.update(exported_program.graph_module.meta) + transformed_ep.graph_module.meta.update(gm.meta) return transformed_ep