Skip to content
This repository was archived by the owner on Feb 20, 2025. It is now read-only.

Commit baa8567

Browse files
committed
debug: host port and server validation
1 parent 8983536 commit baa8567

File tree

3 files changed

+61
-17
lines changed

3 files changed

+61
-17
lines changed

hatchet_sdk/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from .clients.dispatcher.dispatcher import DispatcherClient, new_dispatcher
1313
from .clients.events import EventClient, new_event
1414
from .clients.rest_client import RestApi
15-
from .loader import ClientConfig, ConfigLoader
15+
from .loader import ClientConfig
1616

1717

1818
class Client:
@@ -102,7 +102,7 @@ def __init__(
102102
self.config = config
103103
self.listener = RunEventListenerClient(config)
104104
self.workflow_listener = workflow_listener
105-
self.logInterceptor = config.logInterceptor
105+
self.logInterceptor = config.logger
106106
self.debug = debug
107107

108108

hatchet_sdk/loader.py

Lines changed: 55 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
import json
12
import os
23
from logging import Logger, getLogger
3-
from typing import cast
4+
from typing import Any, cast
45

5-
from pydantic import BaseModel, ValidationError, ValidationInfo, field_validator
6+
from pydantic import BaseModel, ConfigDict, ValidationInfo, field_validator
67

7-
from .token import get_tenant_id_from_jwt
8+
from .token import get_addresses_from_jwt, get_tenant_id_from_jwt
89

910

1011
class ClientTLSConfig(BaseModel):
@@ -15,15 +16,21 @@ class ClientTLSConfig(BaseModel):
1516
server_name: str
1617

1718

18-
def _load_tls_config(host_port: str) -> ClientTLSConfig:
19+
def _load_tls_config(host_port: str | None = None) -> ClientTLSConfig:
20+
server_name = os.getenv("HATCHET_CLIENT_TLS_SERVER_NAME")
21+
22+
if not server_name and host_port:
23+
server_name = host_port.split(":")[0]
24+
25+
if not server_name:
26+
server_name = "localhost"
27+
1928
return ClientTLSConfig(
2029
tls_strategy=os.getenv("HATCHET_CLIENT_TLS_STRATEGY", "tls"),
2130
cert_file=os.getenv("HATCHET_CLIENT_TLS_CERT_FILE"),
2231
key_file=os.getenv("HATCHET_CLIENT_TLS_KEY_FILE"),
2332
ca_file=os.getenv("HATCHET_CLIENT_TLS_ROOT_CA_FILE"),
24-
server_name=os.getenv(
25-
"HATCHET_CLIENT_TLS_SERVER_NAME", host_port.split(":")[0]
26-
),
33+
server_name=server_name,
2734
)
2835

2936

@@ -35,11 +42,13 @@ def parse_listener_timeout(timeout: str | None) -> int | None:
3542

3643

3744
class ClientConfig(BaseModel):
45+
model_config = ConfigDict(arbitrary_types_allowed=True, validate_default=True)
46+
3847
token: str = os.getenv("HATCHET_CLIENT_TOKEN", "")
3948
logger: Logger = getLogger()
4049
tenant_id: str = os.getenv("HATCHET_CLIENT_TENANT_ID", "")
4150
host_port: str = os.getenv("HATCHET_CLIENT_HOST_PORT", "localhost:7070")
42-
tls_config: ClientTLSConfig = _load_tls_config(host_port)
51+
tls_config: ClientTLSConfig = _load_tls_config()
4352
server_url: str = "https://app.dev.hatchet-tools.com"
4453
namespace: str = os.getenv("HATCHET_CLIENT_NAMESPACE", "")
4554
listener_v2_timeout: int | None = parse_listener_timeout(
@@ -72,7 +81,7 @@ class ClientConfig(BaseModel):
7281
@classmethod
7382
def validate_token(cls, token: str) -> str:
7483
if not token:
75-
raise ValidationError("Token must be set")
84+
return ""
7685

7786
return token
7887

@@ -91,14 +100,48 @@ def validate_tenant_id(cls, tenant_id: str, info: ValidationInfo) -> str:
91100

92101
if not tenant_id:
93102
if not token:
94-
raise ValidationError(
95-
"Token must be set before attempting to infer tenant ID"
96-
)
103+
return ""
97104

98105
return get_tenant_id_from_jwt(token)
99106

100107
return tenant_id
101108

109+
@field_validator("host_port", mode="after")
110+
@classmethod
111+
def validate_host_port(cls, host_port: str, info: ValidationInfo) -> str:
112+
token = cast(str | None, info.data.get("token"))
113+
114+
if not token:
115+
return host_port
116+
117+
_, grpc_broadcast_address = get_addresses_from_jwt(token)
118+
119+
return grpc_broadcast_address
120+
121+
@field_validator("server_url", mode="after")
122+
@classmethod
123+
def validate_server_url(cls, server_url: str, info: ValidationInfo) -> str:
124+
token = cast(str | None, info.data.get("token"))
125+
126+
if not token:
127+
return server_url
128+
129+
_server_url, _ = get_addresses_from_jwt(token)
130+
131+
return _server_url
132+
133+
@field_validator("tls_config", mode="after")
134+
@classmethod
135+
def validate_tls_config(
136+
cls, tls_config: ClientTLSConfig, info: ValidationInfo
137+
) -> ClientTLSConfig:
138+
host_port = cast(str, info.data.get("host_port"))
139+
140+
return _load_tls_config(host_port)
141+
142+
def __hash__(self) -> int:
143+
return hash(json.dumps(self.model_dump(), default=str))
144+
102145
## TODO: Fix host port overrides here
103146
## Old code:
104147
## if not host_port:

hatchet_sdk/token.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import base64
22
import json
3+
from typing import Any
34

45

56
def get_tenant_id_from_jwt(token: str) -> str:
@@ -8,13 +9,13 @@ def get_tenant_id_from_jwt(token: str) -> str:
89
return claims.get("sub")
910

1011

11-
def get_addresses_from_jwt(token: str) -> (str, str):
12+
def get_addresses_from_jwt(token: str) -> tuple[str, str]:
1213
claims = extract_claims_from_jwt(token)
1314

14-
return claims.get("server_url"), claims.get("grpc_broadcast_address")
15+
return claims["server_url"], claims["grpc_broadcast_address"]
1516

1617

17-
def extract_claims_from_jwt(token: str):
18+
def extract_claims_from_jwt(token: str) -> dict[str, Any]:
1819
parts = token.split(".")
1920
if len(parts) != 3:
2021
raise ValueError("Invalid token format")

0 commit comments

Comments
 (0)