Skip to content
This repository was archived by the owner on Feb 20, 2025. It is now read-only.
Merged
1 change: 1 addition & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
import subprocess
import time
from io import BytesIO
Expand Down
41 changes: 31 additions & 10 deletions hatchet_sdk/loader.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import json
import os
from logging import Logger, getLogger
from typing import Any, cast
from typing import cast

from pydantic import BaseModel, ConfigDict, ValidationInfo, field_validator

from .token import get_addresses_from_jwt, get_tenant_id_from_jwt
from hatchet_sdk.token import get_addresses_from_jwt, get_tenant_id_from_jwt


class ClientTLSConfig(BaseModel):
Expand Down Expand Up @@ -41,14 +41,22 @@ def parse_listener_timeout(timeout: str | None) -> int | None:
return int(timeout)


DEFAULT_HOST_PORT = "localhost:7070"


class ClientConfig(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, validate_default=True)

token: str = os.getenv("HATCHET_CLIENT_TOKEN", "")
logger: Logger = getLogger()
tenant_id: str = os.getenv("HATCHET_CLIENT_TENANT_ID", "")
host_port: str = os.getenv("HATCHET_CLIENT_HOST_PORT", "localhost:7070")

## IMPORTANT: Order matters here. The validators run in the order that the
## fields are defined in the model. So, we need to make sure that the
## host_port is set before we try to load the tls_config and server_url
host_port: str = os.getenv("HATCHET_CLIENT_HOST_PORT", DEFAULT_HOST_PORT)
tls_config: ClientTLSConfig = _load_tls_config()

server_url: str = "https://app.dev.hatchet-tools.com"
namespace: str = os.getenv("HATCHET_CLIENT_NAMESPACE", "")
listener_v2_timeout: int | None = parse_listener_timeout(
Expand Down Expand Up @@ -112,17 +120,34 @@ def validate_tenant_id(cls, tenant_id: str, info: ValidationInfo) -> str:
@field_validator("host_port", mode="after")
@classmethod
def validate_host_port(cls, host_port: str, info: ValidationInfo) -> str:
if host_port and host_port != DEFAULT_HOST_PORT:
return host_port

token = cast(str, info.data.get("token"))

if not token:
raise ValueError("Token must be set")

_, grpc_broadcast_address = get_addresses_from_jwt(token)

return grpc_broadcast_address

@field_validator("server_url", mode="after")
@classmethod
def validate_server_url(cls, server_url: str, info: ValidationInfo) -> str:
## IMPORTANT: Order matters here. The validators run in the order that the
## fields are defined in the model. So, we need to make sure that the
## host_port is set before we try to load the server_url
host_port = cast(str, info.data.get("host_port"))

if host_port and host_port != DEFAULT_HOST_PORT:
return host_port

token = cast(str, info.data.get("token"))

if not token:
raise ValueError("Token must be set")

_server_url, _ = get_addresses_from_jwt(token)

return _server_url
Expand All @@ -132,16 +157,12 @@ def validate_server_url(cls, server_url: str, info: ValidationInfo) -> str:
def validate_tls_config(
cls, tls_config: ClientTLSConfig, info: ValidationInfo
) -> ClientTLSConfig:
## IMPORTANT: Order matters here. This validator runs in the order
## that the fields are defined in the model. So, we need to make sure
## that the host_port is set before we try to load the tls_config
host_port = cast(str, info.data.get("host_port"))

return _load_tls_config(host_port)

def __hash__(self) -> int:
return hash(json.dumps(self.model_dump(), default=str))

## TODO: Fix host port overrides here
## Old code:
## if not host_port:
## ## extract host and port from token
## server_url, grpc_broadcast_address = get_addresses_from_jwt(token)
## host_port = grpc_broadcast_address
Loading