Skip to content
This repository was archived by the owner on Feb 20, 2025. It is now read-only.

Commit e7740ef

Browse files
committed
feat: allow host_port overrides
1 parent b30b40f commit e7740ef

File tree

1 file changed

+23
-8
lines changed

1 file changed

+23
-8
lines changed

hatchet_sdk/loader.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,22 @@ def parse_listener_timeout(timeout: str | None) -> int | None:
4141
return int(timeout)
4242

4343

44+
DEFAULT_HOST_PORT = "localhost:7070"
45+
46+
4447
class ClientConfig(BaseModel):
4548
model_config = ConfigDict(arbitrary_types_allowed=True, validate_default=True)
4649

4750
token: str = os.getenv("HATCHET_CLIENT_TOKEN", "")
4851
logger: Logger = getLogger()
4952
tenant_id: str = os.getenv("HATCHET_CLIENT_TENANT_ID", "")
50-
host_port: str = os.getenv("HATCHET_CLIENT_HOST_PORT", "localhost:7070")
53+
54+
## IMPORTANT: Order matters here. The validators run in the order that the
55+
## fields are defined in the model. So, we need to make sure that the
56+
## host_port is set before we try to load the tls_config and server_url
57+
host_port: str = os.getenv("HATCHET_CLIENT_HOST_PORT", DEFAULT_HOST_PORT)
5158
tls_config: ClientTLSConfig = _load_tls_config()
59+
5260
server_url: str = "https://app.dev.hatchet-tools.com"
5361
namespace: str = os.getenv("HATCHET_CLIENT_NAMESPACE", "")
5462
listener_v2_timeout: int | None = parse_listener_timeout(
@@ -112,6 +120,9 @@ def validate_tenant_id(cls, tenant_id: str, info: ValidationInfo) -> str:
112120
@field_validator("host_port", mode="after")
113121
@classmethod
114122
def validate_host_port(cls, host_port: str, info: ValidationInfo) -> str:
123+
if host_port and host_port != DEFAULT_HOST_PORT:
124+
return host_port
125+
115126
token = cast(str, info.data.get("token"))
116127

117128
_, grpc_broadcast_address = get_addresses_from_jwt(token)
@@ -121,6 +132,14 @@ def validate_host_port(cls, host_port: str, info: ValidationInfo) -> str:
121132
@field_validator("server_url", mode="after")
122133
@classmethod
123134
def validate_server_url(cls, server_url: str, info: ValidationInfo) -> str:
135+
## IMPORTANT: Order matters here. The validators run in the order that the
136+
## fields are defined in the model. So, we need to make sure that the
137+
## host_port is set before we try to load the server_url
138+
host_port = cast(str, info.data.get("host_port"))
139+
140+
if host_port and host_port != DEFAULT_HOST_PORT:
141+
return host_port
142+
124143
token = cast(str, info.data.get("token"))
125144

126145
_server_url, _ = get_addresses_from_jwt(token)
@@ -132,16 +151,12 @@ def validate_server_url(cls, server_url: str, info: ValidationInfo) -> str:
132151
def validate_tls_config(
133152
cls, tls_config: ClientTLSConfig, info: ValidationInfo
134153
) -> ClientTLSConfig:
154+
## IMPORTANT: Order matters here. This validator runs in the order
155+
## that the fields are defined in the model. So, we need to make sure
156+
## that the host_port is set before we try to load the tls_config
135157
host_port = cast(str, info.data.get("host_port"))
136158

137159
return _load_tls_config(host_port)
138160

139161
def __hash__(self) -> int:
140162
return hash(json.dumps(self.model_dump(), default=str))
141-
142-
## TODO: Fix host port overrides here
143-
## Old code:
144-
## if not host_port:
145-
## ## extract host and port from token
146-
## server_url, grpc_broadcast_address = get_addresses_from_jwt(token)
147-
## host_port = grpc_broadcast_address

0 commit comments

Comments
 (0)