@@ -160,6 +160,9 @@ def driver_gen() -> Generator[task.Task[Any], Any, Any]:
160160 except StopIteration as stop :
161161 return stop .value
162162 except Exception as e :
163+ # Re-raise NonRetryableError directly to preserve its type for the runtime
164+ if isinstance (e , task .NonRetryableError ):
165+ raise
163166 raise AsyncWorkflowError (
164167 f"Workflow failed during initialization: { e } " ,
165168 workflow_name = self ._workflow_name ,
@@ -199,6 +202,15 @@ def _one_shot() -> Generator[task.Task[Any], Any, Any]:
199202 except StopIteration as stop :
200203 return stop .value
201204 except Exception as e :
205+ # Check if this is a TaskFailedError wrapping a NonRetryableError
206+ if isinstance (e , task .TaskFailedError ):
207+ details = e .details
208+ if details .error_type == "NonRetryableError" :
209+ # Reconstruct NonRetryableError to preserve its type for the runtime
210+ raise task .NonRetryableError (details .message ) from e
211+ # Re-raise NonRetryableError directly to preserve its type for the runtime
212+ if isinstance (e , task .NonRetryableError ):
213+ raise
202214 raise AsyncWorkflowError (
203215 f"Workflow failed: { e } " ,
204216 workflow_name = self ._workflow_name ,
@@ -231,6 +243,15 @@ def _one_shot() -> Generator[task.Task[Any], Any, Any]:
231243 except StopIteration as stop :
232244 return stop .value
233245 except Exception as workflow_exc :
246+ # Check if this is a TaskFailedError wrapping a NonRetryableError
247+ if isinstance (workflow_exc , task .TaskFailedError ):
248+ details = workflow_exc .details
249+ if details .error_type == "NonRetryableError" :
250+ # Reconstruct NonRetryableError to preserve its type for the runtime
251+ raise task .NonRetryableError (details .message ) from workflow_exc
252+ # Re-raise NonRetryableError directly to preserve its type for the runtime
253+ if isinstance (workflow_exc , task .NonRetryableError ):
254+ raise
234255 raise AsyncWorkflowError (
235256 f"Workflow failed: { workflow_exc } " ,
236257 workflow_name = self ._workflow_name ,
@@ -247,6 +268,15 @@ def _one_shot() -> Generator[task.Task[Any], Any, Any]:
247268 except StopIteration as stop :
248269 return stop .value
249270 except Exception as workflow_exc :
271+ # Check if this is a TaskFailedError wrapping a NonRetryableError
272+ if isinstance (workflow_exc , task .TaskFailedError ):
273+ details = workflow_exc .details
274+ if details .error_type == "NonRetryableError" :
275+ # Reconstruct NonRetryableError to preserve its type for the runtime
276+ raise task .NonRetryableError (details .message ) from workflow_exc
277+ # Re-raise NonRetryableError directly to preserve its type for the runtime
278+ if isinstance (workflow_exc , task .NonRetryableError ):
279+ raise
250280 raise AsyncWorkflowError (
251281 f"Workflow failed: { workflow_exc } " ,
252282 workflow_name = self ._workflow_name ,
@@ -277,6 +307,15 @@ def _one_shot() -> Generator[task.Task[Any], Any, Any]:
277307 except StopIteration as stop :
278308 return stop .value
279309 except Exception as e :
310+ # Check if this is a TaskFailedError wrapping a NonRetryableError
311+ if isinstance (e , task .TaskFailedError ):
312+ details = e .details
313+ if details .error_type == "NonRetryableError" :
314+ # Reconstruct NonRetryableError to preserve its type for the runtime
315+ raise task .NonRetryableError (details .message ) from e
316+ # Re-raise NonRetryableError directly to preserve its type for the runtime
317+ if isinstance (e , task .NonRetryableError ):
318+ raise
280319 raise AsyncWorkflowError (
281320 f"Workflow failed: { e } " ,
282321 workflow_name = self ._workflow_name ,
0 commit comments