Skip to content

Commit 2ebf7f9

Browse files
committed
Add support for model-explorer in ArmTester
If model-explorer is installed, run it on the exported_graph using ArmTester.visualize(), or use the api the visualize module directly from the debug console. Introduces two pytest configurations: --model_explore_host : if set, tries connecting to to a running server rather than starting a new one. --model_explore_port : set the port of the above host Signed-off-by: Erik Lundell <[email protected]> Change-Id: I00ada14f27e6a7ad3994a439ba4c1e39b1560e2c
1 parent 03b1ef2 commit 2ebf7f9

File tree

3 files changed

+107
-1
lines changed

3 files changed

+107
-1
lines changed

backends/arm/test/common.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class arm_test_options(Enum):
2929
corstone300 = auto()
3030
dump_path = auto()
3131
date_format = auto()
32+
model_explorer_host = auto()
33+
model_explorer_port = auto()
3234

3335

3436
_test_options: dict[arm_test_options, Any] = {}
@@ -41,6 +43,18 @@ def pytest_addoption(parser):
4143
parser.addoption("--arm_run_corstone300", action="store_true")
4244
parser.addoption("--default_dump_path", default=None)
4345
parser.addoption("--date_format", default="%d-%b-%H:%M:%S")
46+
parser.addoption(
47+
"--model_explorer_host",
48+
action="store",
49+
default=None,
50+
help="If set, tries to connect to existing model-explorer server rather than starting a new one.",
51+
)
52+
parser.addoption(
53+
"--model_explorer_port",
54+
action="store",
55+
default=None,
56+
help="Set the port of the model explorer server. If not set, tries ports between 8080 and 8099.",
57+
)
4458

4559

4660
def pytest_configure(config):
@@ -62,7 +76,19 @@ def pytest_configure(config):
6276
raise RuntimeError(
6377
f"Supplied argument 'default_dump_path={dump_path}' that does not exist or is not a directory."
6478
)
79+
if config.option.model_explorer_port:
80+
if not str.isdecimal(config.option.model_explorer_port):
81+
raise RuntimeError(
82+
f"--model_explorer_port needs to be an integer, got '{config.option.model_explorer_port}'."
83+
)
84+
else:
85+
_test_options[arm_test_options.model_explorer_port] = int(
86+
config.option.model_explorer_port
87+
)
6588
_test_options[arm_test_options.date_format] = config.option.date_format
89+
_test_options[arm_test_options.model_explorer_host] = (
90+
config.option.model_explorer_host
91+
)
6692
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
6793

6894

backends/arm/test/tester/arm_tester.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import logging
7-
87
from collections import Counter
98
from pprint import pformat
109
from typing import Any, Iterable, List, Literal, Optional, Tuple, Union
@@ -35,6 +34,7 @@
3534
dbg_tosa_fb_to_json,
3635
RunnerUtil,
3736
)
37+
from executorch.backends.arm.test.visualize import visualize
3838
from executorch.backends.arm.tosa_mapping import extract_tensor_meta
3939

4040
from executorch.backends.xnnpack.test.tester import Tester
@@ -46,6 +46,8 @@
4646
from tabulate import tabulate
4747
from torch.export.graph_signature import ExportGraphSignature, InputSpec, OutputSpec
4848
from torch.fx import Graph
49+
from typing_extensions import Self
50+
4951

5052
logger = logging.getLogger(__name__)
5153

@@ -430,6 +432,22 @@ def dump_dtype_distribution(
430432
_dump_str(to_print, path_to_dump)
431433
return self
432434

435+
def visualize(self) -> Self:
436+
exported_program = self._get_exported_program()
437+
visualize(exported_program)
438+
return self
439+
440+
def _get_exported_program(self):
441+
match self.cur:
442+
case "Export":
443+
return self.get_artifact()
444+
case "ToEdge" | "Partition":
445+
return self.get_artifact().exported_program()
446+
case _:
447+
raise RuntimeError(
448+
"Can only get the exported program for the Export, ToEdge, or Partition stage."
449+
)
450+
433451
@staticmethod
434452
def _calculate_reference_output(
435453
module: Union[torch.fx.GraphModule, torch.nn.Module], inputs

backends/arm/test/visualize.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import logging
7+
from typing import Optional
8+
9+
from executorch.backends.arm.test.common import arm_test_options, get_option
10+
from torch.export import ExportedProgram
11+
12+
logger = logging.getLogger(__name__)
13+
_model_explorer_installed = False
14+
15+
try:
16+
# pyre-ignore[21]: We keep track of whether import succeeded manually.
17+
from model_explorer import config, visualize_from_config, visualize_pytorch
18+
19+
_model_explorer_installed = True
20+
except ImportError:
21+
logger.warning("model-explorer is not installed, can't visualize models.")
22+
23+
24+
def is_model_explorer_installed() -> bool:
25+
return _model_explorer_installed
26+
27+
28+
def get_pytest_option_host() -> str | None:
29+
host = get_option(arm_test_options.model_explorer_host)
30+
return str(host) if host else None
31+
32+
33+
def get_pytest_option_port() -> int | None:
34+
port = get_option(arm_test_options.model_explorer_port)
35+
return int(port) if port else None
36+
37+
38+
def visualize(
39+
exported_program: ExportedProgram,
40+
host: Optional[str] = None,
41+
port: Optional[int] = None,
42+
):
43+
"""Attempt visualizing exported_program using model-explorer."""
44+
45+
host = host if host else get_pytest_option_host()
46+
port = port if port else get_pytest_option_port()
47+
48+
if not is_model_explorer_installed():
49+
logger.warning("Can't visualize model since model-explorer is not installed.")
50+
return
51+
52+
# If a host is provided, we attempt connecting to an already running server.
53+
# Note that this needs a modified model-explorer
54+
if host:
55+
explorer_config = (
56+
config()
57+
.add_model_from_pytorch("ExportedProgram", exported_program)
58+
.set_reuse_server(server_host=host, server_port=port)
59+
)
60+
visualize_from_config(explorer_config)
61+
else:
62+
visualize_pytorch(exported_program)

0 commit comments

Comments
 (0)