Skip to content
This repository was archived by the owner on Feb 20, 2025. It is now read-only.

Commit 5fb9244

Browse files
committed
fix: loader
1 parent c846a01 commit 5fb9244

File tree

1 file changed

+125
-25
lines changed

1 file changed

+125
-25
lines changed

hatchet_sdk/loader.py

Lines changed: 125 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import os
23
from logging import Logger, getLogger
34

45
from pydantic import Field, field_validator, model_validator
@@ -7,19 +8,36 @@
78
from hatchet_sdk.token import get_addresses_from_jwt, get_tenant_id_from_jwt
89

910

10-
def create_settings_config(env_prefix: str) -> SettingsConfigDict:
11-
return SettingsConfigDict(
12-
env_prefix=env_prefix,
13-
env_file=(".env", ".env.hatchet", ".env.dev", ".env.local"),
14-
extra="ignore",
15-
)
11+
class ClientTLSConfig(BaseModel):
12+
tls_strategy: str
13+
cert_file: str | None
14+
key_file: str | None
15+
ca_file: str | None
16+
server_name: str
1617

1718

18-
class ClientTLSConfig(BaseSettings):
19-
model_config = create_settings_config(
20-
env_prefix="HATCHET_CLIENT_TLS_",
19+
def _load_tls_config(host_port: str | None = None) -> ClientTLSConfig:
20+
server_name = os.getenv("HATCHET_CLIENT_TLS_SERVER_NAME")
21+
22+
if not server_name and host_port:
23+
server_name = host_port.split(":")[0]
24+
25+
if not server_name:
26+
server_name = "localhost"
27+
28+
return ClientTLSConfig(
29+
tls_strategy=os.getenv("HATCHET_CLIENT_TLS_STRATEGY", "tls"),
30+
cert_file=os.getenv("HATCHET_CLIENT_TLS_CERT_FILE"),
31+
key_file=os.getenv("HATCHET_CLIENT_TLS_KEY_FILE"),
32+
ca_file=os.getenv("HATCHET_CLIENT_TLS_ROOT_CA_FILE"),
33+
server_name=server_name,
2134
)
2235

36+
37+
def parse_listener_timeout(timeout: str | None) -> int | None:
38+
if timeout is None:
39+
return None
40+
2341
strategy: str = "tls"
2442
cert_file: str | None = None
2543
key_file: str | None = None
@@ -39,16 +57,19 @@ class HealthcheckConfig(BaseSettings):
3957
DEFAULT_HOST_PORT = "localhost:7070"
4058

4159

42-
class ClientConfig(BaseSettings):
43-
model_config = create_settings_config(
44-
env_prefix="HATCHET_CLIENT_",
45-
)
60+
class ClientConfig(BaseModel):
61+
model_config = ConfigDict(arbitrary_types_allowed=True, validate_default=True)
4662

47-
token: str = ""
63+
token: str = os.getenv("HATCHET_CLIENT_TOKEN", "")
4864
logger: Logger = getLogger()
65+
tenant_id: str = os.getenv("HATCHET_CLIENT_TENANT_ID", "")
66+
67+
## IMPORTANT: Order matters here. The validators run in the order that the
68+
## fields are defined in the model. So, we need to make sure that the
69+
## host_port is set before we try to load the tls_config and server_url
70+
host_port: str = os.getenv("HATCHET_CLIENT_HOST_PORT", DEFAULT_HOST_PORT)
71+
tls_config: ClientTLSConfig = _load_tls_config()
4972

50-
tenant_id: str = ""
51-
host_port: str = DEFAULT_HOST_PORT
5273
server_url: str = "https://app.dev.hatchet-tools.com"
5374
namespace: str = ""
5475

@@ -59,19 +80,36 @@ class ClientConfig(BaseSettings):
5980
grpc_max_recv_message_length: int = Field(
6081
default=4 * 1024 * 1024, description="4MB default"
6182
)
62-
grpc_max_send_message_length: int = Field(
63-
default=4 * 1024 * 1024, description="4MB default"
83+
grpc_max_recv_message_length: int = int(
84+
os.getenv("HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH", 4 * 1024 * 1024)
85+
) # 4MB
86+
grpc_max_send_message_length: int = int(
87+
os.getenv("HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH", 4 * 1024 * 1024)
88+
) # 4MB
89+
otel_exporter_oltp_endpoint: str | None = os.getenv(
90+
"HATCHET_CLIENT_OTEL_EXPORTER_OTLP_ENDPOINT"
91+
)
92+
otel_service_name: str | None = os.getenv("HATCHET_CLIENT_OTEL_SERVICE_NAME")
93+
otel_exporter_oltp_headers: str | None = os.getenv(
94+
"HATCHET_CLIENT_OTEL_EXPORTER_OTLP_HEADERS"
95+
)
96+
otel_exporter_oltp_protocol: str | None = os.getenv(
97+
"HATCHET_CLIENT_OTEL_EXPORTER_OTLP_PROTOCOL"
98+
)
99+
worker_healthcheck_port: int = int(
100+
os.getenv("HATCHET_CLIENT_WORKER_HEALTHCHECK_PORT", 8001)
101+
)
102+
worker_healthcheck_enabled: bool = (
103+
os.getenv("HATCHET_CLIENT_WORKER_HEALTHCHECK_ENABLED", "False") == "True"
64104
)
65105

66-
worker_preset_labels: dict[str, str] = Field(default_factory=dict)
67-
68-
@model_validator(mode="after")
69-
def validate_token_and_tenant(self) -> "ClientConfig":
70-
if not self.token:
106+
@field_validator("token", mode="after")
107+
@classmethod
108+
def validate_token(cls, token: str) -> str:
109+
if not token:
71110
raise ValueError("Token must be set")
72111

73-
if not self.tenant_id:
74-
self.tenant_id = get_tenant_id_from_jwt(self.token)
112+
return token
75113

76114
return self
77115

@@ -108,9 +146,71 @@ def validate_listener_timeout(cls, value: int | None | str) -> int | None:
108146
def validate_namespace(cls, namespace: str) -> str:
109147
if not namespace:
110148
return ""
149+
111150
if not namespace.endswith("_"):
112151
namespace = f"{namespace}_"
152+
113153
return namespace.lower()
114154

155+
@field_validator("tenant_id", mode="after")
156+
@classmethod
157+
def validate_tenant_id(cls, tenant_id: str, info: ValidationInfo) -> str:
158+
token = cast(str | None, info.data.get("token"))
159+
160+
if not tenant_id:
161+
if not token:
162+
raise ValueError("Either the token or tenant_id must be set")
163+
164+
return get_tenant_id_from_jwt(token)
165+
166+
return tenant_id
167+
168+
@field_validator("host_port", mode="after")
169+
@classmethod
170+
def validate_host_port(cls, host_port: str, info: ValidationInfo) -> str:
171+
if host_port and host_port != DEFAULT_HOST_PORT:
172+
return host_port
173+
174+
token = cast(str, info.data.get("token"))
175+
176+
if not token:
177+
raise ValueError("Token must be set")
178+
179+
_, grpc_broadcast_address = get_addresses_from_jwt(token)
180+
181+
return grpc_broadcast_address
182+
183+
@field_validator("server_url", mode="after")
184+
@classmethod
185+
def validate_server_url(cls, server_url: str, info: ValidationInfo) -> str:
186+
## IMPORTANT: Order matters here. The validators run in the order that the
187+
## fields are defined in the model. So, we need to make sure that the
188+
## host_port is set before we try to load the server_url
189+
host_port = cast(str, info.data.get("host_port"))
190+
191+
if host_port and host_port != DEFAULT_HOST_PORT:
192+
return host_port
193+
194+
token = cast(str, info.data.get("token"))
195+
196+
if not token:
197+
raise ValueError("Token must be set")
198+
199+
_server_url, _ = get_addresses_from_jwt(token)
200+
201+
return _server_url
202+
203+
@field_validator("tls_config", mode="after")
204+
@classmethod
205+
def validate_tls_config(
206+
cls, tls_config: ClientTLSConfig, info: ValidationInfo
207+
) -> ClientTLSConfig:
208+
## IMPORTANT: Order matters here. This validator runs in the order
209+
## that the fields are defined in the model. So, we need to make sure
210+
## that the host_port is set before we try to load the tls_config
211+
host_port = cast(str, info.data.get("host_port"))
212+
213+
return _load_tls_config(host_port)
214+
115215
def __hash__(self) -> int:
116216
return hash(json.dumps(self.model_dump(), default=str))

0 commit comments

Comments
 (0)