Skip to content

Commit 6a48cab

Browse files
committed
Make stage ids unique in the Arm TestPipeline
Some stages in the TestPipeline may be added multiple times, such as .check(). To be able to target these by id, give them an unique suffix -> 'id.suffix' Refering to stages in terms of id instead of an index is more self documenting and future proof. This change modifies the add/pop_stage interface: - pos arg in add_stage is now an optional kwarg, appending to the pipline as default - suffix is added to add_stage as an optional_kwarg. If a suffix is not given to a non unique stage, a number is added instead. - pop_stage now allows to use ids for referring to stages. Additionally adds .visualize(stage) for quickly adding visualizing stages to the pipeline. Change-Id: If649a19096ddee6b2eca2c8aa735b54ca7eea3e8
1 parent 524ec78 commit 6a48cab

File tree

2 files changed

+103
-41
lines changed

2 files changed

+103
-41
lines changed

backends/arm/test/ops/test_conv2d.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2-
# All rights reserved.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
@@ -371,7 +370,7 @@ def test_conv2d_tosa_BI(test_module):
371370
pipeline = TosaPipelineBI[input_t](
372371
test_module, test_module.get_inputs(), aten_op, exir_op
373372
)
374-
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
373+
pipeline.change_args("run_method_and_compare_outputs.0", qtol=1)
375374
pipeline.run()
376375

377376

backends/arm/test/tester/test_pipeline.py

Lines changed: 102 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ class PipelineStage:
4646
is_called: keeps track of if the function has been called
4747
"""
4848

49-
def __init__(self, func, *args, **kwargs):
50-
self.id: str = func.__name__
49+
def __init__(self, func: Callable, id: str, *args, **kwargs):
50+
self.id: str = id
5151
self.func: Callable = func
5252
self.args = args
5353
self.kwargs = kwargs
@@ -86,72 +86,130 @@ def __init__(
8686
self.test_data = test_data
8787
self._stages = []
8888

89-
self.add_stage(-1, self.tester.export)
90-
self.add_stage(-1, self.tester.check, self.aten_ops)
89+
self.add_stage(self.tester.export)
90+
self.add_stage(self.tester.check, self.aten_ops, suffix="aten")
9191
if use_to_edge_transform_and_lower:
92-
self.add_stage(-1, self.tester.to_edge_transform_and_lower)
93-
92+
self.add_stage(self.tester.to_edge_transform_and_lower)
9493
else:
95-
self.add_stage(-1, self.tester.to_edge)
96-
self.add_stage(-1, self.tester.check, self.exir_ops)
97-
self.add_stage(-1, self.tester.partition)
98-
self.add_stage(-1, self.tester.check_not, self.exir_ops)
94+
self.add_stage(self.tester.to_edge)
95+
self.add_stage(self.tester.check, self.exir_ops, suffix="exir")
96+
self.add_stage(self.tester.partition)
97+
self.add_stage(self.tester.check_not, self.exir_ops, suffix="exir")
9998
self.add_stage(
100-
-1,
10199
self.tester.check_count,
102100
{"torch.ops.higher_order.executorch_call_delegate": 1},
101+
suffix="exir",
103102
)
104-
self.add_stage(-1, self.tester.to_executorch)
103+
self.add_stage(self.tester.to_executorch)
104+
105+
def add_stage(self, func: Callable, *args, **kwargs):
106+
"""
107+
Adds a stage defined by a function with args and kwargs. By default appends to the pipeline.
108+
For stages which may be added multiple times to a pipeline, s.a. checks and debug stages,
109+
a suffix is appended with a dot to make sure every id is unique, e.g. check becomes check.0
105110
106-
def add_stage(self, pos: int, func: Callable, *args, **kwargs):
107-
"""Adds a stage defined by a function with arguments to the pipeline at index pos. Pos wraps around the list for negative values."""
108-
pipeline_stage = self.PipelineStage(func, *args, **kwargs)
111+
Special kwargs:
112+
pos : specifies position in pipeline to add stage at.
113+
suffix : specifies a custom suffix to identify non unique stages, instead of a number.
114+
"""
109115
pipeline_length = len(self._stages)
110116

117+
pos = -1
118+
if "pos" in kwargs:
119+
pos = kwargs.pop("pos")
120+
111121
if pos < 0:
112122
pos = pipeline_length + (pos + 1)
113-
114123
if not -pipeline_length <= pos <= pipeline_length:
115124
raise ValueError(
116125
f"Pos must be between [-{pipeline_length}, {pipeline_length}]"
117126
)
118127

128+
suffix = None
129+
if "suffix" in kwargs:
130+
suffix = kwargs.pop("suffix")
131+
132+
stage_id = func.__name__
133+
unique_stages = [
134+
"quantize",
135+
"export",
136+
"to_edge_transform_and_lower",
137+
"to_edge",
138+
"partition",
139+
"to_executorch",
140+
"serialize",
141+
]
142+
id_list = [stage.id for stage in self._stages]
143+
if stage_id in unique_stages:
144+
if stage_id in id_list:
145+
raise RuntimeError(f"Tried adding {stage_id} to pipeline twice.")
146+
else:
147+
if suffix is None:
148+
stages_containing_stage_id = [
149+
id for id in id_list if stage_id == id.split(".")[0]
150+
]
151+
152+
suffix = str(len(stages_containing_stage_id))
153+
154+
stage_id = stage_id + "." + suffix
155+
156+
if stage_id in id_list:
157+
raise ValueError("Suffix must be unique in pipeline")
158+
159+
pipeline_stage = self.PipelineStage(func, stage_id, *args, **kwargs)
119160
self._stages.insert(pos, pipeline_stage)
120161

121-
logger.debug(f"Added stage {func.__name__} to {type(self).__name__}")
162+
logger.debug(f"Added stage {stage_id} to {type(self).__name__}")
122163

123164
return self
124165

125-
def pop_stage(self, pos: int):
166+
def pop_stage(self, identifier: int | str):
126167
"""Removes and returns the stage at postion pos"""
127-
return self._stages.pop(pos)
168+
if isinstance(identifier, int):
169+
stage = self._stages.pop(identifier)
170+
elif isinstance(identifier, str):
171+
pos = self.find_pos(identifier)
172+
stage = self._stages.pop(pos)
173+
174+
logger.debug(f"Removed stage {stage.id} from {type(self).__name__}")
175+
176+
return stage
128177

129178
def find_pos(self, stage_id: str):
130-
"""Returns the position of the stage id. Note that this only finds the first stage with the given id, i.e. it should only be used with unique stages."""
179+
"""Returns the position of the stage id."""
131180
for i, stage in enumerate(self._stages):
132181
if stage.id == stage_id:
133182
return i
134183

135184
raise Exception(f"Stage id {stage_id} not found in pipeline")
136185

137186
def add_stage_after(self, stage_id: str, func: Callable, *args, **kwargs):
138-
"""Adds a stage after the given stage id. Note that this only finds the first stage with the given id, i.e. it should only be used with unique stages."""
139-
pos = self.find_pos(stage_id)
140-
self.add_stage(pos + 1, func, *args, **kwargs)
187+
"""Adds a stage after the given stage id."""
188+
pos = self.find_pos(stage_id) + 1
189+
kwargs["pos"] = pos
190+
191+
self.add_stage(func, *args, **kwargs)
192+
return self
193+
194+
def dump_artifact(self, stage_id: str, suffix: str = None):
195+
"""Adds a dump_artifact stage after the given stage id."""
196+
self.add_stage_after(stage_id, self.tester.dump_artifact, suffix=suffix)
141197
return self
142198

143-
def dump_artifact(self, stage_id: str):
144-
"""Adds a dump_artifact stage after the given stage id. Note that this only finds the first stage with the given id, i.e. it should only be used with unique stages."""
145-
self.add_stage_after(stage_id, self.tester.dump_artifact)
199+
def dump_operator_distribution(self, stage_id: str, suffix: str = None):
200+
"""Adds a dump_operator_distribution stage after the given stage id."""
201+
self.add_stage_after(
202+
stage_id, self.tester.dump_operator_distribution, suffix=suffix
203+
)
146204
return self
147205

148-
def dump_operator_distribution(self, stage_id: str):
149-
"""Adds a dump_operator_distribution stage after the given stage id. Note that this only finds the first stage with the given id, i.e. it should only be used with unique stages."""
150-
self.add_stage_after(stage_id, self.tester.dump_operator_distribution)
206+
def visualize(self, stage_id: str, suffix: str = None):
207+
"""Adds a dump_operator_distribution stage after the given stage id."""
208+
self.add_stage_after(stage_id, self.tester.visualize, suffix=suffix)
151209
return self
152210

153211
def change_args(self, stage_id: str, *args, **kwargs):
154-
"""Updates the args to the given stage id. Note that this only finds the first stage with the given id, i.e. it should only be used with unique stages."""
212+
"""Updates the args to the given stage id."""
155213
pos = self.find_pos(stage_id)
156214
pipeline_stage = self._stages[pos]
157215
pipeline_stage.update(*args, **kwargs)
@@ -193,14 +251,15 @@ def __init__(
193251
compile_spec,
194252
use_to_edge_transform_and_lower,
195253
)
196-
self.add_stage(0, self.tester.quantize)
254+
self.add_stage(self.tester.quantize, pos=0)
197255
self.add_stage_after(
198256
"quantize",
199257
self.tester.check,
200258
[
201259
"torch.ops.quantized_decomposed.dequantize_per_tensor.default",
202260
"torch.ops.quantized_decomposed.quantize_per_tensor.default",
203261
],
262+
suffix="quant_nodes",
204263
)
205264

206265
remove_quant_nodes_stage = (
@@ -215,10 +274,11 @@ def __init__(
215274
"torch.ops.quantized_decomposed.dequantize_per_tensor.default",
216275
"torch.ops.quantized_decomposed.quantize_per_tensor.default",
217276
],
277+
suffix="quant_nodes",
218278
)
219279

220280
self.add_stage(
221-
-1, self.tester.run_method_and_compare_outputs, inputs=self.test_data
281+
self.tester.run_method_and_compare_outputs, inputs=self.test_data
222282
)
223283

224284

@@ -252,10 +312,11 @@ def __init__(
252312
"torch.ops.quantized_decomposed.dequantize_per_tensor.default",
253313
"torch.ops.quantized_decomposed.quantize_per_tensor.default",
254314
],
315+
suffix="quant_nodes",
255316
)
256317

257318
self.add_stage(
258-
-1, self.tester.run_method_and_compare_outputs, inputs=self.test_data
319+
self.tester.run_method_and_compare_outputs, inputs=self.test_data
259320
)
260321

261322

@@ -280,14 +341,15 @@ def __init__(
280341
compile_spec,
281342
use_to_edge_transform_and_lower,
282343
)
283-
self.add_stage(0, self.tester.quantize)
344+
self.add_stage(self.tester.quantize, pos=0)
284345
self.add_stage_after(
285346
"quantize",
286347
self.tester.check,
287348
[
288349
"torch.ops.quantized_decomposed.dequantize_per_tensor.default",
289350
"torch.ops.quantized_decomposed.quantize_per_tensor.default",
290351
],
352+
suffix="quant_nodes",
291353
)
292354

293355
remove_quant_nodes_stage = (
@@ -302,12 +364,12 @@ def __init__(
302364
"torch.ops.quantized_decomposed.dequantize_per_tensor.default",
303365
"torch.ops.quantized_decomposed.quantize_per_tensor.default",
304366
],
367+
suffix="quant_nodes",
305368
)
306369

307370
if run_on_fvp:
308-
self.add_stage(-1, self.tester.serialize)
371+
self.add_stage(self.tester.serialize)
309372
self.add_stage(
310-
-1,
311373
self.tester.run_method_and_compare_outputs,
312374
qtol=1,
313375
inputs=self.test_data,
@@ -335,14 +397,15 @@ def __init__(
335397
compile_spec,
336398
use_to_edge_transform_and_lower,
337399
)
338-
self.add_stage(0, self.tester.quantize)
400+
self.add_stage(self.tester.quantize, pos=0)
339401
self.add_stage_after(
340402
"quantize",
341403
self.tester.check,
342404
[
343405
"torch.ops.quantized_decomposed.dequantize_per_tensor.default",
344406
"torch.ops.quantized_decomposed.quantize_per_tensor.default",
345407
],
408+
suffix="quant_nodes",
346409
)
347410

348411
remove_quant_nodes_stage = (
@@ -357,12 +420,12 @@ def __init__(
357420
"torch.ops.quantized_decomposed.dequantize_per_tensor.default",
358421
"torch.ops.quantized_decomposed.quantize_per_tensor.default",
359422
],
423+
suffix="quant_nodes",
360424
)
361425

362426
if run_on_fvp:
363-
self.add_stage(-1, self.tester.serialize)
427+
self.add_stage(self.tester.serialize)
364428
self.add_stage(
365-
-1,
366429
self.tester.run_method_and_compare_outputs,
367430
qtol=1,
368431
inputs=self.test_data,

0 commit comments

Comments
 (0)