Skip to content
This repository was archived by the owner on Feb 20, 2025. It is now read-only.

Commit e6843ec

Browse files
committed
fix: clean up step registration
1 parent 562bd3c commit e6843ec

File tree

4 files changed

+33
-37
lines changed

4 files changed

+33
-37
lines changed

hatchet_sdk/v2/workflows.py

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def __init__(
116116
self,
117117
fn: Callable[[Any, Context], R] | Callable[[Any, Context], Awaitable[R]],
118118
type: StepType,
119+
workflow: Union["BaseWorkflow", None] = None,
119120
name: str = "",
120121
timeout: str = "60m",
121122
parents: list[str] = [],
@@ -129,6 +130,7 @@ def __init__(
129130
) -> None:
130131
self.fn = fn
131132
self.is_async_function = is_async_fn(fn)
133+
self.workflow = workflow
132134

133135
self.type = type
134136
self.timeout = timeout
@@ -142,44 +144,38 @@ def __init__(
142144
self.concurrency__max_runs = concurrency__max_runs
143145
self.concurrency__limit_strategy = concurrency__limit_strategy
144146

145-
146-
class RegisteredStep(Generic[R]):
147-
def __init__(
148-
self,
149-
workflow: "BaseWorkflow",
150-
step: Step[R],
151-
) -> None:
152-
self.workflow = workflow
153-
self.step = step
154-
155147
def call(self, ctx: Context) -> R:
156-
if self.step.is_async_function:
157-
raise TypeError(
158-
f"{self.step.name} is not a sync function. Use `acall` instead."
159-
)
148+
if not self.is_registered:
149+
raise ValueError("Only steps that have been registered can be called.")
150+
151+
if self.is_async_function:
152+
raise TypeError(f"{self.name} is not a sync function. Use `acall` instead.")
160153

161-
sync_fn = self.step.fn
154+
sync_fn = self.fn
162155
if is_sync_fn(sync_fn):
163156
return sync_fn(self.workflow, ctx)
164157

165-
raise TypeError(
166-
f"{self.step.name} is not a sync function. Use `acall` instead."
167-
)
158+
raise TypeError(f"{self.name} is not a sync function. Use `acall` instead.")
168159

169160
async def acall(self, ctx: Context) -> R:
170-
if not self.step.is_async_function:
161+
if not self.is_registered:
162+
raise ValueError("Only steps that have been registered can be called.")
163+
164+
if not self.is_async_function:
171165
raise TypeError(
172-
f"{self.step.name} is not an async function. Use `call` instead."
166+
f"{self.name} is not an async function. Use `call` instead."
173167
)
174168

175-
async_fn = self.step.fn
169+
async_fn = self.fn
176170

177171
if is_async_fn(async_fn):
178172
return await async_fn(self.workflow, ctx)
179173

180-
raise TypeError(
181-
f"{self.step.name} is not an async function. Use `call` instead."
182-
)
174+
raise TypeError(f"{self.name} is not an async function. Use `call` instead.")
175+
176+
@property
177+
def is_registered(self) -> bool:
178+
return self.workflow is not None
183179

184180

185181
class WorkflowDeclaration(Generic[TWorkflowInput]):

hatchet_sdk/worker/runner/run_loop_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from hatchet_sdk.worker.runner.utils.capture_logs import capture_logs
1616

1717
if TYPE_CHECKING:
18-
from hatchet_sdk.v2.workflows import RegisteredStep
18+
from hatchet_sdk.v2.workflows import Step
1919

2020
STOP_LOOP_TYPE = Literal["STOP_LOOP"]
2121
STOP_LOOP: STOP_LOOP_TYPE = "STOP_LOOP"
@@ -27,7 +27,7 @@
2727
@dataclass
2828
class WorkerActionRunLoopManager:
2929
name: str
30-
action_registry: dict[str, "RegisteredStep[Any]"]
30+
action_registry: dict[str, "Step[Any]"]
3131
validator_registry: dict[str, WorkflowValidator]
3232
max_runs: int | None
3333
config: ClientConfig

hatchet_sdk/worker/runner/runner.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
T = TypeVar("T")
4141

4242
if TYPE_CHECKING:
43-
from hatchet_sdk.v2.workflows import RegisteredStep
43+
from hatchet_sdk.v2.workflows import Step
4444

4545

4646
class WorkerStatus(Enum):
@@ -57,7 +57,7 @@ def __init__(
5757
event_queue: "Queue[Any]",
5858
max_runs: int | None = None,
5959
handle_kill: bool = True,
60-
action_registry: dict[str, "RegisteredStep[T]"] = {},
60+
action_registry: dict[str, "Step[T]"] = {},
6161
validator_registry: dict[str, WorkflowValidator] = {},
6262
config: ClientConfig = ClientConfig(),
6363
labels: dict[str, str | int] = {},
@@ -69,7 +69,7 @@ def __init__(
6969
self.max_runs = max_runs
7070
self.tasks: dict[str, asyncio.Task[Any]] = {} # Store run ids and futures
7171
self.contexts: dict[str, Context] = {} # Store run ids and contexts
72-
self.action_registry: dict[str, "RegisteredStep[T]"] = action_registry
72+
self.action_registry: dict[str, "Step[T]"] = action_registry
7373
self.validator_registry = validator_registry
7474

7575
self.event_queue = event_queue
@@ -216,7 +216,7 @@ def inner_callback(task: asyncio.Task[Any]) -> None:
216216

217217
## TODO: Stricter type hinting here
218218
def thread_action_func(
219-
self, context: Context, step: "RegisteredStep[T]", action: Action
219+
self, context: Context, step: "Step[T]", action: Action
220220
) -> T:
221221
if action.step_run_id is not None and action.step_run_id != "":
222222
self.threads[action.step_run_id] = current_thread()
@@ -233,15 +233,15 @@ def thread_action_func(
233233
async def async_wrapped_action_func(
234234
self,
235235
context: Context,
236-
step: "RegisteredStep[T]",
236+
step: "Step[T]",
237237
action: Action,
238238
run_id: str,
239239
) -> T:
240240
wr.set(context.workflow_run_id())
241241
sr.set(context.step_run_id)
242242

243243
try:
244-
if step.step.is_async_function:
244+
if step.is_async_function:
245245
return await step.acall(context)
246246
else:
247247
pfunc = functools.partial(

hatchet_sdk/worker/worker.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from hatchet_sdk.logger import logger
2626
from hatchet_sdk.utils.types import WorkflowValidator
2727
from hatchet_sdk.utils.typing import is_basemodel_subclass
28-
from hatchet_sdk.v2.workflows import RegisteredStep, StepType
28+
from hatchet_sdk.v2.workflows import Step, StepType
2929
from hatchet_sdk.worker.action_listener_process import (
3030
ActionEvent,
3131
worker_action_listener_process,
@@ -74,7 +74,7 @@ def __init__(
7474

7575
self.client: Client
7676

77-
self.action_registry: dict[str, RegisteredStep[Any]] = {}
77+
self.action_registry: dict[str, Step[Any]] = {}
7878
self.validator_registry: dict[str, WorkflowValidator] = {}
7979

8080
self.killing: bool = False
@@ -123,10 +123,10 @@ def register_workflow(self, workflow: Union["BaseWorkflow", Any]) -> None:
123123
sys.exit(1)
124124

125125
for step in workflow.steps:
126+
step.workflow = workflow
127+
126128
action_name = workflow.create_action_name(namespace, step)
127-
self.action_registry[action_name] = RegisteredStep(
128-
workflow=workflow, step=step
129-
)
129+
self.action_registry[action_name] = step
130130
return_type = get_type_hints(step.fn).get("return")
131131

132132
self.validator_registry[action_name] = WorkflowValidator(

0 commit comments

Comments
 (0)