Skip to content

Commit ec0b57b

Browse files
Arm backend: Remove get_output_nodes from runner_utils. (#13417)
A graph has only one output node containing a list of output tensors. Remove the use of this function to better reflect this. Signed-off-by: Adrian Lundell <[email protected]>
1 parent 53146a4 commit ec0b57b

File tree

3 files changed

+8
-31
lines changed

3 files changed

+8
-31
lines changed

backends/arm/test/runner_utils.py

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -130,28 +130,8 @@ def get_input_quantization_params(
130130
return quant_params
131131

132132

133-
def get_output_nodes(program: ExportedProgram) -> list[Node]:
134-
"""
135-
Get output node to this model.
136-
137-
Args:
138-
program (ExportedProgram): The program to get the output nodes from.
139-
Returns:
140-
The nodes that are the outputs of the 'program'.
141-
"""
142-
output_nodes = []
143-
for node in program.graph.nodes:
144-
if node.op == "output":
145-
for output in node.args[0]:
146-
output_nodes.append(output)
147-
if len(output_nodes) == 0:
148-
raise RuntimeError("No output nodes found.")
149-
else:
150-
return output_nodes
151-
152-
153133
def get_output_quantization_params(
154-
output_nodes: list[Node],
134+
output_node: Node,
155135
) -> dict[Node, QuantizationParams | None]:
156136
"""
157137
Get output QuantizationParams from a program.
@@ -164,7 +144,7 @@ def get_output_quantization_params(
164144
RuntimeError if no output quantization parameters are found.
165145
"""
166146
quant_params = {}
167-
for node in output_nodes:
147+
for node in output_node.args[0]:
168148
if node.target == torch.ops.quantized_decomposed.dequantize_per_tensor.default:
169149
quant_params[node] = QuantizationParams(
170150
node_name=node.args[0].name,
@@ -411,9 +391,9 @@ def run_corstone(
411391
f"Corstone simulation failed:\ncmd: {' '.join(command_args)}\nlog: \n {result_stdout}\n{result.stderr.decode()}"
412392
)
413393

414-
output_nodes = get_output_nodes(exported_program)
415394
output_np = []
416-
for i, node in enumerate(output_nodes):
395+
output_node = exported_program.graph_module.graph.output_node()
396+
for i, node in enumerate(output_node.args[0]):
417397
output_shape = node.meta["val"].shape
418398
output_dtype = node.meta["val"].dtype
419399
tosa_ref_output = np.fromfile(

backends/arm/test/tester/analyze_output_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from executorch.backends.arm.arm_backend import get_intermediate_path
1111
from executorch.backends.arm.test.runner_utils import (
1212
get_input_quantization_params,
13-
get_output_nodes,
1413
get_output_quantization_params,
1514
)
1615

@@ -254,9 +253,9 @@ def dump_error_output(
254253
export_stage = tester.stages.get(StageType.EXPORT, None)
255254
quantize_stage = tester.stages.get(StageType.QUANTIZE, None)
256255
if export_stage is not None and quantize_stage is not None:
257-
output_nodes = get_output_nodes(export_stage.artifact)
256+
output_node = export_stage.artifact.graph_module.output_node()
258257
qp_input = get_input_quantization_params(export_stage.artifact)
259-
qp_output = get_output_quantization_params(output_nodes)
258+
qp_output = get_output_quantization_params(output_node)
260259
logger.error(f"Input QuantArgs: {qp_input}")
261260
logger.error(f"Output QuantArgs: {qp_output}")
262261

backends/arm/test/tester/arm_tester.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
from executorch.backends.arm.test.runner_utils import (
4949
dbg_tosa_fb_to_json,
5050
get_elf_path,
51-
get_output_nodes,
5251
get_output_quantization_params,
5352
get_target_board,
5453
run_target,
@@ -484,9 +483,8 @@ def run_method_and_compare_outputs(
484483
reference_stage = self.stages[StageType.INITIAL_MODEL]
485484

486485
exported_program = self.stages[StageType.EXPORT].artifact
487-
output_nodes = get_output_nodes(exported_program)
488-
489-
output_qparams = get_output_quantization_params(output_nodes)
486+
output_node = exported_program.graph_module.graph.output_node()
487+
output_qparams = get_output_quantization_params(output_node)
490488

491489
quantization_scales = []
492490
for node in output_qparams:

0 commit comments

Comments
 (0)