Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions backends/arm/test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class arm_test_options(Enum):
corstone300 = auto()
dump_path = auto()
date_format = auto()
model_explorer_host = auto()
model_explorer_port = auto()


_test_options: dict[arm_test_options, Any] = {}
Expand All @@ -41,6 +43,18 @@ def pytest_addoption(parser):
parser.addoption("--arm_run_corstone300", action="store_true")
parser.addoption("--default_dump_path", default=None)
parser.addoption("--date_format", default="%d-%b-%H:%M:%S")
parser.addoption(
"--model_explorer_host",
action="store",
default=None,
help="If set, tries to connect to existing model-explorer server rather than starting a new one.",
)
parser.addoption(
"--model_explorer_port",
action="store",
default=None,
help="Set the port of the model explorer server. If not set, tries ports between 8080 and 8099.",
)


def pytest_configure(config):
Expand All @@ -62,7 +76,19 @@ def pytest_configure(config):
raise RuntimeError(
f"Supplied argument 'default_dump_path={dump_path}' that does not exist or is not a directory."
)
if config.option.model_explorer_port:
if not str.isdecimal(config.option.model_explorer_port):
raise RuntimeError(
f"--model_explorer_port needs to be an integer, got '{config.option.model_explorer_port}'."
)
else:
_test_options[arm_test_options.model_explorer_port] = int(
config.option.model_explorer_port
)
_test_options[arm_test_options.date_format] = config.option.date_format
_test_options[arm_test_options.model_explorer_host] = (
config.option.model_explorer_host
)
logging.basicConfig(level=logging.INFO, stream=sys.stdout)


Expand Down
20 changes: 19 additions & 1 deletion backends/arm/test/tester/arm_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.

import logging

from collections import Counter
from pprint import pformat
from typing import Any, Iterable, List, Literal, Optional, Tuple, Union
Expand Down Expand Up @@ -35,6 +34,7 @@
dbg_tosa_fb_to_json,
RunnerUtil,
)
from executorch.backends.arm.test.visualize import visualize
from executorch.backends.arm.tosa_mapping import extract_tensor_meta

from executorch.backends.xnnpack.test.tester import Tester
Expand All @@ -47,6 +47,8 @@
from tabulate import tabulate
from torch.export.graph_signature import ExportGraphSignature, InputSpec, OutputSpec
from torch.fx import Graph
from typing_extensions import Self


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -473,6 +475,22 @@ def dump_dtype_distribution(
_dump_str(to_print, path_to_dump)
return self

def visualize(self) -> Self:
exported_program = self._get_exported_program()
visualize(exported_program)
return self

def _get_exported_program(self):
match self.cur:
case "Export":
return self.get_artifact()
case "ToEdge" | "Partition":
return self.get_artifact().exported_program()
case _:
raise RuntimeError(
"Can only get the exported program for the Export, ToEdge, or Partition stage."
)

@staticmethod
def _calculate_reference_output(
module: Union[torch.fx.GraphModule, torch.nn.Module], inputs
Expand Down
62 changes: 62 additions & 0 deletions backends/arm/test/visualize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
from typing import Optional

from executorch.backends.arm.test.common import arm_test_options, get_option
from torch.export import ExportedProgram

logger = logging.getLogger(__name__)
_model_explorer_installed = False

try:
# pyre-ignore[21]: We keep track of whether import succeeded manually.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you just test locally?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this has only been tested locally. I am not sure how to test in ci since it requires an external dependency

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Olivia-liu - what do you think of this vs. our SDK plans for visualization?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bump @Olivia-liu, some kind of visualization would be great, so would really appreciate merging unless you have other plans soon

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Erik-Lundell Sorry about the delay! This is great. Glad to see integration with model-explorer going on!

from model_explorer import config, visualize_from_config, visualize_pytorch

_model_explorer_installed = True
except ImportError:
logger.warning("model-explorer is not installed, can't visualize models.")


def is_model_explorer_installed() -> bool:
return _model_explorer_installed


def get_pytest_option_host() -> str | None:
host = get_option(arm_test_options.model_explorer_host)
return str(host) if host else None


def get_pytest_option_port() -> int | None:
port = get_option(arm_test_options.model_explorer_port)
return int(port) if port else None


def visualize(
exported_program: ExportedProgram,
host: Optional[str] = None,
port: Optional[int] = None,
):
"""Attempt visualizing exported_program using model-explorer."""

host = host if host else get_pytest_option_host()
port = port if port else get_pytest_option_port()

if not is_model_explorer_installed():
logger.warning("Can't visualize model since model-explorer is not installed.")
return

# If a host is provided, we attempt connecting to an already running server.
# Note that this needs a modified model-explorer
if host:
explorer_config = (
config()
.add_model_from_pytorch("ExportedProgram", exported_program)
.set_reuse_server(server_host=host, server_port=port)
)
visualize_from_config(explorer_config)
else:
visualize_pytorch(exported_program)
Loading