Skip to content

Commit d5dcc26

Browse files
committed
pip install reference_model and use pybind
The reference model is pip installed in setup.sh. Also install vela similarily. Since the installation contains serialization_lib, we don't have to include it as a package in Executorch's setup.py. The serialization_lib is still needed as a submodule in the arm backend to find the tosa.fbs for deserialization. Change-Id: I24fff6c00a3961444de5d878ab169d5ba4c9156d
1 parent 22a75be commit d5dcc26

File tree

10 files changed

+104
-103
lines changed

10 files changed

+104
-103
lines changed

backends/arm/arm_backend.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import logging
1515
import os
16-
from typing import final, List, Optional
16+
from typing import cast, final, List, Optional
1717

1818
import serializer.tosa_serializer as ts
1919
from executorch.backends.arm.arm_vela import vela_compile
@@ -32,6 +32,7 @@
3232
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
3333
from executorch.exir.backend.compile_spec_schema import CompileSpec
3434
from torch.export.exported_program import ExportedProgram
35+
from torch.fx import Node
3536

3637
# TOSA backend debug functionality
3738
logger = logging.getLogger(__name__)
@@ -269,6 +270,7 @@ def preprocess( # noqa: C901
269270
node_visitors = get_node_visitors(edge_program, tosa_spec)
270271
input_count = 0
271272
for node in graph_module.graph.nodes:
273+
node = cast(Node, node)
272274
if node.op == "call_function":
273275
process_call_function(node, tosa_graph, node_visitors, tosa_spec)
274276
elif node.op == "placeholder":
@@ -282,15 +284,6 @@ def preprocess( # noqa: C901
282284
# any checking of compatibility.
283285
dbg_fail(node, tosa_graph, artifact_path)
284286

285-
if len(input_order) > 0:
286-
if input_count != len(input_order):
287-
raise RuntimeError(
288-
"The rank of the input order is not equal to amount of input tensors"
289-
)
290-
291-
# TODO: It would be awesome if this dump could somehow be done on top level and not here.
292-
# Problem is that the desc.json has to be created on the tosa_graph object, which we can't
293-
# access from top level.
294287
if artifact_path:
295288
tag = _get_first_delegation_tag(graph_module)
296289
dbg_tosa_dump(
@@ -311,6 +304,4 @@ def preprocess( # noqa: C901
311304
else:
312305
raise RuntimeError(f"Unknown format {output_format}")
313306

314-
# Continueing from above. Can I put tosa_graph into this function?
315-
# debug_handle_map = ...
316307
return PreprocessResult(processed_bytes=binary)

backends/arm/test/common.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,19 +74,15 @@ def get_tosa_compile_spec_unbuilt(
7474
the compile spec before calling .build() to finalize it.
7575
"""
7676
if not custom_path:
77-
intermediate_path = maybe_get_tosa_collate_path() or tempfile.mkdtemp(
78-
prefix="arm_tosa_"
79-
)
80-
else:
81-
intermediate_path = custom_path
77+
custom_path = maybe_get_tosa_collate_path()
8278

83-
if not os.path.exists(intermediate_path):
84-
os.makedirs(intermediate_path, exist_ok=True)
79+
if custom_path is not None and not os.path.exists(custom_path):
80+
os.makedirs(custom_path, exist_ok=True)
8581
compile_spec_builder = (
8682
ArmCompileSpecBuilder()
8783
.tosa_compile_spec(tosa_version)
8884
.set_permute_memory_format(permute_memory_to_nhwc)
89-
.dump_intermediate_artifacts_to(intermediate_path)
85+
.dump_intermediate_artifacts_to(custom_path)
9086
)
9187

9288
return compile_spec_builder

backends/arm/test/misc/test_debug_feats.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,9 @@ def test_numerical_diff_prints(self):
111111
model,
112112
example_inputs=model.get_inputs(),
113113
compile_spec=common.get_tosa_compile_spec(
114-
"TOSA-0.80.0+MI", permute_memory_to_nhwc=True
114+
"TOSA-0.80.0+MI",
115+
permute_memory_to_nhwc=True,
116+
custom_path=tempfile.mkdtemp("diff_print_test"),
115117
),
116118
)
117119
.export()

backends/arm/test/ops/test_cat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def test_cat_tosa_MI(self, operands: tuple[torch.Tensor, ...], dim: int):
124124
def test_cat_4d_tosa_MI(self):
125125
square = torch.ones((2, 2, 2, 2))
126126
for dim in range(-3, 3):
127-
test_data = ((square, square), dim)
127+
test_data = ((square, square.clone()), dim)
128128
self._test_cat_tosa_MI_pipeline(self.Cat(), test_data)
129129

130130
@parameterized.expand(Cat.test_parameters)

backends/arm/test/ops/test_scalars.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def _test_add_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: tuple):
157157
def test_MI(self, test_name: str, op: torch.nn.Module, x, y):
158158
expected_exception = None
159159
if any(token in test_name for token in ("Sub_int", "Sub__int")):
160-
expected_exception = RuntimeError
160+
expected_exception = ValueError
161161
elif test_name.endswith("_st"):
162162
expected_exception = AttributeError
163163

backends/arm/test/ops/test_select.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,6 @@ def _test_select_tosa_BI_pipeline(
9393
.check(["torch.ops.quantized_decomposed"])
9494
.to_edge()
9595
.partition()
96-
.dump_artifact()
97-
.dump_operator_distribution()
9896
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
9997
.to_executorch()
10098
.run_method_and_compare_outputs(inputs=test_data)

backends/arm/test/runner_utils.py

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,16 @@
1616

1717
import numpy as np
1818
import torch
19+
import tosa_reference_model
1920

2021
from executorch.backends.arm.test.conftest import arm_test_options, is_option_enabled
2122

2223
from torch.export import ExportedProgram
2324
from torch.fx.node import Node
25+
from tosa import TosaGraph
2426

2527
logger = logging.getLogger(__name__)
26-
logger.setLevel(logging.WARNING)
28+
logger.setLevel(logging.CRITICAL)
2729

2830

2931
class QuantizationParams:
@@ -169,7 +171,7 @@ def __init__(
169171
):
170172
self.intermediate_path = intermediate_path
171173
self.tosa_ref_model_path = tosa_ref_model_path or "tosa_reference_model"
172-
assert os.path.exists(
174+
assert self.intermediate_path is None or os.path.exists(
173175
self.intermediate_path
174176
), f"TOSA artifact path don't exist! Path: {self.intermediate_path}"
175177

@@ -332,7 +334,46 @@ def run_corstone(
332334
tosa_ref_output = np.fromfile(out_path_with_suffix, dtype=np.float32)
333335
output_shape = self.output_node.args[0][0].meta["val"].shape
334336
tosa_ref_output = torch.from_numpy(tosa_ref_output).reshape(output_shape)
335-
return [tosa_ref_output]
337+
return tosa_ref_output
338+
339+
def run_tosa_graph(
340+
self, graph: TosaGraph, inputs: list[np.ndarray] | list[torch.Tensor]
341+
) -> torch.Tensor:
342+
"""Runs the TOSA reference model with inputs and returns the result."""
343+
data_np = [
344+
prep_data_for_save(
345+
input, self.is_quantized, self.input_names[i], self.qp_input[i]
346+
)
347+
for i, input in enumerate(inputs)
348+
]
349+
# tosa_profile: 0 = Base Inference, 1 = Main Inference, 2 = Main Training.
350+
tosa_profile = 0 if self.is_quantized else 1
351+
debug_mode = "ALL" if logger.level <= logging.DEBUG else None
352+
outputs, status = tosa_reference_model.run(
353+
graph,
354+
data_np,
355+
verbosity=_tosa_refmodel_loglevel(logger.level),
356+
tosa_profile=tosa_profile,
357+
initialize_variable_tensor_from_numpy=1, # True
358+
debug_mode=debug_mode,
359+
)
360+
361+
assert (
362+
status == tosa_reference_model.GraphStatus.TOSA_VALID
363+
), "Non-valid TOSA given to reference model."
364+
365+
outputs_torch = []
366+
for output in outputs:
367+
output = torch.from_numpy(output)
368+
if self.is_quantized:
369+
# Need to dequant back to FP32 for comparison with torch output
370+
quant_param = self.qp_output
371+
assert (
372+
quant_param is not None
373+
), "There are no quantization parameters, check output parameters"
374+
output = (output.to(torch.float32) - quant_param.zp) * quant_param.scale
375+
outputs_torch.append(output)
376+
return tuple(outputs_torch)
336377

337378
def run_tosa_ref_model(
338379
self,
@@ -417,21 +458,13 @@ def run_tosa_ref_model(
417458
assert (
418459
shutil.which(self.tosa_ref_model_path) is not None
419460
), f"tosa_reference_model tool not found, did you run examples/arm/setup.sh? Path: {self.tosa_ref_model_path}"
420-
loglevel_map = {
421-
logging.INFO: "INFO",
422-
logging.CRITICAL: "LOW",
423-
logging.ERROR: "LOW",
424-
logging.WARNING: "MED",
425-
logging.DEBUG: "HIGH",
426-
logging.NOTSET: "MED",
427-
}
428-
clamped_logging_level = max(min(logger.level // 10 * 10, 50), 0)
461+
429462
cmd_ref_model = [
430463
self.tosa_ref_model_path,
431464
"--test_desc",
432465
desc_file_path,
433466
"-l",
434-
loglevel_map[clamped_logging_level],
467+
_tosa_refmodel_loglevel(logger.level),
435468
]
436469
_run_cmd(cmd_ref_model)
437470

@@ -467,7 +500,10 @@ def run_tosa_ref_model(
467500

468501

469502
def prep_data_for_save(
470-
data, is_quantized: bool, input_name: str, quant_param: QuantizationParams
503+
data: torch.Tensor,
504+
is_quantized: bool,
505+
input_name: str,
506+
quant_param: QuantizationParams,
471507
):
472508
data_np = np.array(data.detach(), order="C").astype(
473509
f"{data.dtype}".replace("torch.", "")
@@ -576,7 +612,6 @@ def dbg_tosa_fb_to_json(tosa_fb: bytes) -> Dict:
576612
assert os.path.exists(
577613
tosa_schema_file
578614
), f"tosa_schema_file: {tosa_schema_file} does not exist"
579-
580615
assert shutil.which("flatc") is not None
581616
cmd_flatc = [
582617
"flatc",
@@ -611,3 +646,19 @@ def dbg_tosa_fb_to_json(tosa_fb: bytes) -> Dict:
611646
pass
612647

613648
return json_out
649+
650+
651+
def _tosa_refmodel_loglevel(loglevel: int) -> str:
652+
"""Converts a logging loglevel to tosa_reference_model logginglevel,
653+
returned as string.
654+
"""
655+
loglevel_map = {
656+
logging.INFO: "INFO",
657+
logging.CRITICAL: "LOW",
658+
logging.ERROR: "LOW",
659+
logging.WARNING: "MED",
660+
logging.DEBUG: "HIGH",
661+
logging.NOTSET: "MED",
662+
}
663+
clamped_logging_level = max(min(loglevel // 10 * 10, 50), 0)
664+
return loglevel_map[clamped_logging_level]

backends/arm/test/tester/arm_tester.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import logging
7+
import tempfile
78

89
from collections import Counter
910
from pprint import pformat
@@ -35,7 +36,11 @@
3536

3637
from executorch.backends.xnnpack.test.tester import Tester
3738
from executorch.devtools.backend_debug import get_delegation_info
38-
from executorch.exir import EdgeCompileConfig, ExecutorchProgramManager
39+
from executorch.exir import (
40+
EdgeCompileConfig,
41+
EdgeProgramManager,
42+
ExecutorchProgramManager,
43+
)
3944
from executorch.exir.backend.compile_spec_schema import CompileSpec
4045
from executorch.exir.backend.partitioner import Partitioner
4146
from executorch.exir.lowered_backend_module import LoweredBackendModule
@@ -128,10 +133,15 @@ def __init__(
128133
super().__init__(dynamic_shapes)
129134
self.tosa_test_util = tosa_test_util
130135

136+
def run(self, artifact: EdgeProgramManager, inputs=None):
137+
self.executorch_program = artifact.to_executorch(self.config)
138+
if module := getattr(
139+
artifact.exported_program().graph_module, "lowered_module_0", None
140+
):
141+
self.buffer = module.processed_bytes
142+
131143
def run_artifact(self, inputs):
132-
tosa_output = self.tosa_test_util.run_tosa_ref_model(
133-
inputs=inputs,
134-
)
144+
tosa_output = self.tosa_test_util.run_tosa_graph(self.buffer, inputs)
135145
return tosa_output
136146

137147

@@ -348,7 +358,7 @@ def run_method_and_compare_outputs(
348358
logger.info(f"Run #{run_iteration}, input shapes: {input_shape_str}")
349359

350360
reference_output = reference_stage.run_artifact(reference_input)
351-
test_output = tuple(test_stage.run_artifact(test_input))
361+
test_output = test_stage.run_artifact(test_input)
352362
if (
353363
is_nhwc
354364
and test_stage == self.stages[self.stage_name(tester.ToExecutorch)]
@@ -515,6 +525,8 @@ def _compare_outputs(
515525
banner = "=" * 40 + "TOSA debug info" + "=" * 40
516526
logger.error(banner)
517527
path_to_tosa_files = self.runner_util.intermediate_path
528+
if path_to_tosa_files is None:
529+
path_to_tosa_files = tempfile.mkdtemp(prefix="executorch_result_dump_")
518530

519531
export_stage = self.stages.get(self.stage_name(tester.Export), None)
520532
quantize_stage = self.stages.get(self.stage_name(tester.Quantize), None)
@@ -524,8 +536,8 @@ def _compare_outputs(
524536
qp_output = _get_output_quantization_params(
525537
export_stage.artifact, output_node
526538
)
527-
logger.error(f"{qp_input=}")
528-
logger.error(f"{qp_output=}")
539+
logger.error(f"Input QuantArgs: {qp_input}")
540+
logger.error(f"Output QuantArgs: {qp_output}")
529541

530542
logger.error(f"{path_to_tosa_files=}")
531543
import os

examples/arm/setup.sh

Lines changed: 7 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ ethos_u_base_rev="24.08"
8888

8989
# tosa reference model
9090
tosa_reference_model_url="https://review.mlplatform.org/tosa/reference_model"
91-
tosa_reference_model_rev="f9ea4ab7da19318fe36b1c34d68a3e40fd6e56c5"
91+
tosa_reference_model_rev="c5570b79e90c3a36ab8c4ddb8ee3fbc2cd3f7c38"
9292

9393
# vela
9494
vela_repo_url="https://review.mlplatform.org/ml/ethos-u/ethos-u-vela"
@@ -223,64 +223,19 @@ function patch_repo() {
223223
}
224224

225225
function setup_tosa_reference_model() {
226-
# The debug flow on the host includes running on a reference implementation of TOSA
227-
# This is useful primarily for debug of quantization accuracy, but also for internal
228-
# errors for the early codebase
229-
cd "${root_dir}"
230-
if [[ ! -e reference_model ]]; then
231-
git clone ${tosa_reference_model_url}
232-
cd reference_model
233-
git checkout ${tosa_reference_model_rev}
234-
git submodule update --init --recursive
235-
cd ..
236-
fi
237-
cd reference_model
238-
mkdir -p build
239-
cd build
240-
cmake ..
241-
242-
# make use of half the cores for building
243-
if [[ "${OS}" == "Linux" ]]; then
244-
n=$(( $(nproc) / 2 ))
245-
elif [[ "${OS}" == "Darwin" ]]; then
246-
n=$(( $(sysctl -n hw.logicalcpu) / 2 ))
247-
else
248-
n=1
249-
fi
250-
251-
if [[ "$n" -lt 1 ]]; then
252-
n=1
253-
fi
226+
227+
# reference_model flatbuffers version clashes with Vela.
228+
# go with Vela's since it newer.
229+
# Could cause issues down the line, beware..
230+
pip install tosa-tools@git+${tosa_reference_model_url}@${tosa_reference_model_rev} --no-dependencies flatbuffers
254231

255-
make -j"${n}"
256-
cd reference_model
257-
tosa_bin_path=`pwd`
258-
echo "export PATH=\${PATH}:${tosa_bin_path}" >> "${setup_path_script}"
259232
}
260233

261234
function setup_vela() {
262235
#
263236
# Prepare the Vela compiler for AoT to Ethos-U compilation
264237
#
265-
cd "${root_dir}"
266-
if [[ ! -e ethos-u-vela ]]; then
267-
git clone ${vela_repo_url}
268-
repo_dir="${root_dir}/ethos-u-vela"
269-
base_rev=${vela_rev}
270-
patch_repo
271-
fi
272-
cd "${root_dir}/ethos-u-vela"
273-
274-
# different command for conda vs venv
275-
VNV=$(python3 -c "import sys; print('venv') if (sys.prefix != sys.base_prefix) else print('not_venv')")
276-
if [ ${VNV} == "venv" ]; then
277-
pip install .
278-
else
279-
# if not venv, we need the site-path where the vela
280-
vela_path=$(python -c "import site; print(site.USER_BASE+'/bin')")
281-
echo "export PATH=\${PATH}:${vela_path}" >> ${setup_path_script}
282-
pip install . --user
283-
fi
238+
pip install ethos-u-vela@git+${vela_repo_url}@${vela_rev}
284239
}
285240

286241
########

0 commit comments

Comments
 (0)