1+ import json
12import os
23from logging import Logger , getLogger
3- from typing import cast
4+ from typing import Any , cast
45
5- from pydantic import BaseModel , ValidationError , ValidationInfo , field_validator
6+ from pydantic import BaseModel , ConfigDict , ValidationInfo , field_validator
67
7- from .token import get_tenant_id_from_jwt
8+ from .token import get_addresses_from_jwt , get_tenant_id_from_jwt
89
910
1011class ClientTLSConfig (BaseModel ):
@@ -15,15 +16,21 @@ class ClientTLSConfig(BaseModel):
1516 server_name : str
1617
1718
18- def _load_tls_config (host_port : str ) -> ClientTLSConfig :
19+ def _load_tls_config (host_port : str | None = None ) -> ClientTLSConfig :
20+ server_name = os .getenv ("HATCHET_CLIENT_TLS_SERVER_NAME" )
21+
22+ if not server_name and host_port :
23+ server_name = host_port .split (":" )[0 ]
24+
25+ if not server_name :
26+ server_name = "localhost"
27+
1928 return ClientTLSConfig (
2029 tls_strategy = os .getenv ("HATCHET_CLIENT_TLS_STRATEGY" , "tls" ),
2130 cert_file = os .getenv ("HATCHET_CLIENT_TLS_CERT_FILE" ),
2231 key_file = os .getenv ("HATCHET_CLIENT_TLS_KEY_FILE" ),
2332 ca_file = os .getenv ("HATCHET_CLIENT_TLS_ROOT_CA_FILE" ),
24- server_name = os .getenv (
25- "HATCHET_CLIENT_TLS_SERVER_NAME" , host_port .split (":" )[0 ]
26- ),
33+ server_name = server_name ,
2734 )
2835
2936
@@ -35,11 +42,13 @@ def parse_listener_timeout(timeout: str | None) -> int | None:
3542
3643
3744class ClientConfig (BaseModel ):
45+ model_config = ConfigDict (arbitrary_types_allowed = True , validate_default = True )
46+
3847 token : str = os .getenv ("HATCHET_CLIENT_TOKEN" , "" )
3948 logger : Logger = getLogger ()
4049 tenant_id : str = os .getenv ("HATCHET_CLIENT_TENANT_ID" , "" )
4150 host_port : str = os .getenv ("HATCHET_CLIENT_HOST_PORT" , "localhost:7070" )
42- tls_config : ClientTLSConfig = _load_tls_config (host_port )
51+ tls_config : ClientTLSConfig = _load_tls_config ()
4352 server_url : str = "https://app.dev.hatchet-tools.com"
4453 namespace : str = os .getenv ("HATCHET_CLIENT_NAMESPACE" , "" )
4554 listener_v2_timeout : int | None = parse_listener_timeout (
@@ -72,7 +81,7 @@ class ClientConfig(BaseModel):
7281 @classmethod
7382 def validate_token (cls , token : str ) -> str :
7483 if not token :
75- raise ValidationError ( "Token must be set" )
84+ return ""
7685
7786 return token
7887
@@ -91,14 +100,48 @@ def validate_tenant_id(cls, tenant_id: str, info: ValidationInfo) -> str:
91100
92101 if not tenant_id :
93102 if not token :
94- raise ValidationError (
95- "Token must be set before attempting to infer tenant ID"
96- )
103+ return ""
97104
98105 return get_tenant_id_from_jwt (token )
99106
100107 return tenant_id
101108
109+ @field_validator ("host_port" , mode = "after" )
110+ @classmethod
111+ def validate_host_port (cls , host_port : str , info : ValidationInfo ) -> str :
112+ token = cast (str | None , info .data .get ("token" ))
113+
114+ if not token :
115+ return host_port
116+
117+ _ , grpc_broadcast_address = get_addresses_from_jwt (token )
118+
119+ return grpc_broadcast_address
120+
121+ @field_validator ("server_url" , mode = "after" )
122+ @classmethod
123+ def validate_server_url (cls , server_url : str , info : ValidationInfo ) -> str :
124+ token = cast (str | None , info .data .get ("token" ))
125+
126+ if not token :
127+ return server_url
128+
129+ _server_url , _ = get_addresses_from_jwt (token )
130+
131+ return _server_url
132+
133+ @field_validator ("tls_config" , mode = "after" )
134+ @classmethod
135+ def validate_tls_config (
136+ cls , tls_config : ClientTLSConfig , info : ValidationInfo
137+ ) -> ClientTLSConfig :
138+ host_port = cast (str , info .data .get ("host_port" ))
139+
140+ return _load_tls_config (host_port )
141+
142+ def __hash__ (self ) -> int :
143+ return hash (json .dumps (self .model_dump (), default = str ))
144+
102145 ## TODO: Fix host port overrides here
103146 ## Old code:
104147 ## if not host_port:
0 commit comments