Skip to content

Commit 917b261

Browse files
pytorchbotmansnils
andauthored
Arm backend: Fix Arm tester issue for inplace ops (#14855)
Deep-copying the input avoids it getting mutated by the first reference run. cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 Co-authored-by: Måns Nilsson <[email protected]>
1 parent b3ad794 commit 917b261

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

backends/arm/test/tester/arm_tester.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,10 @@ def run_method_and_compare_outputs(
458458
for run_iteration in range(num_runs):
459459
reference_input = inputs if inputs else next(self.generate_random_inputs())
460460

461+
# Avoid issues with inplace operators
462+
test_input = copy.deepcopy(reference_input)
463+
original_input = copy.deepcopy(reference_input)
464+
461465
input_shapes = [
462466
generated_input.shape if hasattr(generated_input, "shape") else (1,)
463467
for generated_input in reference_input
@@ -472,16 +476,16 @@ def run_method_and_compare_outputs(
472476
# Run exported module directly
473477
test_outputs, _ = pytree.tree_flatten(
474478
self._calculate_reference_output(
475-
exported_program.module(), reference_input
479+
exported_program.module(), test_input
476480
)
477481
)
478482
else:
479483
# Run lowered model with target
480484
test_outputs, _ = pytree.tree_flatten(
481-
test_stage.run_artifact(reference_input)
485+
test_stage.run_artifact(test_input)
482486
)
483487

484-
logger.info(f"\n Input: {reference_input}")
488+
logger.info(f"\n Input: {original_input}")
485489
logger.info(f"\n Ref output: {reference_outputs}")
486490
logger.info(f"\nTest output: {test_outputs}")
487491

0 commit comments

Comments
 (0)