Skip to content
This repository was archived by the owner on Feb 20, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions examples/simple/worker.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
from hatchet_sdk import BaseWorkflow, Context, Hatchet
from hatchet_sdk import Context, Hatchet

hatchet = Hatchet(debug=True)


class MyWorkflow(BaseWorkflow):
@hatchet.step(timeout="11s", retries=3)
def step1(self, context: Context) -> dict[str, str]:
print("executed step1")
return {
"step1": "step1",
}
@hatchet.function()
def step1(context: Context) -> dict[str, str]:
message = "Hello from Hatchet!"

context.log(message)

return {"message": message}

def main() -> None:
wf = MyWorkflow()

def main() -> None:
worker = hatchet.worker("test-worker", max_runs=1)
worker.register_workflow(wf)
worker.register_function(step1)
worker.start()


Expand Down
95 changes: 94 additions & 1 deletion hatchet_sdk/hatchet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import logging
from typing import TYPE_CHECKING, Any, Callable, Type, TypeVar, cast
from typing import TYPE_CHECKING, Any, Callable, Generic, Type, TypeVar, Union, cast

from hatchet_sdk.client import Client, new_client, new_client_raw
from hatchet_sdk.clients.admin import AdminClient
Expand Down Expand Up @@ -44,6 +44,56 @@ def transform_desired_worker_label(d: DesiredWorkerLabel) -> DesiredWorkerLabels
)


class Function(Generic[R, TWorkflowInput]):
def __init__(
self,
fn: Callable[[Context], R],
hatchet: "Hatchet",
name: str = "",
on_events: list[str] = [],
on_crons: list[str] = [],
version: str = "",
timeout: str = "60m",
schedule_timeout: str = "5m",
sticky: StickyStrategy | None = None,
retries: int = 0,
rate_limits: list[RateLimit] = [],
desired_worker_labels: dict[str, DesiredWorkerLabel] = {},
concurrency: ConcurrencyExpression | None = None,
on_failure: Union["Function[R]", None] = None,
default_priority: int = 1,
input_validator: Type[TWorkflowInput] | None = None,
backoff_factor: float | None = None,
backoff_max_seconds: int | None = None,
) -> None:
def func(_: Any, context: Context) -> R:
return fn(context)

self.hatchet = hatchet
self.step: Step[R] = hatchet.step(
name=name or fn.__name__,
timeout=timeout,
retries=retries,
rate_limits=rate_limits,
desired_worker_labels=desired_worker_labels,
backoff_factor=backoff_factor,
backoff_max_seconds=backoff_max_seconds,
)(func)
self.on_failure_step = on_failure
self.workflow_config = WorkflowConfig(
name=name or fn.__name__,
on_events=on_events,
on_crons=on_crons,
version=version,
timeout=timeout,
schedule_timeout=schedule_timeout,
sticky=sticky,
default_priority=default_priority,
concurrency=concurrency,
input_validator=input_validator or cast(Type[TWorkflowInput], EmptyModel),
)


class Hatchet:
"""
Main client for interacting with the Hatchet SDK.
Expand Down Expand Up @@ -187,6 +237,49 @@ def inner(func: Callable[[Any, Context], R]) -> Step[R]:

return inner

def function(
self,
name: str = "",
on_events: list[str] = [],
on_crons: list[str] = [],
version: str = "",
timeout: str = "60m",
schedule_timeout: str = "5m",
sticky: StickyStrategy | None = None,
retries: int = 0,
rate_limits: list[RateLimit] = [],
desired_worker_labels: dict[str, DesiredWorkerLabel] = {},
concurrency: ConcurrencyExpression | None = None,
on_failure: Union["Function[Any]", None] = None,
default_priority: int = 1,
input_validator: Type[TWorkflowInput] | None = None,
backoff_factor: float | None = None,
backoff_max_seconds: int | None = None,
) -> Callable[[Callable[[Context], R]], Function[R, TWorkflowInput]]:
def inner(func: Callable[[Context], R]) -> Function[R, TWorkflowInput]:
return Function[R, TWorkflowInput](
func,
hatchet=self,
name=name,
on_events=on_events,
on_crons=on_crons,
version=version,
timeout=timeout,
schedule_timeout=schedule_timeout,
sticky=sticky,
retries=retries,
rate_limits=rate_limits,
desired_worker_labels=desired_worker_labels,
concurrency=concurrency,
on_failure=on_failure,
default_priority=default_priority,
input_validator=input_validator,
backoff_factor=backoff_factor,
backoff_max_seconds=backoff_max_seconds,
)

return inner

def worker(
self, name: str, max_runs: int | None = None, labels: dict[str, str | int] = {}
) -> "Worker":
Expand Down
30 changes: 28 additions & 2 deletions hatchet_sdk/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
STOP_LOOP_TYPE,
WorkerActionRunLoopManager,
)
from hatchet_sdk.workflow import Step
from hatchet_sdk.workflow import BaseWorkflow, Step, StepType

if TYPE_CHECKING:
from hatchet_sdk.workflow import BaseWorkflow
from hatchet_sdk.hatchet import Function

T = TypeVar("T")

Expand Down Expand Up @@ -108,6 +108,32 @@ def register_workflow_from_opts(
logger.error(e)
sys.exit(1)

def register_function(self, function: "Function[Any]") -> None:
from hatchet_sdk.workflow import BaseWorkflow

declaration = function.hatchet.declare_workflow(
**function.workflow_config.model_dump()
)

class Workflow(BaseWorkflow):
config = declaration.config

@property
def default_steps(self) -> list[Step[Any]]:
return [function.step]

@property
def on_failure_steps(self) -> list[Step[Any]]:
if not function.on_failure_step:
return []

step = function.on_failure_step.step
step.type = StepType.ON_FAILURE

return [step]

self.register_workflow(Workflow())

def register_workflow(self, workflow: Union["BaseWorkflow", Any]) -> None:
namespace = self.client.config.namespace

Expand Down