diff --git a/hatchet_sdk/loader.py b/hatchet_sdk/loader.py index 0252f33a..e6aebd8a 100644 --- a/hatchet_sdk/loader.py +++ b/hatchet_sdk/loader.py @@ -134,7 +134,9 @@ def get_config_value(key, env_var): if not tenant_id: tenant_id = get_tenant_id_from_jwt(token) - tls_config = self._load_tls_config(config_data["tls"], host_port) + tls_config = self._load_tls_config( + config_data["tls"], host_port, defaults=defaults.tls_config + ) worker_healthcheck_port = int( get_config_value( @@ -201,16 +203,15 @@ def get_config_value(key, env_var): enable_force_kill_sync_threads=enable_force_kill_sync_threads, ) - def _load_tls_config(self, tls_data: Dict, host_port) -> ClientTLSConfig: + def _load_tls_config( + self, tls_data: Dict, host_port: str, defaults: Optional[ClientTLSConfig] + ) -> ClientTLSConfig: tls_strategy = ( tls_data["tlsStrategy"] if "tlsStrategy" in tls_data else self._get_env_var("HATCHET_CLIENT_TLS_STRATEGY") ) - if not tls_strategy: - tls_strategy = "tls" - cert_file = ( tls_data["tlsCertFile"] if "tlsCertFile" in tls_data @@ -233,6 +234,24 @@ def _load_tls_config(self, tls_data: Dict, host_port) -> ClientTLSConfig: else self._get_env_var("HATCHET_CLIENT_TLS_SERVER_NAME") ) + if not tls_strategy and defaults: + tls_strategy = defaults.tls_strategy + + if not cert_file and defaults: + cert_file = defaults.cert_file + + if not key_file and defaults: + key_file = defaults.key_file + + if not ca_file and defaults: + ca_file = defaults.ca_file + + if not server_name and defaults: + server_name = defaults.server_name + + if not tls_strategy: + tls_strategy = "tls" + # if server_name is not set, use the host from the host_port if not server_name: server_name = host_port.split(":")[0]