diff --git a/README.md b/README.md index cc6ca61..63e44ef 100644 --- a/README.md +++ b/README.md @@ -213,6 +213,48 @@ client = DiodeClient(target="grpcs://example.com", ...) client = DiodeClient(target="grpc://example.com", ...) ``` +### Proxy support + +The SDK automatically detects and uses HTTP/HTTPS proxies configured via standard environment variables: + +```bash +# For insecure connections +export HTTP_PROXY=http://proxy.example.com:8080 + +# For secure connections +export HTTPS_PROXY=http://proxy.example.com:8080 +# Falls back to HTTP_PROXY if HTTPS_PROXY is not set + +# Bypass proxy for specific hosts +export NO_PROXY=localhost,127.0.0.1,.example.com +``` + +**Important notes for proxy usage:** + +1. **Proxy with SKIP_TLS_VERIFY**: When using HTTP(S) proxies, the SDK **always uses secure channels** because proxies require TLS for the CONNECT tunnel. Setting `DIODE_SKIP_TLS_VERIFY=true` with a proxy will log a warning and use a secure channel anyway. + +2. **MITM proxies (like mitmproxy)**: To use an intercepting proxy, you must provide the proxy's CA certificate: + ```bash + export HTTPS_PROXY=http://127.0.0.1:8080 + export DIODE_CERT_FILE=~/.mitmproxy/mitmproxy-ca-cert.pem + ``` + +3. **Non-intercepting proxies**: Regular forwarding proxies work without additional configuration if the target server has a valid certificate trusted by system CAs. + +Example with proxy: +```python +import os + +# Configure proxy +os.environ["HTTPS_PROXY"] = "http://proxy.example.com:8080" + +client = DiodeClient( + target="grpcs://diode.example.com:443", + app_name="my-app", + app_version="1.0.0", +) +``` + #### Using custom certificates ```python diff --git a/netboxlabs/diode/sdk/client.py b/netboxlabs/diode/sdk/client.py index 56b7814..e6c8cf5 100644 --- a/netboxlabs/diode/sdk/client.py +++ b/netboxlabs/diode/sdk/client.py @@ -1,24 +1,24 @@ #!/usr/bin/env python -# Copyright 2024 NetBox Labs Inc +# Copyright 2026 NetBox Labs Inc """NetBox Labs, Diode - SDK - Client.""" import collections -import http.client import json import logging import os import platform -import ssl import sys +import tempfile import time import uuid from collections.abc import Iterable from pathlib import Path from typing import Any -from urllib.parse import urlencode, urlparse +from urllib.parse import urlparse import certifi import grpc +import requests import sentry_sdk from google.protobuf.json_format import MessageToJson, ParseDict from opentelemetry.proto.collector.logs.v1 import ( @@ -138,6 +138,143 @@ def _get_optional_config_value( return value +def _get_proxy_env_var(var_name: str) -> str | None: + """Get proxy environment variable (case-insensitive).""" + value = os.getenv(var_name.upper()) + if value: + return value + return os.getenv(var_name.lower()) + + +def _validate_proxy_url(url: str) -> bool: + """ + Validate proxy URL format. + + Args: + url: Proxy URL to validate + + Returns: + True if URL is valid, False otherwise + + """ + if not url: + return False + try: + parsed = urlparse(url) + return parsed.scheme in ("http", "https") and bool(parsed.netloc) + except Exception: + return False + + +def _matches_no_proxy_entry(host: str, entry: str) -> bool: + """Check if host matches a single NO_PROXY entry.""" + if entry == "*": + _LOGGER.debug("NO_PROXY='*' - bypassing proxy for all hosts") + return True + + if entry == host: + _LOGGER.debug(f"NO_PROXY exact match: {host}") + return True + + if not entry.startswith(".") and not entry.startswith("*"): + if host.endswith(f".{entry}"): + _LOGGER.debug(f"NO_PROXY subdomain match: {host} ends with .{entry}") + return True + + if entry.startswith("."): + if host.endswith(entry): + _LOGGER.debug(f"NO_PROXY suffix match: {host} ends with {entry}") + return True + + if entry.startswith("*."): + suffix = entry[1:] + if host.endswith(suffix): + _LOGGER.debug(f"NO_PROXY wildcard match: {host} ends with {suffix}") + return True + + return False + + +def _should_bypass_proxy(target_host: str) -> bool: + """ + Check if target should bypass proxy based on NO_PROXY. + + Implements Go net/http compatible NO_PROXY matching: + - "*" disables proxy for all hosts + - "example.com" matches example.com AND all subdomains + - ".example.com" matches only subdomains, NOT example.com itself + - Port numbers are stripped before matching + - Matching is case-insensitive + - localhost and 127.0.0.1 always bypass proxy + - NO_PROXY entries longer than 256 characters are ignored (security limit) + """ + host = target_host.split(":")[0].lower() + + if host in ("localhost", "127.0.0.1", "::1"): + return True + + no_proxy = _get_proxy_env_var("NO_PROXY") + if not no_proxy: + return False + + # Maximum reasonable length for hostname/domain (RFC 1035: 253 chars, we allow 256) + MAX_NO_PROXY_ENTRY_LENGTH = 256 + + no_proxy_list = [ + entry.strip().lower() + for entry in no_proxy.split(",") + if len(entry.strip()) <= MAX_NO_PROXY_ENTRY_LENGTH + ] + + filtered_count = len([e for e in no_proxy.split(",") if len(e.strip()) > MAX_NO_PROXY_ENTRY_LENGTH]) + if filtered_count > 0: + _LOGGER.warning( + f"Ignored {filtered_count} NO_PROXY entries exceeding {MAX_NO_PROXY_ENTRY_LENGTH} characters" + ) + + for entry in no_proxy_list: + if entry and _matches_no_proxy_entry(host, entry): + return True + + return False + + +def _get_grpc_proxy_url(target_host: str, use_tls: bool) -> str | None: + """ + Get proxy URL for gRPC target, respecting environment variables. + + Args: + target_host: gRPC target (may include port) + use_tls: Whether connection uses TLS + + Returns: + Proxy URL if proxy should be used, None otherwise + + """ + if _should_bypass_proxy(target_host): + return None + + # For HTTPS: check HTTPS_PROXY first, fall back to HTTP_PROXY + if use_tls: + proxy_url = _get_proxy_env_var("HTTPS_PROXY") + if not proxy_url: + proxy_url = _get_proxy_env_var("HTTP_PROXY") + else: + # For HTTP: only check HTTP_PROXY + proxy_url = _get_proxy_env_var("HTTP_PROXY") + + if proxy_url: + if not _validate_proxy_url(proxy_url): + _LOGGER.warning( + f"Invalid proxy URL format: {proxy_url}. " + f"Proxy URL must be http:// or https:// with valid host. Ignoring proxy." + ) + return None + _LOGGER.debug(f"Using proxy {proxy_url} for gRPC target {target_host}") + + return proxy_url + + class DiodeClient(DiodeClientInterface): """Diode Client.""" @@ -198,20 +335,36 @@ def __init__( self._authenticate(_INGEST_SCOPE) - channel_opts = ( + channel_opts = [ ( "grpc.primary_user_agent", f"{self._name}/{self._version} {self._app_name}/{self._app_version}", ), - ) + ] + + proxy_url = _get_grpc_proxy_url(self._target, self._tls_verify) + if proxy_url: + channel_opts.append(("grpc.http_proxy", proxy_url)) + _LOGGER.debug(f"Configured gRPC proxy: {proxy_url}") - if self._tls_verify and self._certificates: - _LOGGER.debug("Setting up gRPC secure channel") + channel_opts = tuple(channel_opts) + + # Channel creation logic + if self._tls_verify: + credentials = ( + grpc.ssl_channel_credentials(root_certificates=self._certificates) + if self._certificates + else grpc.ssl_channel_credentials() + ) + + _LOGGER.debug( + f"Setting up gRPC secure channel with " + f"{'custom certificates' if self._certificates else 'system certificates'}" + f"{' via proxy' if proxy_url else ''}" + ) self._channel = grpc.secure_channel( self._target, - grpc.ssl_channel_credentials( - root_certificates=self._certificates, - ), + credentials, options=channel_opts, ) else: @@ -353,7 +506,12 @@ def _authenticate(self, scope: str): self._client_id, self._client_secret, scope, + self._name, + self._version, + self._app_name, + self._app_version, self._certificates, + self._cert_file, ) access_token = authentication_client.authenticate() self._metadata = list( @@ -473,25 +631,44 @@ def __init__( else None ) - channel_opts = ( + channel_opts = [ ( "grpc.primary_user_agent", f"{self._name}/{self._version} {self._app_name}/{self._app_version}", ), - ) + ] + + proxy_url = _get_grpc_proxy_url(self._target, self._tls_verify) + if proxy_url: + channel_opts.append(("grpc.http_proxy", proxy_url)) + # Extract hostname for SSL target name override + target_host = self._target.split(":")[0] + channel_opts.append(("grpc.ssl_target_name_override", target_host)) + _LOGGER.debug(f"Configured gRPC proxy: {proxy_url}") + _LOGGER.debug(f"SSL target name override: {target_host}") + channel_opts = tuple(channel_opts) + + # Channel creation logic if self._tls_verify: credentials = ( grpc.ssl_channel_credentials(root_certificates=self._certificates) if self._certificates else grpc.ssl_channel_credentials() ) + + _LOGGER.debug( + f"Setting up gRPC secure channel with " + f"{'custom certificates' if self._certificates else 'system certificates'}" + f"{' via proxy' if proxy_url else ''}" + ) base_channel = grpc.secure_channel( self._target, credentials, options=channel_opts, ) else: + _LOGGER.debug(f"Setting up gRPC insecure channel") base_channel = grpc.insecure_channel( target=self._target, options=channel_opts, @@ -723,7 +900,12 @@ def __init__( client_id: str, client_secret: str, scope: str, + sdk_name: str, + sdk_version: str, + app_name: str, + app_version: str, certificates: bytes | None = None, + cert_file: str | None = None, ): self._target = target self._tls_verify = tls_verify @@ -731,47 +913,78 @@ def __init__( self._client_secret = client_secret self._path = path self._scope = scope + self._sdk_name = sdk_name + self._sdk_version = sdk_version + self._app_name = app_name + self._app_version = app_version self._certificates = certificates + self._cert_file = cert_file def authenticate(self) -> str: """Request an OAuth2 token using client credentials and return it.""" - if self._tls_verify and self._certificates: - context = ssl.create_default_context() - context.load_verify_locations(cadata=self._certificates.decode("utf-8")) - conn = http.client.HTTPSConnection( - self._target, - context=context, - ) - else: - conn = http.client.HTTPConnection( - self._target, - ) - headers = {"Content-type": "application/x-www-form-urlencoded"} - data = urlencode( - { + session = requests.Session() + temp_cert_file = None + + try: + # Configure SSL verification + if self._tls_verify and self._certificates: + # Use cert_file path directly if available, otherwise write to temp file + if self._cert_file: + session.verify = self._cert_file + else: + # Write certificates to temp file for requests + with tempfile.NamedTemporaryFile( + mode="wb", delete=False, suffix=".pem" + ) as f: + f.write(self._certificates) + temp_cert_file = f.name + session.verify = temp_cert_file + elif not self._tls_verify: + session.verify = False + + # Prepare auth request + url = self._get_full_auth_url() + data = { "grant_type": "client_credentials", "client_id": self._client_id, "client_secret": self._client_secret, "scope": self._scope, } - ) - url = self._get_auth_url() - try: - conn.request("POST", url, data, headers) - response = conn.getresponse() - except Exception as e: - raise DiodeConfigError(f"Failed to obtain access token: {e}") - if response.status != 200: - raise DiodeConfigError(f"Failed to obtain access token: {response.reason}") - token_info = json.loads(response.read().decode()) - access_token = token_info.get("access_token") - if not access_token: - raise DiodeConfigError( - f"Failed to obtain access token for client {self._client_id}" - ) + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "User-Agent": f"{self._sdk_name}/{self._sdk_version} {self._app_name}/{self._app_version}", + } + + response = session.post(url, data=data, headers=headers) + + if response.status_code != 200: + raise DiodeConfigError( + f"Failed to obtain access token: {response.reason}" + ) + + token_info = response.json() + access_token = token_info.get("access_token") + + if not access_token: + raise DiodeConfigError( + f"Failed to obtain access token for client {self._client_id}" + ) - _LOGGER.debug(f"Access token obtained for client {self._client_id}") - return access_token + _LOGGER.debug(f"Access token obtained for client {self._client_id}") + return access_token + + except requests.RequestException as e: + raise DiodeConfigError(f"Failed to obtain access token: {e}") + finally: + # Clean up temp certificate file + if temp_cert_file and os.path.exists(temp_cert_file): + try: + os.unlink(temp_cert_file) + _LOGGER.debug(f"Cleaned up temp certificate file: {temp_cert_file}") + except OSError as e: + _LOGGER.warning( + f"Failed to clean up temp certificate file {temp_cert_file}: {e}" + ) def _get_auth_url(self) -> str: """Construct the authentication URL, handling trailing slashes in the path.""" @@ -779,6 +992,23 @@ def _get_auth_url(self) -> str: path = self._path.rstrip("/") if self._path else "" return f"{path}/auth/token" + def _get_full_auth_url(self) -> str: + """Construct full authentication URL with scheme and authority.""" + # Determine the correct scheme + # If tls_verify is False, check if SKIP_TLS_VERIFY was set + # If it was set, the original scheme was likely HTTPS but verification is disabled + skip_tls_env = os.getenv(_DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME, "").lower() + skip_tls_from_env = skip_tls_env in ["true", "1", "yes", "on"] + + # Use HTTPS if: + # 1. tls_verify is True, OR + # 2. tls_verify is False but SKIP_TLS_VERIFY is set (original was HTTPS) + use_https = self._tls_verify or (not self._tls_verify and skip_tls_from_env) + scheme = "https" if use_https else "http" + + path = self._path.rstrip("/") if self._path else "" + return f"{scheme}://{self._target}{path}/auth/token" + class _ClientCallDetails( collections.namedtuple( diff --git a/pyproject.toml b/pyproject.toml index 7025a26..0379ba9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "certifi>=2024.7.4", "grpcio>=1.68.1", "grpcio-status>=1.68.1", + "requests>=2.31.0", "sentry-sdk>=2.2.1", "opentelemetry-proto>=1.26.0", ] diff --git a/tests/test_client.py b/tests/test_client.py index 0371a39..bbf39b1 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -220,7 +220,9 @@ def test_client_sets_up_secure_channel_when_grpcs_scheme_is_found_in_target( client_secret="123456", ) - mock_debug.assert_called_once_with("Setting up gRPC secure channel") + # Check that debug was called with the secure channel message + debug_calls = [call[0][0] for call in mock_debug.call_args_list] + assert any("Setting up gRPC secure channel with" in call for call in debug_calls) mock_secure_channel.assert_called_once() @@ -601,13 +603,17 @@ def test_diode_authentication_success(mock_diode_authentication): client_id="test_client_id", client_secret="test_client_secret", scope="diode:ingest", + sdk_name="diode-sdk-python", + sdk_version="0.1.0", + app_name="test-app", + app_version="1.0.0", ) - with mock.patch("http.client.HTTPConnection") as mock_http_conn: - mock_conn_instance = mock_http_conn.return_value - mock_conn_instance.getresponse.return_value.status = 200 - mock_conn_instance.getresponse.return_value.read.return_value = json.dumps( - {"access_token": "mocked_token"} - ).encode() + with mock.patch("requests.Session") as mock_session_class: + mock_session = mock_session_class.return_value + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"access_token": "mocked_token"} + mock_session.post.return_value = mock_response token = auth.authenticate() assert token == "mocked_token" @@ -622,11 +628,17 @@ def test_diode_authentication_failure(mock_diode_authentication): client_id="test_client_id", client_secret="test_client_secret", scope="diode:ingest", + sdk_name="diode-sdk-python", + sdk_version="0.1.0", + app_name="test-app", + app_version="1.0.0", ) - with mock.patch("http.client.HTTPConnection") as mock_http_conn: - mock_conn_instance = mock_http_conn.return_value - mock_conn_instance.getresponse.return_value.status = 401 - mock_conn_instance.getresponse.return_value.reason = "Unauthorized" + with mock.patch("requests.Session") as mock_session_class: + mock_session = mock_session_class.return_value + mock_response = mock.Mock() + mock_response.status_code = 401 + mock_response.reason = "Unauthorized" + mock_session.post.return_value = mock_response with pytest.raises(DiodeConfigError) as excinfo: auth.authenticate() @@ -653,17 +665,26 @@ def test_diode_authentication_url_with_path(mock_diode_authentication, path): client_id="test_client_id", client_secret="test_client_secret", scope="diode:ingest", + sdk_name="diode-sdk-python", + sdk_version="0.1.0", + app_name="test-app", + app_version="1.0.0", ) - with mock.patch("http.client.HTTPConnection") as mock_http_conn: - mock_conn_instance = mock_http_conn.return_value - mock_conn_instance.getresponse.return_value.status = 200 - mock_conn_instance.getresponse.return_value.read.return_value = json.dumps( - {"access_token": "mocked_token"} - ).encode() + with mock.patch("requests.Session") as mock_session_class: + mock_session = mock_session_class.return_value + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"access_token": "mocked_token"} + mock_session.post.return_value = mock_response + auth.authenticate() - mock_conn_instance.request.assert_called_once_with( - "POST", f"{(path or '').rstrip('/')}/auth/token", mock.ANY, mock.ANY - ) + + # Verify the URL in the post call + mock_session.post.assert_called_once() + call_args = mock_session.post.call_args + url = call_args[0][0] + expected_url = f"http://localhost:8081{(path or '').rstrip('/')}/auth/token" + assert url == expected_url def test_diode_authentication_request_exception(mock_diode_authentication): @@ -675,10 +696,16 @@ def test_diode_authentication_request_exception(mock_diode_authentication): client_id="test_client_id", client_secret="test_client_secret", scope="diode:ingest", + sdk_name="diode-sdk-python", + sdk_version="0.1.0", + app_name="test-app", + app_version="1.0.0", ) - with mock.patch("http.client.HTTPConnection") as mock_http_conn: - mock_conn_instance = mock_http_conn.return_value - mock_conn_instance.request.side_effect = Exception("Connection error") + with mock.patch("requests.Session") as mock_session_class: + mock_session = mock_session_class.return_value + # Import requests.RequestException for the side effect + import requests + mock_session.post.side_effect = requests.RequestException("Connection error") with pytest.raises(DiodeConfigError) as excinfo: auth.authenticate() @@ -871,39 +898,60 @@ def test_diode_authentication_with_custom_certificates(): client_id="test_client", client_secret="test_secret", scope="test_scope", + sdk_name="diode-sdk-python", + sdk_version="0.1.0", + app_name="test-app", + app_version="1.0.0", certificates=cert_content, ) with ( - mock.patch("http.client.HTTPSConnection") as mock_https_conn, - mock.patch("ssl.create_default_context") as mock_ssl_context, + mock.patch("requests.Session") as mock_session_class, + mock.patch("tempfile.NamedTemporaryFile") as mock_tempfile, + mock.patch("os.path.exists") as mock_exists, + mock.patch("os.unlink") as mock_unlink, ): - # Setup mocks - mock_context_instance = mock.Mock() - mock_ssl_context.return_value = mock_context_instance - - mock_conn_instance = mock.Mock() - mock_https_conn.return_value = mock_conn_instance - + # Setup temp file mock + mock_temp_file = mock.Mock() + mock_temp_file.name = "/tmp/test_cert.pem" + mock_temp_file.__enter__ = mock.Mock(return_value=mock_temp_file) + mock_temp_file.__exit__ = mock.Mock(return_value=False) + mock_tempfile.return_value = mock_temp_file + + # Mock os.path.exists to return True so cleanup happens + mock_exists.return_value = True + + # Setup session mock + mock_session = mock_session_class.return_value mock_response = mock.Mock() - mock_response.status = 200 - mock_response.read.return_value = b'{"access_token": "test_token"}' - mock_conn_instance.getresponse.return_value = mock_response + mock_response.status_code = 200 + mock_response.json.return_value = {"access_token": "test_token"} + mock_session.post.return_value = mock_response - # Call authenticate to trigger SSL context creation + # Call authenticate token = auth.authenticate() - # Verify SSL context was created and configured with custom certs - mock_ssl_context.assert_called_once() - mock_context_instance.load_verify_locations.assert_called_once_with( - cadata=cert_content.decode("utf-8") - ) + # Verify tempfile was created for the certificate + mock_tempfile.assert_called_once() + call_kwargs = mock_tempfile.call_args[1] + assert call_kwargs["mode"] == "wb" + assert call_kwargs["delete"] is False + assert call_kwargs["suffix"] == ".pem" - # Verify HTTPS connection was created with custom context - mock_https_conn.assert_called_once_with( - "example.com:443", - context=mock_context_instance, - ) + # Verify certificate was written + mock_temp_file.write.assert_called_once_with(cert_content) + + # Verify session.verify was set to the temp file path + assert mock_session.verify == "/tmp/test_cert.pem" + + # Verify the post request was made + mock_session.post.assert_called_once() + + # Verify os.path.exists was checked + mock_exists.assert_called_once_with("/tmp/test_cert.pem") + + # Verify temp file was cleaned up + mock_unlink.assert_called_once_with("/tmp/test_cert.pem") # Verify token was returned assert token == "test_token" @@ -1282,12 +1330,14 @@ def test_certificate_loading_efficiency(tmp_path): # Verify certificates are stored and reused assert client._certificates == cert_content - # Verify that the authentication class was created with the certificate bytes + # Verify that the authentication class was created with the certificate bytes and cert_file mock_auth_class.assert_called_once() auth_call_args = mock_auth_class.call_args - # The last argument should be the certificate bytes - assert auth_call_args[0][-1] == cert_content # certificates parameter + # The second-to-last argument should be the certificate bytes + assert auth_call_args[0][-2] == cert_content # certificates parameter + # The last argument should be the cert_file path + assert auth_call_args[0][-1] == str(cert_file) # cert_file parameter # Reset the mock to verify no additional calls during authentication mock_load_certs.reset_mock() @@ -1613,3 +1663,467 @@ def test_otlp_client_without_metadata(): # Verify no diode.metadata.* attributes are present metadata_attrs = [k for k in attributes if k.startswith("diode.metadata.")] assert len(metadata_attrs) == 0 + + +def test_get_proxy_env_var_uppercase(): + """Test _get_proxy_env_var returns uppercase environment variable.""" + from netboxlabs.diode.sdk.client import _get_proxy_env_var + + os.environ["HTTP_PROXY"] = "http://proxy.example.com:8080" + try: + assert _get_proxy_env_var("HTTP_PROXY") == "http://proxy.example.com:8080" + finally: + del os.environ["HTTP_PROXY"] + + +def test_get_proxy_env_var_lowercase(): + """Test _get_proxy_env_var returns lowercase environment variable.""" + from netboxlabs.diode.sdk.client import _get_proxy_env_var + + os.environ["http_proxy"] = "http://proxy.example.com:8080" + try: + assert _get_proxy_env_var("http_proxy") == "http://proxy.example.com:8080" + finally: + del os.environ["http_proxy"] + + +def test_get_proxy_env_var_prefers_uppercase(): + """Test _get_proxy_env_var prefers uppercase over lowercase.""" + from netboxlabs.diode.sdk.client import _get_proxy_env_var + + os.environ["HTTP_PROXY"] = "http://upper.example.com:8080" + os.environ["http_proxy"] = "http://lower.example.com:8080" + try: + assert _get_proxy_env_var("http_proxy") == "http://upper.example.com:8080" + finally: + del os.environ["HTTP_PROXY"] + del os.environ["http_proxy"] + + +def test_should_bypass_proxy_localhost(): + """Test _should_bypass_proxy returns True for localhost.""" + from netboxlabs.diode.sdk.client import _should_bypass_proxy + + assert _should_bypass_proxy("localhost") is True + assert _should_bypass_proxy("localhost:8080") is True + + +def test_should_bypass_proxy_127_0_0_1(): + """Test _should_bypass_proxy returns True for 127.0.0.1.""" + from netboxlabs.diode.sdk.client import _should_bypass_proxy + + assert _should_bypass_proxy("127.0.0.1") is True + assert _should_bypass_proxy("127.0.0.1:8080") is True + + +def test_should_bypass_proxy_with_no_proxy_asterisk(): + """Test _should_bypass_proxy returns True when NO_PROXY is '*'.""" + from netboxlabs.diode.sdk.client import _should_bypass_proxy + + os.environ["NO_PROXY"] = "*" + try: + assert _should_bypass_proxy("example.com") is True + assert _should_bypass_proxy("any.host.com") is True + finally: + del os.environ["NO_PROXY"] + + +def test_should_bypass_proxy_exact_match(): + """Test _should_bypass_proxy matches exact hostname.""" + from netboxlabs.diode.sdk.client import _should_bypass_proxy + + os.environ["NO_PROXY"] = "example.com" + try: + assert _should_bypass_proxy("example.com") is True + assert _should_bypass_proxy("example.com:443") is True + finally: + del os.environ["NO_PROXY"] + + +def test_should_bypass_proxy_subdomain_match(): + """Test _should_bypass_proxy matches subdomains.""" + from netboxlabs.diode.sdk.client import _should_bypass_proxy + + os.environ["NO_PROXY"] = "example.com" + try: + assert _should_bypass_proxy("api.example.com") is True + assert _should_bypass_proxy("www.example.com") is True + finally: + del os.environ["NO_PROXY"] + + +def test_get_grpc_proxy_url_https_proxy_for_tls(): + """Test _get_grpc_proxy_url uses HTTPS_PROXY for TLS connections.""" + from netboxlabs.diode.sdk.client import _get_grpc_proxy_url + + os.environ["HTTPS_PROXY"] = "http://https-proxy.example.com:8080" + try: + proxy_url = _get_grpc_proxy_url("example.com:443", use_tls=True) + assert proxy_url == "http://https-proxy.example.com:8080" + finally: + del os.environ["HTTPS_PROXY"] + + +def test_get_grpc_proxy_url_http_proxy_fallback_for_tls(): + """Test _get_grpc_proxy_url falls back to HTTP_PROXY for TLS connections.""" + from netboxlabs.diode.sdk.client import _get_grpc_proxy_url + + os.environ["HTTP_PROXY"] = "http://http-proxy.example.com:8080" + try: + proxy_url = _get_grpc_proxy_url("example.com:443", use_tls=True) + assert proxy_url == "http://http-proxy.example.com:8080" + finally: + del os.environ["HTTP_PROXY"] + + +def test_get_grpc_proxy_url_respects_no_proxy(): + """Test _get_grpc_proxy_url respects NO_PROXY.""" + from netboxlabs.diode.sdk.client import _get_grpc_proxy_url + + os.environ["HTTP_PROXY"] = "http://proxy.example.com:8080" + os.environ["NO_PROXY"] = "example.com" + try: + proxy_url = _get_grpc_proxy_url("example.com:443", use_tls=True) + assert proxy_url is None + finally: + del os.environ["HTTP_PROXY"] + del os.environ["NO_PROXY"] + + +def test_diode_client_configures_proxy_option(mock_diode_authentication): + """Test DiodeClient adds grpc.http_proxy option when proxy is detected.""" + os.environ["HTTP_PROXY"] = "http://proxy.example.com:8080" + try: + with mock.patch("grpc.insecure_channel") as mock_insecure_channel: + DiodeClient( + target="grpc://example.com:8081", + app_name="my-producer", + app_version="0.0.1", + client_id="abcde", + client_secret="123456", + ) + + # Should use insecure channel for grpc:// target, even with proxy + mock_insecure_channel.assert_called_once() + _, kwargs = mock_insecure_channel.call_args + options = kwargs["options"] + + # Check that grpc.http_proxy option is present + proxy_option = next( + (opt for opt in options if opt[0] == "grpc.http_proxy"), None + ) + assert proxy_option is not None + assert proxy_option[1] == "http://proxy.example.com:8080" + finally: + del os.environ["HTTP_PROXY"] + + +def test_diode_client_uses_insecure_channel_with_proxy_when_skip_tls( + mock_diode_authentication, +): + """Test DiodeClient uses insecure channel with proxy when SKIP_TLS_VERIFY is set.""" + os.environ["HTTP_PROXY"] = "http://proxy.example.com:8080" + os.environ["DIODE_SKIP_TLS_VERIFY"] = "true" + try: + with mock.patch("grpc.insecure_channel") as mock_insecure_channel: + DiodeClient( + target="grpcs://example.com:443", + app_name="my-producer", + app_version="0.0.1", + client_id="abcde", + client_secret="123456", + ) + + # Should use insecure channel when SKIP_TLS_VERIFY is set, even with proxy + mock_insecure_channel.assert_called_once() + _, kwargs = mock_insecure_channel.call_args + options = kwargs["options"] + + # Verify proxy option is set + proxy_option = next( + (opt for opt in options if opt[0] == "grpc.http_proxy"), None + ) + assert proxy_option is not None + assert proxy_option[1] == "http://proxy.example.com:8080" + finally: + del os.environ["HTTP_PROXY"] + del os.environ["DIODE_SKIP_TLS_VERIFY"] + + +def test_diode_client_respects_no_proxy_for_target(mock_diode_authentication): + """Test DiodeClient respects NO_PROXY environment variable.""" + os.environ["HTTP_PROXY"] = "http://proxy.example.com:8080" + os.environ["NO_PROXY"] = "example.com" + try: + with mock.patch("grpc.insecure_channel") as mock_insecure_channel: + DiodeClient( + target="grpc://example.com:8081", + app_name="my-producer", + app_version="0.0.1", + client_id="abcde", + client_secret="123456", + ) + + mock_insecure_channel.assert_called_once() + _, kwargs = mock_insecure_channel.call_args + options = kwargs["options"] + + # Check that grpc.http_proxy option is NOT present + proxy_option = next( + (opt for opt in options if opt[0] == "grpc.http_proxy"), None + ) + assert proxy_option is None + finally: + del os.environ["HTTP_PROXY"] + del os.environ["NO_PROXY"] + + +def test_diode_client_with_proxy_and_custom_cert(mock_diode_authentication, tmp_path): + """Test DiodeClient with proxy and custom certificate (for MITM proxies).""" + cert_content = ( + b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" + ) + cert_file = tmp_path / "custom.pem" + cert_file.write_bytes(cert_content) + + os.environ["HTTPS_PROXY"] = "http://proxy.example.com:8080" + try: + with ( + mock.patch("grpc.secure_channel") as mock_secure_channel, + mock.patch("grpc.ssl_channel_credentials") as mock_ssl_creds, + ): + mock_ssl_creds.return_value = mock.Mock() + + DiodeClient( + target="grpcs://example.com:443", + app_name="my-producer", + app_version="0.0.1", + client_id="abcde", + client_secret="123456", + cert_file=str(cert_file), + ) + + # Should use secure channel + mock_secure_channel.assert_called_once() + + # Should use custom certificate + mock_ssl_creds.assert_called_once() + ssl_call_args = mock_ssl_creds.call_args + assert ssl_call_args[1]["root_certificates"] == cert_content + + # Verify proxy option is set + _, kwargs = mock_secure_channel.call_args + options = kwargs["options"] + proxy_option = next( + (opt for opt in options if opt[0] == "grpc.http_proxy"), None + ) + assert proxy_option is not None + assert proxy_option[1] == "http://proxy.example.com:8080" + finally: + del os.environ["HTTPS_PROXY"] + + +def test_validate_proxy_url_valid_http(): + """Test _validate_proxy_url with valid HTTP URL.""" + from netboxlabs.diode.sdk.client import _validate_proxy_url + + assert _validate_proxy_url("http://proxy.example.com:8080") is True + + +def test_validate_proxy_url_valid_https(): + """Test _validate_proxy_url with valid HTTPS URL.""" + from netboxlabs.diode.sdk.client import _validate_proxy_url + + assert _validate_proxy_url("https://proxy.example.com:8443") is True + + +def test_validate_proxy_url_invalid_scheme(): + """Test _validate_proxy_url with invalid scheme.""" + from netboxlabs.diode.sdk.client import _validate_proxy_url + + assert _validate_proxy_url("ftp://proxy.example.com:8080") is False + assert _validate_proxy_url("socks5://proxy.example.com:1080") is False + + +def test_validate_proxy_url_missing_netloc(): + """Test _validate_proxy_url with missing netloc.""" + from netboxlabs.diode.sdk.client import _validate_proxy_url + + assert _validate_proxy_url("http://") is False + assert _validate_proxy_url("https://") is False + + +def test_validate_proxy_url_empty_string(): + """Test _validate_proxy_url with empty string.""" + from netboxlabs.diode.sdk.client import _validate_proxy_url + + assert _validate_proxy_url("") is False + + +def test_validate_proxy_url_malformed(): + """Test _validate_proxy_url with malformed URLs.""" + from netboxlabs.diode.sdk.client import _validate_proxy_url + + assert _validate_proxy_url("not_a_url") is False + assert _validate_proxy_url("://missing-scheme") is False + + +def test_get_grpc_proxy_url_invalid_proxy_url(): + """Test _get_grpc_proxy_url with invalid proxy URL format.""" + from netboxlabs.diode.sdk.client import _get_grpc_proxy_url + + os.environ["HTTP_PROXY"] = "not_a_valid_url" + try: + with mock.patch("logging.Logger.warning") as mock_warning: + result = _get_grpc_proxy_url("example.com:443", use_tls=False) + + # Should return None for invalid proxy URL + assert result is None + + # Should log warning + mock_warning.assert_called_once() + warning_message = mock_warning.call_args[0][0] + assert "Invalid proxy URL format" in warning_message + assert "not_a_valid_url" in warning_message + finally: + del os.environ["HTTP_PROXY"] + + +def test_get_grpc_proxy_url_ftp_scheme_rejected(): + """Test _get_grpc_proxy_url rejects non-HTTP/HTTPS schemes.""" + from netboxlabs.diode.sdk.client import _get_grpc_proxy_url + + os.environ["HTTP_PROXY"] = "ftp://proxy.example.com:21" + try: + with mock.patch("logging.Logger.warning") as mock_warning: + result = _get_grpc_proxy_url("example.com:443", use_tls=False) + + # Should return None + assert result is None + + # Should log warning + mock_warning.assert_called_once() + finally: + del os.environ["HTTP_PROXY"] + + +def test_should_bypass_proxy_with_long_no_proxy_entries(): + """Test _should_bypass_proxy filters out entries exceeding max length.""" + from netboxlabs.diode.sdk.client import _should_bypass_proxy + + # Create a NO_PROXY with one valid and one excessively long entry + valid_entry = "example.com" + long_entry = "a" * 300 # 300 characters, exceeds 256 limit + + os.environ["NO_PROXY"] = f"{valid_entry},{long_entry}" + try: + with mock.patch("logging.Logger.warning") as mock_warning: + # Should match valid entry + result = _should_bypass_proxy("example.com:443") + assert result is True + + # Should warn about filtered entries + mock_warning.assert_called_once() + warning_message = mock_warning.call_args[0][0] + assert "Ignored 1 NO_PROXY entries exceeding 256 characters" in warning_message + finally: + del os.environ["NO_PROXY"] + + +def test_should_bypass_proxy_with_multiple_long_entries(): + """Test _should_bypass_proxy warns about multiple long entries.""" + from netboxlabs.diode.sdk.client import _should_bypass_proxy + + # Create multiple excessively long entries + long_entry1 = "a" * 300 + long_entry2 = "b" * 400 + long_entry3 = "c" * 500 + valid_entry = "valid.example.com" + + os.environ["NO_PROXY"] = f"{long_entry1},{valid_entry},{long_entry2},{long_entry3}" + try: + with mock.patch("logging.Logger.warning") as mock_warning: + # Should match valid entry + result = _should_bypass_proxy("valid.example.com:443") + assert result is True + + # Should warn about 3 filtered entries + mock_warning.assert_called_once() + warning_message = mock_warning.call_args[0][0] + assert "Ignored 3 NO_PROXY entries exceeding 256 characters" in warning_message + finally: + del os.environ["NO_PROXY"] + + +def test_should_bypass_proxy_max_length_entry_accepted(): + """Test _should_bypass_proxy accepts entries at max length (256 chars).""" + from netboxlabs.diode.sdk.client import _should_bypass_proxy + + # Create an entry exactly 256 characters long + max_length_entry = "a" * 256 + + os.environ["NO_PROXY"] = max_length_entry + try: + with mock.patch("logging.Logger.warning") as mock_warning: + # Should not match (hostname doesn't match) + result = _should_bypass_proxy("example.com:443") + assert result is False + + # Should NOT warn (entry is within limit) + mock_warning.assert_not_called() + finally: + del os.environ["NO_PROXY"] + + +def test_should_bypass_proxy_over_max_length_filtered(): + """Test _should_bypass_proxy filters entries over max length (257+ chars).""" + from netboxlabs.diode.sdk.client import _should_bypass_proxy + + # Create an entry just over max length + over_max_entry = "a" * 257 + + os.environ["NO_PROXY"] = over_max_entry + try: + with mock.patch("logging.Logger.warning") as mock_warning: + result = _should_bypass_proxy("example.com:443") + assert result is False + + # Should warn about filtered entry + mock_warning.assert_called_once() + finally: + del os.environ["NO_PROXY"] + + +def test_diode_client_with_invalid_proxy_url_falls_back_to_no_proxy( + mock_diode_authentication, +): + """Test DiodeClient falls back to no proxy when proxy URL is invalid.""" + os.environ["HTTP_PROXY"] = "invalid_url_format" + try: + with ( + mock.patch("grpc.insecure_channel") as mock_insecure_channel, + mock.patch("logging.Logger.warning") as mock_warning, + ): + DiodeClient( + target="grpc://example.com:8081", + app_name="my-producer", + app_version="0.0.1", + client_id="abcde", + client_secret="123456", + ) + + # Should use insecure channel without proxy + mock_insecure_channel.assert_called_once() + + # Verify no proxy option is set (invalid proxy was rejected) + _, kwargs = mock_insecure_channel.call_args + options = kwargs["options"] + proxy_option = next( + (opt for opt in options if opt[0] == "grpc.http_proxy"), None + ) + assert proxy_option is None + + # Should log warning about invalid proxy + assert any("Invalid proxy URL format" in str(call) for call in mock_warning.call_args_list) + finally: + del os.environ["HTTP_PROXY"]