|
1 | 1 | import json |
2 | | -import os |
3 | 2 | from logging import Logger, getLogger |
4 | | -from typing import cast |
5 | | - |
6 | | -from pydantic import BaseModel, ConfigDict, ValidationInfo, field_validator |
| 3 | +from typing import Any |
| 4 | + |
| 5 | +from pydantic import ( |
| 6 | + BaseModel, |
| 7 | + ConfigDict, |
| 8 | + Field, |
| 9 | + ValidationInfo, |
| 10 | + field_validator, |
| 11 | + model_validator, |
| 12 | +) |
| 13 | +from pydantic_settings import BaseSettings, SettingsConfigDict |
7 | 14 |
|
8 | 15 | from hatchet_sdk.token import get_addresses_from_jwt, get_tenant_id_from_jwt |
9 | 16 |
|
10 | 17 |
|
11 | | -class ClientTLSConfig(BaseModel): |
12 | | - tls_strategy: str |
13 | | - cert_file: str | None |
14 | | - key_file: str | None |
15 | | - ca_file: str | None |
16 | | - server_name: str |
| 18 | +def create_settings_config(env_prefix: str) -> SettingsConfigDict: |
| 19 | + return SettingsConfigDict( |
| 20 | + env_prefix=env_prefix, |
| 21 | + env_file=(".env", ".env.hatchet", ".env.dev", ".env.local"), |
| 22 | + extra="ignore", |
| 23 | + ) |
17 | 24 |
|
18 | 25 |
|
19 | | -def _load_tls_config(host_port: str | None = None) -> ClientTLSConfig: |
20 | | - server_name = os.getenv("HATCHET_CLIENT_TLS_SERVER_NAME") |
| 26 | +class ClientTLSConfig(BaseSettings): |
| 27 | + model_config = create_settings_config( |
| 28 | + env_prefix="HATCHET_CLIENT_TLS_", |
| 29 | + ) |
21 | 30 |
|
22 | | - if not server_name and host_port: |
23 | | - server_name = host_port.split(":")[0] |
| 31 | + strategy: str = "tls" |
| 32 | + cert_file: str | None = None |
| 33 | + key_file: str | None = None |
| 34 | + root_ca_file: str | None = None |
| 35 | + server_name: str = "" |
24 | 36 |
|
25 | | - if not server_name: |
26 | | - server_name = "localhost" |
27 | 37 |
|
28 | | - return ClientTLSConfig( |
29 | | - tls_strategy=os.getenv("HATCHET_CLIENT_TLS_STRATEGY", "tls"), |
30 | | - cert_file=os.getenv("HATCHET_CLIENT_TLS_CERT_FILE"), |
31 | | - key_file=os.getenv("HATCHET_CLIENT_TLS_KEY_FILE"), |
32 | | - ca_file=os.getenv("HATCHET_CLIENT_TLS_ROOT_CA_FILE"), |
33 | | - server_name=server_name, |
| 38 | +class OTELConfig(BaseSettings): |
| 39 | + model_config = create_settings_config( |
| 40 | + env_prefix="HATCHET_CLIENT_OTEL_", |
34 | 41 | ) |
35 | 42 |
|
| 43 | + service_name: str | None = None |
| 44 | + exporter_otlp_endpoint: str | None = None |
| 45 | + exporter_otlp_headers: str | None = None |
| 46 | + exporter_otlp_protocol: str | None = None |
36 | 47 |
|
37 | | -def parse_listener_timeout(timeout: str | None) -> int | None: |
38 | | - if timeout is None: |
39 | | - return None |
40 | 48 |
|
41 | | - return int(timeout) |
| 49 | +class HealthcheckConfig(BaseSettings): |
| 50 | + model_config = create_settings_config( |
| 51 | + env_prefix="HATCHET_CLIENT_WORKER_HEALTHCHECK_", |
| 52 | + ) |
| 53 | + |
| 54 | + port: int = 8001 |
| 55 | + enabled: bool = False |
42 | 56 |
|
43 | 57 |
|
44 | 58 | DEFAULT_HOST_PORT = "localhost:7070" |
45 | 59 |
|
46 | 60 |
|
47 | | -class ClientConfig(BaseModel): |
48 | | - model_config = ConfigDict(arbitrary_types_allowed=True, validate_default=True) |
| 61 | +class ClientConfig(BaseSettings): |
| 62 | + model_config = create_settings_config( |
| 63 | + env_prefix="HATCHET_CLIENT_", |
| 64 | + ) |
49 | 65 |
|
50 | | - token: str = os.getenv("HATCHET_CLIENT_TOKEN", "") |
| 66 | + token: str = "" |
51 | 67 | logger: Logger = getLogger() |
52 | | - tenant_id: str = os.getenv("HATCHET_CLIENT_TENANT_ID", "") |
53 | | - |
54 | | - ## IMPORTANT: Order matters here. The validators run in the order that the |
55 | | - ## fields are defined in the model. So, we need to make sure that the |
56 | | - ## host_port is set before we try to load the tls_config and server_url |
57 | | - host_port: str = os.getenv("HATCHET_CLIENT_HOST_PORT", DEFAULT_HOST_PORT) |
58 | | - tls_config: ClientTLSConfig = _load_tls_config() |
59 | 68 |
|
| 69 | + tenant_id: str = "" |
| 70 | + host_port: str = DEFAULT_HOST_PORT |
60 | 71 | server_url: str = "https://app.dev.hatchet-tools.com" |
61 | | - namespace: str = os.getenv("HATCHET_CLIENT_NAMESPACE", "") |
62 | | - listener_v2_timeout: int | None = parse_listener_timeout( |
63 | | - os.getenv("HATCHET_CLIENT_LISTENER_V2_TIMEOUT") |
64 | | - ) |
65 | | - grpc_max_recv_message_length: int = int( |
66 | | - os.getenv("HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH", 4 * 1024 * 1024) |
67 | | - ) # 4MB |
68 | | - grpc_max_send_message_length: int = int( |
69 | | - os.getenv("HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH", 4 * 1024 * 1024) |
70 | | - ) # 4MB |
71 | | - otel_exporter_oltp_endpoint: str | None = os.getenv( |
72 | | - "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_ENDPOINT" |
73 | | - ) |
74 | | - otel_service_name: str | None = os.getenv("HATCHET_CLIENT_OTEL_SERVICE_NAME") |
75 | | - otel_exporter_oltp_headers: str | None = os.getenv( |
76 | | - "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_HEADERS" |
77 | | - ) |
78 | | - otel_exporter_oltp_protocol: str | None = os.getenv( |
79 | | - "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_PROTOCOL" |
80 | | - ) |
81 | | - worker_healthcheck_port: int = int( |
82 | | - os.getenv("HATCHET_CLIENT_WORKER_HEALTHCHECK_PORT", 8001) |
83 | | - ) |
84 | | - worker_healthcheck_enabled: bool = ( |
85 | | - os.getenv("HATCHET_CLIENT_WORKER_HEALTHCHECK_ENABLED", "False") == "True" |
86 | | - ) |
87 | | - |
88 | | - @field_validator("token", mode="after") |
89 | | - @classmethod |
90 | | - def validate_token(cls, token: str) -> str: |
91 | | - if not token: |
92 | | - raise ValueError("Token must be set") |
93 | | - |
94 | | - return token |
95 | | - |
96 | | - @field_validator("namespace", mode="after") |
97 | | - @classmethod |
98 | | - def validate_namespace(cls, namespace: str) -> str: |
99 | | - if not namespace: |
100 | | - return "" |
101 | | - |
102 | | - if not namespace.endswith("_"): |
103 | | - namespace = f"{namespace}_" |
| 72 | + namespace: str = "" |
104 | 73 |
|
105 | | - return namespace.lower() |
| 74 | + tls_config: ClientTLSConfig = Field(default_factory=lambda: ClientTLSConfig()) |
| 75 | + otel: OTELConfig = Field(default_factory=lambda: OTELConfig()) |
| 76 | + healthcheck: HealthcheckConfig = Field(default_factory=lambda: HealthcheckConfig()) |
106 | 77 |
|
107 | | - @field_validator("tenant_id", mode="after") |
108 | | - @classmethod |
109 | | - def validate_tenant_id(cls, tenant_id: str, info: ValidationInfo) -> str: |
110 | | - token = cast(str | None, info.data.get("token")) |
| 78 | + listener_v2_timeout: int | None = None |
| 79 | + grpc_max_recv_message_length: int = Field( |
| 80 | + default=4 * 1024 * 1024, description="4MB default" |
| 81 | + ) |
| 82 | + grpc_max_send_message_length: int = Field( |
| 83 | + default=4 * 1024 * 1024, description="4MB default" |
| 84 | + ) |
111 | 85 |
|
112 | | - if not tenant_id: |
113 | | - if not token: |
114 | | - raise ValueError("Either the token or tenant_id must be set") |
| 86 | + worker_preset_labels: dict[str, str] = Field(default_factory=dict) |
115 | 87 |
|
116 | | - return get_tenant_id_from_jwt(token) |
| 88 | + @model_validator(mode="after") |
| 89 | + def validate_token_and_tenant(self) -> "ClientConfig": |
| 90 | + if not self.token: |
| 91 | + raise ValueError("Token must be set") |
117 | 92 |
|
118 | | - return tenant_id |
| 93 | + if not self.tenant_id: |
| 94 | + self.tenant_id = get_tenant_id_from_jwt(self.token) |
119 | 95 |
|
120 | | - @field_validator("host_port", mode="after") |
121 | | - @classmethod |
122 | | - def validate_host_port(cls, host_port: str, info: ValidationInfo) -> str: |
123 | | - if host_port and host_port != DEFAULT_HOST_PORT: |
124 | | - return host_port |
| 96 | + return self |
125 | 97 |
|
126 | | - token = cast(str, info.data.get("token")) |
| 98 | + @model_validator(mode="after") |
| 99 | + def validate_addresses(self) -> "ClientConfig": |
| 100 | + if self.host_port == DEFAULT_HOST_PORT: |
| 101 | + server_url, grpc_broadcast_address = get_addresses_from_jwt(self.token) |
| 102 | + self.host_port = grpc_broadcast_address |
| 103 | + self.server_url = server_url |
127 | 104 |
|
128 | | - if not token: |
129 | | - raise ValueError("Token must be set") |
| 105 | + if not self.tls_config.server_name: |
| 106 | + self.tls_config.server_name = self.host_port.split(":")[0] |
130 | 107 |
|
131 | | - _, grpc_broadcast_address = get_addresses_from_jwt(token) |
| 108 | + if not self.tls_config.server_name: |
| 109 | + self.tls_config.server_name = "localhost" |
132 | 110 |
|
133 | | - return grpc_broadcast_address |
| 111 | + return self |
134 | 112 |
|
135 | | - @field_validator("server_url", mode="after") |
| 113 | + @field_validator("listener_v2_timeout") |
136 | 114 | @classmethod |
137 | | - def validate_server_url(cls, server_url: str, info: ValidationInfo) -> str: |
138 | | - ## IMPORTANT: Order matters here. The validators run in the order that the |
139 | | - ## fields are defined in the model. So, we need to make sure that the |
140 | | - ## host_port is set before we try to load the server_url |
141 | | - host_port = cast(str, info.data.get("host_port")) |
| 115 | + def validate_listener_timeout(cls, value: int | None | str) -> int | None: |
| 116 | + if value is None: |
| 117 | + return None |
142 | 118 |
|
143 | | - if host_port and host_port != DEFAULT_HOST_PORT: |
144 | | - return host_port |
| 119 | + if isinstance(value, int): |
| 120 | + return value |
145 | 121 |
|
146 | | - token = cast(str, info.data.get("token")) |
| 122 | + return int(value) |
147 | 123 |
|
148 | | - if not token: |
149 | | - raise ValueError("Token must be set") |
150 | | - |
151 | | - _server_url, _ = get_addresses_from_jwt(token) |
152 | | - |
153 | | - return _server_url |
154 | | - |
155 | | - @field_validator("tls_config", mode="after") |
| 124 | + @field_validator("namespace") |
156 | 125 | @classmethod |
157 | | - def validate_tls_config( |
158 | | - cls, tls_config: ClientTLSConfig, info: ValidationInfo |
159 | | - ) -> ClientTLSConfig: |
160 | | - ## IMPORTANT: Order matters here. This validator runs in the order |
161 | | - ## that the fields are defined in the model. So, we need to make sure |
162 | | - ## that the host_port is set before we try to load the tls_config |
163 | | - host_port = cast(str, info.data.get("host_port")) |
164 | | - |
165 | | - return _load_tls_config(host_port) |
| 126 | + def validate_namespace(cls, namespace: str) -> str: |
| 127 | + if not namespace: |
| 128 | + return "" |
| 129 | + if not namespace.endswith("_"): |
| 130 | + namespace = f"{namespace}_" |
| 131 | + return namespace.lower() |
166 | 132 |
|
167 | 133 | def __hash__(self) -> int: |
168 | 134 | return hash(json.dumps(self.model_dump(), default=str)) |
0 commit comments