Skip to content

Commit 727bf64

Browse files
committed
[devtools/visualization] Add visualize_graph
When working with passes, you might have access to a modified graph_module rather than an exported_program. visualize_graph allows visualization of this graph_module by combining the modified graph_module with an exported_program. Note that the graph_module can't be set directly, a new exported_program needs to be constructed. Additionally, we disable the operator validation for the newly constructed ExportedProgram. This is ok since it is only used for visualization. Signed-off-by: Erik Lundell <[email protected]> Change-Id: I4fad809bf094a1ec70e25534cc0858f9d8d3d225
1 parent f5e888b commit 727bf64

File tree

4 files changed

+72
-14
lines changed

4 files changed

+72
-14
lines changed

devtools/visualization/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@
88
ModelExplorerServer,
99
SingletonModelExplorerServer,
1010
visualize,
11+
visualize_graph,
1112
)

devtools/visualization/visualization_utils.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,13 @@
66

77
import subprocess
88
import time
9+
from typing import Any, Callable, Type
910

1011
from executorch.exir import EdgeProgramManager, ExecutorchProgramManager
12+
from executorch.exir.program._program import _update_exported_program_graph_module
13+
from torch._export.verifier import Verifier
1114
from torch.export.exported_program import ExportedProgram
15+
from torch.fx import GraphModule
1216

1317
try:
1418
from model_explorer import config, consts, visualize_from_config # type: ignore
@@ -27,7 +31,7 @@ class SingletonModelExplorerServer:
2731

2832
server: None | subprocess.Popen = None
2933
num_open: int = 0
30-
wait_after_start = 2.0
34+
wait_after_start = 3.0
3135

3236
def __init__(self, open_in_browser: bool = True, port: int | None = None):
3337
if SingletonModelExplorerServer.server is None:
@@ -124,3 +128,29 @@ def visualize(
124128
no_open_in_browser=no_open_in_browser,
125129
**kwargs,
126130
)
131+
132+
133+
def visualize_graph(
134+
graph_module: GraphModule,
135+
exported_program: ExportedProgram | EdgeProgramManager | ExecutorchProgramManager,
136+
reuse_server: bool = True,
137+
no_open_in_browser: bool = False,
138+
**kwargs,
139+
):
140+
"""Overrides the graph_module of the supplied exported_program with 'graph_module' before visualizing.
141+
Also disables validating operators to allow visualizing graphs containing custom ops.
142+
143+
A typical example is after running passes, which returns a graph_module rather than an ExportedProgram.
144+
"""
145+
146+
class _any_op(Verifier):
147+
dialect = "ANY_OP"
148+
149+
def allowed_op_types(self) -> tuple[Type[Any], ...]:
150+
return (Callable,) # type: ignore
151+
152+
exported_program = _get_exported_program(exported_program)
153+
exported_program = _update_exported_program_graph_module(
154+
exported_program, graph_module, override_verifiers=[_any_op]
155+
)
156+
visualize(exported_program, reuse_server, no_open_in_browser, **kwargs)

devtools/visualization/visualization_utils_test.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,17 @@
88

99
import pytest
1010
import torch
11+
from executorch.backends.arm._passes.decompose_linear_pass import DecomposeLinearPass
1112
from executorch.backends.xnnpack.test.tester import Tester
1213

1314
from executorch.devtools.visualization import (
1415
ModelExplorerServer,
1516
SingletonModelExplorerServer,
1617
visualization_utils,
1718
visualize,
19+
visualize_graph,
1820
)
19-
from executorch.exir import ExportedProgram
21+
from executorch.exir import ExportedProgram, to_edge_transform_and_lower
2022

2123
try:
2224
from model_explorer.config import ModelExplorerConfig # type: ignore
@@ -145,6 +147,17 @@ def test_visualize_to_executorch(server):
145147
)
146148

147149

150+
def test_visualize_graph(server):
151+
with server():
152+
model = Linear(20, 30)
153+
exported_program = torch.export.export(model, model.get_inputs())
154+
exported_program = to_edge_transform_and_lower(
155+
exported_program
156+
).exported_program()
157+
modified_gm = DecomposeLinearPass()(exported_program.graph_module).graph_module
158+
visualize_graph(modified_gm, exported_program)
159+
160+
148161
if __name__ == "__main__":
149162
"""A test to run locally to make sure that the web browser opens up
150163
automatically as intended.
@@ -158,3 +171,7 @@ def test_visualize_to_executorch(server):
158171
test_visualize_to_edge(SingletonModelExplorerServer)
159172
test_visualize_partition(SingletonModelExplorerServer)
160173
test_visualize_to_executorch(SingletonModelExplorerServer)
174+
test_visualize_graph(SingletonModelExplorerServer)
175+
176+
# Sleep to give the server time to load the last graph before killing it.
177+
time.sleep(3.0)

exir/program/_program.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -10,7 +11,7 @@
1011
import io
1112
import logging
1213
import os
13-
from typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Tuple, Union
14+
from typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Tuple, Type, Union
1415

1516
import torch
1617
import torch._export
@@ -66,6 +67,7 @@
6667
)
6768
from executorch.extension.flat_tensor.serialize.serialize import FlatTensorSerializer
6869
from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass
70+
from torch._export.verifier import Verifier
6971
from torch.export import ExportedProgram
7072
from torch.export._remove_auto_functionalized_pass import (
7173
unsafe_remove_auto_functionalized_pass,
@@ -213,21 +215,29 @@ def _transform(self, *passes: PassType) -> "ExportedProgram":
213215
if transformed_gm is self.graph_module and not res.modified:
214216
return self
215217

218+
return _update_exported_program_graph_module(self, transformed_gm)
219+
220+
221+
def _update_exported_program_graph_module(
222+
exported_program: ExportedProgram,
223+
gm: torch.fx.GraphModule,
224+
override_verifiers: None | list[Type[Verifier]] = None,
225+
) -> "ExportedProgram":
216226
transformed_ep = ExportedProgram(
217-
root=transformed_gm,
218-
graph=transformed_gm.graph,
227+
root=gm,
228+
graph=gm.graph,
219229
graph_signature=_get_updated_graph_signature(
220-
self.graph_signature, transformed_gm
230+
exported_program.graph_signature, gm
221231
),
222-
state_dict=self.state_dict,
223-
range_constraints=_get_updated_range_constraints(transformed_gm),
224-
module_call_graph=copy.deepcopy(self._module_call_graph),
225-
example_inputs=self.example_inputs,
226-
constants=self.constants,
227-
verifiers=[self.verifier],
232+
state_dict=exported_program.state_dict,
233+
range_constraints=_get_updated_range_constraints(gm),
234+
module_call_graph=copy.deepcopy(exported_program._module_call_graph),
235+
example_inputs=exported_program.example_inputs,
236+
constants=exported_program.constants,
237+
verifiers=override_verifiers or [exported_program.verifier],
228238
)
229-
transformed_ep.graph_module.meta.update(self.graph_module.meta)
230-
transformed_ep.graph_module.meta.update(res.graph_module.meta)
239+
transformed_ep.graph_module.meta.update(exported_program.graph_module.meta)
240+
transformed_ep.graph_module.meta.update(gm.meta)
231241
return transformed_ep
232242

233243

0 commit comments

Comments
 (0)