Skip to content

Commit b6bc421

Browse files
authored
Arm backend: Fix Arm tester issue for inplace ops (pytorch#14625)
Deep-copying the input avoids it getting mutated by the first reference run.
1 parent 8484aee commit b6bc421

File tree

2 files changed

+7
-8
lines changed

2 files changed

+7
-8
lines changed

backends/arm/test/ops/test_silu.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from typing import Optional, Tuple
1010

11-
import pytest
1211
import torch
1312
from executorch.backends.arm.test import common
1413
from executorch.backends.arm.test.tester.test_pipeline import (
@@ -125,7 +124,6 @@ def test_silu_u85_INT_inplace(test_data: input_t):
125124

126125
@common.parametrize("test_data", Silu.test_data)
127126
@common.SkipIfNoModelConverter
128-
@pytest.mark.xfail(reason="MLETORCH-1387: Output differs")
129127
def test_silu_vgf_FP(test_data: input_t):
130128
silu_data = (test_data(), False)
131129
pipeline = VgfPipeline[input_t](
@@ -136,7 +134,6 @@ def test_silu_vgf_FP(test_data: input_t):
136134

137135
@common.parametrize("test_data", Silu.test_data)
138136
@common.SkipIfNoModelConverter
139-
@pytest.mark.xfail(reason="MLETORCH-1387: Output differs")
140137
def test_silu_vgf_FP_inplace(test_data: input_t):
141138
silu_data = (test_data(), True)
142139
pipeline = VgfPipeline[input_t](
@@ -147,7 +144,6 @@ def test_silu_vgf_FP_inplace(test_data: input_t):
147144

148145
@common.parametrize("test_data", Silu.test_data)
149146
@common.SkipIfNoModelConverter
150-
@pytest.mark.xfail(reason="MLETORCH-1387: Output differs")
151147
def test_silu_vgf_INT(test_data: input_t):
152148
silu_data = (test_data(), False)
153149
pipeline = VgfPipeline[input_t](
@@ -161,7 +157,6 @@ def test_silu_vgf_INT(test_data: input_t):
161157

162158
@common.parametrize("test_data", Silu.test_data)
163159
@common.SkipIfNoModelConverter
164-
@pytest.mark.xfail(reason="MLETORCH-1387: Output differs")
165160
def test_silu_vgf_INT_inplace(test_data: input_t):
166161
silu_data = (test_data(), True)
167162
pipeline = VgfPipeline[input_t](

backends/arm/test/tester/arm_tester.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,10 @@ def run_method_and_compare_outputs(
430430
for run_iteration in range(num_runs):
431431
reference_input = inputs if inputs else next(self.generate_random_inputs())
432432

433+
# Avoid issues with inplace operators
434+
test_input = copy.deepcopy(reference_input)
435+
original_input = copy.deepcopy(reference_input)
436+
433437
input_shapes = [
434438
generated_input.shape if hasattr(generated_input, "shape") else (1,)
435439
for generated_input in reference_input
@@ -444,16 +448,16 @@ def run_method_and_compare_outputs(
444448
# Run exported module directly
445449
test_outputs, _ = pytree.tree_flatten(
446450
self._calculate_reference_output(
447-
exported_program.module(), reference_input
451+
exported_program.module(), test_input
448452
)
449453
)
450454
else:
451455
# Run lowered model with target
452456
test_outputs, _ = pytree.tree_flatten(
453-
test_stage.run_artifact(reference_input)
457+
test_stage.run_artifact(test_input)
454458
)
455459

456-
logger.info(f"\n Input: {reference_input}")
460+
logger.info(f"\n Input: {original_input}")
457461
logger.info(f"\n Ref output: {reference_outputs}")
458462
logger.info(f"\nTest output: {test_outputs}")
459463

0 commit comments

Comments
 (0)