Skip to content

Commit 4bf7e2f

Browse files
Erik-Lundellfacebook-github-bot
authored andcommitted
Input name bugfix in runner_utils (#5071)
Summary: There are cases where the names of the inputs in the edge-IR graph do not match the names of the inputs in the export-IR. Since desc.json used the names from export-IR while TOSA used the names from edge-IR there was a mismatch. I therefore changed desc.json to use the names from edge-IR if possible. I also changed the search of tosa.fbs to be cwd independent, and added a check to avoid a crash if you .dump_artifact() on a graph with no delegate. Change-Id: Iffa56e2f43910adc74608b96518717d14e0beb53 Pull Request resolved: #5071 Reviewed By: cccclai Differential Revision: D62243487 Pulled By: digantdesai fbshipit-source-id: 4c949963dbadf1ca72169e714ff0bc3ae3a6144d
1 parent a9ad3c6 commit 4bf7e2f

File tree

2 files changed

+24
-18
lines changed

2 files changed

+24
-18
lines changed

backends/arm/test/runner_utils.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def _get_input_names(program: ExportedProgram) -> list[str]:
6666

6767

6868
def _get_input_quantization_params(
69-
program: ExportedProgram, input_names: list[str]
69+
program: ExportedProgram,
7070
) -> list[QuantizationParams]:
7171
"""
7272
Get input QuantizationParams in a program, maximum one per input to the program.
@@ -79,6 +79,7 @@ def _get_input_quantization_params(
7979
"""
8080

8181
quant_params = []
82+
input_names = _get_input_names(program)
8283
num_inputs = len(input_names)
8384
for node in program.graph.nodes:
8485
if (
@@ -178,16 +179,19 @@ def __init__(
178179

179180
self._has_init_run = False
180181

181-
def init_run(self, exported_program: ExportedProgram, is_quantized: bool):
182-
self.input_names = _get_input_names(exported_program)
182+
def init_run(
183+
self,
184+
exported_program: ExportedProgram,
185+
edge_program: ExportedProgram,
186+
is_quantized: bool,
187+
):
188+
self.input_names = _get_input_names(edge_program)
183189
self.output_node = _get_output_node(exported_program)
184190
self.output_name = self.output_node.name
185191
self.is_quantized = is_quantized
186192

187193
if is_quantized:
188-
self.qp_input = _get_input_quantization_params(
189-
exported_program, self.input_names
190-
)
194+
self.qp_input = _get_input_quantization_params(exported_program)
191195
self.qp_output = _get_output_quantization_params(
192196
exported_program, self.output_node
193197
)
@@ -407,7 +411,7 @@ def prep_data_for_save(
407411

408412
if is_quantized:
409413
assert (
410-
quant_param.node_name == input_name
414+
quant_param.node_name in input_name
411415
), "These quantization params do not match the input tensor name"
412416
data_np = (
413417
((data_np / np.float32(quant_param.scale)) + quant_param.zp)
@@ -500,7 +504,10 @@ def dbg_tosa_fb_to_json(tosa_fb: bytes) -> Dict:
500504
with open(tosa_input_file, "wb") as f:
501505
f.write(tosa_fb)
502506

503-
tosa_schema_file = "./backends/arm/third-party/serialization_lib/schema/tosa.fbs"
507+
arm_backend_path = os.path.realpath(os.path.dirname(__file__) + "/..")
508+
tosa_schema_file = os.path.join(
509+
arm_backend_path, "third-party/serialization_lib/schema/tosa.fbs"
510+
)
504511
assert os.path.exists(
505512
tosa_schema_file
506513
), f"tosa_schema_file: {tosa_schema_file} does not exist"

backends/arm/test/tester/arm_tester.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
)
2424

2525
from executorch.backends.arm.test.runner_utils import (
26-
_get_input_names,
2726
_get_input_quantization_params,
2827
_get_output_node,
2928
_get_output_quantization_params,
@@ -241,15 +240,18 @@ def run_method_and_compare_outputs(
241240
self.runner_util is not None
242241
), "self.tosa_test_util is not initialized, cannot use run_method()"
243242
assert (
244-
self.stages[self.stage_name(tester.Export)] is not None
245-
), "To compare outputs, at least the Export stage needs to be run."
243+
self.stages[self.stage_name(tester.ToEdge)] is not None
244+
), "To compare outputs, at least the ToEdge stage needs to be run."
246245

247246
stage = stage or self.cur
248247
test_stage = self.stages[stage]
249248
is_quantized = self.stages[self.stage_name(tester.Quantize)] is not None
250-
self.runner_util.init_run(
251-
self.stages[self.stage_name(tester.Export)].artifact, is_quantized
252-
)
249+
250+
exported_program = self.stages[self.stage_name(tester.Export)].artifact
251+
edge_program = self.stages[
252+
self.stage_name(tester.ToEdge)
253+
].artifact.exported_program()
254+
self.runner_util.init_run(exported_program, edge_program, is_quantized)
253255

254256
if is_quantized:
255257
reference_stage = self.stages[self.stage_name(tester.Quantize)]
@@ -395,11 +397,8 @@ def _compare_outputs(
395397
export_stage = self.stages.get(self.stage_name(tester.Export), None)
396398
quantize_stage = self.stages.get(self.stage_name(tester.Quantize), None)
397399
if export_stage is not None and quantize_stage is not None:
398-
input_names = _get_input_names(export_stage.artifact)
399400
output_node = _get_output_node(export_stage.artifact)
400-
qp_input = _get_input_quantization_params(
401-
export_stage.artifact, input_names
402-
)
401+
qp_input = _get_input_quantization_params(export_stage.artifact)
403402
qp_output = _get_output_quantization_params(
404403
export_stage.artifact, output_node
405404
)

0 commit comments

Comments
 (0)