Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/arm/test/ops/test_bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def test_bmm_tosa_MI(self, test_data_generator: Callable[[], Tuple]):
self._test_bmm_tosa_MI_pipeline(self.BMM(), test_data)

@parameterized.expand(BMMSingleInput.test_data_generators)
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
def test_bmm_single_input_tosa_MI(self, test_data_generator: Callable[[], Tuple]):
test_data = test_data_generator()
self._test_bmm_tosa_MI_pipeline(self.BMMSingleInput(), test_data)
Expand All @@ -144,6 +145,7 @@ def test_bmm_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
self._test_bmm_tosa_BI_pipeline(self.BMM(), test_data)

@parameterized.expand(BMMSingleInput.test_data_generators)
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
def test_bmm_single_input_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
test_data = test_data_generator()
self._test_bmm_tosa_BI_pipeline(self.BMMSingleInput(), test_data)
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/test/ops/test_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def test_mm_tosa_MI(self, test_data_generator: Callable[[], Tuple]):
self._test_mm_tosa_MI_pipeline(self.MM(), test_data)

@parameterized.expand(MMSingleInput.test_data_generators)
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
def test_mm_single_input_tosa_MI(self, test_data_generator: Callable[[], Tuple]):
test_data = test_data_generator()
self._test_mm_tosa_MI_pipeline(self.MMSingleInput(), test_data)
Expand All @@ -126,6 +127,7 @@ def test_mm_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
self._test_mm_tosa_BI_pipeline(self.MM(), test_data)

@parameterized.expand(MMSingleInput.test_data_generators)
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
def test_mm_single_input_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
test_data = test_data_generator()
self._test_mm_tosa_BI_pipeline(self.MMSingleInput(), test_data)
Expand Down
14 changes: 13 additions & 1 deletion backends/arm/test/runner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ def get_output_quantization_params(
class TosaReferenceModelDispatch(TorchFunctionMode):
"""A context manager for executing call_delegate nodes using the reference model"""

def __init__(self):
self.ran_tosa_dispatch = False
super().__init__()

def _tosa_dispatch(self, lowered_backend_module: LoweredBackendModule, inputs):
tosa_buffer = lowered_backend_module.processed_bytes
compile_specs = lowered_backend_module.compile_specs
Expand All @@ -168,13 +172,21 @@ def _tosa_dispatch(self, lowered_backend_module: LoweredBackendModule, inputs):

return run_tosa_graph(tosa_buffer, tosa_version, inputs)

def __exit__(self, exc_type, exc_val, exc_tb):
super().__exit__(exc_type, exc_val, exc_tb)
if not self.ran_tosa_dispatch:
raise RuntimeError(
"Ran model with TosaReferenceModelDispatch but never ran ArmBackend delegate."
)

def __torch_function__(self, func, types, args=..., kwargs=None):
if func is torch._higher_order_ops.executorch_call_delegate:
lowered_backend_module = cast(LoweredBackendModule, args[0])
if lowered_backend_module.backend_id == "ArmBackend":
self.ran_tosa_dispatch = True
return self._tosa_dispatch(lowered_backend_module, args[1:])
else:
logger.warning(
raise RuntimeError(
f"Ran model with TosaReferenceModelDispatch but call_delegate with {lowered_backend_module.backend_id=} != 'ArmBackend'."
)

Expand Down
Loading