Skip to content
This repository was archived by the owner on Feb 20, 2025. It is now read-only.

Commit a735a9c

Browse files
committed
fix: proper typing for call and acall of Step
1 parent 3486b27 commit a735a9c

File tree

3 files changed

+64
-22
lines changed

3 files changed

+64
-22
lines changed

examples/simple/worker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import time
22

3+
from dotenv import load_dotenv
4+
35
from hatchet_sdk import Context
46
from hatchet_sdk.v2 import Hatchet, Workflow, WorkflowConfig
57

6-
from dotenv import load_dotenv
7-
88
load_dotenv()
99

1010
hatchet = Hatchet(debug=True)

hatchet_sdk/v2/workflows.py

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
11
import asyncio
22
from enum import Enum
3-
from typing import Any, Callable, Generic, ParamSpec, Type, TypeVar, Union, cast
3+
from typing import (
4+
Any,
5+
Awaitable,
6+
Callable,
7+
Generic,
8+
ParamSpec,
9+
Type,
10+
TypeGuard,
11+
TypeVar,
12+
Union,
13+
cast,
14+
overload,
15+
)
416

517
from pydantic import BaseModel, ConfigDict
618

@@ -16,8 +28,6 @@
1628
)
1729
from hatchet_sdk.contracts.workflows_pb2 import StickyStrategy as StickyStrategyProto
1830
from hatchet_sdk.contracts.workflows_pb2 import WorkflowConcurrencyOpts, WorkflowKind
19-
from hatchet_sdk.labels import DesiredWorkerLabel
20-
from hatchet_sdk.rate_limit import RateLimit
2131

2232
from ..logger import logger
2333

@@ -57,6 +67,7 @@ class StickyStrategy(str, Enum):
5767

5868
class WorkflowConfig(BaseModel):
5969
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
70+
6071
name: str = ""
6172
on_events: list[str] = []
6273
on_crons: list[str] = []
@@ -75,10 +86,23 @@ class StepType(str, Enum):
7586
ON_FAILURE = "on_failure"
7687

7788

89+
AsyncFunc = Callable[[Any, Context], Awaitable[R]]
90+
SyncFunc = Callable[[Any, Context], R]
91+
StepFunc = Union[AsyncFunc[R], SyncFunc[R]]
92+
93+
94+
def is_async_fn(fn: StepFunc[R]) -> TypeGuard[AsyncFunc[R]]:
95+
return asyncio.iscoroutinefunction(fn)
96+
97+
98+
def is_sync_fn(fn: StepFunc[R]) -> TypeGuard[SyncFunc[R]]:
99+
return not asyncio.iscoroutinefunction(fn)
100+
101+
78102
class Step(Generic[R]):
79103
def __init__(
80104
self,
81-
fn: Callable[[Any, Context], R],
105+
fn: Callable[[Any, Context], R] | Callable[[Any, Context], Awaitable[R]],
82106
type: StepType,
83107
name: str = "",
84108
timeout: str = "60m",
@@ -90,7 +114,7 @@ def __init__(
90114
backoff_max_seconds: int | None = None,
91115
) -> None:
92116
self.fn = fn
93-
self.is_async_function = bool(asyncio.iscoroutinefunction(fn))
117+
self.is_async_function = is_async_fn(fn)
94118

95119
self.type = type
96120
self.timeout = timeout
@@ -104,8 +128,28 @@ def __init__(
104128
self.concurrency__max_runs = 1
105129
self.concurrency__limit_strategy = ConcurrencyLimitStrategy.CANCEL_IN_PROGRESS
106130

107-
def __call__(self, ctx: Context) -> R:
108-
return self.fn(None, ctx)
131+
def call(self, ctx: Context) -> R:
132+
if self.is_async_function:
133+
raise TypeError(f"{self.name} is not a sync function. Use `acall` instead.")
134+
135+
sync_fn = self.fn
136+
if is_sync_fn(sync_fn):
137+
return sync_fn(None, ctx)
138+
139+
raise TypeError(f"{self.name} is not a sync function. Use `acall` instead.")
140+
141+
async def acall(self, ctx: Context) -> R:
142+
if not self.is_async_function:
143+
raise TypeError(
144+
f"{self.name} is not an async function. Use `call` instead."
145+
)
146+
147+
async_fn = self.fn
148+
149+
if is_async_fn(async_fn):
150+
return await async_fn(None, ctx)
151+
152+
raise TypeError(f"{self.name} is not an async function. Use `call` instead.")
109153

110154

111155
class Workflow:

hatchet_sdk/worker/runner/runner.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
from hatchet_sdk.worker.action_listener_process import ActionEvent
4848
from hatchet_sdk.worker.runner.utils.capture_logs import copy_context_vars, sr, wr
4949

50+
T = TypeVar("T")
51+
5052
if TYPE_CHECKING:
5153
from hatchet_sdk.v2.workflows import Step
5254

@@ -65,7 +67,7 @@ def __init__(
6567
event_queue: "Queue[Any]",
6668
max_runs: int | None = None,
6769
handle_kill: bool = True,
68-
action_registry: dict[str, "Step[Any]"] = {},
70+
action_registry: dict[str, "Step[T]"] = {},
6971
validator_registry: dict[str, WorkflowValidator] = {},
7072
config: ClientConfig = ClientConfig(),
7173
labels: dict[str, str | int] = {},
@@ -77,7 +79,7 @@ def __init__(
7779
self.max_runs = max_runs
7880
self.tasks: dict[str, asyncio.Task[Any]] = {} # Store run ids and futures
7981
self.contexts: dict[str, Context] = {} # Store run ids and contexts
80-
self.action_registry: dict[str, "Step[Any]"] = action_registry
82+
self.action_registry: dict[str, "Step[T]"] = action_registry
8183
self.validator_registry = validator_registry
8284

8385
self.event_queue = event_queue
@@ -223,9 +225,7 @@ def inner_callback(task: asyncio.Task[Any]) -> None:
223225
return inner_callback
224226

225227
## TODO: Stricter type hinting here
226-
def thread_action_func(
227-
self, context: Context, action_func: Callable[..., Any], action: Action
228-
) -> Any:
228+
def thread_action_func(self, context: Context, step: Step[T], action: Action) -> T:
229229
if action.step_run_id is not None and action.step_run_id != "":
230230
self.threads[action.step_run_id] = current_thread()
231231
elif (
@@ -234,25 +234,23 @@ def thread_action_func(
234234
):
235235
self.threads[action.get_group_key_run_id] = current_thread()
236236

237-
return action_func(context)
237+
return step.call(context)
238238

239239
## TODO: Stricter type hinting here
240240
# We wrap all actions in an async func
241241
async def async_wrapped_action_func(
242242
self,
243243
context: Context,
244-
action_func: Callable[..., Any],
244+
step: Step[T],
245245
action: Action,
246246
run_id: str,
247-
) -> Any:
247+
) -> T:
248248
wr.set(context.workflow_run_id())
249249
sr.set(context.step_run_id)
250250

251251
try:
252-
if (
253-
hasattr(action_func, "is_coroutine") and action_func.is_coroutine
254-
) or asyncio.iscoroutinefunction(action_func):
255-
return await action_func(context)
252+
if step.is_async_function:
253+
return await step.acall(context)
256254
else:
257255
pfunc = functools.partial(
258256
# we must copy the context vars to the new thread, as only asyncio natively supports
@@ -261,7 +259,7 @@ async def async_wrapped_action_func(
261259
contextvars.copy_context().items(),
262260
self.thread_action_func,
263261
context,
264-
action_func,
262+
step,
265263
action,
266264
)
267265

0 commit comments

Comments
 (0)