Skip to content
This repository was archived by the owner on Feb 20, 2025. It is now read-only.

Commit 9dd8eec

Browse files
committed
feat: initial pass at hatchet function
1 parent 842283d commit 9dd8eec

File tree

4 files changed

+162
-136
lines changed

4 files changed

+162
-136
lines changed

examples/simple/worker.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,17 @@
88
hatchet = Hatchet(debug=True)
99

1010

11-
class MyWorkflow(BaseWorkflowImpl):
12-
@hatchet.step(timeout="11s", retries=3)
13-
def step1(self, context: Context) -> dict[str, str]:
14-
print("executed step1")
15-
return {
16-
"step1": "step1",
17-
}
11+
@hatchet.function(timeout="11s")
12+
def step1(context: Context) -> dict[str, str]:
13+
print("executed step1")
14+
return {
15+
"step1": "step1",
16+
}
1817

1918

2019
def main() -> None:
21-
wf = MyWorkflow()
22-
2320
worker = hatchet.worker("test-worker", max_runs=1)
24-
worker.register_workflow(wf)
21+
worker.register_function(step1)
2522
worker.start()
2623

2724

hatchet_sdk/loader.py

Lines changed: 91 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -1,168 +1,134 @@
11
import json
2-
import os
32
from logging import Logger, getLogger
4-
from typing import cast
5-
6-
from pydantic import BaseModel, ConfigDict, ValidationInfo, field_validator
3+
from typing import Any
4+
5+
from pydantic import (
6+
BaseModel,
7+
ConfigDict,
8+
Field,
9+
ValidationInfo,
10+
field_validator,
11+
model_validator,
12+
)
13+
from pydantic_settings import BaseSettings, SettingsConfigDict
714

815
from hatchet_sdk.token import get_addresses_from_jwt, get_tenant_id_from_jwt
916

1017

11-
class ClientTLSConfig(BaseModel):
12-
tls_strategy: str
13-
cert_file: str | None
14-
key_file: str | None
15-
ca_file: str | None
16-
server_name: str
18+
def create_settings_config(env_prefix: str) -> SettingsConfigDict:
19+
return SettingsConfigDict(
20+
env_prefix=env_prefix,
21+
env_file=(".env", ".env.hatchet", ".env.dev", ".env.local"),
22+
extra="ignore",
23+
)
1724

1825

19-
def _load_tls_config(host_port: str | None = None) -> ClientTLSConfig:
20-
server_name = os.getenv("HATCHET_CLIENT_TLS_SERVER_NAME")
26+
class ClientTLSConfig(BaseSettings):
27+
model_config = create_settings_config(
28+
env_prefix="HATCHET_CLIENT_TLS_",
29+
)
2130

22-
if not server_name and host_port:
23-
server_name = host_port.split(":")[0]
31+
strategy: str = "tls"
32+
cert_file: str | None = None
33+
key_file: str | None = None
34+
root_ca_file: str | None = None
35+
server_name: str = ""
2436

25-
if not server_name:
26-
server_name = "localhost"
2737

28-
return ClientTLSConfig(
29-
tls_strategy=os.getenv("HATCHET_CLIENT_TLS_STRATEGY", "tls"),
30-
cert_file=os.getenv("HATCHET_CLIENT_TLS_CERT_FILE"),
31-
key_file=os.getenv("HATCHET_CLIENT_TLS_KEY_FILE"),
32-
ca_file=os.getenv("HATCHET_CLIENT_TLS_ROOT_CA_FILE"),
33-
server_name=server_name,
38+
class OTELConfig(BaseSettings):
39+
model_config = create_settings_config(
40+
env_prefix="HATCHET_CLIENT_OTEL_",
3441
)
3542

43+
service_name: str | None = None
44+
exporter_otlp_endpoint: str | None = None
45+
exporter_otlp_headers: str | None = None
46+
exporter_otlp_protocol: str | None = None
3647

37-
def parse_listener_timeout(timeout: str | None) -> int | None:
38-
if timeout is None:
39-
return None
4048

41-
return int(timeout)
49+
class HealthcheckConfig(BaseSettings):
50+
model_config = create_settings_config(
51+
env_prefix="HATCHET_CLIENT_WORKER_HEALTHCHECK_",
52+
)
53+
54+
port: int = 8001
55+
enabled: bool = False
4256

4357

4458
DEFAULT_HOST_PORT = "localhost:7070"
4559

4660

47-
class ClientConfig(BaseModel):
48-
model_config = ConfigDict(arbitrary_types_allowed=True, validate_default=True)
61+
class ClientConfig(BaseSettings):
62+
model_config = create_settings_config(
63+
env_prefix="HATCHET_CLIENT_",
64+
)
4965

50-
token: str = os.getenv("HATCHET_CLIENT_TOKEN", "")
66+
token: str = ""
5167
logger: Logger = getLogger()
52-
tenant_id: str = os.getenv("HATCHET_CLIENT_TENANT_ID", "")
53-
54-
## IMPORTANT: Order matters here. The validators run in the order that the
55-
## fields are defined in the model. So, we need to make sure that the
56-
## host_port is set before we try to load the tls_config and server_url
57-
host_port: str = os.getenv("HATCHET_CLIENT_HOST_PORT", DEFAULT_HOST_PORT)
58-
tls_config: ClientTLSConfig = _load_tls_config()
5968

69+
tenant_id: str = ""
70+
host_port: str = DEFAULT_HOST_PORT
6071
server_url: str = "https://app.dev.hatchet-tools.com"
61-
namespace: str = os.getenv("HATCHET_CLIENT_NAMESPACE", "")
62-
listener_v2_timeout: int | None = parse_listener_timeout(
63-
os.getenv("HATCHET_CLIENT_LISTENER_V2_TIMEOUT")
64-
)
65-
grpc_max_recv_message_length: int = int(
66-
os.getenv("HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH", 4 * 1024 * 1024)
67-
) # 4MB
68-
grpc_max_send_message_length: int = int(
69-
os.getenv("HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH", 4 * 1024 * 1024)
70-
) # 4MB
71-
otel_exporter_oltp_endpoint: str | None = os.getenv(
72-
"HATCHET_CLIENT_OTEL_EXPORTER_OTLP_ENDPOINT"
73-
)
74-
otel_service_name: str | None = os.getenv("HATCHET_CLIENT_OTEL_SERVICE_NAME")
75-
otel_exporter_oltp_headers: str | None = os.getenv(
76-
"HATCHET_CLIENT_OTEL_EXPORTER_OTLP_HEADERS"
77-
)
78-
otel_exporter_oltp_protocol: str | None = os.getenv(
79-
"HATCHET_CLIENT_OTEL_EXPORTER_OTLP_PROTOCOL"
80-
)
81-
worker_healthcheck_port: int = int(
82-
os.getenv("HATCHET_CLIENT_WORKER_HEALTHCHECK_PORT", 8001)
83-
)
84-
worker_healthcheck_enabled: bool = (
85-
os.getenv("HATCHET_CLIENT_WORKER_HEALTHCHECK_ENABLED", "False") == "True"
86-
)
87-
88-
@field_validator("token", mode="after")
89-
@classmethod
90-
def validate_token(cls, token: str) -> str:
91-
if not token:
92-
raise ValueError("Token must be set")
93-
94-
return token
95-
96-
@field_validator("namespace", mode="after")
97-
@classmethod
98-
def validate_namespace(cls, namespace: str) -> str:
99-
if not namespace:
100-
return ""
101-
102-
if not namespace.endswith("_"):
103-
namespace = f"{namespace}_"
72+
namespace: str = ""
10473

105-
return namespace.lower()
74+
tls_config: ClientTLSConfig = Field(default_factory=lambda: ClientTLSConfig())
75+
otel: OTELConfig = Field(default_factory=lambda: OTELConfig())
76+
healthcheck: HealthcheckConfig = Field(default_factory=lambda: HealthcheckConfig())
10677

107-
@field_validator("tenant_id", mode="after")
108-
@classmethod
109-
def validate_tenant_id(cls, tenant_id: str, info: ValidationInfo) -> str:
110-
token = cast(str | None, info.data.get("token"))
78+
listener_v2_timeout: int | None = None
79+
grpc_max_recv_message_length: int = Field(
80+
default=4 * 1024 * 1024, description="4MB default"
81+
)
82+
grpc_max_send_message_length: int = Field(
83+
default=4 * 1024 * 1024, description="4MB default"
84+
)
11185

112-
if not tenant_id:
113-
if not token:
114-
raise ValueError("Either the token or tenant_id must be set")
86+
worker_preset_labels: dict[str, str] = Field(default_factory=dict)
11587

116-
return get_tenant_id_from_jwt(token)
88+
@model_validator(mode="after")
89+
def validate_token_and_tenant(self) -> "ClientConfig":
90+
if not self.token:
91+
raise ValueError("Token must be set")
11792

118-
return tenant_id
93+
if not self.tenant_id:
94+
self.tenant_id = get_tenant_id_from_jwt(self.token)
11995

120-
@field_validator("host_port", mode="after")
121-
@classmethod
122-
def validate_host_port(cls, host_port: str, info: ValidationInfo) -> str:
123-
if host_port and host_port != DEFAULT_HOST_PORT:
124-
return host_port
96+
return self
12597

126-
token = cast(str, info.data.get("token"))
98+
@model_validator(mode="after")
99+
def validate_addresses(self) -> "ClientConfig":
100+
if self.host_port == DEFAULT_HOST_PORT:
101+
server_url, grpc_broadcast_address = get_addresses_from_jwt(self.token)
102+
self.host_port = grpc_broadcast_address
103+
self.server_url = server_url
127104

128-
if not token:
129-
raise ValueError("Token must be set")
105+
if not self.tls_config.server_name:
106+
self.tls_config.server_name = self.host_port.split(":")[0]
130107

131-
_, grpc_broadcast_address = get_addresses_from_jwt(token)
108+
if not self.tls_config.server_name:
109+
self.tls_config.server_name = "localhost"
132110

133-
return grpc_broadcast_address
111+
return self
134112

135-
@field_validator("server_url", mode="after")
113+
@field_validator("listener_v2_timeout")
136114
@classmethod
137-
def validate_server_url(cls, server_url: str, info: ValidationInfo) -> str:
138-
## IMPORTANT: Order matters here. The validators run in the order that the
139-
## fields are defined in the model. So, we need to make sure that the
140-
## host_port is set before we try to load the server_url
141-
host_port = cast(str, info.data.get("host_port"))
115+
def validate_listener_timeout(cls, value: int | None | str) -> int | None:
116+
if value is None:
117+
return None
142118

143-
if host_port and host_port != DEFAULT_HOST_PORT:
144-
return host_port
119+
if isinstance(value, int):
120+
return value
145121

146-
token = cast(str, info.data.get("token"))
122+
return int(value)
147123

148-
if not token:
149-
raise ValueError("Token must be set")
150-
151-
_server_url, _ = get_addresses_from_jwt(token)
152-
153-
return _server_url
154-
155-
@field_validator("tls_config", mode="after")
124+
@field_validator("namespace")
156125
@classmethod
157-
def validate_tls_config(
158-
cls, tls_config: ClientTLSConfig, info: ValidationInfo
159-
) -> ClientTLSConfig:
160-
## IMPORTANT: Order matters here. This validator runs in the order
161-
## that the fields are defined in the model. So, we need to make sure
162-
## that the host_port is set before we try to load the tls_config
163-
host_port = cast(str, info.data.get("host_port"))
164-
165-
return _load_tls_config(host_port)
126+
def validate_namespace(cls, namespace: str) -> str:
127+
if not namespace:
128+
return ""
129+
if not namespace.endswith("_"):
130+
namespace = f"{namespace}_"
131+
return namespace.lower()
166132

167133
def __hash__(self) -> int:
168134
return hash(json.dumps(self.model_dump(), default=str))

hatchet_sdk/v2/hatchet.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
import asyncio
22
import logging
3-
from typing import TYPE_CHECKING, Any, Callable, Optional, Type, TypeVar, cast
3+
from typing import (
4+
TYPE_CHECKING,
5+
Any,
6+
Callable,
7+
Optional,
8+
ParamSpec,
9+
Type,
10+
TypeVar,
11+
cast,
12+
)
413

514
from hatchet_sdk.client import Client, new_client, new_client_raw
615
from hatchet_sdk.clients.admin import AdminClient
@@ -17,6 +26,7 @@
1726
from hatchet_sdk.logger import logger
1827
from hatchet_sdk.rate_limit import RateLimit
1928
from hatchet_sdk.v2.workflows import (
29+
BaseWorkflowImpl,
2030
ConcurrencyExpression,
2131
EmptyModel,
2232
Step,
@@ -30,6 +40,7 @@
3040
if TYPE_CHECKING:
3141
from hatchet_sdk.worker.worker import Worker
3242

43+
P = ParamSpec("P")
3344
R = TypeVar("R")
3445

3546

@@ -187,6 +198,55 @@ def inner(func: Callable[[Any, Context], R]) -> Step[R]:
187198

188199
return inner
189200

201+
def function(
202+
self,
203+
name: str = "",
204+
on_events: list[str] = [],
205+
on_crons: list[str] = [],
206+
version: str = "",
207+
timeout: str = "60m",
208+
schedule_timeout: str = "5m",
209+
sticky: StickyStrategy | None = None,
210+
default_priority: int = 1,
211+
concurrency: ConcurrencyExpression | None = None,
212+
input_validator: Type[TWorkflowInput] | None = None,
213+
) -> Callable[[Callable[[Context], R]], BaseWorkflowImpl]:
214+
declaration = WorkflowDeclaration[TWorkflowInput](
215+
WorkflowConfig(
216+
name=name,
217+
on_events=on_events,
218+
on_crons=on_crons,
219+
version=version,
220+
timeout=timeout,
221+
schedule_timeout=schedule_timeout,
222+
sticky=sticky,
223+
default_priority=default_priority,
224+
concurrency=concurrency,
225+
input_validator=input_validator
226+
or cast(Type[TWorkflowInput], EmptyModel),
227+
),
228+
self,
229+
)
230+
231+
def inner(func: Callable[[Context], R]) -> BaseWorkflowImpl:
232+
class Workflow(BaseWorkflowImpl):
233+
config = declaration.config
234+
235+
@self.step(
236+
name=name,
237+
timeout=timeout,
238+
retries=0,
239+
rate_limits=[],
240+
backoff_factor=None,
241+
backoff_max_seconds=None,
242+
)
243+
def fn(self, context: Context) -> R:
244+
return func(context)
245+
246+
return Workflow()
247+
248+
return inner
249+
190250
def worker(
191251
self, name: str, max_runs: int | None = None, labels: dict[str, str | int] = {}
192252
) -> "Worker":

hatchet_sdk/worker/worker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ def register_workflow_from_opts(
110110
logger.error(e)
111111
sys.exit(1)
112112

113+
def register_function(self, function: "BaseWorkflowImpl") -> None:
114+
self.register_workflow(function)
115+
113116
def register_workflow(self, workflow: Union["BaseWorkflowImpl", Any]) -> None:
114117
namespace = self.client.config.namespace
115118

0 commit comments

Comments
 (0)