Skip to content

Commit 03db71c

Browse files
committed
pip install reference_model and use pybind
The reference model is pip installed in setup.sh. Also install vela similarily. Since the installation contains serialization_lib, we don't have to include it as a package in Executorch's setup.py. The serialization_lib is still needed as a submodule in the arm backend to find the tosa.fbs for deserialization. Change-Id: I24fff6c00a3961444de5d878ab169d5ba4c9156d
1 parent 2967302 commit 03db71c

File tree

9 files changed

+108
-102
lines changed

9 files changed

+108
-102
lines changed

backends/arm/arm_backend.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import logging
1515
import os
16-
from typing import final, List, Optional
16+
from typing import cast, final, List, Optional
1717

1818
import serializer.tosa_serializer as ts
1919
from executorch.backends.arm.arm_vela import vela_compile
@@ -32,6 +32,7 @@
3232
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
3333
from executorch.exir.backend.compile_spec_schema import CompileSpec
3434
from torch.export.exported_program import ExportedProgram
35+
from torch.fx import Node
3536

3637
# TOSA backend debug functionality
3738
logger = logging.getLogger(__name__)
@@ -267,6 +268,7 @@ def preprocess( # noqa: C901
267268
node_visitors = get_node_visitors(edge_program, tosa_spec)
268269
input_count = 0
269270
for node in graph_module.graph.nodes:
271+
node = cast(Node, node)
270272
if node.op == "call_function":
271273
process_call_function(node, tosa_graph, node_visitors, tosa_spec)
272274
elif node.op == "placeholder":
@@ -280,15 +282,6 @@ def preprocess( # noqa: C901
280282
# any checking of compatibility.
281283
dbg_fail(node, tosa_graph, artifact_path)
282284

283-
if len(input_order) > 0:
284-
if input_count != len(input_order):
285-
raise RuntimeError(
286-
"The rank of the input order is not equal to amount of input tensors"
287-
)
288-
289-
# TODO: It would be awesome if this dump could somehow be done on top level and not here.
290-
# Problem is that the desc.json has to be created on the tosa_graph object, which we can't
291-
# access from top level.
292285
if artifact_path:
293286
tag = _get_first_delegation_tag(graph_module)
294287
dbg_tosa_dump(
@@ -309,6 +302,4 @@ def preprocess( # noqa: C901
309302
else:
310303
raise RuntimeError(f"Unknown format {output_format}")
311304

312-
# Continueing from above. Can I put tosa_graph into this function?
313-
# debug_handle_map = ...
314305
return PreprocessResult(processed_bytes=binary)

backends/arm/test/common.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -197,19 +197,15 @@ def get_tosa_compile_spec_unbuilt(
197197
the compile spec before calling .build() to finalize it.
198198
"""
199199
if not custom_path:
200-
intermediate_path = maybe_get_tosa_collate_path() or tempfile.mkdtemp(
201-
prefix="arm_tosa_"
202-
)
203-
else:
204-
intermediate_path = custom_path
200+
custom_path = maybe_get_tosa_collate_path()
205201

206-
if not os.path.exists(intermediate_path):
207-
os.makedirs(intermediate_path, exist_ok=True)
202+
if custom_path is not None and not os.path.exists(custom_path):
203+
os.makedirs(custom_path, exist_ok=True)
208204
compile_spec_builder = (
209205
ArmCompileSpecBuilder()
210206
.tosa_compile_spec(tosa_version)
211207
.set_permute_memory_format(permute_memory_to_nhwc)
212-
.dump_intermediate_artifacts_to(intermediate_path)
208+
.dump_intermediate_artifacts_to(custom_path)
213209
)
214210

215211
return compile_spec_builder

backends/arm/test/misc/test_debug_feats.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ def test_numerical_diff_prints(self):
107107
model,
108108
example_inputs=model.get_inputs(),
109109
compile_spec=common.get_tosa_compile_spec(
110-
"TOSA-0.80.0+MI", permute_memory_to_nhwc=True
110+
"TOSA-0.80.0+MI",
111+
permute_memory_to_nhwc=True,
112+
custom_path=tempfile.mkdtemp("diff_print_test"),
111113
),
112114
)
113115
.export()

backends/arm/test/ops/test_cat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def test_cat_tosa_MI(self, operands: tuple[torch.Tensor, ...], dim: int):
121121
def test_cat_4d_tosa_MI(self):
122122
square = torch.ones((2, 2, 2, 2))
123123
for dim in range(-3, 3):
124-
test_data = ((square, square), dim)
124+
test_data = ((square, square.clone()), dim)
125125
self._test_cat_tosa_MI_pipeline(self.Cat(), test_data)
126126

127127
@parameterized.expand(Cat.test_parameters)

backends/arm/test/ops/test_select.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,6 @@ def _test_select_tosa_BI_pipeline(
9393
.check(["torch.ops.quantized_decomposed"])
9494
.to_edge()
9595
.partition()
96-
.dump_artifact()
97-
.dump_operator_distribution()
9896
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
9997
.to_executorch()
10098
.run_method_and_compare_outputs(inputs=test_data)

backends/arm/test/runner_utils.py

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@
1818
import torch
1919

2020
from executorch.backends.arm.test.common import arm_test_options, is_option_enabled
21+
import tosa_reference_model
2122

2223
from torch.export import ExportedProgram
2324
from torch.fx.node import Node
25+
from tosa import TosaGraph
2426

2527
logger = logging.getLogger(__name__)
26-
logger.setLevel(logging.WARNING)
28+
logger.setLevel(logging.CRITICAL)
2729

2830

2931
class QuantizationParams:
@@ -169,7 +171,7 @@ def __init__(
169171
):
170172
self.intermediate_path = intermediate_path
171173
self.tosa_ref_model_path = tosa_ref_model_path or "tosa_reference_model"
172-
assert os.path.exists(
174+
assert self.intermediate_path is None or os.path.exists(
173175
self.intermediate_path
174176
), f"TOSA artifact path don't exist! Path: {self.intermediate_path}"
175177

@@ -334,7 +336,46 @@ def run_corstone(
334336
tosa_ref_output = np.fromfile(out_path_with_suffix, dtype=np.float32)
335337
output_shape = self.output_node.args[0][0].meta["val"].shape
336338
tosa_ref_output = torch.from_numpy(tosa_ref_output).reshape(output_shape)
337-
return [tosa_ref_output]
339+
return tosa_ref_output
340+
341+
def run_tosa_graph(
342+
self, graph: TosaGraph, inputs: list[np.ndarray] | list[torch.Tensor]
343+
) -> torch.Tensor:
344+
"""Runs the TOSA reference model with inputs and returns the result."""
345+
data_np = [
346+
prep_data_for_save(
347+
input, self.is_quantized, self.input_names[i], self.qp_input[i]
348+
)
349+
for i, input in enumerate(inputs)
350+
]
351+
# tosa_profile: 0 = Base Inference, 1 = Main Inference, 2 = Main Training.
352+
tosa_profile = 0 if self.is_quantized else 1
353+
debug_mode = "ALL" if logger.level <= logging.DEBUG else None
354+
outputs, status = tosa_reference_model.run(
355+
graph,
356+
data_np,
357+
verbosity=_tosa_refmodel_loglevel(logger.level),
358+
tosa_profile=tosa_profile,
359+
initialize_variable_tensor_from_numpy=1, # True
360+
debug_mode=debug_mode,
361+
)
362+
363+
assert (
364+
status == tosa_reference_model.GraphStatus.TOSA_VALID
365+
), "Non-valid TOSA given to reference model."
366+
367+
outputs_torch = []
368+
for output in outputs:
369+
output = torch.from_numpy(output)
370+
if self.is_quantized:
371+
# Need to dequant back to FP32 for comparison with torch output
372+
quant_param = self.qp_output
373+
assert (
374+
quant_param is not None
375+
), "There are no quantization parameters, check output parameters"
376+
output = (output.to(torch.float32) - quant_param.zp) * quant_param.scale
377+
outputs_torch.append(output)
378+
return tuple(outputs_torch)
338379

339380
def run_tosa_ref_model(
340381
self,
@@ -419,21 +460,13 @@ def run_tosa_ref_model(
419460
assert (
420461
shutil.which(self.tosa_ref_model_path) is not None
421462
), f"tosa_reference_model tool not found, did you run examples/arm/setup.sh? Path: {self.tosa_ref_model_path}"
422-
loglevel_map = {
423-
logging.INFO: "INFO",
424-
logging.CRITICAL: "LOW",
425-
logging.ERROR: "LOW",
426-
logging.WARNING: "MED",
427-
logging.DEBUG: "HIGH",
428-
logging.NOTSET: "MED",
429-
}
430-
clamped_logging_level = max(min(logger.level // 10 * 10, 50), 0)
463+
431464
cmd_ref_model = [
432465
self.tosa_ref_model_path,
433466
"--test_desc",
434467
desc_file_path,
435468
"-l",
436-
loglevel_map[clamped_logging_level],
469+
_tosa_refmodel_loglevel(logger.level),
437470
]
438471
_run_cmd(cmd_ref_model)
439472

@@ -469,7 +502,10 @@ def run_tosa_ref_model(
469502

470503

471504
def prep_data_for_save(
472-
data, is_quantized: bool, input_name: str, quant_param: QuantizationParams
505+
data: torch.Tensor,
506+
is_quantized: bool,
507+
input_name: str,
508+
quant_param: QuantizationParams,
473509
):
474510
data_np = np.array(data.detach(), order="C").astype(
475511
f"{data.dtype}".replace("torch.", "")
@@ -578,7 +614,6 @@ def dbg_tosa_fb_to_json(tosa_fb: bytes) -> Dict:
578614
assert os.path.exists(
579615
tosa_schema_file
580616
), f"tosa_schema_file: {tosa_schema_file} does not exist"
581-
582617
assert shutil.which("flatc") is not None
583618
cmd_flatc = [
584619
"flatc",
@@ -613,3 +648,19 @@ def dbg_tosa_fb_to_json(tosa_fb: bytes) -> Dict:
613648
pass
614649

615650
return json_out
651+
652+
653+
def _tosa_refmodel_loglevel(loglevel: int) -> str:
654+
"""Converts a logging loglevel to tosa_reference_model logginglevel,
655+
returned as string.
656+
"""
657+
loglevel_map = {
658+
logging.INFO: "INFO",
659+
logging.CRITICAL: "LOW",
660+
logging.ERROR: "LOW",
661+
logging.WARNING: "MED",
662+
logging.DEBUG: "HIGH",
663+
logging.NOTSET: "MED",
664+
}
665+
clamped_logging_level = max(min(loglevel // 10 * 10, 50), 0)
666+
return loglevel_map[clamped_logging_level]

backends/arm/test/tester/arm_tester.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import logging
7+
import tempfile
78

89
from collections import Counter
910
from pprint import pformat
@@ -40,7 +41,11 @@
4041

4142
from executorch.backends.xnnpack.test.tester import Tester
4243
from executorch.devtools.backend_debug import get_delegation_info
43-
from executorch.exir import EdgeCompileConfig, ExecutorchProgramManager
44+
from executorch.exir import (
45+
EdgeCompileConfig,
46+
EdgeProgramManager,
47+
ExecutorchProgramManager,
48+
)
4449
from executorch.exir.backend.compile_spec_schema import CompileSpec
4550
from executorch.exir.backend.partitioner import Partitioner
4651
from executorch.exir.lowered_backend_module import LoweredBackendModule
@@ -133,10 +138,15 @@ def __init__(
133138
super().__init__(dynamic_shapes)
134139
self.tosa_test_util = tosa_test_util
135140

141+
def run(self, artifact: EdgeProgramManager, inputs=None):
142+
self.executorch_program = artifact.to_executorch(self.config)
143+
if module := getattr(
144+
artifact.exported_program().graph_module, "lowered_module_0", None
145+
):
146+
self.buffer = module.processed_bytes
147+
136148
def run_artifact(self, inputs):
137-
tosa_output = self.tosa_test_util.run_tosa_ref_model(
138-
inputs=inputs,
139-
)
149+
tosa_output = self.tosa_test_util.run_tosa_graph(self.buffer, inputs)
140150
return tosa_output
141151

142152

@@ -353,7 +363,7 @@ def run_method_and_compare_outputs(
353363
logger.info(f"Run #{run_iteration}, input shapes: {input_shape_str}")
354364

355365
reference_output = reference_stage.run_artifact(reference_input)
356-
test_output = tuple(test_stage.run_artifact(test_input))
366+
test_output = test_stage.run_artifact(test_input)
357367
if (
358368
is_nhwc
359369
and test_stage == self.stages[self.stage_name(tester.ToExecutorch)]
@@ -520,6 +530,8 @@ def _compare_outputs(
520530
banner = "=" * 40 + "TOSA debug info" + "=" * 40
521531
logger.error(banner)
522532
path_to_tosa_files = self.runner_util.intermediate_path
533+
if path_to_tosa_files is None:
534+
path_to_tosa_files = tempfile.mkdtemp(prefix="executorch_result_dump_")
523535

524536
export_stage = self.stages.get(self.stage_name(tester.Export), None)
525537
quantize_stage = self.stages.get(self.stage_name(tester.Quantize), None)
@@ -529,8 +541,8 @@ def _compare_outputs(
529541
qp_output = _get_output_quantization_params(
530542
export_stage.artifact, output_node
531543
)
532-
logger.error(f"{qp_input=}")
533-
logger.error(f"{qp_output=}")
544+
logger.error(f"Input QuantArgs: {qp_input}")
545+
logger.error(f"Output QuantArgs: {qp_output}")
534546

535547
logger.error(f"{path_to_tosa_files=}")
536548
import os

examples/arm/setup.sh

Lines changed: 12 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,11 @@ ethos_u_base_rev="24.08"
8888

8989
# tosa reference model
9090
tosa_reference_model_url="https://review.mlplatform.org/tosa/reference_model"
91-
tosa_reference_model_rev="f9ea4ab7da19318fe36b1c34d68a3e40fd6e56c5"
91+
tosa_reference_model_rev="ef31e7222e99cb1c24b2aff9fc52b2d609612283"
92+
93+
# vela
94+
vela_repo_url="https://review.mlplatform.org/ml/ethos-u/ethos-u-vela"
95+
vela_rev="57ce18c89ccc6f6309333dccb24ed30dc68b571f"
9296

9397
########
9498
### Mandatory user args
@@ -198,6 +202,7 @@ function setup_ethos_u() {
198202
cd ethos-u
199203
git reset --hard ${ethos_u_base_rev}
200204
python3 ./fetch_externals.py -c ${ethos_u_base_rev}.json fetch
205+
201206
pip install pyelftools
202207
echo "[${FUNCNAME[0]}] Done @ $(git describe --all --long 3> /dev/null) in ${root_dir}/ethos-u dir."
203208
}
@@ -218,64 +223,19 @@ function patch_repo() {
218223
}
219224

220225
function setup_tosa_reference_model() {
221-
# The debug flow on the host includes running on a reference implementation of TOSA
222-
# This is useful primarily for debug of quantization accuracy, but also for internal
223-
# errors for the early codebase
224-
cd "${root_dir}"
225-
if [[ ! -e reference_model ]]; then
226-
git clone ${tosa_reference_model_url}
227-
cd reference_model
228-
git checkout ${tosa_reference_model_rev}
229-
git submodule update --init --recursive
230-
cd ..
231-
fi
232-
cd reference_model
233-
mkdir -p build
234-
cd build
235-
cmake ..
236-
237-
# make use of half the cores for building
238-
if [[ "${OS}" == "Linux" ]]; then
239-
n=$(( $(nproc) / 2 ))
240-
elif [[ "${OS}" == "Darwin" ]]; then
241-
n=$(( $(sysctl -n hw.logicalcpu) / 2 ))
242-
else
243-
n=1
244-
fi
226+
227+
# reference_model flatbuffers version clashes with Vela.
228+
# go with Vela's since it newer.
229+
# Could cause issues down the line, beware..
230+
pip install tosa-tools@git+${tosa_reference_model_url}@${tosa_reference_model_rev} --no-dependencies flatbuffers
245231

246-
if [[ "$n" -lt 1 ]]; then
247-
n=1
248-
fi
249-
250-
make -j"${n}"
251-
cd reference_model
252-
tosa_bin_path=`pwd`
253-
echo "export PATH=\${PATH}:${tosa_bin_path}" >> "${setup_path_script}"
254232
}
255233

256234
function setup_vela() {
257235
#
258236
# Prepare the Vela compiler for AoT to Ethos-U compilation
259237
#
260-
cd "${root_dir}"
261-
if [[ ! -e ethos-u-vela ]]; then
262-
git clone https://review.mlplatform.org/ml/ethos-u/ethos-u-vela
263-
repo_dir="${root_dir}/ethos-u-vela"
264-
base_rev=57ce18c89ccc6f6309333dccb24ed30dc68b571f
265-
patch_repo
266-
fi
267-
cd "${root_dir}/ethos-u-vela"
268-
269-
# different command for conda vs venv
270-
VNV=$(python3 -c "import sys; print('venv') if (sys.prefix != sys.base_prefix) else print('not_venv')")
271-
if [ ${VNV} == "venv" ]; then
272-
pip install .
273-
else
274-
# if not venv, we need the site-path where the vela
275-
vela_path=$(python -c "import site; print(site.USER_BASE+'/bin')")
276-
echo "export PATH=\${PATH}:${vela_path}" >> ${setup_path_script}
277-
pip install . --user
278-
fi
238+
pip install ethos-u-vela@git+${vela_repo_url}@${vela_rev}
279239
}
280240

281241
########

0 commit comments

Comments
 (0)