Skip to content

Commit 63d391b

Browse files
Arm backend: Fixed numerical difference detection test (#15048)
1 parent 57a7903 commit 63d391b

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

backends/arm/test/misc/test_debug_feats.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
TosaPipelineFP,
2222
TosaPipelineINT,
2323
)
24+
from executorch.backends.test.harness.stages import StageType
2425

2526

2627
input_t1 = Tuple[torch.Tensor] # Input x
@@ -104,7 +105,7 @@ def test_INT_artifact(test_data: input_t1):
104105

105106
@common.parametrize("test_data", Linear.inputs)
106107
def test_numerical_diff_print(test_data: input_t1):
107-
pipeline = TosaPipelineFP[input_t1](
108+
pipeline = TosaPipelineINT[input_t1](
108109
Linear(),
109110
test_data,
110111
[],
@@ -119,7 +120,9 @@ def test_numerical_diff_print(test_data: input_t1):
119120
# not present.
120121
try:
121122
# Tolerate 0 difference => we want to trigger a numerical diff
122-
tester.run_method_and_compare_outputs(atol=0, rtol=0, qtol=0)
123+
tester.run_method_and_compare_outputs(
124+
stage=StageType.INITIAL_MODEL, atol=0, rtol=0, qtol=0
125+
)
123126
except AssertionError:
124127
pass # Implicit pass test
125128
else:

0 commit comments

Comments
 (0)