@@ -157,6 +157,10 @@ def get_output_quantization_params(
157157class TosaReferenceModelDispatch (TorchFunctionMode ):
158158 """A context manager for executing call_delegate nodes using the reference model"""
159159
160+ def __init__ (self ):
161+ self .ran_tosa_dispatch = False
162+ super ().__init__ ()
163+
160164 def _tosa_dispatch (self , lowered_backend_module : LoweredBackendModule , inputs ):
161165 tosa_buffer = lowered_backend_module .processed_bytes
162166 compile_specs = lowered_backend_module .compile_specs
@@ -168,13 +172,21 @@ def _tosa_dispatch(self, lowered_backend_module: LoweredBackendModule, inputs):
168172
169173 return run_tosa_graph (tosa_buffer , tosa_version , inputs )
170174
175+ def __exit__ (self , exc_type , exc_val , exc_tb ):
176+ super ().__exit__ (exc_type , exc_val , exc_tb )
177+ if not self .ran_tosa_dispatch :
178+ raise RuntimeError (
179+ "Ran model with TosaReferenceModelDispatch but never ran ArmBackend delegate."
180+ )
181+
171182 def __torch_function__ (self , func , types , args = ..., kwargs = None ):
172183 if func is torch ._higher_order_ops .executorch_call_delegate :
173184 lowered_backend_module = cast (LoweredBackendModule , args [0 ])
174185 if lowered_backend_module .backend_id == "ArmBackend" :
186+ self .ran_tosa_dispatch = True
175187 return self ._tosa_dispatch (lowered_backend_module , args [1 :])
176188 else :
177- logger . warning (
189+ raise RuntimeError (
178190 f"Ran model with TosaReferenceModelDispatch but call_delegate with { lowered_backend_module .backend_id = } != 'ArmBackend'."
179191 )
180192
0 commit comments