Skip to content

Commit d98cfe6

Browse files
committed
Add support for model-explorer in ArmTester
If model-explorer is installed, run it on the exported_graph using ArmTester.visualize(), or use the api the visualize module directly from the debug console. Introduces two pytest configurations: --model_explore_host : if set, tries connecting to to a running server rather than starting a new one. --model_explore_port : set the port of the above host Signed-off-by: Erik Lundell <[email protected]> Change-Id: I00ada14f27e6a7ad3994a439ba4c1e39b1560e2c
1 parent 7fcd0af commit d98cfe6

File tree

3 files changed

+107
-1
lines changed

3 files changed

+107
-1
lines changed

backends/arm/test/common.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class arm_test_options(Enum):
2929
corstone300 = auto()
3030
dump_path = auto()
3131
date_format = auto()
32+
model_explorer_host = auto()
33+
model_explorer_port = auto()
3234

3335

3436
_test_options: dict[arm_test_options, Any] = {}
@@ -41,6 +43,18 @@ def pytest_addoption(parser):
4143
parser.addoption("--arm_run_corstone300", action="store_true")
4244
parser.addoption("--default_dump_path", default=None)
4345
parser.addoption("--date_format", default="%d-%b-%H:%M:%S")
46+
parser.addoption(
47+
"--model_explorer_host",
48+
action="store",
49+
default=None,
50+
help="If set, tries to connect to existing model-explorer server rather than starting a new one.",
51+
)
52+
parser.addoption(
53+
"--model_explorer_port",
54+
action="store",
55+
default=None,
56+
help="Set the port of the model explorer server. If not set, tries ports between 8080 and 8099.",
57+
)
4458

4559

4660
def pytest_configure(config):
@@ -62,7 +76,19 @@ def pytest_configure(config):
6276
raise RuntimeError(
6377
f"Supplied argument 'default_dump_path={dump_path}' that does not exist or is not a directory."
6478
)
79+
if config.option.model_explorer_port:
80+
if not str.isdecimal(config.option.model_explorer_port):
81+
raise RuntimeError(
82+
f"--model_explorer_port needs to be an integer, got '{config.option.model_explorer_port}'."
83+
)
84+
else:
85+
_test_options[arm_test_options.model_explorer_port] = int(
86+
config.option.model_explorer_port
87+
)
6588
_test_options[arm_test_options.date_format] = config.option.date_format
89+
_test_options[arm_test_options.model_explorer_host] = (
90+
config.option.model_explorer_host
91+
)
6692
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
6793

6894

backends/arm/test/tester/arm_tester.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import logging
7-
87
from collections import Counter
98
from pprint import pformat
109
from typing import Any, Iterable, List, Literal, Optional, Tuple, Union
@@ -35,6 +34,7 @@
3534
dbg_tosa_fb_to_json,
3635
RunnerUtil,
3736
)
37+
from executorch.backends.arm.test.visualize import visualize
3838
from executorch.backends.arm.tosa_mapping import extract_tensor_meta
3939

4040
from executorch.backends.xnnpack.test.tester import Tester
@@ -47,6 +47,8 @@
4747
from tabulate import tabulate
4848
from torch.export.graph_signature import ExportGraphSignature, InputSpec, OutputSpec
4949
from torch.fx import Graph
50+
from typing_extensions import Self
51+
5052

5153
logger = logging.getLogger(__name__)
5254

@@ -473,6 +475,22 @@ def dump_dtype_distribution(
473475
_dump_str(to_print, path_to_dump)
474476
return self
475477

478+
def visualize(self) -> Self:
479+
exported_program = self._get_exported_program()
480+
visualize(exported_program)
481+
return self
482+
483+
def _get_exported_program(self):
484+
match self.cur:
485+
case "Export":
486+
return self.get_artifact()
487+
case "ToEdge" | "Partition":
488+
return self.get_artifact().exported_program()
489+
case _:
490+
raise RuntimeError(
491+
"Can only get the exported program for the Export, ToEdge, or Partition stage."
492+
)
493+
476494
@staticmethod
477495
def _calculate_reference_output(
478496
module: Union[torch.fx.GraphModule, torch.nn.Module], inputs

backends/arm/test/visualize.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import logging
7+
from typing import Optional
8+
9+
from executorch.backends.arm.test.common import arm_test_options, get_option
10+
from torch.export import ExportedProgram
11+
12+
logger = logging.getLogger(__name__)
13+
_model_explorer_installed = False
14+
15+
try:
16+
# pyre-ignore[21]: We keep track of whether import succeeded manually.
17+
from model_explorer import config, visualize_from_config, visualize_pytorch
18+
19+
_model_explorer_installed = True
20+
except ImportError:
21+
logger.warning("model-explorer is not installed, can't visualize models.")
22+
23+
24+
def is_model_explorer_installed() -> bool:
25+
return _model_explorer_installed
26+
27+
28+
def get_pytest_option_host() -> str | None:
29+
host = get_option(arm_test_options.model_explorer_host)
30+
return str(host) if host else None
31+
32+
33+
def get_pytest_option_port() -> int | None:
34+
port = get_option(arm_test_options.model_explorer_port)
35+
return int(port) if port else None
36+
37+
38+
def visualize(
39+
exported_program: ExportedProgram,
40+
host: Optional[str] = None,
41+
port: Optional[int] = None,
42+
):
43+
"""Attempt visualizing exported_program using model-explorer."""
44+
45+
host = host if host else get_pytest_option_host()
46+
port = port if port else get_pytest_option_port()
47+
48+
if not is_model_explorer_installed():
49+
logger.warning("Can't visualize model since model-explorer is not installed.")
50+
return
51+
52+
# If a host is provided, we attempt connecting to an already running server.
53+
# Note that this needs a modified model-explorer
54+
if host:
55+
explorer_config = (
56+
config()
57+
.add_model_from_pytorch("ExportedProgram", exported_program)
58+
.set_reuse_server(server_host=host, server_port=port)
59+
)
60+
visualize_from_config(explorer_config)
61+
else:
62+
visualize_pytorch(exported_program)

0 commit comments

Comments
 (0)