Skip to content
This repository was archived by the owner on Feb 20, 2025. It is now read-only.

Commit b579ee6

Browse files
committed
fix: step needs to be registered to a workflow before it can be run
1 parent 842283d commit b579ee6

File tree

4 files changed

+25
-12
lines changed

4 files changed

+25
-12
lines changed

hatchet_sdk/v2/workflows.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
from abc import abstractmethod
23
from enum import Enum
34
from typing import (
45
TYPE_CHECKING,
@@ -141,13 +142,23 @@ def __init__(
141142
self.concurrency__max_runs = concurrency__max_runs
142143
self.concurrency__limit_strategy = concurrency__limit_strategy
143144

145+
146+
class RegisteredStep(Step[R]):
147+
def __init__(
148+
self,
149+
workflow: "BaseWorkflowImpl",
150+
step: Step[R],
151+
) -> None:
152+
self.workflow = workflow
153+
self.step = step
154+
144155
def call(self, ctx: Context) -> R:
145156
if self.is_async_function:
146157
raise TypeError(f"{self.name} is not a sync function. Use `acall` instead.")
147158

148159
sync_fn = self.fn
149160
if is_sync_fn(sync_fn):
150-
return sync_fn(None, ctx)
161+
return sync_fn(self.workflow, ctx)
151162

152163
raise TypeError(f"{self.name} is not a sync function. Use `acall` instead.")
153164

@@ -160,7 +171,7 @@ async def acall(self, ctx: Context) -> R:
160171
async_fn = self.fn
161172

162173
if is_async_fn(async_fn):
163-
return await async_fn(None, ctx)
174+
return await async_fn(self.workflow, ctx)
164175

165176
raise TypeError(f"{self.name} is not an async function. Use `call` instead.")
166177

hatchet_sdk/worker/runner/run_loop_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from hatchet_sdk.worker.runner.utils.capture_logs import capture_logs
1616

1717
if TYPE_CHECKING:
18-
from hatchet_sdk.v2.workflows import Step
18+
from hatchet_sdk.v2.workflows import RegisteredStep
1919

2020
STOP_LOOP_TYPE = Literal["STOP_LOOP"]
2121
STOP_LOOP: STOP_LOOP_TYPE = "STOP_LOOP"
@@ -27,7 +27,7 @@
2727
@dataclass
2828
class WorkerActionRunLoopManager:
2929
name: str
30-
action_registry: dict[str, "Step[Any]"]
30+
action_registry: dict[str, "RegisteredStep[Any]"]
3131
validator_registry: dict[str, WorkflowValidator]
3232
max_runs: int | None
3333
config: ClientConfig

hatchet_sdk/worker/runner/runner.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
T = TypeVar("T")
4141

4242
if TYPE_CHECKING:
43-
from hatchet_sdk.v2.workflows import Step
43+
from hatchet_sdk.v2.workflows import RegisteredStep
4444

4545

4646
class WorkerStatus(Enum):
@@ -57,7 +57,7 @@ def __init__(
5757
event_queue: "Queue[Any]",
5858
max_runs: int | None = None,
5959
handle_kill: bool = True,
60-
action_registry: dict[str, "Step[T]"] = {},
60+
action_registry: dict[str, "RegisteredStep[T]"] = {},
6161
validator_registry: dict[str, WorkflowValidator] = {},
6262
config: ClientConfig = ClientConfig(),
6363
labels: dict[str, str | int] = {},
@@ -69,7 +69,7 @@ def __init__(
6969
self.max_runs = max_runs
7070
self.tasks: dict[str, asyncio.Task[Any]] = {} # Store run ids and futures
7171
self.contexts: dict[str, Context] = {} # Store run ids and contexts
72-
self.action_registry: dict[str, "Step[T]"] = action_registry
72+
self.action_registry: dict[str, "RegisteredStep[T]"] = action_registry
7373
self.validator_registry = validator_registry
7474

7575
self.event_queue = event_queue
@@ -216,7 +216,7 @@ def inner_callback(task: asyncio.Task[Any]) -> None:
216216

217217
## TODO: Stricter type hinting here
218218
def thread_action_func(
219-
self, context: Context, step: "Step[T]", action: Action
219+
self, context: Context, step: "RegisteredStep[T]", action: Action
220220
) -> T:
221221
if action.step_run_id is not None and action.step_run_id != "":
222222
self.threads[action.step_run_id] = current_thread()
@@ -233,7 +233,7 @@ def thread_action_func(
233233
async def async_wrapped_action_func(
234234
self,
235235
context: Context,
236-
step: "Step[T]",
236+
step: "RegisteredStep[T]",
237237
action: Action,
238238
run_id: str,
239239
) -> T:

hatchet_sdk/worker/worker.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from hatchet_sdk.logger import logger
2626
from hatchet_sdk.utils.types import WorkflowValidator
2727
from hatchet_sdk.utils.typing import is_basemodel_subclass
28-
from hatchet_sdk.v2.workflows import Step, StepType
28+
from hatchet_sdk.v2.workflows import RegisteredStep, StepType
2929
from hatchet_sdk.worker.action_listener_process import (
3030
ActionEvent,
3131
worker_action_listener_process,
@@ -74,7 +74,7 @@ def __init__(
7474

7575
self.client: Client
7676

77-
self.action_registry: dict[str, Step[Any]] = {}
77+
self.action_registry: dict[str, RegisteredStep[Any]] = {}
7878
self.validator_registry: dict[str, WorkflowValidator] = {}
7979

8080
self.killing: bool = False
@@ -124,7 +124,9 @@ def register_workflow(self, workflow: Union["BaseWorkflowImpl", Any]) -> None:
124124

125125
for step in workflow.steps:
126126
action_name = workflow.create_action_name(namespace, step)
127-
self.action_registry[action_name] = step
127+
self.action_registry[action_name] = RegisteredStep(
128+
workflow=workflow, step=step
129+
)
128130
return_type = get_type_hints(step.fn).get("return")
129131

130132
self.validator_registry[action_name] = WorkflowValidator(

0 commit comments

Comments
 (0)