Skip to content

Commit d308fdc

Browse files
authored
Merge branch 'main' into fix-discord-link
2 parents e86a46c + 5aa0ee2 commit d308fdc

File tree

3 files changed

+17
-1
lines changed

3 files changed

+17
-1
lines changed

backends/arm/test/ops/test_bmm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def test_bmm_tosa_MI(self, test_data_generator: Callable[[], Tuple]):
124124
self._test_bmm_tosa_MI_pipeline(self.BMM(), test_data)
125125

126126
@parameterized.expand(BMMSingleInput.test_data_generators)
127+
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
127128
def test_bmm_single_input_tosa_MI(self, test_data_generator: Callable[[], Tuple]):
128129
test_data = test_data_generator()
129130
self._test_bmm_tosa_MI_pipeline(self.BMMSingleInput(), test_data)
@@ -144,6 +145,7 @@ def test_bmm_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
144145
self._test_bmm_tosa_BI_pipeline(self.BMM(), test_data)
145146

146147
@parameterized.expand(BMMSingleInput.test_data_generators)
148+
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
147149
def test_bmm_single_input_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
148150
test_data = test_data_generator()
149151
self._test_bmm_tosa_BI_pipeline(self.BMMSingleInput(), test_data)

backends/arm/test/ops/test_mm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def test_mm_tosa_MI(self, test_data_generator: Callable[[], Tuple]):
115115
self._test_mm_tosa_MI_pipeline(self.MM(), test_data)
116116

117117
@parameterized.expand(MMSingleInput.test_data_generators)
118+
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
118119
def test_mm_single_input_tosa_MI(self, test_data_generator: Callable[[], Tuple]):
119120
test_data = test_data_generator()
120121
self._test_mm_tosa_MI_pipeline(self.MMSingleInput(), test_data)
@@ -126,6 +127,7 @@ def test_mm_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
126127
self._test_mm_tosa_BI_pipeline(self.MM(), test_data)
127128

128129
@parameterized.expand(MMSingleInput.test_data_generators)
130+
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
129131
def test_mm_single_input_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
130132
test_data = test_data_generator()
131133
self._test_mm_tosa_BI_pipeline(self.MMSingleInput(), test_data)

backends/arm/test/runner_utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,10 @@ def get_output_quantization_params(
157157
class TosaReferenceModelDispatch(TorchFunctionMode):
158158
"""A context manager for executing call_delegate nodes using the reference model"""
159159

160+
def __init__(self):
161+
self.ran_tosa_dispatch = False
162+
super().__init__()
163+
160164
def _tosa_dispatch(self, lowered_backend_module: LoweredBackendModule, inputs):
161165
tosa_buffer = lowered_backend_module.processed_bytes
162166
compile_specs = lowered_backend_module.compile_specs
@@ -168,13 +172,21 @@ def _tosa_dispatch(self, lowered_backend_module: LoweredBackendModule, inputs):
168172

169173
return run_tosa_graph(tosa_buffer, tosa_version, inputs)
170174

175+
def __exit__(self, exc_type, exc_val, exc_tb):
176+
super().__exit__(exc_type, exc_val, exc_tb)
177+
if not self.ran_tosa_dispatch:
178+
raise RuntimeError(
179+
"Ran model with TosaReferenceModelDispatch but never ran ArmBackend delegate."
180+
)
181+
171182
def __torch_function__(self, func, types, args=..., kwargs=None):
172183
if func is torch._higher_order_ops.executorch_call_delegate:
173184
lowered_backend_module = cast(LoweredBackendModule, args[0])
174185
if lowered_backend_module.backend_id == "ArmBackend":
186+
self.ran_tosa_dispatch = True
175187
return self._tosa_dispatch(lowered_backend_module, args[1:])
176188
else:
177-
logger.warning(
189+
raise RuntimeError(
178190
f"Ran model with TosaReferenceModelDispatch but call_delegate with {lowered_backend_module.backend_id=} != 'ArmBackend'."
179191
)
180192

0 commit comments

Comments
 (0)