|
49 | 49 | from hatchet_sdk.runnables.task import Task |
50 | 50 | from hatchet_sdk.runnables.types import R, TWorkflowInput |
51 | 51 | from hatchet_sdk.serde import HATCHET_PYDANTIC_SENTINEL |
| 52 | +from hatchet_sdk.utils.cache import BoundedDict |
52 | 53 | from hatchet_sdk.utils.serde import remove_null_unicode_character |
53 | 54 | from hatchet_sdk.utils.typing import DataclassInstance |
54 | 55 | from hatchet_sdk.worker.action_listener_process import ActionEvent |
@@ -86,6 +87,7 @@ def __init__( |
86 | 87 | self.slots = slots |
87 | 88 | self.tasks: dict[ActionKey, asyncio.Task[Any]] = {} # Store run ids and futures |
88 | 89 | self.contexts: dict[ActionKey, Context] = {} # Store run ids and contexts |
| 90 | + self.cancellations = BoundedDict[str, bool](maxsize=1000) |
89 | 91 | self.action_registry = action_registry or {} |
90 | 92 |
|
91 | 93 | self.event_queue = event_queue |
@@ -156,8 +158,9 @@ def step_run_callback( |
156 | 158 | ) -> Callable[[asyncio.Task[Any]], None]: |
157 | 159 | def inner_callback(task: asyncio.Task[Any]) -> None: |
158 | 160 | self.cleanup_run_id(action.key) |
| 161 | + was_cancelled = self.cancellations.pop(action.key, False) |
159 | 162 |
|
160 | | - if task.cancelled(): |
| 163 | + if was_cancelled or task.cancelled(): |
161 | 164 | return |
162 | 165 |
|
163 | 166 | try: |
@@ -348,6 +351,9 @@ def cleanup_run_id(self, key: ActionKey) -> None: |
348 | 351 | del self.threads[key] |
349 | 352 |
|
350 | 353 | if key in self.contexts: |
| 354 | + if self.contexts[key].exit_flag: |
| 355 | + self.cancellations[key] = True |
| 356 | + |
351 | 357 | del self.contexts[key] |
352 | 358 |
|
353 | 359 | @overload |
@@ -467,6 +473,7 @@ async def handle_cancel_action(self, action: Action) -> None: |
467 | 473 | # call cancel to signal the context to stop |
468 | 474 | if key in self.contexts: |
469 | 475 | self.contexts[key]._set_cancellation_flag() |
| 476 | + self.cancellations[key] = True |
470 | 477 |
|
471 | 478 | await asyncio.sleep(1) |
472 | 479 |
|
|
0 commit comments