Skip to content

Commit 485a5df

Browse files
authored
Revert "Run tosa_reference_model using python binding" (#6729)
Revert "Run tosa_reference_model using python binding (#6658)" This reverts commit 4bbe994.
1 parent 39e5b91 commit 485a5df

File tree

8 files changed

+59
-96
lines changed

8 files changed

+59
-96
lines changed

backends/arm/arm_backend.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import logging
1515
import os
16-
from typing import cast, final, List, Optional
16+
from typing import final, List, Optional
1717

1818
import serializer.tosa_serializer as ts
1919
from executorch.backends.arm.arm_vela import vela_compile
@@ -31,7 +31,6 @@
3131
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
3232
from executorch.exir.backend.compile_spec_schema import CompileSpec
3333
from torch.export.exported_program import ExportedProgram
34-
from torch.fx import Node
3534

3635
# TOSA backend debug functionality
3736
logger = logging.getLogger(__name__)
@@ -226,7 +225,6 @@ def preprocess( # noqa: C901
226225
node_visitors = get_node_visitors(edge_program)
227226

228227
for node in graph_module.graph.nodes:
229-
node = cast(Node, node)
230228
if node.op == "call_function":
231229
process_call_function(node, tosa_graph, node_visitors)
232230
elif node.op == "placeholder":
@@ -238,6 +236,9 @@ def preprocess( # noqa: C901
238236
# any checking of compatibility.
239237
dbg_fail(node, tosa_graph, artifact_path)
240238

239+
# TODO: It would be awesome if this dump could somehow be done on top level and not here.
240+
# Problem is that the desc.json has to be created on the tosa_graph object, which we can't
241+
# access from top level.
241242
if artifact_path:
242243
tag = _get_first_delegation_tag(graph_module)
243244
dbg_tosa_dump(
@@ -258,4 +259,6 @@ def preprocess( # noqa: C901
258259
else:
259260
raise RuntimeError(f"Unknown format {output_format}")
260261

262+
# Continueing from above. Can I put tosa_graph into this function?
263+
# debug_handle_map = ...
261264
return PreprocessResult(processed_bytes=binary)

backends/arm/test/common.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,15 +192,19 @@ def get_tosa_compile_spec_unbuilt(
192192
the compile spec before calling .build() to finalize it.
193193
"""
194194
if not custom_path:
195-
custom_path = maybe_get_tosa_collate_path()
195+
intermediate_path = maybe_get_tosa_collate_path() or tempfile.mkdtemp(
196+
prefix="arm_tosa_"
197+
)
198+
else:
199+
intermediate_path = custom_path
196200

197-
if custom_path is not None and not os.path.exists(custom_path):
198-
os.makedirs(custom_path, exist_ok=True)
201+
if not os.path.exists(intermediate_path):
202+
os.makedirs(intermediate_path, exist_ok=True)
199203
compile_spec_builder = (
200204
ArmCompileSpecBuilder()
201205
.tosa_compile_spec()
202206
.set_permute_memory_format(permute_memory_to_nhwc)
203-
.dump_intermediate_artifacts_to(custom_path)
207+
.dump_intermediate_artifacts_to(intermediate_path)
204208
)
205209

206210
return compile_spec_builder

backends/arm/test/misc/test_debug_feats.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,7 @@ def test_numerical_diff_prints(self):
107107
ArmTester(
108108
model,
109109
example_inputs=model.get_inputs(),
110-
compile_spec=common.get_tosa_compile_spec(
111-
permute_memory_to_nhwc=True,
112-
custom_path=tempfile.mkdtemp("diff_print_test"),
113-
),
110+
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False),
114111
)
115112
.export()
116113
.to_edge()

backends/arm/test/ops/test_cat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def test_cat_tosa_MI(self, operands: tuple[torch.Tensor, ...], dim: int):
121121
def test_cat_4d_tosa_MI(self):
122122
square = torch.ones((2, 2, 2, 2))
123123
for dim in range(-3, 3):
124-
test_data = ((square, square.clone()), dim)
124+
test_data = ((square, square), dim)
125125
self._test_cat_tosa_MI_pipeline(self.Cat(), test_data)
126126

127127
@parameterized.expand(Cat.test_parameters)

backends/arm/test/ops/test_select.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ def _test_select_tosa_BI_pipeline(
9393
.check(["torch.ops.quantized_decomposed"])
9494
.to_edge()
9595
.partition()
96+
.dump_artifact()
97+
.dump_operator_distribution()
9698
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
9799
.to_executorch()
98100
.run_method_and_compare_outputs(inputs=test_data)
@@ -160,14 +162,12 @@ def test_select_int_tosa_MI(self, test_data: test_data_t):
160162
)
161163

162164
@parameterized.expand(test_data_suite)
163-
@unittest.skip
164165
def test_select_copy_tosa_BI(self, test_data: test_data_t):
165166
self._test_select_tosa_BI_pipeline(
166167
self.SelectCopy(), test_data, export_target="torch.ops.aten.select_copy.int"
167168
)
168169

169170
@parameterized.expand(test_data_suite)
170-
@unittest.skip
171171
def test_select_int_tosa_BI(self, test_data: test_data_t):
172172
self._test_select_tosa_BI_pipeline(
173173
self.SelectInt(), test_data, export_target="torch.ops.aten.select.int"

backends/arm/test/runner_utils.py

Lines changed: 14 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,11 @@
1717
import numpy as np
1818
import torch
1919

20-
import tosa_reference_model
21-
2220
from torch.export import ExportedProgram
2321
from torch.fx.node import Node
24-
from tosa import TosaGraph
2522

2623
logger = logging.getLogger(__name__)
27-
logger.setLevel(logging.CRITICAL)
24+
logger.setLevel(logging.WARNING)
2825

2926

3027
class QuantizationParams:
@@ -170,7 +167,7 @@ def __init__(
170167
):
171168
self.intermediate_path = intermediate_path
172169
self.tosa_ref_model_path = tosa_ref_model_path or "tosa_reference_model"
173-
assert self.intermediate_path is None or os.path.exists(
170+
assert os.path.exists(
174171
self.intermediate_path
175172
), f"TOSA artifact path don't exist! Path: {self.intermediate_path}"
176173

@@ -326,46 +323,7 @@ def run_corstone(
326323
tosa_ref_output = np.fromfile(out_path_with_suffix, dtype=np.float32)
327324
output_shape = self.output_node.args[0][0].meta["val"].shape
328325
tosa_ref_output = torch.from_numpy(tosa_ref_output).reshape(output_shape)
329-
return tosa_ref_output
330-
331-
def run_tosa_graph(
332-
self, graph: TosaGraph, inputs: list[np.ndarray] | list[torch.Tensor]
333-
) -> torch.Tensor:
334-
"""Runs the TOSA reference model with inputs and returns the result."""
335-
data_np = [
336-
prep_data_for_save(
337-
input, self.is_quantized, self.input_names[i], self.qp_input[i]
338-
)
339-
for i, input in enumerate(inputs)
340-
]
341-
# tosa_profile: 0 = Base Inference, 1 = Main Inference, 2 = Main Training.
342-
tosa_profile = 0 if self.is_quantized else 1
343-
debug_mode = "ALL" if logger.level <= logging.DEBUG else None
344-
outputs, status = tosa_reference_model.run(
345-
graph,
346-
data_np,
347-
verbosity=_tosa_refmodel_loglevel(logger.level),
348-
tosa_profile=tosa_profile,
349-
initialize_variable_tensor_from_numpy=1, # True
350-
debug_mode=debug_mode,
351-
)
352-
353-
assert (
354-
status == tosa_reference_model.GraphStatus.TOSA_VALID
355-
), "Non-valid TOSA given to reference model."
356-
357-
outputs_torch = []
358-
for output in outputs:
359-
output = output.astype(np.float32)
360-
if self.is_quantized:
361-
# Need to dequant back to FP32 for comparison with torch output
362-
quant_param = self.qp_output
363-
assert (
364-
quant_param is not None
365-
), "There are no quantization parameters, check output parameters"
366-
output = (output - quant_param.zp) * quant_param.scale
367-
outputs_torch.append(torch.from_numpy(output))
368-
return tuple(outputs_torch)
326+
return [tosa_ref_output]
369327

370328
def run_tosa_ref_model(
371329
self,
@@ -450,13 +408,21 @@ def run_tosa_ref_model(
450408
assert (
451409
shutil.which(self.tosa_ref_model_path) is not None
452410
), f"tosa_reference_model tool not found, did you run examples/arm/setup.sh? Path: {self.tosa_ref_model_path}"
453-
411+
loglevel_map = {
412+
logging.INFO: "INFO",
413+
logging.CRITICAL: "LOW",
414+
logging.ERROR: "LOW",
415+
logging.WARNING: "MED",
416+
logging.DEBUG: "HIGH",
417+
logging.NOTSET: "MED",
418+
}
419+
clamped_logging_level = max(min(logger.level // 10 * 10, 50), 0)
454420
cmd_ref_model = [
455421
self.tosa_ref_model_path,
456422
"--test_desc",
457423
desc_file_path,
458424
"-l",
459-
_tosa_refmodel_loglevel(logger.level),
425+
loglevel_map[clamped_logging_level],
460426
]
461427
_run_cmd(cmd_ref_model)
462428

@@ -492,10 +458,7 @@ def run_tosa_ref_model(
492458

493459

494460
def prep_data_for_save(
495-
data: torch.Tensor,
496-
is_quantized: bool,
497-
input_name: str,
498-
quant_param: QuantizationParams,
461+
data, is_quantized: bool, input_name: str, quant_param: QuantizationParams
499462
):
500463
data_np = np.array(data.detach(), order="C").astype(
501464
f"{data.dtype}".replace("torch.", "")
@@ -639,19 +602,3 @@ def dbg_tosa_fb_to_json(tosa_fb: bytes) -> Dict:
639602
pass
640603

641604
return json_out
642-
643-
644-
def _tosa_refmodel_loglevel(loglevel: int) -> str:
645-
"""Converts a logging loglevel to tosa_reference_model logginglevel,
646-
returned as string.
647-
"""
648-
loglevel_map = {
649-
logging.INFO: "INFO",
650-
logging.CRITICAL: "LOW",
651-
logging.ERROR: "LOW",
652-
logging.WARNING: "MED",
653-
logging.DEBUG: "HIGH",
654-
logging.NOTSET: "MED",
655-
}
656-
clamped_logging_level = max(min(loglevel // 10 * 10, 50), 0)
657-
return loglevel_map[clamped_logging_level]

backends/arm/test/tester/arm_tester.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
from executorch.backends.xnnpack.test.tester import Tester
4141
from executorch.devtools.backend_debug import get_delegation_info
42-
from executorch.exir import EdgeCompileConfig, EdgeProgramManager
42+
from executorch.exir import EdgeCompileConfig
4343
from executorch.exir.backend.compile_spec_schema import CompileSpec
4444

4545
from executorch.exir.lowered_backend_module import LoweredBackendModule
@@ -120,15 +120,10 @@ def __init__(
120120
super().__init__(dynamic_shapes)
121121
self.tosa_test_util = tosa_test_util
122122

123-
def run(self, artifact: EdgeProgramManager, inputs=None):
124-
self.executorch_program = artifact.to_executorch(self.config)
125-
if module := getattr(
126-
artifact.exported_program().graph_module, "lowered_module_0", None
127-
):
128-
self.buffer = module.processed_bytes
129-
130123
def run_artifact(self, inputs):
131-
tosa_output = self.tosa_test_util.run_tosa_graph(self.buffer, inputs)
124+
tosa_output = self.tosa_test_util.run_tosa_ref_model(
125+
inputs=inputs,
126+
)
132127
return tosa_output
133128

134129

@@ -321,7 +316,7 @@ def run_method_and_compare_outputs(
321316
logger.info(f"Run #{run_iteration}, input shapes: {input_shape_str}")
322317

323318
reference_output = reference_stage.run_artifact(reference_input)
324-
test_output = test_stage.run_artifact(test_input)
319+
test_output = tuple(test_stage.run_artifact(test_input))
325320
if (
326321
is_nhwc
327322
and test_stage == self.stages[self.stage_name(tester.ToExecutorch)]

examples/arm/setup.sh

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ ethos_u_base_rev="24.08"
8888

8989
# tosa reference model
9090
tosa_reference_model_url="https://review.mlplatform.org/tosa/reference_model"
91-
tosa_reference_model_rev="ef31e7222e99cb1c24b2aff9fc52b2d609612283"
91+
tosa_reference_model_rev="f9ea4ab7da19318fe36b1c34d68a3e40fd6e56c5"
9292

9393
########
9494
### Mandatory user args
@@ -227,13 +227,30 @@ function setup_tosa_reference_model() {
227227
cd reference_model
228228
git checkout ${tosa_reference_model_rev}
229229
git submodule update --init --recursive
230+
cd ..
231+
fi
232+
cd reference_model
233+
mkdir -p build
234+
cd build
235+
cmake ..
236+
237+
# make use of half the cores for building
238+
if [[ "${OS}" == "Linux" ]]; then
239+
n=$(( $(nproc) / 2 ))
240+
elif [[ "${OS}" == "Darwin" ]]; then
241+
n=$(( $(sysctl -n hw.logicalcpu) / 2 ))
242+
else
243+
n=1
230244
fi
231245

232-
echo "pip installing reference_model..."
233-
repo_dir="${root_dir}/reference_model"
234-
cd $repo_dir
235-
pip install .
246+
if [[ "$n" -lt 1 ]]; then
247+
n=1
248+
fi
236249

250+
make -j"${n}"
251+
cd reference_model
252+
tosa_bin_path=`pwd`
253+
echo "export PATH=\${PATH}:${tosa_bin_path}" >> "${setup_path_script}"
237254
}
238255

239256
function setup_vela() {

0 commit comments

Comments
 (0)