diff --git a/infrahub_sdk/client.py b/infrahub_sdk/client.py index 6b27599a..d4492a7e 100644 --- a/infrahub_sdk/client.py +++ b/infrahub_sdk/client.py @@ -146,6 +146,7 @@ def __init__( self.group_context: InfrahubGroupContext | InfrahubGroupContextSync self._initialize() self._request_context: RequestContext | None = None + _ = self.config.tls_context # Early load of the TLS context to catch errors def _initialize(self) -> None: """Sets the properties for each version of the client""" @@ -1024,7 +1025,7 @@ async def _default_request_method( async with httpx.AsyncClient( **proxy_config, # type: ignore[arg-type] - verify=self.config.tls_ca_file if self.config.tls_ca_file else not self.config.tls_insecure, + verify=self.config.tls_context, ) as client: try: response = await client.request( @@ -2748,7 +2749,7 @@ def _default_request_method( with httpx.Client( **proxy_config, # type: ignore[arg-type] - verify=self.config.tls_ca_file if self.config.tls_ca_file else not self.config.tls_insecure, + verify=self.config.tls_context, ) as client: try: response = client.request( diff --git a/infrahub_sdk/config.py b/infrahub_sdk/config.py index b0a2402a..47b3b4e8 100644 --- a/infrahub_sdk/config.py +++ b/infrahub_sdk/config.py @@ -1,9 +1,10 @@ from __future__ import annotations +import ssl from copy import deepcopy from typing import Any -from pydantic import Field, field_validator, model_validator +from pydantic import Field, PrivateAttr, field_validator, model_validator from pydantic_settings import BaseSettings, SettingsConfigDict from typing_extensions import Self @@ -78,6 +79,7 @@ class ConfigBase(BaseSettings): Can be useful to test with self-signed certificates.""", ) tls_ca_file: str | None = Field(default=None, description="File path to CA cert or bundle in PEM format") + _ssl_context: ssl.SSLContext | None = PrivateAttr(default=None) @model_validator(mode="before") @classmethod @@ -133,6 +135,25 @@ def default_infrahub_branch(self) -> str: def password_authentication(self) -> bool: return bool(self.username) + @property + def tls_context(self) -> bool | ssl.SSLContext: + if self._ssl_context: + return self._ssl_context + + if self.tls_insecure: + return False + + if self.tls_ca_file: + self._ssl_context = ssl.create_default_context(cafile=self.tls_ca_file) + + if self._ssl_context is None: + self._ssl_context = ssl.create_default_context() + + return self._ssl_context + + def set_ssl_context(self, context: ssl.SSLContext) -> None: + self._ssl_context = context + class Config(ConfigBase): recorder: RecorderType = Field(default=RecorderType.NONE, description="Select builtin recorder for later replay.") @@ -174,4 +195,7 @@ def clone(self, branch: str | None = None) -> Config: if field not in covered_keys: config[field] = deepcopy(getattr(self, field)) - return Config(**config) + new_config = Config(**config) + if self._ssl_context: + new_config.set_ssl_context(self._ssl_context) + return new_config diff --git a/tests/unit/sdk/test_client.py b/tests/unit/sdk/test_client.py index 8660f046..b3c98926 100644 --- a/tests/unit/sdk/test_client.py +++ b/tests/unit/sdk/test_client.py @@ -1,9 +1,11 @@ import inspect +import ssl +from pathlib import Path import pytest from pytest_httpx import HTTPXMock -from infrahub_sdk import InfrahubClient, InfrahubClientSync +from infrahub_sdk import Config, InfrahubClient, InfrahubClientSync from infrahub_sdk.exceptions import NodeNotFoundError from infrahub_sdk.node import InfrahubNode, InfrahubNodeSync from tests.unit.sdk.conftest import BothClients @@ -28,6 +30,84 @@ client_types = ["standard", "sync"] +CURRENT_DIRECTORY = Path(__file__).parent + + +async def test_verify_config_caches_default_ssl_context(monkeypatch) -> None: + contexts: list[tuple[str | None, object]] = [] + + def fake_create_default_context(*args: object, **kwargs: object) -> object: + context = object() + contexts.append((kwargs.get("cafile"), context)) + return context + + monkeypatch.setattr("ssl.create_default_context", fake_create_default_context) + + client = InfrahubClient(config=Config(address="http://mock")) + + first = client.config.tls_context + second = client.config.tls_context + + assert first is second + assert contexts == [(None, first)] + + +async def test_verify_config_caches_tls_ca_file_context(monkeypatch) -> None: + contexts: list[tuple[str | None, object]] = [] + + def fake_create_default_context(*args: object, **kwargs: object) -> object: + context = object() + contexts.append((kwargs.get("cafile"), context)) + return context + + monkeypatch.setattr("ssl.create_default_context", fake_create_default_context) + + client = InfrahubClient( + config=Config(address="http://mock", tls_ca_file=str(CURRENT_DIRECTORY / "test_data/path-1.pem")) + ) + + first = client.config.tls_context + second = client.config.tls_context + + assert first is second + assert contexts == [(str(CURRENT_DIRECTORY / "test_data/path-1.pem"), first)] + + client.config.tls_ca_file = str(CURRENT_DIRECTORY / "test_data/path-2.pem") + third = client.config.tls_context + + assert third is first + assert contexts == [ + (str(CURRENT_DIRECTORY / "test_data/path-1.pem"), first), + ] + + +async def test_verify_config_respects_tls_insecure(monkeypatch) -> None: + def fake_create_default_context(*args: object, **kwargs: object) -> object: + raise AssertionError("create_default_context should not be called when TLS is insecure") + + monkeypatch.setattr("ssl.create_default_context", fake_create_default_context) + + client = InfrahubClient(config=Config(address="http://mock", tls_insecure=True)) + + verify_value = client.config.tls_context + + assert verify_value is False + + +async def test_verify_config_uses_custom_tls_context(monkeypatch) -> None: + def fake_create_default_context(*args: object, **kwargs: object) -> object: + raise AssertionError("create_default_context should not be called when custom context is provided") + + monkeypatch.setattr("ssl.create_default_context", fake_create_default_context) + + config = Config(address="http://mock") + custom_context = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT) + config.set_ssl_context(custom_context) + + client = InfrahubClient(config=config) + + assert client.config.tls_context is custom_context + async def test_method_sanity() -> None: """Validate that there is at least one public method and that both clients look the same.""" diff --git a/tests/unit/sdk/test_data/path-1.pem b/tests/unit/sdk/test_data/path-1.pem new file mode 100644 index 00000000..29f0dfda --- /dev/null +++ b/tests/unit/sdk/test_data/path-1.pem @@ -0,0 +1,9 @@ +-----BEGIN CERTIFICATE----- +MIIBQDCB86ADAgECAhR6y429KiST51bZy+t330M7dN5SbzAFBgMrZXAwFjEUMBIG +A1UEAwwLZXhhbXBsZS5jb20wHhcNMjUxMDE1MTE0MjUwWhcNMzUxMDEzMTE0MjUw +WjAWMRQwEgYDVQQDDAtleGFtcGxlLmNvbTAqMAUGAytlcAMhAPIl8y8AXSWF33vX +JT2YwhMJzarOuSdPif01Gxr3Rr6Lo1MwUTAdBgNVHQ4EFgQU4heN1ZhyXpOujgcJ +WZ4LQk2m7RAwHwYDVR0jBBgwFoAU4heN1ZhyXpOujgcJWZ4LQk2m7RAwDwYDVR0T +AQH/BAUwAwEB/zAFBgMrZXADQQBoEf+8R+KWwGdaoeqinWOvrqbVZatMis0eUMvA +o+vABSPU7LIYGxLT6fpUwFSTvempzNqGZMVJ9UvVH+hYDU4D +-----END CERTIFICATE----- diff --git a/tests/unit/sdk/test_data/path-2.pem b/tests/unit/sdk/test_data/path-2.pem new file mode 100644 index 00000000..e3d2b646 --- /dev/null +++ b/tests/unit/sdk/test_data/path-2.pem @@ -0,0 +1,9 @@ +-----BEGIN CERTIFICATE----- +MIIBQDCB86ADAgECAhQTRmRZxUSA5L7VfYJb3/t+dRK0ETAFBgMrZXAwFjEUMBIG +A1UEAwwLZXhhbXBsZS5jb20wHhcNMjUxMDE1MTE0MzM0WhcNMzUxMDEzMTE0MzM0 +WjAWMRQwEgYDVQQDDAtleGFtcGxlLmNvbTAqMAUGAytlcAMhAK1O3ZhE5qzfT7Qx ++0My3ToDVDi5wwpllkKn0X50zXFao1MwUTAdBgNVHQ4EFgQUH+qBMU+h4t1vdLbO +jMSSgXdURewwHwYDVR0jBBgwFoAUH+qBMU+h4t1vdLbOjMSSgXdURewwDwYDVR0T +AQH/BAUwAwEB/zAFBgMrZXADQQB3Z03f3gQcktxk4h/v8pVi5soz8viPx17TSPXf +1WYG+Jlk4C5GQ+tyjZgZUE9LL2BFRYBv28V/NPT/0TjPGtcC +-----END CERTIFICATE-----