Skip to content
This repository was archived by the owner on Mar 26, 2025. It is now read-only.

Commit 47d9405

Browse files
committed
wip v2 experimentation
1 parent 008e210 commit 47d9405

File tree

4 files changed

+191
-130
lines changed

4 files changed

+191
-130
lines changed

hatchet_sdk/runtime/__init__.py

Whitespace-only changes.

hatchet_sdk/runtime/admin.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
2+
# import hatchet_sdk.v2.callable as sdk
3+
# import hatchet_sdk.clients.admin as client
4+
5+
# from hatchet_sdk.contracts.workflows_pb2 import (
6+
# CreateStepRateLimit,
7+
# CreateWorkflowJobOpts,
8+
# CreateWorkflowStepOpts,
9+
# CreateWorkflowVersionOpts,
10+
# DesiredWorkerLabels,
11+
# StickyStrategy,
12+
# WorkflowConcurrencyOpts,
13+
# WorkflowKind,
14+
# )
15+
16+
# async def put_workflow(callable: sdk.HatchetCallable, client: client.AdminClient):
17+
# options = callable._.options
18+
19+
# kind: WorkflowKind = WorkflowKind.DURABLE if options.durable else WorkflowKind.FUNCTION
20+

hatchet_sdk/runtime/registry.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from typing import List, Dict
2+
3+
4+
class ActionRegistry:
5+
6+
_registry: Dict[str, "HatchetCallable"] = dict()
7+
8+
def register(self, callable: "HatchetCallable") -> str:
9+
key = "{namespace}:{name}".format(
10+
namespace=callable._.namespace, name=callable._.name
11+
)
12+
self._registry[key] = callable
13+
return key
14+
15+
def list(self) -> List[str]:
16+
return list(self._registry.keys())
17+
18+
19+
global_registry = ActionRegistry()

hatchet_sdk/v2/callable.py

Lines changed: 152 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
import asyncio
2-
from typing import Callable, Dict, Generic, List, Optional, TypedDict, TypeVar, Union
2+
from typing import Dict, Generic, List, Optional, TypedDict, TypeVar, Union, ParamSpec
3+
from collections.abc import Callable, Awaitable
4+
from contextvars import ContextVar
5+
import inspect
6+
7+
from pydantic import BaseModel, Field, computed_field
8+
from hatchet_sdk.runtime import registry
9+
import json
310

411
from hatchet_sdk.context import Context
512
from hatchet_sdk.contracts.workflows_pb2 import (
@@ -17,151 +24,166 @@
1724
from hatchet_sdk.rate_limit import RateLimit
1825
from hatchet_sdk.v2.concurrency import ConcurrencyFunction
1926
from hatchet_sdk.workflow_run import RunRef
27+
from datetime import timedelta
28+
29+
import hatchet_sdk.v2.hatchet as hatchet
2030

2131
T = TypeVar("T")
32+
P = ParamSpec("P")
33+
34+
35+
class Options(BaseModel):
36+
durable: bool = Field(default=False)
37+
auto_register: bool = Field(default=True)
38+
on_failure: Optional["HatchetCallable"] = Field(default=None)
39+
40+
# triggering options
41+
on_events: List[str] = Field(default=[])
42+
on_crons: List[str] = Field(default=[])
43+
44+
# metadata
45+
version: str = Field(default="")
46+
47+
# timeout
48+
execution_timeout: str = Field(default="60m", alias="timeout")
49+
schedule_timeout: str = Field(default="5m")
50+
51+
# execution
52+
sticky: StickyStrategy | None = Field(default=None)
53+
retries: int = Field(default=0, ge=0)
54+
ratelimits: List[RateLimit] = Field(default=[])
55+
priority: Optional[int] = Field(default=None, alias="default_priority", ge=1, le=3)
56+
desired_worker_labels: Dict[str, DesiredWorkerLabel] = Field(default=dict())
57+
concurrency: Optional[ConcurrencyFunction] = Field(default=None)
58+
59+
@computed_field
60+
@property
61+
def ratelimits_proto(self) -> List[CreateStepRateLimit]:
62+
return [
63+
CreateStepRateLimit(key=limit.key, units=limit.units)
64+
for limit in self.ratelimits
65+
]
66+
67+
@computed_field
68+
@property
69+
def desired_worker_labels_proto(self) -> Dict[str, DesiredWorkerLabels]:
70+
labels = dict()
71+
for key, d in self.desired_worker_labels.items():
72+
value = d.get("value", None)
73+
labels[key] = DesiredWorkerLabels(
74+
strValue=str(value) if not isinstance(value, int) else None,
75+
intValue=value if isinstance(value, int) else None,
76+
required=d.get("required", None),
77+
weight=d.get("weight", None),
78+
comparator=d.get("comparator", None),
79+
)
80+
return labels
81+
82+
83+
class CallableMetadata(BaseModel):
84+
name: str
85+
namespace: str = Field(default="default")
86+
options: Options = Options()
2287

2388

24-
class HatchetCallable(Generic[T]):
89+
class HatchetCallableBase(Generic[P, T]):
90+
91+
action_name: str
92+
func: Callable[P, T] # note that T can be an Awaitable if func is a coroutine
93+
_: CallableMetadata
94+
2595
def __init__(
2696
self,
27-
func: Callable[[Context], T],
28-
durable: bool = False,
97+
*,
98+
func: Callable[P, T],
2999
name: str = "",
30-
auto_register: bool = True,
31-
on_events: list | None = None,
32-
on_crons: list | None = None,
33-
version: str = "",
34-
timeout: str = "60m",
35-
schedule_timeout: str = "5m",
36-
sticky: StickyStrategy = None,
37-
retries: int = 0,
38-
rate_limits: List[RateLimit] | None = None,
39-
concurrency: ConcurrencyFunction | None = None,
40-
on_failure: Optional["HatchetCallable"] = None,
41-
desired_worker_labels: dict[str:DesiredWorkerLabel] = {},
42-
default_priority: int | None = None,
100+
namespace: str = "default",
101+
options: Options = Options(),
43102
):
44103
self.func = func
104+
name = name.lower() or str(func.__name__).lower()
105+
self._ = CallableMetadata(name=name, namespace=namespace, options=options)
106+
self.action_name = registry.global_registry.register(self)
107+
108+
# def __call__(self, context: Context) -> T:
109+
# return self.func(context)
110+
111+
# def with_namespace(self, namespace: str):
112+
# if namespace is not None and namespace != "":
113+
# self.function_namespace = namespace
114+
# self.function_name = namespace + self.function_name
115+
116+
def _to_workflow_proto(self) -> CreateWorkflowVersionOpts:
117+
options = self._.options
118+
119+
# if self.function_on_failure is not None:
120+
# on_failure_job = CreateWorkflowJobOpts(
121+
# name=self.function_name + "-on-failure",
122+
# steps=[
123+
# self.function_on_failure.to_step(),
124+
# ],
125+
# )
126+
# # concurrency: WorkflowConcurrencyOpts | None = None
127+
# if self.function_concurrency is not None:
128+
# self.function_concurrency.set_namespace(self.function_namespace)
129+
# concurrency = WorkflowConcurrencyOpts(
130+
# action=self.function_concurrency.get_action_name(),
131+
# max_runs=self.function_concurrency.max_runs,
132+
# limit_strategy=self.function_concurrency.limit_strategy,
133+
# )
45134

46-
on_events = on_events or []
47-
on_crons = on_crons or []
48-
49-
limits = None
50-
if rate_limits:
51-
limits = [
52-
CreateStepRateLimit(key=rate_limit.key, units=rate_limit.units)
53-
for rate_limit in rate_limits or []
54-
]
55-
56-
self.function_desired_worker_labels = {}
57-
58-
for key, d in desired_worker_labels.items():
59-
value = d["value"] if "value" in d else None
60-
self.function_desired_worker_labels[key] = DesiredWorkerLabels(
61-
strValue=str(value) if not isinstance(value, int) else None,
62-
intValue=value if isinstance(value, int) else None,
63-
required=d["required"] if "required" in d else None,
64-
weight=d["weight"] if "weight" in d else None,
65-
comparator=d["comparator"] if "comparator" in d else None,
66-
)
67-
self.sticky = sticky
68-
self.default_priority = default_priority
69-
self.durable = durable
70-
self.function_name = name.lower() or str(func.__name__).lower()
71-
self.function_version = version
72-
self.function_on_events = on_events
73-
self.function_on_crons = on_crons
74-
self.function_timeout = timeout
75-
self.function_schedule_timeout = schedule_timeout
76-
self.function_retries = retries
77-
self.function_rate_limits = limits
78-
self.function_concurrency = concurrency
79-
self.function_on_failure = on_failure
80-
self.function_namespace = "default"
81-
self.function_auto_register = auto_register
82-
83-
self.is_coroutine = False
84-
85-
if asyncio.iscoroutinefunction(func):
86-
self.is_coroutine = True
87-
88-
def __call__(self, context: Context) -> T:
89-
return self.func(context)
90-
91-
def with_namespace(self, namespace: str):
92-
if namespace is not None and namespace != "":
93-
self.function_namespace = namespace
94-
self.function_name = namespace + self.function_name
95-
96-
def to_workflow_opts(self) -> CreateWorkflowVersionOpts:
97-
kind: WorkflowKind = WorkflowKind.FUNCTION
98-
99-
if self.durable:
100-
kind = WorkflowKind.DURABLE
101-
102-
on_failure_job: CreateWorkflowJobOpts | None = None
103-
104-
if self.function_on_failure is not None:
105-
on_failure_job = CreateWorkflowJobOpts(
106-
name=self.function_name + "-on-failure",
107-
steps=[
108-
self.function_on_failure.to_step(),
109-
],
110-
)
135+
workflow = CreateWorkflowVersionOpts(
136+
name=self._.name,
137+
kind=WorkflowKind.DURABLE if options.durable else WorkflowKind.FUNCTION,
138+
version=options.version,
139+
event_triggers=options.on_events,
140+
cron_triggers=options.on_crons,
141+
schedule_timeout=options.schedule_timeout,
142+
sticky=options.schedule_timeout,
143+
on_failure_job=(
144+
options.on_failure._to_job_proto() if options.on_failure else None
145+
),
146+
concurrency=None, # TODO
147+
jobs=[
148+
self._to_job_proto()
149+
], # Note that the failure job is also a HatchetCallable, and it should manage its own name.
150+
default_priority=options.priority,
151+
)
152+
return workflow
153+
154+
def _to_job_proto(self) -> CreateWorkflowJobOpts:
155+
job = CreateWorkflowJobOpts(name=self._.name, steps=[self._to_step_opts()])
156+
return job
157+
158+
def _to_step_proto(self) -> CreateWorkflowStepOpts:
159+
options = self._.options
160+
step = CreateWorkflowStepOpts(
161+
readable_id=self._.name,
162+
action=self.action_name,
163+
timeout=options.execution_timeout,
164+
inputs="{}", # TODO: not sure that this is, we're defining a step, not running a step
165+
parents=[], # this is a single step workflow, always empty
166+
retries=options.retries,
167+
rate_limits=options.ratelimits,
168+
# worker_labels=self.function_desired_worker_labels,
169+
)
170+
return step
111171

112-
concurrency: WorkflowConcurrencyOpts | None = None
113172

114-
if self.function_concurrency is not None:
115-
self.function_concurrency.set_namespace(self.function_namespace)
116-
concurrency = WorkflowConcurrencyOpts(
117-
action=self.function_concurrency.get_action_name(),
118-
max_runs=self.function_concurrency.max_runs,
119-
limit_strategy=self.function_concurrency.limit_strategy,
120-
)
173+
class HatchetCallable(HatchetCallableBase[P, T]):
121174

122-
validated_priority = (
123-
max(1, min(3, self.default_priority)) if self.default_priority else None
124-
)
125-
if validated_priority != self.default_priority:
126-
logger.warning(
127-
"Warning: Default Priority Must be between 1 and 3 -- inclusively. Adjusted to be within the range."
128-
)
175+
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
176+
input = json.dumps({args: args, kwargs: kwargs})
177+
client = hatchet.Hatchet() # TODO: get the client somehow
178+
return asyncio.gather(client.admin.run(self.action_name, input).result()).result
129179

130-
return CreateWorkflowVersionOpts(
131-
name=self.function_name,
132-
kind=kind,
133-
version=self.function_version,
134-
event_triggers=self.function_on_events,
135-
cron_triggers=self.function_on_crons,
136-
schedule_timeout=self.function_schedule_timeout,
137-
sticky=self.sticky,
138-
on_failure_job=on_failure_job,
139-
concurrency=concurrency,
140-
jobs=[
141-
CreateWorkflowJobOpts(
142-
name=self.function_name,
143-
steps=[
144-
self.to_step(),
145-
],
146-
)
147-
],
148-
default_priority=validated_priority,
149-
)
150180

151-
def to_step(self) -> CreateWorkflowStepOpts:
152-
return CreateWorkflowStepOpts(
153-
readable_id=self.function_name,
154-
action=self.get_action_name(),
155-
timeout=self.function_timeout,
156-
inputs="{}",
157-
parents=[],
158-
retries=self.function_retries,
159-
rate_limits=self.function_rate_limits,
160-
worker_labels=self.function_desired_worker_labels,
161-
)
181+
class HatchetAwaitable(HatchetCallableBase[P, Awaitable[T]]):
162182

163-
def get_action_name(self) -> str:
164-
return self.function_namespace + ":" + self.function_name
183+
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
184+
input = json.dumps({args: args, kwargs: kwargs})
185+
client = hatchet.Hatchet() # TODO: get the client somehow
186+
return (await client.admin.run(self.action_name, input)).result()
165187

166188

167189
T = TypeVar("T")

0 commit comments

Comments
 (0)