|
1 | 1 | import asyncio |
2 | | -from typing import Callable, Dict, Generic, List, Optional, TypedDict, TypeVar, Union |
| 2 | +from typing import Dict, Generic, List, Optional, TypedDict, TypeVar, Union, ParamSpec |
| 3 | +from collections.abc import Callable, Awaitable |
| 4 | +from contextvars import ContextVar |
| 5 | +import inspect |
| 6 | + |
| 7 | +from pydantic import BaseModel, Field, computed_field |
| 8 | +from hatchet_sdk.runtime import registry |
| 9 | +import json |
3 | 10 |
|
4 | 11 | from hatchet_sdk.context import Context |
5 | 12 | from hatchet_sdk.contracts.workflows_pb2 import ( |
|
17 | 24 | from hatchet_sdk.rate_limit import RateLimit |
18 | 25 | from hatchet_sdk.v2.concurrency import ConcurrencyFunction |
19 | 26 | from hatchet_sdk.workflow_run import RunRef |
| 27 | +from datetime import timedelta |
| 28 | + |
| 29 | +import hatchet_sdk.v2.hatchet as hatchet |
20 | 30 |
|
21 | 31 | T = TypeVar("T") |
| 32 | +P = ParamSpec("P") |
| 33 | + |
| 34 | + |
| 35 | +class Options(BaseModel): |
| 36 | + durable: bool = Field(default=False) |
| 37 | + auto_register: bool = Field(default=True) |
| 38 | + on_failure: Optional["HatchetCallable"] = Field(default=None) |
| 39 | + |
| 40 | + # triggering options |
| 41 | + on_events: List[str] = Field(default=[]) |
| 42 | + on_crons: List[str] = Field(default=[]) |
| 43 | + |
| 44 | + # metadata |
| 45 | + version: str = Field(default="") |
| 46 | + |
| 47 | + # timeout |
| 48 | + execution_timeout: str = Field(default="60m", alias="timeout") |
| 49 | + schedule_timeout: str = Field(default="5m") |
| 50 | + |
| 51 | + # execution |
| 52 | + sticky: StickyStrategy | None = Field(default=None) |
| 53 | + retries: int = Field(default=0, ge=0) |
| 54 | + ratelimits: List[RateLimit] = Field(default=[]) |
| 55 | + priority: Optional[int] = Field(default=None, alias="default_priority", ge=1, le=3) |
| 56 | + desired_worker_labels: Dict[str, DesiredWorkerLabel] = Field(default=dict()) |
| 57 | + concurrency: Optional[ConcurrencyFunction] = Field(default=None) |
| 58 | + |
| 59 | + @computed_field |
| 60 | + @property |
| 61 | + def ratelimits_proto(self) -> List[CreateStepRateLimit]: |
| 62 | + return [ |
| 63 | + CreateStepRateLimit(key=limit.key, units=limit.units) |
| 64 | + for limit in self.ratelimits |
| 65 | + ] |
| 66 | + |
| 67 | + @computed_field |
| 68 | + @property |
| 69 | + def desired_worker_labels_proto(self) -> Dict[str, DesiredWorkerLabels]: |
| 70 | + labels = dict() |
| 71 | + for key, d in self.desired_worker_labels.items(): |
| 72 | + value = d.get("value", None) |
| 73 | + labels[key] = DesiredWorkerLabels( |
| 74 | + strValue=str(value) if not isinstance(value, int) else None, |
| 75 | + intValue=value if isinstance(value, int) else None, |
| 76 | + required=d.get("required", None), |
| 77 | + weight=d.get("weight", None), |
| 78 | + comparator=d.get("comparator", None), |
| 79 | + ) |
| 80 | + return labels |
| 81 | + |
| 82 | + |
| 83 | +class CallableMetadata(BaseModel): |
| 84 | + name: str |
| 85 | + namespace: str = Field(default="default") |
| 86 | + options: Options = Options() |
22 | 87 |
|
23 | 88 |
|
24 | | -class HatchetCallable(Generic[T]): |
| 89 | +class HatchetCallableBase(Generic[P, T]): |
| 90 | + |
| 91 | + action_name: str |
| 92 | + func: Callable[P, T] # note that T can be an Awaitable if func is a coroutine |
| 93 | + _: CallableMetadata |
| 94 | + |
25 | 95 | def __init__( |
26 | 96 | self, |
27 | | - func: Callable[[Context], T], |
28 | | - durable: bool = False, |
| 97 | + *, |
| 98 | + func: Callable[P, T], |
29 | 99 | name: str = "", |
30 | | - auto_register: bool = True, |
31 | | - on_events: list | None = None, |
32 | | - on_crons: list | None = None, |
33 | | - version: str = "", |
34 | | - timeout: str = "60m", |
35 | | - schedule_timeout: str = "5m", |
36 | | - sticky: StickyStrategy = None, |
37 | | - retries: int = 0, |
38 | | - rate_limits: List[RateLimit] | None = None, |
39 | | - concurrency: ConcurrencyFunction | None = None, |
40 | | - on_failure: Optional["HatchetCallable"] = None, |
41 | | - desired_worker_labels: dict[str:DesiredWorkerLabel] = {}, |
42 | | - default_priority: int | None = None, |
| 100 | + namespace: str = "default", |
| 101 | + options: Options = Options(), |
43 | 102 | ): |
44 | 103 | self.func = func |
| 104 | + name = name.lower() or str(func.__name__).lower() |
| 105 | + self._ = CallableMetadata(name=name, namespace=namespace, options=options) |
| 106 | + self.action_name = registry.global_registry.register(self) |
| 107 | + |
| 108 | + # def __call__(self, context: Context) -> T: |
| 109 | + # return self.func(context) |
| 110 | + |
| 111 | + # def with_namespace(self, namespace: str): |
| 112 | + # if namespace is not None and namespace != "": |
| 113 | + # self.function_namespace = namespace |
| 114 | + # self.function_name = namespace + self.function_name |
| 115 | + |
| 116 | + def _to_workflow_proto(self) -> CreateWorkflowVersionOpts: |
| 117 | + options = self._.options |
| 118 | + |
| 119 | + # if self.function_on_failure is not None: |
| 120 | + # on_failure_job = CreateWorkflowJobOpts( |
| 121 | + # name=self.function_name + "-on-failure", |
| 122 | + # steps=[ |
| 123 | + # self.function_on_failure.to_step(), |
| 124 | + # ], |
| 125 | + # ) |
| 126 | + # # concurrency: WorkflowConcurrencyOpts | None = None |
| 127 | + # if self.function_concurrency is not None: |
| 128 | + # self.function_concurrency.set_namespace(self.function_namespace) |
| 129 | + # concurrency = WorkflowConcurrencyOpts( |
| 130 | + # action=self.function_concurrency.get_action_name(), |
| 131 | + # max_runs=self.function_concurrency.max_runs, |
| 132 | + # limit_strategy=self.function_concurrency.limit_strategy, |
| 133 | + # ) |
45 | 134 |
|
46 | | - on_events = on_events or [] |
47 | | - on_crons = on_crons or [] |
48 | | - |
49 | | - limits = None |
50 | | - if rate_limits: |
51 | | - limits = [ |
52 | | - CreateStepRateLimit(key=rate_limit.key, units=rate_limit.units) |
53 | | - for rate_limit in rate_limits or [] |
54 | | - ] |
55 | | - |
56 | | - self.function_desired_worker_labels = {} |
57 | | - |
58 | | - for key, d in desired_worker_labels.items(): |
59 | | - value = d["value"] if "value" in d else None |
60 | | - self.function_desired_worker_labels[key] = DesiredWorkerLabels( |
61 | | - strValue=str(value) if not isinstance(value, int) else None, |
62 | | - intValue=value if isinstance(value, int) else None, |
63 | | - required=d["required"] if "required" in d else None, |
64 | | - weight=d["weight"] if "weight" in d else None, |
65 | | - comparator=d["comparator"] if "comparator" in d else None, |
66 | | - ) |
67 | | - self.sticky = sticky |
68 | | - self.default_priority = default_priority |
69 | | - self.durable = durable |
70 | | - self.function_name = name.lower() or str(func.__name__).lower() |
71 | | - self.function_version = version |
72 | | - self.function_on_events = on_events |
73 | | - self.function_on_crons = on_crons |
74 | | - self.function_timeout = timeout |
75 | | - self.function_schedule_timeout = schedule_timeout |
76 | | - self.function_retries = retries |
77 | | - self.function_rate_limits = limits |
78 | | - self.function_concurrency = concurrency |
79 | | - self.function_on_failure = on_failure |
80 | | - self.function_namespace = "default" |
81 | | - self.function_auto_register = auto_register |
82 | | - |
83 | | - self.is_coroutine = False |
84 | | - |
85 | | - if asyncio.iscoroutinefunction(func): |
86 | | - self.is_coroutine = True |
87 | | - |
88 | | - def __call__(self, context: Context) -> T: |
89 | | - return self.func(context) |
90 | | - |
91 | | - def with_namespace(self, namespace: str): |
92 | | - if namespace is not None and namespace != "": |
93 | | - self.function_namespace = namespace |
94 | | - self.function_name = namespace + self.function_name |
95 | | - |
96 | | - def to_workflow_opts(self) -> CreateWorkflowVersionOpts: |
97 | | - kind: WorkflowKind = WorkflowKind.FUNCTION |
98 | | - |
99 | | - if self.durable: |
100 | | - kind = WorkflowKind.DURABLE |
101 | | - |
102 | | - on_failure_job: CreateWorkflowJobOpts | None = None |
103 | | - |
104 | | - if self.function_on_failure is not None: |
105 | | - on_failure_job = CreateWorkflowJobOpts( |
106 | | - name=self.function_name + "-on-failure", |
107 | | - steps=[ |
108 | | - self.function_on_failure.to_step(), |
109 | | - ], |
110 | | - ) |
| 135 | + workflow = CreateWorkflowVersionOpts( |
| 136 | + name=self._.name, |
| 137 | + kind=WorkflowKind.DURABLE if options.durable else WorkflowKind.FUNCTION, |
| 138 | + version=options.version, |
| 139 | + event_triggers=options.on_events, |
| 140 | + cron_triggers=options.on_crons, |
| 141 | + schedule_timeout=options.schedule_timeout, |
| 142 | + sticky=options.schedule_timeout, |
| 143 | + on_failure_job=( |
| 144 | + options.on_failure._to_job_proto() if options.on_failure else None |
| 145 | + ), |
| 146 | + concurrency=None, # TODO |
| 147 | + jobs=[ |
| 148 | + self._to_job_proto() |
| 149 | + ], # Note that the failure job is also a HatchetCallable, and it should manage its own name. |
| 150 | + default_priority=options.priority, |
| 151 | + ) |
| 152 | + return workflow |
| 153 | + |
| 154 | + def _to_job_proto(self) -> CreateWorkflowJobOpts: |
| 155 | + job = CreateWorkflowJobOpts(name=self._.name, steps=[self._to_step_opts()]) |
| 156 | + return job |
| 157 | + |
| 158 | + def _to_step_proto(self) -> CreateWorkflowStepOpts: |
| 159 | + options = self._.options |
| 160 | + step = CreateWorkflowStepOpts( |
| 161 | + readable_id=self._.name, |
| 162 | + action=self.action_name, |
| 163 | + timeout=options.execution_timeout, |
| 164 | + inputs="{}", # TODO: not sure that this is, we're defining a step, not running a step |
| 165 | + parents=[], # this is a single step workflow, always empty |
| 166 | + retries=options.retries, |
| 167 | + rate_limits=options.ratelimits, |
| 168 | + # worker_labels=self.function_desired_worker_labels, |
| 169 | + ) |
| 170 | + return step |
111 | 171 |
|
112 | | - concurrency: WorkflowConcurrencyOpts | None = None |
113 | 172 |
|
114 | | - if self.function_concurrency is not None: |
115 | | - self.function_concurrency.set_namespace(self.function_namespace) |
116 | | - concurrency = WorkflowConcurrencyOpts( |
117 | | - action=self.function_concurrency.get_action_name(), |
118 | | - max_runs=self.function_concurrency.max_runs, |
119 | | - limit_strategy=self.function_concurrency.limit_strategy, |
120 | | - ) |
| 173 | +class HatchetCallable(HatchetCallableBase[P, T]): |
121 | 174 |
|
122 | | - validated_priority = ( |
123 | | - max(1, min(3, self.default_priority)) if self.default_priority else None |
124 | | - ) |
125 | | - if validated_priority != self.default_priority: |
126 | | - logger.warning( |
127 | | - "Warning: Default Priority Must be between 1 and 3 -- inclusively. Adjusted to be within the range." |
128 | | - ) |
| 175 | + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: |
| 176 | + input = json.dumps({args: args, kwargs: kwargs}) |
| 177 | + client = hatchet.Hatchet() # TODO: get the client somehow |
| 178 | + return asyncio.gather(client.admin.run(self.action_name, input).result()).result |
129 | 179 |
|
130 | | - return CreateWorkflowVersionOpts( |
131 | | - name=self.function_name, |
132 | | - kind=kind, |
133 | | - version=self.function_version, |
134 | | - event_triggers=self.function_on_events, |
135 | | - cron_triggers=self.function_on_crons, |
136 | | - schedule_timeout=self.function_schedule_timeout, |
137 | | - sticky=self.sticky, |
138 | | - on_failure_job=on_failure_job, |
139 | | - concurrency=concurrency, |
140 | | - jobs=[ |
141 | | - CreateWorkflowJobOpts( |
142 | | - name=self.function_name, |
143 | | - steps=[ |
144 | | - self.to_step(), |
145 | | - ], |
146 | | - ) |
147 | | - ], |
148 | | - default_priority=validated_priority, |
149 | | - ) |
150 | 180 |
|
151 | | - def to_step(self) -> CreateWorkflowStepOpts: |
152 | | - return CreateWorkflowStepOpts( |
153 | | - readable_id=self.function_name, |
154 | | - action=self.get_action_name(), |
155 | | - timeout=self.function_timeout, |
156 | | - inputs="{}", |
157 | | - parents=[], |
158 | | - retries=self.function_retries, |
159 | | - rate_limits=self.function_rate_limits, |
160 | | - worker_labels=self.function_desired_worker_labels, |
161 | | - ) |
| 181 | +class HatchetAwaitable(HatchetCallableBase[P, Awaitable[T]]): |
162 | 182 |
|
163 | | - def get_action_name(self) -> str: |
164 | | - return self.function_namespace + ":" + self.function_name |
| 183 | + async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: |
| 184 | + input = json.dumps({args: args, kwargs: kwargs}) |
| 185 | + client = hatchet.Hatchet() # TODO: get the client somehow |
| 186 | + return (await client.admin.run(self.action_name, input)).result() |
165 | 187 |
|
166 | 188 |
|
167 | 189 | T = TypeVar("T") |
|
0 commit comments