Skip to content

Commit 16aac24

Browse files
authored
Update default executor runner with new optional options (#14017)
* By default not all outputs are printed. Adds option for printing all output. * Adds option to save output to file. * In the executor runner all inputs get hard coded to ones. Adds optional input option, in which case tensor inputs will be written from supplied binary input files. For Arm backend update the Arm VKML unit test runner as a user with real inputs and file output. Enables add/acos_unit tests to run on Vulkan runtime that depend on this.
1 parent 465e65f commit 16aac24

File tree

8 files changed

+233
-61
lines changed

8 files changed

+233
-61
lines changed

backends/arm/test/ops/test_acos.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55
from typing import Tuple
66

7+
import pytest
78
import torch
89

910
from executorch.backends.arm.test import common
@@ -102,8 +103,12 @@ def test_acos_vgf_FP(test_data: Tuple):
102103
[],
103104
[],
104105
tosa_version="TOSA-1.0+FP",
106+
run_on_vulkan_runtime=True,
105107
)
106-
pipeline.run()
108+
try:
109+
pipeline.run()
110+
except FileNotFoundError as e:
111+
pytest.skip(f"VKML executor_runner not found - not built - skip {e}")
107112

108113

109114
@common.parametrize("test_data", test_data_suite)
@@ -115,5 +120,9 @@ def test_acos_vgf_INT(test_data: Tuple):
115120
[],
116121
[],
117122
tosa_version="TOSA-1.0+INT",
123+
run_on_vulkan_runtime=True,
118124
)
119-
pipeline.run()
125+
try:
126+
pipeline.run()
127+
except FileNotFoundError as e:
128+
pytest.skip(f"VKML executor_runner not found - not built - skip {e}")

backends/arm/test/ops/test_add.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -200,12 +200,7 @@ def test_add_tensor_u85_INT_2(test_data: input_t2):
200200
pipeline.run()
201201

202202

203-
# TODO/MLETORCH-1282: remove once inputs are not hard coded to ones
204-
skip_keys = {"5d_float", "1d_ones", "1d_randn"}
205-
filtered_test_data = {k: v for k, v in Add.test_data.items() if k not in skip_keys}
206-
207-
208-
@common.parametrize("test_data", filtered_test_data)
203+
@common.parametrize("test_data", Add.test_data)
209204
@common.SkipIfNoModelConverter
210205
def test_add_tensor_vgf_FP(test_data: input_t1):
211206
pipeline = VgfPipeline[input_t1](
@@ -222,7 +217,7 @@ def test_add_tensor_vgf_FP(test_data: input_t1):
222217
pytest.skip(f"VKML executor_runner not found - not built - skip {e}")
223218

224219

225-
@common.parametrize("test_data", filtered_test_data)
220+
@common.parametrize("test_data", Add.test_data)
226221
@common.SkipIfNoModelConverter
227222
def test_add_tensor_vgf_INT(test_data: input_t1):
228223
pipeline = VgfPipeline[input_t1](

backends/arm/test/runner_utils.py

Lines changed: 58 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -216,13 +216,48 @@ def run_target(
216216
elif target_board == "vkml_emulation_layer":
217217
return run_vkml_emulation_layer(
218218
executorch_program_manager,
219+
inputs,
219220
intermediate_path,
220221
elf_path,
221222
)
222223

223224

225+
def save_inputs_to_file(
226+
exported_program: ExportedProgram,
227+
inputs: Tuple[torch.Tensor],
228+
intermediate_path: str | Path,
229+
):
230+
input_file_paths = []
231+
input_names = get_input_names(exported_program)
232+
for input_name, input_ in zip(input_names, inputs):
233+
input_path = save_bytes(intermediate_path, input_, input_name)
234+
input_file_paths.append(input_path)
235+
236+
return input_file_paths
237+
238+
239+
def get_output_from_file(
240+
exported_program: ExportedProgram,
241+
intermediate_path: str | Path,
242+
output_base_name: str,
243+
):
244+
output_np = []
245+
output_node = exported_program.graph_module.graph.output_node()
246+
for i, node in enumerate(output_node.args[0]):
247+
output_shape = node.meta["val"].shape
248+
output_dtype = node.meta["val"].dtype
249+
tosa_ref_output = np.fromfile(
250+
os.path.join(intermediate_path, f"{output_base_name}-{i}.bin"),
251+
_torch_to_numpy_dtype_dict[output_dtype],
252+
)
253+
254+
output_np.append(torch.from_numpy(tosa_ref_output).reshape(output_shape))
255+
return tuple(output_np)
256+
257+
224258
def run_vkml_emulation_layer(
225259
executorch_program_manager: ExecutorchProgramManager,
260+
inputs: Tuple[torch.Tensor],
226261
intermediate_path: str | Path,
227262
elf_path: str | Path,
228263
):
@@ -232,7 +267,7 @@ def run_vkml_emulation_layer(
232267
`intermediate_path`: Directory to save the .pte and capture outputs.
233268
`elf_path`: Path to the Vulkan-capable executor_runner binary.
234269
"""
235-
270+
exported_program = executorch_program_manager.exported_program()
236271
intermediate_path = Path(intermediate_path)
237272
intermediate_path.mkdir(exist_ok=True)
238273
elf_path = Path(elf_path)
@@ -244,26 +279,29 @@ def run_vkml_emulation_layer(
244279
with open(pte_path, "wb") as f:
245280
f.write(executorch_program_manager.buffer)
246281

247-
cmd_line = [str(elf_path), "-model_path", pte_path]
282+
output_base_name = "out"
283+
out_path = os.path.join(intermediate_path, output_base_name)
284+
285+
cmd_line = f"{elf_path} -model_path {pte_path} -output_file {out_path}"
286+
287+
input_string = None
288+
input_paths = save_inputs_to_file(exported_program, inputs, intermediate_path)
289+
for input_path in input_paths:
290+
if input_string is None:
291+
input_string = f" -inputs={input_path}"
292+
else:
293+
input_string += f",{input_path}"
294+
if input_string is not None:
295+
cmd_line += input_string
296+
cmd_line = cmd_line.split()
297+
248298
result = _run_cmd(cmd_line)
249299

250-
result_stdout = result.stdout.decode() # noqa: F841
251300
# TODO: MLETORCH-1234: Support VGF e2e tests in VgfPipeline
252301
# TODO: Add regex to check for error or fault messages in stdout from Emulation Layer
253-
# Regex to extract tensor values from stdout
254-
output_np = []
255-
matches = re.findall(
256-
r"Output\s+\d+:\s+tensor\(sizes=\[(.*?)\],\s+\[(.*?)\]\)",
257-
result_stdout,
258-
re.DOTALL,
259-
)
260-
261-
for shape_str, values_str in matches:
262-
shape = list(map(int, shape_str.split(",")))
263-
values = list(map(float, re.findall(r"[-+]?\d*\.\d+|\d+", values_str)))
264-
output_np.append(torch.tensor(values).reshape(shape))
302+
result_stdout = result.stdout.decode() # noqa: F841
265303

266-
return tuple(output_np)
304+
return get_output_from_file(exported_program, intermediate_path, output_base_name)
267305

268306

269307
def run_corstone(
@@ -305,14 +343,10 @@ def run_corstone(
305343
with open(pte_path, "wb") as f:
306344
f.write(executorch_program_manager.buffer)
307345

308-
# Save inputs to file
309-
input_names = get_input_names(exported_program)
310-
input_paths = []
311-
for input_name, input_ in zip(input_names, inputs):
312-
input_path = save_bytes(intermediate_path, input_, input_name)
313-
input_paths.append(input_path)
346+
input_paths = save_inputs_to_file(exported_program, inputs, intermediate_path)
314347

315-
out_path = os.path.join(intermediate_path, "out")
348+
output_base_name = "out"
349+
out_path = os.path.join(intermediate_path, output_base_name)
316350

317351
cmd_line = f"executor_runner -m {pte_path} -o {out_path}"
318352
for input_path in input_paths:
@@ -394,18 +428,7 @@ def run_corstone(
394428
f"Corstone simulation failed:\ncmd: {' '.join(command_args)}\nlog: \n {result_stdout}\n{result.stderr.decode()}"
395429
)
396430

397-
output_np = []
398-
output_node = exported_program.graph_module.graph.output_node()
399-
for i, node in enumerate(output_node.args[0]):
400-
output_shape = node.meta["val"].shape
401-
output_dtype = node.meta["val"].dtype
402-
tosa_ref_output = np.fromfile(
403-
os.path.join(intermediate_path, f"out-{i}.bin"),
404-
_torch_to_numpy_dtype_dict[output_dtype],
405-
)
406-
407-
output_np.append(torch.from_numpy(tosa_ref_output).reshape(output_shape))
408-
return tuple(output_np)
431+
return get_output_from_file(exported_program, intermediate_path, output_base_name)
409432

410433

411434
def prep_data_for_save(

examples/portable/executor_runner/executor_runner.cpp

Lines changed: 106 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/*
22
* Copyright (c) Meta Platforms, Inc. and affiliates.
3-
* Copyright 2024-2025 Arm Limited and/or its affiliates.
43
* All rights reserved.
4+
* Copyright 2024-2025 Arm Limited and/or its affiliates.
55
*
66
* This source code is licensed under the BSD-style license found in the
77
* LICENSE file in the root directory of this source tree.
@@ -18,6 +18,7 @@
1818
* all fp32 tensors.
1919
*/
2020

21+
#include <fstream>
2122
#include <iostream>
2223
#include <memory>
2324

@@ -49,6 +50,16 @@ DEFINE_string(
4950
model_path,
5051
"model.pte",
5152
"Model serialized in flatbuffer format.");
53+
DEFINE_string(inputs, "", "Comma-separated list of input files");
54+
DEFINE_string(
55+
output_file,
56+
"",
57+
"Base name of output file. If not empty output will be written to the file(s).");
58+
59+
DEFINE_bool(
60+
print_all_output,
61+
false,
62+
"Prints all output. By default only first and last 100 elements are printed.");
5263
DEFINE_uint32(num_executions, 1, "Number of times to run the model.");
5364
#ifdef ET_EVENT_TRACER_ENABLED
5465
DEFINE_string(etdump_path, "model.etdump", "Write ETDump data to this path.");
@@ -58,6 +69,8 @@ DEFINE_int32(
5869
-1,
5970
"Number of CPU threads for inference. Defaults to -1, which implies we'll use a heuristic to derive the # of performant cores for a specific device.");
6071

72+
using executorch::aten::ScalarType;
73+
using executorch::aten::Tensor;
6174
using executorch::extension::FileDataLoader;
6275
using executorch::runtime::Error;
6376
using executorch::runtime::EValue;
@@ -70,6 +83,8 @@ using executorch::runtime::MethodMeta;
7083
using executorch::runtime::Program;
7184
using executorch::runtime::Result;
7285
using executorch::runtime::Span;
86+
using executorch::runtime::Tag;
87+
using executorch::runtime::TensorInfo;
7388

7489
/// Helper to manage resources for ETDump generation
7590
class EventTraceManager {
@@ -156,6 +171,31 @@ int main(int argc, char** argv) {
156171
"FileDataLoader::from() failed: 0x%" PRIx32,
157172
(uint32_t)loader.error());
158173

174+
std::vector<std::string> inputs_storage;
175+
std::vector<std::pair<char*, size_t>> input_buffers;
176+
177+
std::stringstream list_of_input_files(FLAGS_inputs);
178+
std::string token;
179+
180+
while (std::getline(list_of_input_files, token, ',')) {
181+
std::ifstream input_file_handle(token, std::ios::binary | std::ios::ate);
182+
if (!input_file_handle) {
183+
ET_LOG(Error, "Failed to open input file: %s\n", token.c_str());
184+
return 1;
185+
}
186+
187+
std::streamsize file_size = input_file_handle.tellg();
188+
input_file_handle.seekg(0, std::ios::beg);
189+
190+
inputs_storage.emplace_back(file_size, '\0');
191+
if (!input_file_handle.read(&inputs_storage.back()[0], file_size)) {
192+
ET_LOG(Error, "Failed to read input file: %s\n", token.c_str());
193+
return 1;
194+
}
195+
196+
input_buffers.emplace_back(&inputs_storage.back()[0], file_size);
197+
}
198+
159199
// Parse the program file. This is immutable, and can also be reused between
160200
// multiple execution invocations across multiple threads.
161201
Result<Program> program = Program::load(&loader.get());
@@ -254,15 +294,17 @@ int main(int argc, char** argv) {
254294
// Run the model.
255295
for (uint32_t i = 0; i < FLAGS_num_executions; i++) {
256296
ET_LOG(Debug, "Preparing inputs.");
257-
// Allocate input tensors and set all of their elements to 1. The `inputs`
297+
// Allocate input tensors and set all of their elements to 1 or to the
298+
// contents of input_buffers if available. The `inputs`
258299
// variable owns the allocated memory and must live past the last call to
259300
// `execute()`.
260301
//
261302
// NOTE: we have to re-prepare input tensors on every execution
262303
// because inputs whose space gets reused by memory planning (if
263304
// any such inputs exist) will not be preserved for the next
264305
// execution.
265-
auto inputs = executorch::extension::prepare_input_tensors(*method);
306+
auto inputs = executorch::extension::prepare_input_tensors(
307+
*method, {}, input_buffers);
266308
ET_CHECK_MSG(
267309
inputs.ok(),
268310
"Could not prepare inputs: 0x%" PRIx32,
@@ -295,10 +337,67 @@ int main(int argc, char** argv) {
295337
ET_LOG(Info, "%zu outputs: ", outputs.size());
296338
Error status = method->get_outputs(outputs.data(), outputs.size());
297339
ET_CHECK(status == Error::Ok);
298-
// Print the first and last 100 elements of long lists of scalars.
299-
std::cout << executorch::extension::evalue_edge_items(100);
300-
for (int i = 0; i < outputs.size(); ++i) {
301-
std::cout << "Output " << i << ": " << outputs[i] << std::endl;
340+
341+
if (FLAGS_output_file.size() > 0) {
342+
for (int i = 0; i < outputs.size(); ++i) {
343+
if (outputs[i].isTensor()) {
344+
Tensor tensor = outputs[i].toTensor();
345+
346+
char out_filename[255];
347+
snprintf(out_filename, 255, "%s-%d.bin", FLAGS_output_file.c_str(), i);
348+
ET_LOG(Info, "Writing output to file: %s", out_filename);
349+
FILE* out_file = fopen(out_filename, "wb");
350+
auto written_size =
351+
fwrite(tensor.const_data_ptr<char>(), 1, tensor.nbytes(), out_file);
352+
fclose(out_file);
353+
}
354+
}
355+
}
356+
357+
if (FLAGS_print_all_output) {
358+
for (int i = 0; i < outputs.size(); ++i) {
359+
if (outputs[i].isTensor()) {
360+
Tensor tensor = outputs[i].toTensor();
361+
362+
for (int j = 0; j < tensor.numel(); ++j) {
363+
if (tensor.scalar_type() == ScalarType::Int) {
364+
printf(
365+
"Output[%d][%d]: (int) %d\n",
366+
i,
367+
j,
368+
tensor.const_data_ptr<int>()[j]);
369+
} else if (tensor.scalar_type() == ScalarType::Float) {
370+
printf(
371+
"Output[%d][%d]: (float) %f\n",
372+
i,
373+
j,
374+
tensor.const_data_ptr<float>()[j]);
375+
} else if (tensor.scalar_type() == ScalarType::Char) {
376+
printf(
377+
"Output[%d][%d]: (char) %d\n",
378+
i,
379+
j,
380+
tensor.const_data_ptr<int8_t>()[j]);
381+
} else if (tensor.scalar_type() == ScalarType::Bool) {
382+
printf(
383+
"Output[%d][%d]: (bool) %s (0x%x)\n",
384+
i,
385+
j,
386+
tensor.const_data_ptr<int8_t>()[j] ? "true " : "false",
387+
tensor.const_data_ptr<int8_t>()[j]);
388+
}
389+
}
390+
} else {
391+
printf("Output[%d]: Not Tensor\n", i);
392+
}
393+
}
394+
} else {
395+
// Print the first and last 100 elements of long lists of scalars.
396+
std::cout << executorch::extension::evalue_edge_items(100);
397+
398+
for (int i = 0; i < outputs.size(); ++i) {
399+
std::cout << "OutputX " << i << ": " << outputs[i] << std::endl;
400+
}
302401
}
303402

304403
if (tracer.get_event_tracer()) {

0 commit comments

Comments
 (0)