11import json
2+ import os
23from logging import Logger , getLogger
34
45from pydantic import Field , field_validator , model_validator
78from hatchet_sdk .token import get_addresses_from_jwt , get_tenant_id_from_jwt
89
910
10- def create_settings_config ( env_prefix : str ) -> SettingsConfigDict :
11- return SettingsConfigDict (
12- env_prefix = env_prefix ,
13- env_file = ( ".env" , ".env.hatchet" , ".env.dev" , ".env.local" ),
14- extra = "ignore" ,
15- )
11+ class ClientTLSConfig ( BaseModel ) :
12+ tls_strategy : str
13+ cert_file : str | None
14+ key_file : str | None
15+ ca_file : str | None
16+ server_name : str
1617
1718
18- class ClientTLSConfig (BaseSettings ):
19- model_config = create_settings_config (
20- env_prefix = "HATCHET_CLIENT_TLS_" ,
19+ def _load_tls_config (host_port : str | None = None ) -> ClientTLSConfig :
20+ server_name = os .getenv ("HATCHET_CLIENT_TLS_SERVER_NAME" )
21+
22+ if not server_name and host_port :
23+ server_name = host_port .split (":" )[0 ]
24+
25+ if not server_name :
26+ server_name = "localhost"
27+
28+ return ClientTLSConfig (
29+ tls_strategy = os .getenv ("HATCHET_CLIENT_TLS_STRATEGY" , "tls" ),
30+ cert_file = os .getenv ("HATCHET_CLIENT_TLS_CERT_FILE" ),
31+ key_file = os .getenv ("HATCHET_CLIENT_TLS_KEY_FILE" ),
32+ ca_file = os .getenv ("HATCHET_CLIENT_TLS_ROOT_CA_FILE" ),
33+ server_name = server_name ,
2134 )
2235
36+
37+ def parse_listener_timeout (timeout : str | None ) -> int | None :
38+ if timeout is None :
39+ return None
40+
2341 strategy : str = "tls"
2442 cert_file : str | None = None
2543 key_file : str | None = None
@@ -39,16 +57,19 @@ class HealthcheckConfig(BaseSettings):
3957DEFAULT_HOST_PORT = "localhost:7070"
4058
4159
42- class ClientConfig (BaseSettings ):
43- model_config = create_settings_config (
44- env_prefix = "HATCHET_CLIENT_" ,
45- )
60+ class ClientConfig (BaseModel ):
61+ model_config = ConfigDict (arbitrary_types_allowed = True , validate_default = True )
4662
47- token : str = ""
63+ token : str = os . getenv ( "HATCHET_CLIENT_TOKEN" , "" )
4864 logger : Logger = getLogger ()
65+ tenant_id : str = os .getenv ("HATCHET_CLIENT_TENANT_ID" , "" )
66+
67+ ## IMPORTANT: Order matters here. The validators run in the order that the
68+ ## fields are defined in the model. So, we need to make sure that the
69+ ## host_port is set before we try to load the tls_config and server_url
70+ host_port : str = os .getenv ("HATCHET_CLIENT_HOST_PORT" , DEFAULT_HOST_PORT )
71+ tls_config : ClientTLSConfig = _load_tls_config ()
4972
50- tenant_id : str = ""
51- host_port : str = DEFAULT_HOST_PORT
5273 server_url : str = "https://app.dev.hatchet-tools.com"
5374 namespace : str = ""
5475
@@ -59,19 +80,36 @@ class ClientConfig(BaseSettings):
5980 grpc_max_recv_message_length : int = Field (
6081 default = 4 * 1024 * 1024 , description = "4MB default"
6182 )
62- grpc_max_send_message_length : int = Field (
63- default = 4 * 1024 * 1024 , description = "4MB default"
83+ grpc_max_recv_message_length : int = int (
84+ os .getenv ("HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH" , 4 * 1024 * 1024 )
85+ ) # 4MB
86+ grpc_max_send_message_length : int = int (
87+ os .getenv ("HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH" , 4 * 1024 * 1024 )
88+ ) # 4MB
89+ otel_exporter_oltp_endpoint : str | None = os .getenv (
90+ "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_ENDPOINT"
91+ )
92+ otel_service_name : str | None = os .getenv ("HATCHET_CLIENT_OTEL_SERVICE_NAME" )
93+ otel_exporter_oltp_headers : str | None = os .getenv (
94+ "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_HEADERS"
95+ )
96+ otel_exporter_oltp_protocol : str | None = os .getenv (
97+ "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_PROTOCOL"
98+ )
99+ worker_healthcheck_port : int = int (
100+ os .getenv ("HATCHET_CLIENT_WORKER_HEALTHCHECK_PORT" , 8001 )
101+ )
102+ worker_healthcheck_enabled : bool = (
103+ os .getenv ("HATCHET_CLIENT_WORKER_HEALTHCHECK_ENABLED" , "False" ) == "True"
64104 )
65105
66- worker_preset_labels : dict [str , str ] = Field (default_factory = dict )
67-
68- @model_validator (mode = "after" )
69- def validate_token_and_tenant (self ) -> "ClientConfig" :
70- if not self .token :
106+ @field_validator ("token" , mode = "after" )
107+ @classmethod
108+ def validate_token (cls , token : str ) -> str :
109+ if not token :
71110 raise ValueError ("Token must be set" )
72111
73- if not self .tenant_id :
74- self .tenant_id = get_tenant_id_from_jwt (self .token )
112+ return token
75113
76114 return self
77115
@@ -108,9 +146,71 @@ def validate_listener_timeout(cls, value: int | None | str) -> int | None:
108146 def validate_namespace (cls , namespace : str ) -> str :
109147 if not namespace :
110148 return ""
149+
111150 if not namespace .endswith ("_" ):
112151 namespace = f"{ namespace } _"
152+
113153 return namespace .lower ()
114154
155+ @field_validator ("tenant_id" , mode = "after" )
156+ @classmethod
157+ def validate_tenant_id (cls , tenant_id : str , info : ValidationInfo ) -> str :
158+ token = cast (str | None , info .data .get ("token" ))
159+
160+ if not tenant_id :
161+ if not token :
162+ raise ValueError ("Either the token or tenant_id must be set" )
163+
164+ return get_tenant_id_from_jwt (token )
165+
166+ return tenant_id
167+
168+ @field_validator ("host_port" , mode = "after" )
169+ @classmethod
170+ def validate_host_port (cls , host_port : str , info : ValidationInfo ) -> str :
171+ if host_port and host_port != DEFAULT_HOST_PORT :
172+ return host_port
173+
174+ token = cast (str , info .data .get ("token" ))
175+
176+ if not token :
177+ raise ValueError ("Token must be set" )
178+
179+ _ , grpc_broadcast_address = get_addresses_from_jwt (token )
180+
181+ return grpc_broadcast_address
182+
183+ @field_validator ("server_url" , mode = "after" )
184+ @classmethod
185+ def validate_server_url (cls , server_url : str , info : ValidationInfo ) -> str :
186+ ## IMPORTANT: Order matters here. The validators run in the order that the
187+ ## fields are defined in the model. So, we need to make sure that the
188+ ## host_port is set before we try to load the server_url
189+ host_port = cast (str , info .data .get ("host_port" ))
190+
191+ if host_port and host_port != DEFAULT_HOST_PORT :
192+ return host_port
193+
194+ token = cast (str , info .data .get ("token" ))
195+
196+ if not token :
197+ raise ValueError ("Token must be set" )
198+
199+ _server_url , _ = get_addresses_from_jwt (token )
200+
201+ return _server_url
202+
203+ @field_validator ("tls_config" , mode = "after" )
204+ @classmethod
205+ def validate_tls_config (
206+ cls , tls_config : ClientTLSConfig , info : ValidationInfo
207+ ) -> ClientTLSConfig :
208+ ## IMPORTANT: Order matters here. This validator runs in the order
209+ ## that the fields are defined in the model. So, we need to make sure
210+ ## that the host_port is set before we try to load the tls_config
211+ host_port = cast (str , info .data .get ("host_port" ))
212+
213+ return _load_tls_config (host_port )
214+
115215 def __hash__ (self ) -> int :
116216 return hash (json .dumps (self .model_dump (), default = str ))
0 commit comments