Skip to content

Commit e74c265

Browse files
Fixed numerical difference detection test
Change-Id: I5fcd32b128d215c30f50b6e8b4c819312f4e9301
1 parent d00279d commit e74c265

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

backends/arm/test/misc/test_debug_feats.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
TosaPipelineFP,
2222
TosaPipelineINT,
2323
)
24+
from executorch.backends.test.harness.stages import StageType
2425

2526

2627
input_t1 = Tuple[torch.Tensor] # Input x
@@ -104,7 +105,7 @@ def test_INT_artifact(test_data: input_t1):
104105

105106
@common.parametrize("test_data", Linear.inputs)
106107
def test_numerical_diff_print(test_data: input_t1):
107-
pipeline = TosaPipelineFP[input_t1](
108+
pipeline = TosaPipelineINT[input_t1](
108109
Linear(),
109110
test_data,
110111
[],
@@ -119,7 +120,7 @@ def test_numerical_diff_print(test_data: input_t1):
119120
# not present.
120121
try:
121122
# Tolerate 0 difference => we want to trigger a numerical diff
122-
tester.run_method_and_compare_outputs(atol=0, rtol=0, qtol=0)
123+
tester.run_method_and_compare_outputs(stage=StageType.INITIAL_MODEL, atol=0, rtol=0, qtol=0)
123124
except AssertionError:
124125
pass # Implicit pass test
125126
else:

0 commit comments

Comments
 (0)