Skip to content

Commit 3ae2b41

Browse files
committed
Add visualization to devtools
- Add ai-edge-model-explorer as development requirement - Add lightweight wrapper around model-explorer in devtools - Use it in .visualize() call in XNNPack tester - Add tests. The actual call to model-explorer is mocked, since getting it to run in ci did not work out. However, asserting that the model actually loads properly needs visual inspection anyways. Instead, visualization_utils_test.py can be run locally. - Add two context managers for starting model-explorer servers, currently only used for testing but might be useful. Signed-off-by: Erik Lundell <[email protected]> Change-Id: I3fc0efb3a2012b038f699f7737c4d97c5039df40
1 parent 08770b7 commit 3ae2b41

File tree

5 files changed

+295
-0
lines changed

5 files changed

+295
-0
lines changed

backends/xnnpack/test/tester/tester.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
23
# All rights reserved.
34
#
45
# This source code is licensed under the BSD-style license found in the
@@ -627,6 +628,15 @@ def check_node_count(self, input: Dict[Any, int]):
627628

628629
return self
629630

631+
def visualize(
632+
self, reuse_server: bool = True, stage: Optional[str] = None, **kwargs
633+
):
634+
# import here to avoid importing model_explorer when it is not needed which is most of the time.
635+
from executorch.devtools.visualization import visualize
636+
637+
visualize(self.get_artifact(stage), reuse_server=reuse_server, **kwargs)
638+
return self
639+
630640
def run_method_and_compare_outputs(
631641
self,
632642
stage: Optional[str] = None,

devtools/visualization/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
from executorch.devtools.visualization.visualization_utils import ( # noqa: F401
8+
ModelExplorerServer,
9+
SingletonModelExplorerServer,
10+
visualize,
11+
)
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
import subprocess
8+
import time
9+
10+
from executorch.exir import EdgeProgramManager, ExecutorchProgramManager
11+
from model_explorer import config, consts, visualize_from_config # type: ignore
12+
from torch.export.exported_program import ExportedProgram
13+
14+
15+
class SingletonModelExplorerServer:
16+
"""Singleton context manager for starting a model-explorer server.
17+
If multiple ModelExplorerServer contexts are nested, a single
18+
server is still used.
19+
"""
20+
21+
server: None | subprocess.Popen = None
22+
num_open: int = 0
23+
wait_after_start = 2.0
24+
25+
def __init__(self, open_in_browser: bool = True, port: int | None = None):
26+
if SingletonModelExplorerServer.server is None:
27+
command = ["model-explorer"]
28+
if not open_in_browser:
29+
command.append("--no_open_in_browser")
30+
if port is not None:
31+
command.append("--port")
32+
command.append(str(port))
33+
SingletonModelExplorerServer.server = subprocess.Popen(command)
34+
35+
def __enter__(self):
36+
SingletonModelExplorerServer.num_open = (
37+
SingletonModelExplorerServer.num_open + 1
38+
)
39+
time.sleep(SingletonModelExplorerServer.wait_after_start)
40+
return self
41+
42+
def __exit__(self, type, value, traceback):
43+
SingletonModelExplorerServer.num_open = (
44+
SingletonModelExplorerServer.num_open - 1
45+
)
46+
if SingletonModelExplorerServer.num_open == 0:
47+
if SingletonModelExplorerServer.server is not None:
48+
SingletonModelExplorerServer.server.kill()
49+
try:
50+
SingletonModelExplorerServer.server.wait(
51+
SingletonModelExplorerServer.wait_after_start
52+
)
53+
except subprocess.TimeoutExpired:
54+
SingletonModelExplorerServer.server.terminate()
55+
SingletonModelExplorerServer.server = None
56+
57+
58+
class ModelExplorerServer:
59+
"""Context manager for starting a model-explorer server."""
60+
61+
wait_after_start = 2.0
62+
63+
def __init__(self, open_in_browser: bool = True, port: int | None = None):
64+
command = ["model-explorer"]
65+
if not open_in_browser:
66+
command.append("--no_open_in_browser")
67+
if port is not None:
68+
command.append("--port")
69+
command.append(str(port))
70+
self.server = subprocess.Popen(command)
71+
72+
def __enter__(self):
73+
time.sleep(self.wait_after_start)
74+
75+
def __exit__(self, type, value, traceback):
76+
self.server.kill()
77+
try:
78+
self.server.wait(self.wait_after_start)
79+
except subprocess.TimeoutExpired:
80+
self.server.terminate()
81+
82+
83+
def _get_exported_program(
84+
visualizable: ExportedProgram | EdgeProgramManager | ExecutorchProgramManager,
85+
) -> ExportedProgram:
86+
if isinstance(visualizable, ExportedProgram):
87+
return visualizable
88+
if isinstance(visualizable, (EdgeProgramManager, ExecutorchProgramManager)):
89+
return visualizable.exported_program()
90+
raise RuntimeError(f"Cannot get ExportedProgram from {visualizable}")
91+
92+
93+
def visualize(
94+
visualizable: ExportedProgram | EdgeProgramManager | ExecutorchProgramManager,
95+
reuse_server: bool = True,
96+
no_open_in_browser: bool = False,
97+
**kwargs,
98+
):
99+
"""Wraps the visualize_from_config call from model_explorer.
100+
For convenicence, figures out how to find the exported_program
101+
from EdgeProgramManager and ExecutorchProgramManager for you.
102+
103+
See https://github.com/google-ai-edge/model-explorer/wiki/4.-API-Guide#visualize-pytorch-models
104+
for full documentation.
105+
"""
106+
cur_config = config()
107+
settings = consts.DEFAULT_SETTINGS
108+
cur_config.add_model_from_pytorch(
109+
"Executorch",
110+
exported_program=_get_exported_program(visualizable),
111+
settings=settings,
112+
)
113+
if reuse_server:
114+
cur_config.set_reuse_server()
115+
visualize_from_config(
116+
cur_config,
117+
no_open_in_browser=no_open_in_browser,
118+
**kwargs,
119+
)
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
import time
8+
9+
import pytest
10+
import torch
11+
from executorch.backends.xnnpack.test.tester import Tester
12+
13+
from executorch.devtools.visualization import (
14+
ModelExplorerServer,
15+
SingletonModelExplorerServer,
16+
visualization_utils,
17+
visualize,
18+
)
19+
from executorch.exir import ExportedProgram
20+
from model_explorer.config import ModelExplorerConfig # type: ignore
21+
22+
23+
@pytest.fixture
24+
def server():
25+
"""Mock relevant calls in visualization.visualize and check that parameters have their expected value."""
26+
monkeypatch = pytest.MonkeyPatch()
27+
with monkeypatch.context():
28+
_called_reuse_server = False
29+
30+
def mock_set_reuse_server(self):
31+
nonlocal _called_reuse_server
32+
_called_reuse_server = True
33+
34+
def mock_add_model_from_pytorch(self, name, exported_program, settings):
35+
assert isinstance(exported_program, ExportedProgram)
36+
37+
def mock_visualize_from_config(cur_config, no_open_in_browser):
38+
pass
39+
40+
monkeypatch.setattr(
41+
ModelExplorerConfig, "set_reuse_server", mock_set_reuse_server
42+
)
43+
monkeypatch.setattr(
44+
ModelExplorerConfig, "add_model_from_pytorch", mock_add_model_from_pytorch
45+
)
46+
monkeypatch.setattr(
47+
visualization_utils, "visualize_from_config", mock_visualize_from_config
48+
)
49+
yield monkeypatch.context
50+
assert _called_reuse_server, "Did not call reuse_server"
51+
52+
53+
class Linear(torch.nn.Module):
54+
def __init__(
55+
self,
56+
in_features: int,
57+
out_features: int = 3,
58+
bias: bool = True,
59+
):
60+
super().__init__()
61+
self.inputs = (torch.randn(5, 10, 25, in_features),)
62+
self.fc = torch.nn.Linear(
63+
in_features=in_features,
64+
out_features=out_features,
65+
bias=bias,
66+
)
67+
68+
def get_inputs(self) -> tuple[torch.Tensor]:
69+
return self.inputs
70+
71+
def forward(self, x: torch.Tensor) -> torch.Tensor:
72+
return self.fc(x)
73+
74+
75+
def test_visualize_manual_export(server):
76+
with server():
77+
model = Linear(20, 30)
78+
exported_program = torch.export.export(model, model.get_inputs())
79+
visualize(exported_program)
80+
time.sleep(3.0)
81+
82+
83+
def test_visualize_exported_program(server):
84+
with server():
85+
model = Linear(20, 30)
86+
(
87+
Tester(
88+
model,
89+
example_inputs=model.get_inputs(),
90+
)
91+
.export()
92+
.visualize()
93+
)
94+
95+
96+
def test_visualize_to_edge(server):
97+
with server():
98+
model = Linear(20, 30)
99+
(
100+
Tester(
101+
model,
102+
example_inputs=model.get_inputs(),
103+
)
104+
.export()
105+
.to_edge()
106+
.visualize()
107+
)
108+
109+
110+
def test_visualize_partition(server):
111+
with server():
112+
model = Linear(20, 30)
113+
(
114+
Tester(
115+
model,
116+
example_inputs=model.get_inputs(),
117+
)
118+
.export()
119+
.to_edge()
120+
.partition()
121+
.visualize()
122+
)
123+
124+
125+
def test_visualize_to_executorch(server):
126+
with server():
127+
model = Linear(20, 30)
128+
(
129+
Tester(
130+
model,
131+
example_inputs=model.get_inputs(),
132+
)
133+
.export()
134+
.to_edge()
135+
.partition()
136+
.to_executorch()
137+
.visualize()
138+
)
139+
140+
141+
if __name__ == "__main__":
142+
"""A test to run locally to make sure that the web browser opens up
143+
automatically as intended.
144+
"""
145+
146+
test_visualize_manual_export(ModelExplorerServer)
147+
148+
with SingletonModelExplorerServer():
149+
test_visualize_manual_export(SingletonModelExplorerServer)
150+
test_visualize_exported_program(SingletonModelExplorerServer)
151+
test_visualize_to_edge(SingletonModelExplorerServer)
152+
test_visualize_partition(SingletonModelExplorerServer)
153+
test_visualize_to_executorch(SingletonModelExplorerServer)

install_requirements.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2024-25 Arm Limited and/or its affiliates.
23
# All rights reserved.
34
#
45
# This source code is licensed under the BSD-style license found in the
@@ -169,6 +170,7 @@ def python_is_compatible():
169170
"tomli", # Imported by extract_sources.py when using python < 3.11.
170171
"wheel", # For building the pip package archive.
171172
"zstd", # Imported by resolve_buck.py.
173+
"ai-edge-model-explorer>=0.1.16", # For visualizing ExportedPrograms
172174
]
173175

174176
# Assemble the list of requirements to actually install.

0 commit comments

Comments
 (0)