Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def insert_input_transpose(node, input_node, graph_module):
with graph_module.graph.inserting_before(node):
permute_node = create_node(
graph_module.graph,
torch.ops.passthrough_to_tosa._transpose,
torch.ops.passthrough_to_tosa._transpose.default,
args=(
input_node,
list(AnnotateChannelsLastDimOrder.NHWC_inverse_order),
Expand All @@ -135,7 +135,7 @@ def insert_output_transpose(node, graph_module):
with graph_module.graph.inserting_after(node):
permute_node = create_node(
graph_module.graph,
torch.ops.passthrough_to_tosa._transpose,
torch.ops.passthrough_to_tosa._transpose.default,
args=(node, list(AnnotateChannelsLastDimOrder.NHWC_order)),
)
permute_node.meta["tosa_dim_order"] = (
Expand Down
10 changes: 7 additions & 3 deletions backends/arm/_passes/insert_table_ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand Down Expand Up @@ -92,7 +92,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
with graph_module.graph.inserting_before(node):
table_node = create_node(
graph=graph_module.graph,
op_target=torch.ops.tosa._table,
op_target=torch.ops.tosa._table.default,
args=(node.args[0],),
)
assert len(input_qparams) == 1
Expand All @@ -104,7 +104,11 @@ def call(self, graph_module: GraphModule) -> PassResult:
out_quantargs=output_qparams[0],
)
# Register buffer in self.exported_program.state_dict
self.register_buffer(buffer_name=table_node.name, buffer=buffer)
# When the graph is retraced, the implementation _table is used and the suffix _default disappears from the node name
# Remove it here to make it possible to find in the node_visitor
self.register_buffer(
buffer_name=table_node.name.replace("_default", ""), buffer=buffer
)
node.replace_all_uses_with(table_node)
graph_module.graph.erase_node(node)
table_node.meta["input_qparams"] = input_qparams
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/operators/op_table.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -21,7 +21,7 @@

@register_node_visitor
class TableVisitor(NodeVisitor):
target = "_table"
target = "_table.default"

def define_node(
self,
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/operators/op_transpose.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -25,7 +25,7 @@ class TransposeVisitor(NodeVisitor):
Inserts a TOSA TRANSPOSE.
"""

target = "_transpose"
target = "_transpose.default"

def define_node(
self,
Expand Down
1 change: 1 addition & 0 deletions devtools/visualization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
ModelExplorerServer,
SingletonModelExplorerServer,
visualize,
visualize_graph,
)
32 changes: 31 additions & 1 deletion devtools/visualization/visualization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@

import subprocess
import time
from typing import Any, Callable, Type

from executorch.exir import EdgeProgramManager, ExecutorchProgramManager
from executorch.exir.program._program import _update_exported_program_graph_module
from torch._export.verifier import Verifier
from torch.export.exported_program import ExportedProgram
from torch.fx import GraphModule

try:
from model_explorer import config, consts, visualize_from_config # type: ignore
Expand All @@ -27,7 +31,7 @@ class SingletonModelExplorerServer:

server: None | subprocess.Popen = None
num_open: int = 0
wait_after_start = 2.0
wait_after_start = 3.0

def __init__(self, open_in_browser: bool = True, port: int | None = None):
if SingletonModelExplorerServer.server is None:
Expand Down Expand Up @@ -124,3 +128,29 @@ def visualize(
no_open_in_browser=no_open_in_browser,
**kwargs,
)


def visualize_graph(
graph_module: GraphModule,
exported_program: ExportedProgram | EdgeProgramManager | ExecutorchProgramManager,
reuse_server: bool = True,
no_open_in_browser: bool = False,
**kwargs,
):
"""Overrides the graph_module of the supplied exported_program with 'graph_module' before visualizing.
Also disables validating operators to allow visualizing graphs containing custom ops.

A typical example is after running passes, which returns a graph_module rather than an ExportedProgram.
"""

class _any_op(Verifier):
dialect = "ANY_OP"

def allowed_op_types(self) -> tuple[Type[Any], ...]:
return (Callable,) # type: ignore

exported_program = _get_exported_program(exported_program)
exported_program = _update_exported_program_graph_module(
exported_program, graph_module, override_verifiers=[_any_op]
)
visualize(exported_program, reuse_server, no_open_in_browser, **kwargs)
19 changes: 18 additions & 1 deletion devtools/visualization/visualization_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,17 @@

import pytest
import torch
from executorch.backends.arm._passes.decompose_linear_pass import DecomposeLinearPass
from executorch.backends.xnnpack.test.tester import Tester

from executorch.devtools.visualization import (
ModelExplorerServer,
SingletonModelExplorerServer,
visualization_utils,
visualize,
visualize_graph,
)
from executorch.exir import ExportedProgram
from executorch.exir import ExportedProgram, to_edge_transform_and_lower

try:
from model_explorer.config import ModelExplorerConfig # type: ignore
Expand Down Expand Up @@ -145,6 +147,17 @@ def test_visualize_to_executorch(server):
)


def test_visualize_graph(server):
with server():
model = Linear(20, 30)
exported_program = torch.export.export(model, model.get_inputs())
exported_program = to_edge_transform_and_lower(
exported_program
).exported_program()
modified_gm = DecomposeLinearPass()(exported_program.graph_module).graph_module
visualize_graph(modified_gm, exported_program)


if __name__ == "__main__":
"""A test to run locally to make sure that the web browser opens up
automatically as intended.
Expand All @@ -158,3 +171,7 @@ def test_visualize_to_executorch(server):
test_visualize_to_edge(SingletonModelExplorerServer)
test_visualize_partition(SingletonModelExplorerServer)
test_visualize_to_executorch(SingletonModelExplorerServer)
test_visualize_graph(SingletonModelExplorerServer)

# Sleep to give the server time to load the last graph before killing it.
time.sleep(3.0)
34 changes: 22 additions & 12 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -10,7 +11,7 @@
import io
import logging
import os
from typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Tuple, Type, Union

import torch
import torch._export
Expand Down Expand Up @@ -66,6 +67,7 @@
)
from executorch.extension.flat_tensor.serialize.serialize import FlatTensorSerializer
from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass
from torch._export.verifier import Verifier
from torch.export import ExportedProgram
from torch.export._remove_auto_functionalized_pass import (
unsafe_remove_auto_functionalized_pass,
Expand Down Expand Up @@ -213,21 +215,29 @@ def _transform(self, *passes: PassType) -> "ExportedProgram":
if transformed_gm is self.graph_module and not res.modified:
return self

return _update_exported_program_graph_module(self, transformed_gm)


def _update_exported_program_graph_module(
exported_program: ExportedProgram,
gm: torch.fx.GraphModule,
override_verifiers: None | list[Type[Verifier]] = None,
) -> "ExportedProgram":
transformed_ep = ExportedProgram(
root=transformed_gm,
graph=transformed_gm.graph,
root=gm,
graph=gm.graph,
graph_signature=_get_updated_graph_signature(
self.graph_signature, transformed_gm
exported_program.graph_signature, gm
),
state_dict=self.state_dict,
range_constraints=_get_updated_range_constraints(transformed_gm),
module_call_graph=copy.deepcopy(self._module_call_graph),
example_inputs=self.example_inputs,
constants=self.constants,
verifiers=[self.verifier],
state_dict=exported_program.state_dict,
range_constraints=_get_updated_range_constraints(gm),
module_call_graph=copy.deepcopy(exported_program._module_call_graph),
example_inputs=exported_program.example_inputs,
constants=exported_program.constants,
verifiers=override_verifiers or [exported_program.verifier],
)
transformed_ep.graph_module.meta.update(self.graph_module.meta)
transformed_ep.graph_module.meta.update(res.graph_module.meta)
transformed_ep.graph_module.meta.update(exported_program.graph_module.meta)
transformed_ep.graph_module.meta.update(gm.meta)
return transformed_ep


Expand Down