Skip to content

Commit d9d2465

Browse files
committed
Update default executor runner with output options
By default not all output is printed. Adds option for printing all output. Also adds option to print output to file. Also update the Arm VKML unit test runner as a user that prints output to file. Enables acos_unit test to run on Vulkan runtime that depends on this. Change-Id: If61c1fe89c9da004fa9db4524e1413893549abce
1 parent 58998b0 commit d9d2465

File tree

5 files changed

+126
-44
lines changed

5 files changed

+126
-44
lines changed

backends/arm/test/ops/test_acos.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55
from typing import Tuple
66

7+
import pytest
78
import torch
89

910
from executorch.backends.arm.test import common
@@ -102,8 +103,12 @@ def test_acos_vgf_FP(test_data: Tuple):
102103
[],
103104
[],
104105
tosa_version="TOSA-1.0+FP",
106+
run_on_vulkan_runtime=True,
105107
)
106-
pipeline.run()
108+
try:
109+
pipeline.run()
110+
except FileNotFoundError as e:
111+
pytest.skip(f"VKML executor_runner not found - not built - skip {e}")
107112

108113

109114
@common.parametrize("test_data", test_data_suite)
@@ -115,5 +120,9 @@ def test_acos_vgf_INT(test_data: Tuple):
115120
[],
116121
[],
117122
tosa_version="TOSA-1.0+INT",
123+
run_on_vulkan_runtime=True,
118124
)
119-
pipeline.run()
125+
try:
126+
pipeline.run()
127+
except FileNotFoundError as e:
128+
pytest.skip(f"VKML executor_runner not found - not built - skip {e}")

backends/arm/test/ops/test_add.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from typing import Tuple
99

10+
import pytest
1011
import torch
1112
from executorch.backends.arm.quantizer import arm_quantizer
1213
from executorch.backends.arm.test import common, conftest
@@ -196,7 +197,10 @@ def test_add_tensor_vgf_FP(test_data: input_t1):
196197
tosa_version="TOSA-1.0+FP",
197198
run_on_vulkan_runtime=True,
198199
)
199-
pipeline.run()
200+
try:
201+
pipeline.run()
202+
except FileNotFoundError as e:
203+
pytest.skip(f"VKML executor_runner not found - not built - skip {e}")
200204

201205

202206
@common.parametrize("test_data", Add.test_data)
@@ -210,4 +214,7 @@ def test_add_tensor_vgf_INT(test_data: input_t1):
210214
tosa_version="TOSA-1.0+INT",
211215
run_on_vulkan_runtime=True,
212216
)
213-
pipeline.run()
217+
try:
218+
pipeline.run()
219+
except FileNotFoundError as e:
220+
pytest.skip(f"VKML executor_runner not found - not built - skip {e}")

backends/arm/test/runner_utils.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,25 @@ def save_inputs_to_file(
243243
return input_file_paths
244244

245245

246+
def get_output_from_file(
247+
exported_program: ExportedProgram,
248+
intermediate_path: str | Path,
249+
output_base_name: str,
250+
):
251+
output_np = []
252+
output_node = exported_program.graph_module.graph.output_node()
253+
for i, node in enumerate(output_node.args[0]):
254+
output_shape = node.meta["val"].shape
255+
output_dtype = node.meta["val"].dtype
256+
tosa_ref_output = np.fromfile(
257+
os.path.join(intermediate_path, f"{output_base_name}-{i}.bin"),
258+
_torch_to_numpy_dtype_dict[output_dtype],
259+
)
260+
261+
output_np.append(torch.from_numpy(tosa_ref_output).reshape(output_shape))
262+
return tuple(output_np)
263+
264+
246265
def run_vkml_emulation_layer(
247266
executorch_program_manager: ExecutorchProgramManager,
248267
inputs: Tuple[torch.Tensor],
@@ -267,10 +286,13 @@ def run_vkml_emulation_layer(
267286
with open(pte_path, "wb") as f:
268287
f.write(executorch_program_manager.buffer)
269288

270-
input_paths = save_inputs_to_file(exported_program, inputs, intermediate_path)
289+
output_base_name = "out"
290+
out_path = os.path.join(intermediate_path, output_base_name)
291+
292+
cmd_line = f"{elf_path} -model_path {pte_path} -output_file {out_path}"
271293

272-
cmd_line = f"{elf_path} -model_path {pte_path}"
273294
input_string = None
295+
input_paths = save_inputs_to_file(exported_program, inputs, intermediate_path)
274296
for input_path in input_paths:
275297
if input_string is None:
276298
input_string = f" -inputs={input_path}"
@@ -282,23 +304,11 @@ def run_vkml_emulation_layer(
282304

283305
result = _run_cmd(cmd_line)
284306

285-
result_stdout = result.stdout.decode() # noqa: F841
286307
# TODO: MLETORCH-1234: Support VGF e2e tests in VgfPipeline
287308
# TODO: Add regex to check for error or fault messages in stdout from Emulation Layer
288-
# Regex to extract tensor values from stdout
289-
output_np = []
290-
matches = re.findall(
291-
r"Output\s+\d+:\s+tensor\(sizes=\[(.*?)\],\s+\[(.*?)\]\)",
292-
result_stdout,
293-
re.DOTALL,
294-
)
295-
296-
for shape_str, values_str in matches:
297-
shape = list(map(int, shape_str.split(",")))
298-
values = list(map(float, re.findall(r"[-+]?\d*\.\d+|\d+", values_str)))
299-
output_np.append(torch.tensor(values).reshape(shape))
309+
result_stdout = result.stdout.decode() # noqa: F841
300310

301-
return tuple(output_np)
311+
return get_output_from_file(exported_program, intermediate_path, output_base_name)
302312

303313

304314
def run_corstone(
@@ -342,7 +352,8 @@ def run_corstone(
342352

343353
input_paths = save_inputs_to_file(exported_program, inputs, intermediate_path)
344354

345-
out_path = os.path.join(intermediate_path, "out")
355+
output_base_name = "out"
356+
out_path = os.path.join(intermediate_path, output_base_name)
346357

347358
cmd_line = f"executor_runner -m {pte_path} -o {out_path}"
348359
for input_path in input_paths:
@@ -424,18 +435,7 @@ def run_corstone(
424435
f"Corstone simulation failed:\ncmd: {' '.join(command_args)}\nlog: \n {result_stdout}\n{result.stderr.decode()}"
425436
)
426437

427-
output_np = []
428-
output_node = exported_program.graph_module.graph.output_node()
429-
for i, node in enumerate(output_node.args[0]):
430-
output_shape = node.meta["val"].shape
431-
output_dtype = node.meta["val"].dtype
432-
tosa_ref_output = np.fromfile(
433-
os.path.join(intermediate_path, f"out-{i}.bin"),
434-
_torch_to_numpy_dtype_dict[output_dtype],
435-
)
436-
437-
output_np.append(torch.from_numpy(tosa_ref_output).reshape(output_shape))
438-
return tuple(output_np)
438+
return get_output_from_file(exported_program, intermediate_path, output_base_name)
439439

440440

441441
def prep_data_for_save(

examples/portable/executor_runner/executor_runner.cpp

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,15 @@ DEFINE_string(
5151
"model.pte",
5252
"Model serialized in flatbuffer format.");
5353
DEFINE_string(inputs, "", "Comma-separated list of input files");
54+
DEFINE_string(
55+
output_file,
56+
"",
57+
"Base name of output file. If not empty output will be written to the file(s).");
58+
59+
DEFINE_bool(
60+
print_all_output,
61+
false,
62+
"Prints all output. By default only first and last 100 elements are printed.");
5463
DEFINE_uint32(num_executions, 1, "Number of times to run the model.");
5564
#ifdef ET_EVENT_TRACER_ENABLED
5665
DEFINE_string(etdump_path, "model.etdump", "Write ETDump data to this path.");
@@ -328,10 +337,67 @@ int main(int argc, char** argv) {
328337
ET_LOG(Info, "%zu outputs: ", outputs.size());
329338
Error status = method->get_outputs(outputs.data(), outputs.size());
330339
ET_CHECK(status == Error::Ok);
331-
// Print the first and last 100 elements of long lists of scalars.
332-
std::cout << executorch::extension::evalue_edge_items(100);
333-
for (int i = 0; i < outputs.size(); ++i) {
334-
std::cout << "Output " << i << ": " << outputs[i] << std::endl;
340+
341+
if (FLAGS_output_file.size() > 0) {
342+
for (int i = 0; i < outputs.size(); ++i) {
343+
if (outputs[i].isTensor()) {
344+
Tensor tensor = outputs[i].toTensor();
345+
346+
char out_filename[255];
347+
snprintf(out_filename, 255, "%s-%d.bin", FLAGS_output_file.c_str(), i);
348+
ET_LOG(Info, "Writing output to file: %s", out_filename);
349+
FILE* out_file = fopen(out_filename, "wb");
350+
auto written_size =
351+
fwrite(tensor.const_data_ptr<char>(), 1, tensor.nbytes(), out_file);
352+
fclose(out_file);
353+
}
354+
}
355+
}
356+
357+
if (FLAGS_print_all_output) {
358+
for (int i = 0; i < outputs.size(); ++i) {
359+
if (outputs[i].isTensor()) {
360+
Tensor tensor = outputs[i].toTensor();
361+
362+
for (int j = 0; j < tensor.numel(); ++j) {
363+
if (tensor.scalar_type() == ScalarType::Int) {
364+
printf(
365+
"Output[%d][%d]: (int) %d\n",
366+
i,
367+
j,
368+
tensor.const_data_ptr<int>()[j]);
369+
} else if (tensor.scalar_type() == ScalarType::Float) {
370+
printf(
371+
"Output[%d][%d]: (float) %f\n",
372+
i,
373+
j,
374+
tensor.const_data_ptr<float>()[j]);
375+
} else if (tensor.scalar_type() == ScalarType::Char) {
376+
printf(
377+
"Output[%d][%d]: (char) %d\n",
378+
i,
379+
j,
380+
tensor.const_data_ptr<int8_t>()[j]);
381+
} else if (tensor.scalar_type() == ScalarType::Bool) {
382+
printf(
383+
"Output[%d][%d]: (bool) %s (0x%x)\n",
384+
i,
385+
j,
386+
tensor.const_data_ptr<int8_t>()[j] ? "true " : "false",
387+
tensor.const_data_ptr<int8_t>()[j]);
388+
}
389+
}
390+
} else {
391+
printf("Output[%d]: Not Tensor\n", i);
392+
}
393+
}
394+
} else {
395+
// Print the first and last 100 elements of long lists of scalars.
396+
std::cout << executorch::extension::evalue_edge_items(100);
397+
398+
for (int i = 0; i < outputs.size(); ++i) {
399+
std::cout << "OutputX " << i << ": " << outputs[i] << std::endl;
400+
}
335401
}
336402

337403
if (tracer.get_event_tracer()) {

extension/runner_util/inputs.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,15 @@ Result<BufferCleanup> prepare_input_tensors(
3131
size_t num_inputs = method_meta.num_inputs();
3232
bool hard_code_inputs_to_ones = true;
3333

34-
ET_CHECK_OR_RETURN_ERROR(
35-
input_buffers.size() > 0 && num_inputs == input_buffers.size(),
36-
InvalidArgument,
37-
"Wrong number of inputs allocated compared to method %zu ? %zu",
38-
num_inputs,
39-
input_buffers.size());
40-
4134
if (input_buffers.size() > 0) {
4235
hard_code_inputs_to_ones = false;
36+
37+
ET_CHECK_OR_RETURN_ERROR(
38+
num_inputs == input_buffers.size(),
39+
InvalidArgument,
40+
"Wrong number of inputs allocated compared to method %zu ? %zu",
41+
num_inputs,
42+
input_buffers.size());
4343
}
4444

4545
// A large number of small allocations could exhaust the heap even if the

0 commit comments

Comments
 (0)