|
4 | 4 | TYPE_CHECKING, |
5 | 5 | Any, |
6 | 6 | Callable, |
| 7 | + Generic, |
7 | 8 | Optional, |
8 | 9 | ParamSpec, |
9 | 10 | Type, |
10 | 11 | TypeVar, |
| 12 | + Union, |
11 | 13 | cast, |
12 | 14 | ) |
13 | 15 |
|
@@ -55,6 +57,56 @@ def transform_desired_worker_label(d: DesiredWorkerLabel) -> DesiredWorkerLabels |
55 | 57 | ) |
56 | 58 |
|
57 | 59 |
|
| 60 | +class Function(Generic[R, TWorkflowInput]): |
| 61 | + def __init__( |
| 62 | + self, |
| 63 | + fn: Callable[[Context], R], |
| 64 | + hatchet: "Hatchet", |
| 65 | + name: str = "", |
| 66 | + on_events: list[str] = [], |
| 67 | + on_crons: list[str] = [], |
| 68 | + version: str = "", |
| 69 | + timeout: str = "60m", |
| 70 | + schedule_timeout: str = "5m", |
| 71 | + sticky: StickyStrategy | None = None, |
| 72 | + retries: int = 0, |
| 73 | + rate_limits: list[RateLimit] = [], |
| 74 | + desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, |
| 75 | + concurrency: ConcurrencyExpression | None = None, |
| 76 | + on_failure: Union["Function[R]", None] = None, |
| 77 | + default_priority: int = 1, |
| 78 | + input_validator: Type[TWorkflowInput] | None = None, |
| 79 | + backoff_factor: float | None = None, |
| 80 | + backoff_max_seconds: int | None = None, |
| 81 | + ) -> None: |
| 82 | + def func(_: Any, context: Context) -> R: |
| 83 | + return fn(context) |
| 84 | + |
| 85 | + self.hatchet = hatchet |
| 86 | + self.step: Step[R] = hatchet.step( |
| 87 | + name=name or fn.__name__, |
| 88 | + timeout=timeout, |
| 89 | + retries=retries, |
| 90 | + rate_limits=rate_limits, |
| 91 | + desired_worker_labels=desired_worker_labels, |
| 92 | + backoff_factor=backoff_factor, |
| 93 | + backoff_max_seconds=backoff_max_seconds, |
| 94 | + )(func) |
| 95 | + self.on_failure_step = on_failure |
| 96 | + self.workflow_config = WorkflowConfig( |
| 97 | + name=name or fn.__name__, |
| 98 | + on_events=on_events, |
| 99 | + on_crons=on_crons, |
| 100 | + version=version, |
| 101 | + timeout=timeout, |
| 102 | + schedule_timeout=schedule_timeout, |
| 103 | + sticky=sticky, |
| 104 | + default_priority=default_priority, |
| 105 | + concurrency=concurrency, |
| 106 | + input_validator=input_validator or cast(Type[TWorkflowInput], EmptyModel), |
| 107 | + ) |
| 108 | + |
| 109 | + |
58 | 110 | class Hatchet: |
59 | 111 | """ |
60 | 112 | Main client for interacting with the Hatchet SDK. |
@@ -207,44 +259,38 @@ def function( |
207 | 259 | timeout: str = "60m", |
208 | 260 | schedule_timeout: str = "5m", |
209 | 261 | sticky: StickyStrategy | None = None, |
210 | | - default_priority: int = 1, |
| 262 | + retries: int = 0, |
| 263 | + rate_limits: list[RateLimit] = [], |
| 264 | + desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, |
211 | 265 | concurrency: ConcurrencyExpression | None = None, |
| 266 | + on_failure: Union["Function[Any]", None] = None, |
| 267 | + default_priority: int = 1, |
212 | 268 | input_validator: Type[TWorkflowInput] | None = None, |
213 | | - ) -> Callable[[Callable[[Context], R]], BaseWorkflowImpl]: |
214 | | - def inner(func: Callable[[Context], R]) -> BaseWorkflowImpl: |
215 | | - declaration = WorkflowDeclaration[TWorkflowInput]( |
216 | | - WorkflowConfig( |
217 | | - name=name or func.__name__, |
218 | | - on_events=on_events, |
219 | | - on_crons=on_crons, |
220 | | - version=version, |
221 | | - timeout=timeout, |
222 | | - schedule_timeout=schedule_timeout, |
223 | | - sticky=sticky, |
224 | | - default_priority=default_priority, |
225 | | - concurrency=concurrency, |
226 | | - input_validator=input_validator |
227 | | - or cast(Type[TWorkflowInput], EmptyModel), |
228 | | - ), |
229 | | - self, |
| 269 | + backoff_factor: float | None = None, |
| 270 | + backoff_max_seconds: int | None = None, |
| 271 | + ) -> Callable[[Callable[[Context], R]], Function[R, TWorkflowInput]]: |
| 272 | + def inner(func: Callable[[Context], R]) -> Function[R, TWorkflowInput]: |
| 273 | + return Function[R, TWorkflowInput]( |
| 274 | + func, |
| 275 | + hatchet=self, |
| 276 | + name=name, |
| 277 | + on_events=on_events, |
| 278 | + on_crons=on_crons, |
| 279 | + version=version, |
| 280 | + timeout=timeout, |
| 281 | + schedule_timeout=schedule_timeout, |
| 282 | + sticky=sticky, |
| 283 | + retries=retries, |
| 284 | + rate_limits=rate_limits, |
| 285 | + desired_worker_labels=desired_worker_labels, |
| 286 | + concurrency=concurrency, |
| 287 | + on_failure=on_failure, |
| 288 | + default_priority=default_priority, |
| 289 | + input_validator=input_validator, |
| 290 | + backoff_factor=backoff_factor, |
| 291 | + backoff_max_seconds=backoff_max_seconds, |
230 | 292 | ) |
231 | 293 |
|
232 | | - class Workflow(BaseWorkflowImpl): |
233 | | - config = declaration.config |
234 | | - |
235 | | - @self.step( |
236 | | - name=declaration.config.name, |
237 | | - timeout=timeout, |
238 | | - retries=0, |
239 | | - rate_limits=[], |
240 | | - backoff_factor=None, |
241 | | - backoff_max_seconds=None, |
242 | | - ) |
243 | | - def fn(self, context: Context) -> R: |
244 | | - return func(context) |
245 | | - |
246 | | - return Workflow() |
247 | | - |
248 | 294 | return inner |
249 | 295 |
|
250 | 296 | def worker( |
|
0 commit comments