Skip to content
This repository was archived by the owner on Feb 20, 2025. It is now read-only.

Commit 2c949b7

Browse files
committed
fix: loader
1 parent 9dd8eec commit 2c949b7

File tree

1 file changed

+125
-91
lines changed

1 file changed

+125
-91
lines changed

hatchet_sdk/loader.py

Lines changed: 125 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,134 +1,168 @@
11
import json
2+
import os
23
from logging import Logger, getLogger
3-
from typing import Any
4-
5-
from pydantic import (
6-
BaseModel,
7-
ConfigDict,
8-
Field,
9-
ValidationInfo,
10-
field_validator,
11-
model_validator,
12-
)
13-
from pydantic_settings import BaseSettings, SettingsConfigDict
4+
from typing import cast
5+
6+
from pydantic import BaseModel, ConfigDict, ValidationInfo, field_validator
147

158
from hatchet_sdk.token import get_addresses_from_jwt, get_tenant_id_from_jwt
169

1710

18-
def create_settings_config(env_prefix: str) -> SettingsConfigDict:
19-
return SettingsConfigDict(
20-
env_prefix=env_prefix,
21-
env_file=(".env", ".env.hatchet", ".env.dev", ".env.local"),
22-
extra="ignore",
23-
)
11+
class ClientTLSConfig(BaseModel):
12+
tls_strategy: str
13+
cert_file: str | None
14+
key_file: str | None
15+
ca_file: str | None
16+
server_name: str
2417

2518

26-
class ClientTLSConfig(BaseSettings):
27-
model_config = create_settings_config(
28-
env_prefix="HATCHET_CLIENT_TLS_",
29-
)
19+
def _load_tls_config(host_port: str | None = None) -> ClientTLSConfig:
20+
server_name = os.getenv("HATCHET_CLIENT_TLS_SERVER_NAME")
3021

31-
strategy: str = "tls"
32-
cert_file: str | None = None
33-
key_file: str | None = None
34-
root_ca_file: str | None = None
35-
server_name: str = ""
22+
if not server_name and host_port:
23+
server_name = host_port.split(":")[0]
3624

25+
if not server_name:
26+
server_name = "localhost"
3727

38-
class OTELConfig(BaseSettings):
39-
model_config = create_settings_config(
40-
env_prefix="HATCHET_CLIENT_OTEL_",
28+
return ClientTLSConfig(
29+
tls_strategy=os.getenv("HATCHET_CLIENT_TLS_STRATEGY", "tls"),
30+
cert_file=os.getenv("HATCHET_CLIENT_TLS_CERT_FILE"),
31+
key_file=os.getenv("HATCHET_CLIENT_TLS_KEY_FILE"),
32+
ca_file=os.getenv("HATCHET_CLIENT_TLS_ROOT_CA_FILE"),
33+
server_name=server_name,
4134
)
4235

43-
service_name: str | None = None
44-
exporter_otlp_endpoint: str | None = None
45-
exporter_otlp_headers: str | None = None
46-
exporter_otlp_protocol: str | None = None
4736

37+
def parse_listener_timeout(timeout: str | None) -> int | None:
38+
if timeout is None:
39+
return None
4840

49-
class HealthcheckConfig(BaseSettings):
50-
model_config = create_settings_config(
51-
env_prefix="HATCHET_CLIENT_WORKER_HEALTHCHECK_",
52-
)
53-
54-
port: int = 8001
55-
enabled: bool = False
41+
return int(timeout)
5642

5743

5844
DEFAULT_HOST_PORT = "localhost:7070"
5945

6046

61-
class ClientConfig(BaseSettings):
62-
model_config = create_settings_config(
63-
env_prefix="HATCHET_CLIENT_",
64-
)
47+
class ClientConfig(BaseModel):
48+
model_config = ConfigDict(arbitrary_types_allowed=True, validate_default=True)
6549

66-
token: str = ""
50+
token: str = os.getenv("HATCHET_CLIENT_TOKEN", "")
6751
logger: Logger = getLogger()
52+
tenant_id: str = os.getenv("HATCHET_CLIENT_TENANT_ID", "")
6853

69-
tenant_id: str = ""
70-
host_port: str = DEFAULT_HOST_PORT
71-
server_url: str = "https://app.dev.hatchet-tools.com"
72-
namespace: str = ""
73-
74-
tls_config: ClientTLSConfig = Field(default_factory=lambda: ClientTLSConfig())
75-
otel: OTELConfig = Field(default_factory=lambda: OTELConfig())
76-
healthcheck: HealthcheckConfig = Field(default_factory=lambda: HealthcheckConfig())
54+
## IMPORTANT: Order matters here. The validators run in the order that the
55+
## fields are defined in the model. So, we need to make sure that the
56+
## host_port is set before we try to load the tls_config and server_url
57+
host_port: str = os.getenv("HATCHET_CLIENT_HOST_PORT", DEFAULT_HOST_PORT)
58+
tls_config: ClientTLSConfig = _load_tls_config()
7759

78-
listener_v2_timeout: int | None = None
79-
grpc_max_recv_message_length: int = Field(
80-
default=4 * 1024 * 1024, description="4MB default"
60+
server_url: str = "https://app.dev.hatchet-tools.com"
61+
namespace: str = os.getenv("HATCHET_CLIENT_NAMESPACE", "")
62+
listener_v2_timeout: int | None = parse_listener_timeout(
63+
os.getenv("HATCHET_CLIENT_LISTENER_V2_TIMEOUT")
8164
)
82-
grpc_max_send_message_length: int = Field(
83-
default=4 * 1024 * 1024, description="4MB default"
65+
grpc_max_recv_message_length: int = int(
66+
os.getenv("HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH", 4 * 1024 * 1024)
67+
) # 4MB
68+
grpc_max_send_message_length: int = int(
69+
os.getenv("HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH", 4 * 1024 * 1024)
70+
) # 4MB
71+
otel_exporter_oltp_endpoint: str | None = os.getenv(
72+
"HATCHET_CLIENT_OTEL_EXPORTER_OTLP_ENDPOINT"
73+
)
74+
otel_service_name: str | None = os.getenv("HATCHET_CLIENT_OTEL_SERVICE_NAME")
75+
otel_exporter_oltp_headers: str | None = os.getenv(
76+
"HATCHET_CLIENT_OTEL_EXPORTER_OTLP_HEADERS"
77+
)
78+
otel_exporter_oltp_protocol: str | None = os.getenv(
79+
"HATCHET_CLIENT_OTEL_EXPORTER_OTLP_PROTOCOL"
80+
)
81+
worker_healthcheck_port: int = int(
82+
os.getenv("HATCHET_CLIENT_WORKER_HEALTHCHECK_PORT", 8001)
83+
)
84+
worker_healthcheck_enabled: bool = (
85+
os.getenv("HATCHET_CLIENT_WORKER_HEALTHCHECK_ENABLED", "False") == "True"
8486
)
8587

86-
worker_preset_labels: dict[str, str] = Field(default_factory=dict)
87-
88-
@model_validator(mode="after")
89-
def validate_token_and_tenant(self) -> "ClientConfig":
90-
if not self.token:
88+
@field_validator("token", mode="after")
89+
@classmethod
90+
def validate_token(cls, token: str) -> str:
91+
if not token:
9192
raise ValueError("Token must be set")
9293

93-
if not self.tenant_id:
94-
self.tenant_id = get_tenant_id_from_jwt(self.token)
94+
return token
95+
96+
@field_validator("namespace", mode="after")
97+
@classmethod
98+
def validate_namespace(cls, namespace: str) -> str:
99+
if not namespace:
100+
return ""
101+
102+
if not namespace.endswith("_"):
103+
namespace = f"{namespace}_"
95104

96-
return self
105+
return namespace.lower()
97106

98-
@model_validator(mode="after")
99-
def validate_addresses(self) -> "ClientConfig":
100-
if self.host_port == DEFAULT_HOST_PORT:
101-
server_url, grpc_broadcast_address = get_addresses_from_jwt(self.token)
102-
self.host_port = grpc_broadcast_address
103-
self.server_url = server_url
107+
@field_validator("tenant_id", mode="after")
108+
@classmethod
109+
def validate_tenant_id(cls, tenant_id: str, info: ValidationInfo) -> str:
110+
token = cast(str | None, info.data.get("token"))
104111

105-
if not self.tls_config.server_name:
106-
self.tls_config.server_name = self.host_port.split(":")[0]
112+
if not tenant_id:
113+
if not token:
114+
raise ValueError("Either the token or tenant_id must be set")
107115

108-
if not self.tls_config.server_name:
109-
self.tls_config.server_name = "localhost"
116+
return get_tenant_id_from_jwt(token)
110117

111-
return self
118+
return tenant_id
112119

113-
@field_validator("listener_v2_timeout")
120+
@field_validator("host_port", mode="after")
114121
@classmethod
115-
def validate_listener_timeout(cls, value: int | None | str) -> int | None:
116-
if value is None:
117-
return None
122+
def validate_host_port(cls, host_port: str, info: ValidationInfo) -> str:
123+
if host_port and host_port != DEFAULT_HOST_PORT:
124+
return host_port
125+
126+
token = cast(str, info.data.get("token"))
118127

119-
if isinstance(value, int):
120-
return value
128+
if not token:
129+
raise ValueError("Token must be set")
121130

122-
return int(value)
131+
_, grpc_broadcast_address = get_addresses_from_jwt(token)
123132

124-
@field_validator("namespace")
133+
return grpc_broadcast_address
134+
135+
@field_validator("server_url", mode="after")
125136
@classmethod
126-
def validate_namespace(cls, namespace: str) -> str:
127-
if not namespace:
128-
return ""
129-
if not namespace.endswith("_"):
130-
namespace = f"{namespace}_"
131-
return namespace.lower()
137+
def validate_server_url(cls, server_url: str, info: ValidationInfo) -> str:
138+
## IMPORTANT: Order matters here. The validators run in the order that the
139+
## fields are defined in the model. So, we need to make sure that the
140+
## host_port is set before we try to load the server_url
141+
host_port = cast(str, info.data.get("host_port"))
142+
143+
if host_port and host_port != DEFAULT_HOST_PORT:
144+
return host_port
145+
146+
token = cast(str, info.data.get("token"))
147+
148+
if not token:
149+
raise ValueError("Token must be set")
150+
151+
_server_url, _ = get_addresses_from_jwt(token)
152+
153+
return _server_url
154+
155+
@field_validator("tls_config", mode="after")
156+
@classmethod
157+
def validate_tls_config(
158+
cls, tls_config: ClientTLSConfig, info: ValidationInfo
159+
) -> ClientTLSConfig:
160+
## IMPORTANT: Order matters here. This validator runs in the order
161+
## that the fields are defined in the model. So, we need to make sure
162+
## that the host_port is set before we try to load the tls_config
163+
host_port = cast(str, info.data.get("host_port"))
164+
165+
return _load_tls_config(host_port)
132166

133167
def __hash__(self) -> int:
134168
return hash(json.dumps(self.model_dump(), default=str))

0 commit comments

Comments
 (0)