From cef6a2b34d827260d913597b8259248fbd493938 Mon Sep 17 00:00:00 2001 From: nbayati <99771966+nbayati@users.noreply.github.com> Date: Tue, 9 Dec 2025 14:52:19 -0800 Subject: [PATCH 1/2] feat: Add support for Agent Identity bound tokens (#1821) This change introduces support for requesting certificate-bound access tokens for Agent Identities on GKE and Cloud Run. The design doc: [go/sdk-agent-identity](http://goto.google.com/sdk-agent-identity) --- google/auth/_agent_identity_utils.py | 262 ++++++++++++++++++++ google/auth/_oauth2client.py | 2 +- google/auth/compute_engine/_metadata.py | 13 +- google/auth/compute_engine/credentials.py | 2 +- google/auth/environment_vars.py | 9 + google/auth/external_account.py | 15 +- google/auth/identity_pool.py | 22 ++ google/auth/transport/_mtls_helper.py | 4 +- mypy.ini | 2 +- noxfile.py | 1 + setup.py | 1 + tests/compute_engine/test__metadata.py | 100 ++++++++ tests/compute_engine/test_credentials.py | 72 ++++++ tests/test_agent_identity_utils.py | 288 ++++++++++++++++++++++ tests/test_external_account.py | 18 ++ tests/test_identity_pool.py | 56 +++++ 16 files changed, 855 insertions(+), 12 deletions(-) create mode 100644 google/auth/_agent_identity_utils.py create mode 100644 tests/test_agent_identity_utils.py diff --git a/google/auth/_agent_identity_utils.py b/google/auth/_agent_identity_utils.py new file mode 100644 index 000000000..665f2aa50 --- /dev/null +++ b/google/auth/_agent_identity_utils.py @@ -0,0 +1,262 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for Agent Identity credentials.""" + +import base64 +import hashlib +import logging +import os +import re +import time +from urllib.parse import urlparse, quote + +from google.auth import environment_vars +from google.auth import exceptions + + +_LOGGER = logging.getLogger(__name__) + +CRYPTOGRAPHY_NOT_FOUND_ERROR = ( + "The cryptography library is required for certificate-based authentication." + "Please install it with `pip install google-auth[cryptography]`." +) + +# SPIFFE trust domain patterns for Agent Identities. +_AGENT_IDENTITY_SPIFFE_TRUST_DOMAIN_PATTERNS = [ + r"^agents\.global\.org-\d+\.system\.id\.goog$", + r"^agents\.global\.proj-\d+\.system\.id\.goog$", +] + +_WELL_KNOWN_CERT_PATH = "/var/run/secrets/workload-spiffe-credentials/certificates.pem" + +# Constants for polling the certificate file. +_FAST_POLL_CYCLES = 50 +_FAST_POLL_INTERVAL = 0.1 # 100ms +_SLOW_POLL_INTERVAL = 0.5 # 500ms +_TOTAL_TIMEOUT = 30 # seconds + +# Calculate the number of slow poll cycles based on the total timeout. +_SLOW_POLL_CYCLES = int( + (_TOTAL_TIMEOUT - (_FAST_POLL_CYCLES * _FAST_POLL_INTERVAL)) / _SLOW_POLL_INTERVAL +) + +_POLLING_INTERVALS = ([_FAST_POLL_INTERVAL] * _FAST_POLL_CYCLES) + ( + [_SLOW_POLL_INTERVAL] * _SLOW_POLL_CYCLES +) + + +def _is_certificate_file_ready(path): + """Checks if a file exists and is not empty.""" + return path and os.path.exists(path) and os.path.getsize(path) > 0 + + +def get_agent_identity_certificate_path(): + """Gets the certificate path from the certificate config file. + + The path to the certificate config file is read from the + GOOGLE_API_CERTIFICATE_CONFIG environment variable. This function + implements a retry mechanism to handle cases where the environment + variable is set before the files are available on the filesystem. + + Returns: + str: The path to the leaf certificate file. + + Raises: + google.auth.exceptions.RefreshError: If the certificate config file + or the certificate file cannot be found after retries. + """ + import json + + cert_config_path = os.environ.get(environment_vars.GOOGLE_API_CERTIFICATE_CONFIG) + if not cert_config_path: + return None + + has_logged_warning = False + + for interval in _POLLING_INTERVALS: + try: + with open(cert_config_path, "r") as f: + cert_config = json.load(f) + cert_path = ( + cert_config.get("cert_configs", {}) + .get("workload", {}) + .get("cert_path") + ) + if _is_certificate_file_ready(cert_path): + return cert_path + except (IOError, ValueError, KeyError): + if not has_logged_warning: + _LOGGER.warning( + "Certificate config file not found at %s (from %s environment " + "variable). Retrying for up to %s seconds.", + cert_config_path, + environment_vars.GOOGLE_API_CERTIFICATE_CONFIG, + _TOTAL_TIMEOUT, + ) + has_logged_warning = True + pass + + # As a fallback, check the well-known certificate path. + if _is_certificate_file_ready(_WELL_KNOWN_CERT_PATH): + return _WELL_KNOWN_CERT_PATH + + # A sleep is required in two cases: + # 1. The config file is not found (the except block). + # 2. The config file is found, but the certificate is not yet available. + # In both cases, we need to poll, so we sleep on every iteration + # that doesn't return a certificate. + time.sleep(interval) + + raise exceptions.RefreshError( + "Certificate config or certificate file not found after multiple retries. " + f"Token binding protection is failing. You can turn off this protection by setting " + f"{environment_vars.GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES} to false " + "to fall back to unbound tokens." + ) + + +def get_and_parse_agent_identity_certificate(): + """Gets and parses the agent identity certificate if not opted out. + + Checks if the user has opted out of certificate-bound tokens. If not, + it gets the certificate path, reads the file, and parses it. + + Returns: + The parsed certificate object if found and not opted out, otherwise None. + """ + # If the user has opted out of cert bound tokens, there is no need to + # look up the certificate. + is_opted_out = ( + os.environ.get( + environment_vars.GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES, + "true", + ).lower() + == "false" + ) + if is_opted_out: + return None + + cert_path = get_agent_identity_certificate_path() + if not cert_path: + return None + + with open(cert_path, "rb") as cert_file: + cert_bytes = cert_file.read() + + return parse_certificate(cert_bytes) + + +def parse_certificate(cert_bytes): + """Parses a PEM-encoded certificate. + + Args: + cert_bytes (bytes): The PEM-encoded certificate bytes. + + Returns: + cryptography.x509.Certificate: The parsed certificate object. + """ + try: + from cryptography import x509 + + return x509.load_pem_x509_certificate(cert_bytes) + except ImportError as e: + raise ImportError(CRYPTOGRAPHY_NOT_FOUND_ERROR) from e + + +def _is_agent_identity_certificate(cert): + """Checks if a certificate is an Agent Identity certificate. + + This is determined by checking the Subject Alternative Name (SAN) for a + SPIFFE ID with a trust domain matching Agent Identity patterns. + + Args: + cert (cryptography.x509.Certificate): The parsed certificate object. + + Returns: + bool: True if the certificate is an Agent Identity certificate, + False otherwise. + """ + try: + from cryptography import x509 + from cryptography.x509.oid import ExtensionOID + + try: + ext = cert.extensions.get_extension_for_oid( + ExtensionOID.SUBJECT_ALTERNATIVE_NAME + ) + except x509.ExtensionNotFound: + return False + uris = ext.value.get_values_for_type(x509.UniformResourceIdentifier) + + for uri in uris: + parsed_uri = urlparse(uri) + if parsed_uri.scheme == "spiffe": + trust_domain = parsed_uri.netloc + for pattern in _AGENT_IDENTITY_SPIFFE_TRUST_DOMAIN_PATTERNS: + if re.match(pattern, trust_domain): + return True + return False + except ImportError as e: + raise ImportError(CRYPTOGRAPHY_NOT_FOUND_ERROR) from e + + +def calculate_certificate_fingerprint(cert): + """Calculates the URL-encoded, unpadded, base64-encoded SHA256 hash of a + DER-encoded certificate. + + Args: + cert (cryptography.x509.Certificate): The parsed certificate object. + + Returns: + str: The URL-encoded, unpadded, base64-encoded SHA256 fingerprint. + """ + try: + from cryptography.hazmat.primitives import serialization + + der_cert = cert.public_bytes(serialization.Encoding.DER) + fingerprint = hashlib.sha256(der_cert).digest() + # The certificate fingerprint is generated in two steps to align with GFE's + # expectations and ensure proper URL transmission: + # 1. Standard base64 encoding is applied, and padding ('=') is removed. + # 2. The resulting string is then URL-encoded to handle special characters + # ('+', '/') that would otherwise be misinterpreted in URL parameters. + base64_fingerprint = base64.b64encode(fingerprint).decode("utf-8") + unpadded_base64_fingerprint = base64_fingerprint.rstrip("=") + return quote(unpadded_base64_fingerprint) + except ImportError as e: + raise ImportError(CRYPTOGRAPHY_NOT_FOUND_ERROR) from e + + +def should_request_bound_token(cert): + """Determines if a bound token should be requested. + + This is based on the GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES + environment variable and whether the certificate is an agent identity cert. + + Args: + cert (cryptography.x509.Certificate): The parsed certificate object. + + Returns: + bool: True if a bound token should be requested, False otherwise. + """ + is_agent_cert = _is_agent_identity_certificate(cert) + is_opted_in = ( + os.environ.get( + environment_vars.GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES, + "true", + ).lower() + == "true" + ) + return is_agent_cert and is_opted_in diff --git a/google/auth/_oauth2client.py b/google/auth/_oauth2client.py index 8b83ff23c..8032b26ad 100644 --- a/google/auth/_oauth2client.py +++ b/google/auth/_oauth2client.py @@ -127,7 +127,7 @@ def _convert_appengine_app_assertion_credentials(credentials): oauth2client.contrib.gce.AppAssertionCredentials: _convert_gce_app_assertion_credentials, } -if _HAS_APPENGINE: +if _HAS_APPENGINE: # pragma: no cover _CLASS_CONVERSION_MAP[ oauth2client.contrib.appengine.AppAssertionCredentials ] = _convert_appengine_app_assertion_credentials diff --git a/google/auth/compute_engine/_metadata.py b/google/auth/compute_engine/_metadata.py index 96f1ff526..35b6c4495 100644 --- a/google/auth/compute_engine/_metadata.py +++ b/google/auth/compute_engine/_metadata.py @@ -451,12 +451,19 @@ def get_service_account_token(request, service_account="default", scopes=None): google.auth.exceptions.TransportError: if an error occurred while retrieving metadata. """ + from google.auth import _agent_identity_utils + + params = {} if scopes: if not isinstance(scopes, str): scopes = ",".join(scopes) - params = {"scopes": scopes} - else: - params = None + params["scopes"] = scopes + + cert = _agent_identity_utils.get_and_parse_agent_identity_certificate() + if cert: + if _agent_identity_utils.should_request_bound_token(cert): + fingerprint = _agent_identity_utils.calculate_certificate_fingerprint(cert) + params["bindCertificateFingerprint"] = fingerprint metrics_header = { metrics.API_CLIENT_HEADER: metrics.token_request_access_token_mds() diff --git a/google/auth/compute_engine/credentials.py b/google/auth/compute_engine/credentials.py index 0f518166a..554547619 100644 --- a/google/auth/compute_engine/credentials.py +++ b/google/auth/compute_engine/credentials.py @@ -135,9 +135,9 @@ def _refresh_token(self, request): service can't be reached if if the instance has not credentials. """ - scopes = self._scopes if self._scopes is not None else self._default_scopes try: self._retrieve_info(request) + scopes = self._scopes if self._scopes is not None else self._default_scopes # Always fetch token with default service account email. self.token, self.expiry = _metadata.get_service_account_token( request, service_account="default", scopes=scopes diff --git a/google/auth/environment_vars.py b/google/auth/environment_vars.py index 5da3a7382..1e6557272 100644 --- a/google/auth/environment_vars.py +++ b/google/auth/environment_vars.py @@ -92,3 +92,12 @@ GOOGLE_AUTH_TRUST_BOUNDARY_ENABLED = "GOOGLE_AUTH_TRUST_BOUNDARY_ENABLED" """Environment variable controlling whether to enable trust boundary feature. The default value is false. Users have to explicitly set this value to true.""" + +GOOGLE_API_CERTIFICATE_CONFIG = "GOOGLE_API_CERTIFICATE_CONFIG" +"""Environment variable defining the location of Google API certificate config +file.""" + +GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES = ( + "GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES" +) +"""Environment variable to prevent agent token sharing for GCP services.""" diff --git a/google/auth/external_account.py b/google/auth/external_account.py index 8eba0d249..8cf928778 100644 --- a/google/auth/external_account.py +++ b/google/auth/external_account.py @@ -420,6 +420,9 @@ def refresh(self, request): credentials, it will refresh the access token and the trust boundary. """ self._refresh_token(request) + self._handle_trust_boundary(request) + + def _handle_trust_boundary(self, request): # If we are impersonating, the trust boundary is handled by the # impersonated credentials object. We need to get it from there. if self._service_account_impersonation_url: @@ -428,7 +431,7 @@ def refresh(self, request): # Otherwise, refresh the trust boundary for the external account. self._refresh_trust_boundary(request) - def _refresh_token(self, request): + def _refresh_token(self, request, cert_fingerprint=None): scopes = self._scopes if self._scopes is not None else self._default_scopes # Inject client certificate into request. @@ -446,11 +449,15 @@ def _refresh_token(self, request): self.expiry = self._impersonated_credentials.expiry else: now = _helpers.utcnow() - additional_options = None + additional_options = {} # Do not pass workforce_pool_user_project when client authentication # is used. The client ID is sufficient for determining the user project. if self._workforce_pool_user_project and not self._client_id: - additional_options = {"userProject": self._workforce_pool_user_project} + additional_options["userProject"] = self._workforce_pool_user_project + + if cert_fingerprint: + additional_options["bindCertFingerprint"] = cert_fingerprint + additional_headers = { metrics.API_CLIENT_HEADER: metrics.byoid_metrics_header( self._metrics_options @@ -464,7 +471,7 @@ def _refresh_token(self, request): audience=self._audience, scopes=scopes, requested_token_type=_STS_REQUESTED_TOKEN_TYPE, - additional_options=additional_options, + additional_options=additional_options if additional_options else None, additional_headers=additional_headers, ) self.token = response_data.get("access_token") diff --git a/google/auth/identity_pool.py b/google/auth/identity_pool.py index 79b7de920..d2ed8c85a 100644 --- a/google/auth/identity_pool.py +++ b/google/auth/identity_pool.py @@ -550,3 +550,25 @@ def from_file(cls, filename, **kwargs): credentials. """ return super(Credentials, cls).from_file(filename, **kwargs) + + def refresh(self, request): + """Refreshes the access token. + + Args: + request (google.auth.transport.Request): The object used to make + HTTP requests. + """ + from google.auth import _agent_identity_utils + + cert_fingerprint = None + # Check if the credential is X.509 based. + if self._credential_source_certificate is not None: + cert_bytes = self._get_cert_bytes() + cert = _agent_identity_utils.parse_certificate(cert_bytes) + if _agent_identity_utils.should_request_bound_token(cert): + cert_fingerprint = _agent_identity_utils.calculate_certificate_fingerprint( + cert + ) + + self._refresh_token(request, cert_fingerprint=cert_fingerprint) + self._handle_trust_boundary(request) diff --git a/google/auth/transport/_mtls_helper.py b/google/auth/transport/_mtls_helper.py index f5d6b6724..497613a4f 100644 --- a/google/auth/transport/_mtls_helper.py +++ b/google/auth/transport/_mtls_helper.py @@ -20,11 +20,11 @@ import re import subprocess +from google.auth import environment_vars from google.auth import exceptions CONTEXT_AWARE_METADATA_PATH = "~/.secureConnect/context_aware_metadata.json" CERTIFICATE_CONFIGURATION_DEFAULT_PATH = "~/.config/gcloud/certificate_config.json" -_CERTIFICATE_CONFIGURATION_ENV = "GOOGLE_API_CERTIFICATE_CONFIG" _CERT_PROVIDER_COMMAND = "cert_provider_command" _CERT_REGEX = re.compile( b"-----BEGIN CERTIFICATE-----.+-----END CERTIFICATE-----\r?\n?", re.DOTALL @@ -146,7 +146,7 @@ def _get_cert_config_path(certificate_config_path=None): """ if certificate_config_path is None: - env_path = environ.get(_CERTIFICATE_CONFIGURATION_ENV, None) + env_path = environ.get(environment_vars.GOOGLE_API_CERTIFICATE_CONFIG, None) if env_path is not None and env_path != "": certificate_config_path = env_path else: diff --git a/mypy.ini b/mypy.ini index 574c5aed3..c129006db 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,3 +1,3 @@ [mypy] -python_version = 3.7 +python_version = 3.9 namespace_packages = True diff --git a/noxfile.py b/noxfile.py index 728e8c7cc..3a2c0e883 100644 --- a/noxfile.py +++ b/noxfile.py @@ -105,6 +105,7 @@ def mypy(session): "types-requests", "types-setuptools", "types-mock", + "pytest", ) session.run("mypy", "-p", "google", "-p", "tests", "-p", "tests_async") diff --git a/setup.py b/setup.py index 20f79ce66..56c451fdb 100644 --- a/setup.py +++ b/setup.py @@ -81,6 +81,7 @@ ] extras = { + "cryptography": cryptography_base_require, "aiohttp": aiohttp_extra_require, "enterprise_cert": enterprise_cert_extra_require, "pyopenssl": pyopenssl_extra_require, diff --git a/tests/compute_engine/test__metadata.py b/tests/compute_engine/test__metadata.py index adb63f667..5bb85c264 100644 --- a/tests/compute_engine/test__metadata.py +++ b/tests/compute_engine/test__metadata.py @@ -40,6 +40,29 @@ DATA_DIR, "smbios_product_name_non_google" ) +# A mock PEM-encoded certificate without an Agent Identity SPIFFE ID. +NON_AGENT_IDENTITY_CERT_BYTES = ( + b"-----BEGIN CERTIFICATE-----\n" + b"MIIDIzCCAgugAwIBAgIJAMfISuBQ5m+5MA0GCSqGSIb3DQEBBQUAMBUxEzARBgNV\n" + b"BAMTCnVuaXQtdGVzdHMwHhcNMTExMjA2MTYyNjAyWhcNMjExMjAzMTYyNjAyWjAV\n" + b"MRMwEQYDVQQDEwp1bml0LXRlc3RzMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB\n" + b"CgKCAQEA4ej0p7bQ7L/r4rVGUz9RN4VQWoej1Bg1mYWIDYslvKrk1gpj7wZgkdmM\n" + b"7oVK2OfgrSj/FCTkInKPqaCR0gD7K80q+mLBrN3PUkDrJQZpvRZIff3/xmVU1Wer\n" + b"uQLFJjnFb2dqu0s/FY/2kWiJtBCakXvXEOb7zfbINuayL+MSsCGSdVYsSliS5qQp\n" + b"gyDap+8b5fpXZVJkq92hrcNtbkg7hCYUJczt8n9hcCTJCfUpApvaFQ18pe+zpyl4\n" + b"+WzkP66I28hniMQyUlA1hBiskT7qiouq0m8IOodhv2fagSZKjOTTU2xkSBc//fy3\n" + b"ZpsL7WqgsZS7Q+0VRK8gKfqkxg5OYQIDAQABo3YwdDAdBgNVHQ4EFgQU2RQ8yO+O\n" + b"gN8oVW2SW7RLrfYd9jEwRQYDVR0jBD4wPIAU2RQ8yO+OgN8oVW2SW7RLrfYd9jGh\n" + b"GaQXMBUxEzARBgNVBAMTCnVuaXQtdGVzdHOCCQDHyErgUOZvuTAMBgNVHRMEBTAD\n" + b"AQH/MA0GCSqGSIb3DQEBBQUAA4IBAQBRv+M/6+FiVu7KXNjFI5pSN17OcW5QUtPr\n" + b"odJMlWrJBtynn/TA1oJlYu3yV5clc/71Vr/AxuX5xGP+IXL32YDF9lTUJXG/uUGk\n" + b"+JETpKmQviPbRsvzYhz4pf6ZIOZMc3/GIcNq92ECbseGO+yAgyWUVKMmZM0HqXC9\n" + b"ovNslqe0M8C1sLm1zAR5z/h/litE7/8O2ietija3Q/qtl2TOXJdCA6sgjJX2WUql\n" + b"ybrC55ct18NKf3qhpcEkGQvFU40rVYApJpi98DiZPYFdx1oBDp/f4uZ3ojpxRVFT\n" + b"cDwcJLfNRCPUhormsY7fDS9xSyThiHsW9mjJYdcaKQkwYZ0F11yB\n" + b"-----END CERTIFICATE-----\n" +) + ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE = ( "gl-python/3.7 auth/1.1 auth-request-type/at cred-type/mds" ) @@ -670,6 +693,83 @@ def test_get_service_account_token_with_scopes_string( assert expiry == utcnow() + datetime.timedelta(seconds=ttl) +@mock.patch("google.auth._agent_identity_utils.calculate_certificate_fingerprint") +@mock.patch("google.auth._agent_identity_utils.should_request_bound_token") +@mock.patch( + "google.auth._agent_identity_utils.get_and_parse_agent_identity_certificate" +) +@mock.patch( + "google.auth.metrics.token_request_access_token_mds", + return_value=ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE, +) +@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) +def test_get_service_account_token_with_bound_token( + utcnow, + mock_metrics_header_value, + mock_get_and_parse, + mock_should_request, + mock_calculate_fingerprint, +): + # Test the successful path where a certificate is found and a bound token + # is requested. + mock_cert = mock.sentinel.cert + mock_get_and_parse.return_value = mock_cert + mock_should_request.return_value = True + mock_calculate_fingerprint.return_value = "fake_fingerprint" + + token_response = json.dumps({"access_token": "token", "expires_in": 3600}) + request = make_request(token_response, headers={"content-type": "application/json"}) + + _metadata.get_service_account_token(request) + + mock_get_and_parse.assert_called_once() + mock_should_request.assert_called_once_with(mock_cert) + mock_calculate_fingerprint.assert_called_once_with(mock_cert) + + request.assert_called_once() + _, kwargs = request.call_args + url = kwargs["url"] + assert "bindCertificateFingerprint=fake_fingerprint" in url + + +@mock.patch( + "google.auth._agent_identity_utils.get_and_parse_agent_identity_certificate" +) +def test_get_service_account_token_no_cert(mock_get_and_parse): + # Test that no fingerprint is added when no certificate is found. + mock_get_and_parse.return_value = None + token_response = json.dumps({"access_token": "token", "expires_in": 3600}) + request = make_request(token_response, headers={"content-type": "application/json"}) + + _metadata.get_service_account_token(request) + + request.assert_called_once() + _, kwargs = request.call_args + url = kwargs["url"] + assert "bindCertificateFingerprint" not in url + + +@mock.patch("google.auth._agent_identity_utils.should_request_bound_token") +@mock.patch( + "google.auth._agent_identity_utils.get_and_parse_agent_identity_certificate" +) +def test_get_service_account_token_should_not_bind( + mock_get_and_parse, mock_should_request +): + # Test that no fingerprint is added when a cert is found but should not be used. + mock_get_and_parse.return_value = mock.sentinel.cert + mock_should_request.return_value = False + token_response = json.dumps({"access_token": "token", "expires_in": 3600}) + request = make_request(token_response, headers={"content-type": "application/json"}) + + _metadata.get_service_account_token(request) + + request.assert_called_once() + _, kwargs = request.call_args + url = kwargs["url"] + assert "bindCertificateFingerprint" not in url + + def test_get_service_account_info(): key, value = "foo", "bar" request = make_request( diff --git a/tests/compute_engine/test_credentials.py b/tests/compute_engine/test_credentials.py index 1c7706993..9ef671425 100644 --- a/tests/compute_engine/test_credentials.py +++ b/tests/compute_engine/test_credentials.py @@ -658,6 +658,78 @@ def test_build_trust_boundary_lookup_url_no_email( assert excinfo.match(r"missing 'email' field") + @mock.patch("google.auth.compute_engine._metadata.get") + @mock.patch("google.auth._agent_identity_utils.get_agent_identity_certificate_path") + @mock.patch("google.auth._agent_identity_utils.parse_certificate") + @mock.patch( + "google.auth._agent_identity_utils.should_request_bound_token", + return_value=True, + ) + @mock.patch( + "google.auth._agent_identity_utils.calculate_certificate_fingerprint", + return_value="fingerprint", + ) + def test_refresh_with_agent_identity( + self, + mock_calculate_fingerprint, + mock_should_request, + mock_parse_certificate, + mock_get_path, + mock_metadata_get, + tmpdir, + ): + cert_path = tmpdir.join("cert.pem") + cert_path.write(b"cert_content") + mock_get_path.return_value = str(cert_path) + + mock_metadata_get.side_effect = [ + {"email": "service-account@example.com", "scopes": ["one", "two"]}, + {"access_token": "token", "expires_in": 500}, + ] + + self.credentials.refresh(None) + + assert self.credentials.token == "token" + mock_parse_certificate.assert_called_once_with(b"cert_content") + mock_should_request.assert_called_once_with(mock_parse_certificate.return_value) + kwargs = mock_metadata_get.call_args[1] + assert kwargs["params"] == { + "scopes": "one,two", + "bindCertificateFingerprint": "fingerprint", + } + + @mock.patch("google.auth.compute_engine._metadata.get") + @mock.patch("google.auth._agent_identity_utils.get_agent_identity_certificate_path") + @mock.patch("google.auth._agent_identity_utils.parse_certificate") + @mock.patch( + "google.auth._agent_identity_utils.should_request_bound_token", + return_value=False, + ) + def test_refresh_with_agent_identity_opt_out_or_not_agent( + self, + mock_should_request, + mock_parse_certificate, + mock_get_path, + mock_metadata_get, + tmpdir, + ): + cert_path = tmpdir.join("cert.pem") + cert_path.write(b"cert_content") + mock_get_path.return_value = str(cert_path) + + mock_metadata_get.side_effect = [ + {"email": "service-account@example.com", "scopes": ["one", "two"]}, + {"access_token": "token", "expires_in": 500}, + ] + + self.credentials.refresh(None) + + assert self.credentials.token == "token" + mock_parse_certificate.assert_called_once_with(b"cert_content") + mock_should_request.assert_called_once_with(mock_parse_certificate.return_value) + kwargs = mock_metadata_get.call_args[1] + assert "bindCertificateFingerprint" not in kwargs.get("params", {}) + class TestIDTokenCredentials(object): credentials = None diff --git a/tests/test_agent_identity_utils.py b/tests/test_agent_identity_utils.py new file mode 100644 index 000000000..65c6bf144 --- /dev/null +++ b/tests/test_agent_identity_utils.py @@ -0,0 +1,288 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import hashlib +import json + +import urllib.parse + +from cryptography import x509 +import mock +import pytest + +from google.auth import _agent_identity_utils +from google.auth import environment_vars +from google.auth import exceptions + +# A mock PEM-encoded certificate without an Agent Identity SPIFFE ID. +NON_AGENT_IDENTITY_CERT_BYTES = ( + b"-----BEGIN CERTIFICATE-----\n" + b"MIIDIzCCAgugAwIBAgIJAMfISuBQ5m+5MA0GCSqGSIb3DQEBBQUAMBUxEzARBgNV\n" + b"BAMTCnVuaXQtdGVzdHMwHhcNMTExMjA2MTYyNjAyWhcNMjExMjAzMTYyNjAyWjAV\n" + b"MRMwEQYDVQQDEwp1bml0LXRlc3RzMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB\n" + b"CgKCAQEA4ej0p7bQ7L/r4rVGUz9RN4VQWoej1Bg1mYWIDYslvKrk1gpj7wZgkdmM\n" + b"7oVK2OfgrSj/FCTkInKPqaCR0gD7K80q+mLBrN3PUkDrJQZpvRZIff3/xmVU1Wer\n" + b"uQLFJjnFb2dqu0s/FY/2kWiJtBCakXvXEOb7zfbINuayL+MSsCGSdVYsSliS5qQp\n" + b"gyDap+8b5fpXZVJkq92hrcNtbkg7hCYUJczt8n9hcCTJCfUpApvaFQ18pe+zpyl4\n" + b"+WzkP66I28hniMQyUlA1hBiskT7qiouq0m8IOodhv2fagSZKjOTTU2xkSBc//fy3\n" + b"ZpsL7WqgsZS7Q+0VRK8gKfqkxg5OYQIDAQABo3YwdDAdBgNVHQ4EFgQU2RQ8yO+O\n" + b"gN8oVW2SW7RLrfYd9jEwRQYDVR0jBD4wPIAU2RQ8yO+OgN8oVW2SW7RLrfYd9jGh\n" + b"GaQXMBUxEzARBgNVBAMTCnVuaXQtdGVzdHOCCQDHyErgUOZvuTAMBgNVHRMEBTAD\n" + b"AQH/MA0GCSqGSIb3DQEBBQUAA4IBAQBRv+M/6+FiVu7KXNjFI5pSN17OcW5QUtPr\n" + b"odJMlWrJBtynn/TA1oJlYu3yV5clc/71Vr/AxuX5xGP+IXL32YDF9lTUJXG/uUGk\n" + b"+JETpKmQviPbRsvzYhz4pf6ZIOZMc3/GIcNq92ECbseGO+yAgyWUVKMmZM0HqXC9\n" + b"ovNslqe0M8C1sLm1zAR5z/h/litE7/8O2ietija3Q/qtl2TOXJdCA6sgjJX2WUql\n" + b"ybrC55ct18NKf3qhpcEkGQvFU40rVYApJpi98DiZPYFdx1oBDp/f4uZ3ojpxRVFT\n" + b"cDwcJLfNRCPUhormsY7fDS9xSyThiHsW9mjJYdcaKQkwYZ0F11yB\n" + b"-----END CERTIFICATE-----\n" +) + + +class TestAgentIdentityUtils: + @mock.patch("cryptography.x509.load_pem_x509_certificate") + def test_parse_certificate(self, mock_load_cert): + result = _agent_identity_utils.parse_certificate(b"cert_bytes") + mock_load_cert.assert_called_once_with(b"cert_bytes") + assert result == mock_load_cert.return_value + + def test__is_agent_identity_certificate_invalid(self): + cert = _agent_identity_utils.parse_certificate(NON_AGENT_IDENTITY_CERT_BYTES) + assert not _agent_identity_utils._is_agent_identity_certificate(cert) + + def test__is_agent_identity_certificate_valid_spiffe(self): + mock_cert = mock.MagicMock() + mock_ext = mock.MagicMock() + mock_san_value = mock.MagicMock() + mock_cert.extensions.get_extension_for_oid.return_value = mock_ext + mock_ext.value = mock_san_value + mock_san_value.get_values_for_type.return_value = [ + "spiffe://agents.global.proj-12345.system.id.goog/workload" + ] + assert _agent_identity_utils._is_agent_identity_certificate(mock_cert) + + def test__is_agent_identity_certificate_non_matching_spiffe(self): + mock_cert = mock.MagicMock() + mock_ext = mock.MagicMock() + mock_san_value = mock.MagicMock() + mock_cert.extensions.get_extension_for_oid.return_value = mock_ext + mock_ext.value = mock_san_value + mock_san_value.get_values_for_type.return_value = [ + "spiffe://other.domain.com/workload" + ] + assert not _agent_identity_utils._is_agent_identity_certificate(mock_cert) + + def test__is_agent_identity_certificate_no_san(self): + mock_cert = mock.MagicMock() + mock_cert.extensions.get_extension_for_oid.side_effect = x509.ExtensionNotFound( + "Test extension not found", None + ) + assert not _agent_identity_utils._is_agent_identity_certificate(mock_cert) + + def test__is_agent_identity_certificate_not_spiffe_uri(self): + mock_cert = mock.MagicMock() + mock_ext = mock.MagicMock() + mock_san_value = mock.MagicMock() + mock_cert.extensions.get_extension_for_oid.return_value = mock_ext + mock_ext.value = mock_san_value + mock_san_value.get_values_for_type.return_value = ["https://example.com"] + assert not _agent_identity_utils._is_agent_identity_certificate(mock_cert) + + def test_calculate_certificate_fingerprint(self): + mock_cert = mock.MagicMock() + mock_cert.public_bytes.return_value = b"der-bytes" + + # Expected: base64 (standard), unpadded, then URL-encoded + base64_fingerprint = base64.b64encode(hashlib.sha256(b"der-bytes").digest()).decode("utf-8") + unpadded_base64_fingerprint = base64_fingerprint.rstrip("=") + expected_fingerprint = urllib.parse.quote(unpadded_base64_fingerprint) + + fingerprint = _agent_identity_utils.calculate_certificate_fingerprint(mock_cert) + + assert fingerprint == expected_fingerprint + + @mock.patch("google.auth._agent_identity_utils._is_agent_identity_certificate") + def test_should_request_bound_token(self, mock_is_agent, monkeypatch): + # Agent cert, default env var (opt-in) + mock_is_agent.return_value = True + monkeypatch.delenv( + environment_vars.GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES, + raising=False, + ) + assert _agent_identity_utils.should_request_bound_token(mock.sentinel.cert) + + # Agent cert, explicit opt-in + monkeypatch.setenv( + environment_vars.GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES, + "true", + ) + assert _agent_identity_utils.should_request_bound_token(mock.sentinel.cert) + + # Agent cert, explicit opt-out + monkeypatch.setenv( + environment_vars.GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES, + "false", + ) + assert not _agent_identity_utils.should_request_bound_token(mock.sentinel.cert) + + # Non-agent cert, opt-in + mock_is_agent.return_value = False + monkeypatch.setenv( + environment_vars.GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES, + "true", + ) + assert not _agent_identity_utils.should_request_bound_token(mock.sentinel.cert) + + def test_get_agent_identity_certificate_path_success(self, tmpdir, monkeypatch): + cert_path = tmpdir.join("cert.pem") + cert_path.write("cert_content") + config_path = tmpdir.join("config.json") + config_path.write( + json.dumps({"cert_configs": {"workload": {"cert_path": str(cert_path)}}}) + ) + monkeypatch.setenv( + environment_vars.GOOGLE_API_CERTIFICATE_CONFIG, str(config_path) + ) + + result = _agent_identity_utils.get_agent_identity_certificate_path() + assert result == str(cert_path) + + @mock.patch("time.sleep") + def test_get_agent_identity_certificate_path_retry( + self, mock_sleep, tmpdir, monkeypatch + ): + config_path = tmpdir.join("config.json") + monkeypatch.setenv( + environment_vars.GOOGLE_API_CERTIFICATE_CONFIG, str(config_path) + ) + + # File doesn't exist initially + with pytest.raises(exceptions.RefreshError): + _agent_identity_utils.get_agent_identity_certificate_path() + + assert mock_sleep.call_count == 100 + + @mock.patch("time.sleep") + def test_get_agent_identity_certificate_path_failure( + self, mock_sleep, tmpdir, monkeypatch + ): + config_path = tmpdir.join("non_existent_config.json") + monkeypatch.setenv( + environment_vars.GOOGLE_API_CERTIFICATE_CONFIG, str(config_path) + ) + + with pytest.raises(exceptions.RefreshError) as excinfo: + _agent_identity_utils.get_agent_identity_certificate_path() + + assert "not found after multiple retries" in str(excinfo.value) + assert ( + environment_vars.GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES + in str(excinfo.value) + ) + assert mock_sleep.call_count == 100 + + @mock.patch("time.sleep") + @mock.patch("os.path.exists") + def test_get_agent_identity_certificate_path_cert_not_found( + self, mock_exists, mock_sleep, tmpdir, monkeypatch + ): + cert_path_str = str(tmpdir.join("cert.pem")) + config_path = tmpdir.join("config.json") + config_path.write( + json.dumps({"cert_configs": {"workload": {"cert_path": cert_path_str}}}) + ) + monkeypatch.setenv( + environment_vars.GOOGLE_API_CERTIFICATE_CONFIG, str(config_path) + ) + + def exists_side_effect(path): + return path == str(config_path) + + mock_exists.side_effect = exists_side_effect + + with pytest.raises(exceptions.RefreshError): + _agent_identity_utils.get_agent_identity_certificate_path() + + assert mock_sleep.call_count == 100 + + @mock.patch("google.auth._agent_identity_utils.get_agent_identity_certificate_path") + def test_get_and_parse_agent_identity_certificate_opted_out( + self, mock_get_path, monkeypatch + ): + monkeypatch.setenv( + environment_vars.GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES, + "false", + ) + result = _agent_identity_utils.get_and_parse_agent_identity_certificate() + assert result is None + mock_get_path.assert_not_called() + + @mock.patch("google.auth._agent_identity_utils.get_agent_identity_certificate_path") + def test_get_and_parse_agent_identity_certificate_no_path( + self, mock_get_path, monkeypatch + ): + monkeypatch.setenv( + environment_vars.GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES, + "true", + ) + mock_get_path.return_value = None + result = _agent_identity_utils.get_and_parse_agent_identity_certificate() + assert result is None + mock_get_path.assert_called_once() + + @mock.patch("google.auth._agent_identity_utils.parse_certificate") + @mock.patch("google.auth._agent_identity_utils.get_agent_identity_certificate_path") + def test_get_and_parse_agent_identity_certificate_success( + self, mock_get_path, mock_parse_certificate, monkeypatch + ): + monkeypatch.setenv( + environment_vars.GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES, + "true", + ) + mock_get_path.return_value = "/fake/cert.pem" + mock_open = mock.mock_open(read_data=b"cert_bytes") + + with mock.patch("builtins.open", mock_open): + result = _agent_identity_utils.get_and_parse_agent_identity_certificate() + + mock_open.assert_called_once_with("/fake/cert.pem", "rb") + mock_parse_certificate.assert_called_once_with(b"cert_bytes") + assert result == mock_parse_certificate.return_value + + +class TestAgentIdentityUtilsNoCryptography: + @pytest.fixture(autouse=True) + def mock_cryptography_import(self): + with mock.patch.dict( + "sys.modules", + { + "cryptography": None, + "cryptography.hazmat": None, + "cryptography.hazmat.primitives": None, + "cryptography.hazmat.primitives.serialization": None, + }, + ): + yield + + def test_parse_certificate_raises_import_error(self): + with pytest.raises(ImportError, match="The cryptography library is required"): + _agent_identity_utils.parse_certificate(b"cert_bytes") + + def test_is_agent_identity_certificate_raises_import_error(self): + with pytest.raises(ImportError, match="The cryptography library is required"): + _agent_identity_utils._is_agent_identity_certificate(mock.sentinel.cert) + + def test_calculate_certificate_fingerprint_raises_import_error(self): + with pytest.raises(ImportError, match="The cryptography library is required"): + _agent_identity_utils.calculate_certificate_fingerprint(mock.sentinel.cert) diff --git a/tests/test_external_account.py b/tests/test_external_account.py index 2fa64361d..a56b54a43 100644 --- a/tests/test_external_account.py +++ b/tests/test_external_account.py @@ -737,6 +737,24 @@ def test_refresh_skips_trust_boundary_lookup_when_disabled( credentials.apply(headers_applied) assert "x-allowed-locations" not in headers_applied + def test_refresh_token_with_cert_fingerprint(self): + credentials = self.make_credentials() + credentials._sts_client = mock.MagicMock() + credentials._sts_client.exchange_token.return_value = { + "access_token": "token", + "expires_in": 3600, + } + credentials.retrieve_subject_token = mock.MagicMock( + return_value="subject_token" + ) + + credentials._refresh_token( + request=mock.sentinel.request, cert_fingerprint="my-fingerprint" + ) + + _, kwargs = credentials._sts_client.exchange_token.call_args + assert kwargs["additional_options"]["bindCertFingerprint"] == "my-fingerprint" + def test_refresh_skips_sending_allowed_locations_header_with_trust_boundary(self): # This test verifies that the x-allowed-locations header is not sent with # the STS request even if a trust boundary is cached. diff --git a/tests/test_identity_pool.py b/tests/test_identity_pool.py index dbbdbf53a..529d83d65 100644 --- a/tests/test_identity_pool.py +++ b/tests/test_identity_pool.py @@ -1772,3 +1772,59 @@ def test_get_mtls_certs_invalid(self): assert excinfo.match( 'The credential is not configured to use mtls requests. The credential should include a "certificate" section in the credential source.' ) + + @mock.patch("google.auth._agent_identity_utils.parse_certificate") + @mock.patch( + "google.auth._agent_identity_utils.should_request_bound_token", + return_value=True, + ) + @mock.patch( + "google.auth._agent_identity_utils.calculate_certificate_fingerprint", + return_value="fingerprint", + ) + @mock.patch.object( + identity_pool.Credentials, "_get_cert_bytes", return_value=b"cert" + ) + @mock.patch.object(external_account.Credentials, "_refresh_token") + def test_refresh_with_agent_identity( + self, + mock_refresh_token, + mock_get_cert_bytes, + mock_calculate_fingerprint, + mock_should_request, + mock_parse_certificate, + ): + mock_parse_certificate.return_value = mock.sentinel.cert + credentials = self.make_credentials( + credential_source=self.CREDENTIAL_SOURCE_CERTIFICATE.copy() + ) + credentials.refresh(None) + mock_parse_certificate.assert_called_once_with(b"cert") + mock_should_request.assert_called_once_with(mock.sentinel.cert) + mock_calculate_fingerprint.assert_called_once_with(mock.sentinel.cert) + mock_refresh_token.assert_called_once_with(None, cert_fingerprint="fingerprint") + + @mock.patch("google.auth._agent_identity_utils.parse_certificate") + @mock.patch( + "google.auth._agent_identity_utils.should_request_bound_token", + return_value=False, + ) + @mock.patch.object( + identity_pool.Credentials, "_get_cert_bytes", return_value=b"cert" + ) + @mock.patch.object(external_account.Credentials, "_refresh_token") + def test_refresh_with_agent_identity_opt_out_or_not_agent( + self, + mock_refresh_token, + mock_get_cert_bytes, + mock_should_request, + mock_parse_certificate, + ): + mock_parse_certificate.return_value = mock.sentinel.cert + credentials = self.make_credentials( + credential_source=self.CREDENTIAL_SOURCE_CERTIFICATE.copy() + ) + credentials.refresh(None) + mock_parse_certificate.assert_called_once_with(b"cert") + mock_should_request.assert_called_once_with(mock.sentinel.cert) + mock_refresh_token.assert_called_once_with(None, cert_fingerprint=None) From 8a9cb7bfbf8de7d3f26758f82bdaf89cd23a303c Mon Sep 17 00:00:00 2001 From: agrawalradhika-cell Date: Mon, 15 Dec 2025 23:25:01 +0530 Subject: [PATCH 2/2] feat: Add retry logic when certificate mismatch for existing credentials & Agent Identity workloads (#1841) feat: Add retry logic when certificate mismatch for existing credentials & Agent Identity workloads This change introduces retry support when requests are created for existing credentials and Agent Identities on GKE and Cloud Run Workloads. When 401(Unauthorized) error is created, due to certificate at time of configuration of mTLS channel being different from the current certificate, a retry is added to the request by configuring the mTLS channel with the current certificate. --------- Signed-off-by: Radhika Agrawal Co-authored-by: nbayati <99771966+nbayati@users.noreply.github.com> Co-authored-by: Andy Zhao --- google/auth/_agent_identity_utils.py | 21 ++- google/auth/transport/_mtls_helper.py | 27 ++++ google/auth/transport/requests.py | 33 +++- google/auth/transport/urllib3.py | 40 +++++ noxfile.py | 3 +- tests/test_agent_identity_utils.py | 53 +++++- tests/transport/test__mtls_helper.py | 64 +++++++- tests/transport/test_requests.py | 224 ++++++++++++++++++++++++++ tests/transport/test_urllib3.py | 208 ++++++++++++++++++++++++ 9 files changed, 667 insertions(+), 6 deletions(-) diff --git a/google/auth/_agent_identity_utils.py b/google/auth/_agent_identity_utils.py index 665f2aa50..54fcc60fd 100644 --- a/google/auth/_agent_identity_utils.py +++ b/google/auth/_agent_identity_utils.py @@ -20,10 +20,11 @@ import os import re import time -from urllib.parse import urlparse, quote +from urllib.parse import quote, urlparse from google.auth import environment_vars from google.auth import exceptions +from google.auth.transport import _mtls_helper _LOGGER = logging.getLogger(__name__) @@ -260,3 +261,21 @@ def should_request_bound_token(cert): == "true" ) return is_agent_cert and is_opted_in + + +def call_client_cert_callback(): + """Calls the client cert callback and returns the certificate and key.""" + _, cert_bytes, key_bytes, passphrase = _mtls_helper.get_client_ssl_credentials( + generate_encrypted_key=True + ) + return cert_bytes, key_bytes + + +def get_cached_cert_fingerprint(cached_cert): + """Returns the fingerprint of the cached certificate.""" + if cached_cert: + cert_obj = parse_certificate(cached_cert) + cached_cert_fingerprint = calculate_certificate_fingerprint(cert_obj) + else: + raise ValueError("mTLS connection is not configured.") + return cached_cert_fingerprint diff --git a/google/auth/transport/_mtls_helper.py b/google/auth/transport/_mtls_helper.py index 497613a4f..5b56b60b1 100644 --- a/google/auth/transport/_mtls_helper.py +++ b/google/auth/transport/_mtls_helper.py @@ -20,6 +20,7 @@ import re import subprocess +from google.auth import _agent_identity_utils from google.auth import environment_vars from google.auth import exceptions @@ -474,3 +475,29 @@ def check_use_client_cert(): ) as e: _LOGGER.debug("error decoding certificate: %s", e) return False + + +def check_parameters_for_unauthorized_response(cached_cert): + """Returns the cached and current cert fingerprint for reconfiguring mTLS. + + Args: + cached_cert(bytes): The cached client certificate. + + Returns: + bytes: The client callback cert bytes. + bytes: The client callback key bytes. + str: The base64-encoded SHA256 cached fingerprint. + str: The base64-encoded SHA256 current cert fingerprint. + """ + call_cert_bytes, call_key_bytes = _agent_identity_utils.call_client_cert_callback() + cert_obj = _agent_identity_utils.parse_certificate(call_cert_bytes) + current_cert_fingerprint = _agent_identity_utils.calculate_certificate_fingerprint( + cert_obj + ) + if cached_cert: + cached_fingerprint = _agent_identity_utils.get_cached_cert_fingerprint( + cached_cert + ) + else: + cached_fingerprint = current_cert_fingerprint + return call_cert_bytes, call_key_bytes, cached_fingerprint, current_cert_fingerprint diff --git a/google/auth/transport/requests.py b/google/auth/transport/requests.py index d1ff8f368..b750e2b3d 100644 --- a/google/auth/transport/requests.py +++ b/google/auth/transport/requests.py @@ -17,6 +17,7 @@ from __future__ import absolute_import import functools +import http.client as http_client import logging import numbers import time @@ -36,6 +37,7 @@ from google.auth import _helpers from google.auth import exceptions from google.auth import transport +from google.auth.transport import _mtls_helper import google.auth.transport._mtls_helper from google.oauth2 import service_account @@ -463,6 +465,7 @@ def configure_mtls_channel(self, client_cert_callback=None): if self._is_mtls: mtls_adapter = _MutualTlsAdapter(cert, key) + self._cached_cert = cert self.mount("https://", mtls_adapter) except ( exceptions.ClientCertError, @@ -502,6 +505,10 @@ def request( itself does not timeout, e.g. if a large file is being transmitted. The timout error will be raised after such request completes. + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS + channel creation fails for any reason. + ValueError: If the client certificate is invalid. """ # pylint: disable=arguments-differ # Requests has a ton of arguments to request, but only two @@ -551,7 +558,31 @@ def request( response.status_code in self._refresh_status_codes and _credential_refresh_attempt < self._max_refresh_attempts ): - + # Handle unauthorized permission error(401 status code) + if response.status_code == http_client.UNAUTHORIZED: + if self.is_mtls: + call_cert_bytes, call_key_bytes, cached_fingerprint, current_cert_fingerprint = _mtls_helper.check_parameters_for_unauthorized_response( + self._cached_cert + ) + if cached_fingerprint != current_cert_fingerprint: + try: + _LOGGER.info( + "Client certificate has changed, reconfiguring mTLS " + "channel." + ) + self.configure_mtls_channel( + lambda: (call_cert_bytes, call_key_bytes) + ) + except Exception as e: + _LOGGER.error("Failed to reconfigure mTLS channel: %s", e) + raise exceptions.MutualTLSChannelError( + "Failed to reconfigure mTLS channel" + ) from e + else: + _LOGGER.info( + "Skipping reconfiguration of mTLS channel because the client" + " certificate has not changed." + ) _LOGGER.info( "Refreshing credentials due to a %s response. Attempt %s/%s.", response.status_code, diff --git a/google/auth/transport/urllib3.py b/google/auth/transport/urllib3.py index 353cb8e08..3c16831f2 100644 --- a/google/auth/transport/urllib3.py +++ b/google/auth/transport/urllib3.py @@ -16,6 +16,7 @@ from __future__ import absolute_import +import http.client as http_client import logging import warnings @@ -52,6 +53,7 @@ from google.auth import _helpers from google.auth import exceptions from google.auth import transport +from google.auth.transport import _mtls_helper from google.oauth2 import service_account if version.parse(urllib3.__version__) >= version.parse("2.0.0"): # pragma: NO COVER @@ -299,6 +301,7 @@ def __init__( # Request instance used by internal methods (for example, # credentials.refresh). self._request = Request(self.http) + self._is_mtls = False # https://google.aip.dev/auth/4111 # Attempt to use self-signed JWTs when a service account is used. @@ -335,7 +338,10 @@ def configure_mtls_channel(self, client_cert_callback=None): """ use_client_cert = transport._mtls_helper.check_use_client_cert() if not use_client_cert: + self._is_mtls = False return False + else: + self._is_mtls = True try: import OpenSSL except ImportError as caught_exc: @@ -349,6 +355,7 @@ def configure_mtls_channel(self, client_cert_callback=None): if found_cert_key: self.http = _make_mutual_tls_http(cert, key) + self._cached_cert = cert else: self.http = _make_default_http() except ( @@ -381,6 +388,11 @@ def urlopen(self, method, url, body=None, headers=None, **kwargs): if headers is None: headers = self.headers + use_mtls = False + if self._is_mtls: + MTLS_URL_PREFIXES = ["mtls.googleapis.com", "mtls.sandbox.googleapis.com"] + use_mtls = any([prefix in url for prefix in MTLS_URL_PREFIXES]) + # Make a copy of the headers. They will be modified by the credentials # and we want to pass the original headers if we recurse. request_headers = headers.copy() @@ -402,6 +414,34 @@ def urlopen(self, method, url, body=None, headers=None, **kwargs): response.status in self._refresh_status_codes and _credential_refresh_attempt < self._max_refresh_attempts ): + if response.status == http_client.UNAUTHORIZED: + if use_mtls: + call_cert_bytes, call_key_bytes, cached_fingerprint, current_cert_fingerprint = _mtls_helper.check_parameters_for_unauthorized_response( + self._cached_cert + ) + if cached_fingerprint != current_cert_fingerprint: + try: + _LOGGER.info( + "Client certificate has changed, reconfiguring mTLS " + "channel." + ) + self.configure_mtls_channel( + client_cert_callback=lambda: ( + call_cert_bytes, + call_key_bytes, + ) + ) + except Exception as e: + _LOGGER.error("Failed to reconfigure mTLS channel: %s", e) + raise exceptions.MutualTLSChannelError( + "Failed to reconfigure mTLS channel" + ) from e + + else: + _LOGGER.info( + "Skipping reconfiguration of mTLS channel because the " + "client certificate has not changed." + ) _LOGGER.info( "Refreshing credentials due to a %s response. Attempt %s/%s.", diff --git a/noxfile.py b/noxfile.py index 3a2c0e883..dbe8b6bf9 100644 --- a/noxfile.py +++ b/noxfile.py @@ -105,7 +105,7 @@ def mypy(session): "types-requests", "types-setuptools", "types-mock", - "pytest", + "pytest<8.0.0", ) session.run("mypy", "-p", "google", "-p", "tests", "-p", "tests_async") @@ -130,6 +130,7 @@ def unit(session): @nox.session(python=DEFAULT_PYTHON_VERSION) def cover(session): + session.env["PIP_EXTRA_INDEX_URL"] = "https://pypi.org/simple" session.install("-e", ".[testing]") session.run( "pytest", diff --git a/tests/test_agent_identity_utils.py b/tests/test_agent_identity_utils.py index 65c6bf144..86a63e82e 100644 --- a/tests/test_agent_identity_utils.py +++ b/tests/test_agent_identity_utils.py @@ -15,7 +15,6 @@ import base64 import hashlib import json - import urllib.parse from cryptography import x509 @@ -104,7 +103,9 @@ def test_calculate_certificate_fingerprint(self): mock_cert.public_bytes.return_value = b"der-bytes" # Expected: base64 (standard), unpadded, then URL-encoded - base64_fingerprint = base64.b64encode(hashlib.sha256(b"der-bytes").digest()).decode("utf-8") + base64_fingerprint = base64.b64encode( + hashlib.sha256(b"der-bytes").digest() + ).decode("utf-8") unpadded_base64_fingerprint = base64_fingerprint.rstrip("=") expected_fingerprint = urllib.parse.quote(unpadded_base64_fingerprint) @@ -260,6 +261,54 @@ def test_get_and_parse_agent_identity_certificate_success( mock_parse_certificate.assert_called_once_with(b"cert_bytes") assert result == mock_parse_certificate.return_value + @mock.patch("time.sleep", return_value=None) + @mock.patch("google.auth._agent_identity_utils._is_certificate_file_ready") + def test_get_agent_identity_certificate_path_fallback_to_well_known_path( + self, mock_is_ready, mock_sleep, monkeypatch + ): + # Set a dummy config path that won't be found. + monkeypatch.setenv( + environment_vars.GOOGLE_API_CERTIFICATE_CONFIG, "/dummy/config.json" + ) + + # First, the primary path from the (mocked) config is not ready. + # Then, the fallback well-known path is ready. + mock_is_ready.side_effect = [False, True] + + result = _agent_identity_utils.get_agent_identity_certificate_path() + + assert result == _agent_identity_utils._WELL_KNOWN_CERT_PATH + # The sleep should have been called once before the fallback is checked. + mock_sleep.assert_called_once() + assert mock_is_ready.call_count == 2 + + @mock.patch("google.auth.transport._mtls_helper.get_client_ssl_credentials") + def test_call_client_cert_callback(self, mock_get_client_ssl_credentials): + mock_get_client_ssl_credentials.return_value = ( + True, + b"cert_bytes", + b"key_bytes", + b"passphrase", + ) + + cert, key = _agent_identity_utils.call_client_cert_callback() + + assert cert == b"cert_bytes" + assert key == b"key_bytes" + mock_get_client_ssl_credentials.assert_called_once_with( + generate_encrypted_key=True + ) + + def test_get_cached_cert_fingerprint_no_cert(self): + with pytest.raises(ValueError, match="mTLS connection is not configured."): + _agent_identity_utils.get_cached_cert_fingerprint(None) + + def test_get_cached_cert_fingerprint_with_cert(self): + fingerprint = _agent_identity_utils.get_cached_cert_fingerprint( + NON_AGENT_IDENTITY_CERT_BYTES + ) + assert isinstance(fingerprint, str) + class TestAgentIdentityUtilsNoCryptography: @pytest.fixture(autouse=True) diff --git a/tests/transport/test__mtls_helper.py b/tests/transport/test__mtls_helper.py index 2a7a524b1..7b49215cc 100644 --- a/tests/transport/test__mtls_helper.py +++ b/tests/transport/test__mtls_helper.py @@ -23,8 +23,9 @@ from google.auth import exceptions from google.auth.transport import _mtls_helper +CERT_MOCK_VAL = b"cert" +KEY_MOCK_VAL = b"key" CONTEXT_AWARE_METADATA = {"cert_provider_command": ["some command"]} - ENCRYPTED_EC_PRIVATE_KEY = b"""-----BEGIN ENCRYPTED PRIVATE KEY----- MIHkME8GCSqGSIb3DQEFDTBCMCkGCSqGSIb3DQEFDDAcBAgl2/yVgs1h3QICCAAw DAYIKoZIhvcNAgkFADAVBgkrBgEEAZdVAQIECJk2GRrvxOaJBIGQXIBnMU4wmciT @@ -813,3 +814,64 @@ def test_check_use_client_cert_when_file_does_not_exist(self, monkeypatch): monkeypatch.setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "") use_client_cert = _mtls_helper.check_use_client_cert() assert use_client_cert is False + + +class TestMtlsHelper: + @mock.patch("google.auth.transport._mtls_helper._agent_identity_utils") + def test_check_parameters_for_unauthorized_response_with_cached_cert( + self, mock_agent_identity_utils + ): + mock_agent_identity_utils.call_client_cert_callback.return_value = ( + CERT_MOCK_VAL, + KEY_MOCK_VAL, + ) + mock_agent_identity_utils.get_cached_cert_fingerprint.return_value = ( + "cached_fingerprint" + ) + mock_agent_identity_utils.calculate_certificate_fingerprint.return_value = ( + "current_fingerprint" + ) + + ( + cert, + key, + cached_fingerprint, + current_fingerprint, + ) = _mtls_helper.check_parameters_for_unauthorized_response( + cached_cert=b"cached_cert_bytes" + ) + + assert cert == CERT_MOCK_VAL + assert key == KEY_MOCK_VAL + assert cached_fingerprint == "cached_fingerprint" + assert current_fingerprint == "current_fingerprint" + mock_agent_identity_utils.call_client_cert_callback.assert_called_once() + mock_agent_identity_utils.get_cached_cert_fingerprint.assert_called_once_with( + b"cached_cert_bytes" + ) + + @mock.patch("google.auth.transport._mtls_helper._agent_identity_utils") + def test_check_parameters_for_unauthorized_response_without_cached_cert( + self, mock_agent_identity_utils + ): + mock_agent_identity_utils.call_client_cert_callback.return_value = ( + CERT_MOCK_VAL, + KEY_MOCK_VAL, + ) + mock_agent_identity_utils.calculate_certificate_fingerprint.return_value = ( + "current_fingerprint" + ) + + ( + cert, + key, + cached_fingerprint, + current_fingerprint, + ) = _mtls_helper.check_parameters_for_unauthorized_response(cached_cert=None) + + assert cert == CERT_MOCK_VAL + assert key == KEY_MOCK_VAL + assert cached_fingerprint == "current_fingerprint" + assert current_fingerprint == "current_fingerprint" + mock_agent_identity_utils.call_client_cert_callback.assert_called_once() + mock_agent_identity_utils.get_cached_cert_fingerprint.assert_not_called() diff --git a/tests/transport/test_requests.py b/tests/transport/test_requests.py index 0da3e36d9..ccc937527 100644 --- a/tests/transport/test_requests.py +++ b/tests/transport/test_requests.py @@ -41,6 +41,10 @@ def frozen_time(): yield frozen +CERT_MOCK_VAL = b"-----BEGIN CERTIFICATE-----\nMIIDIzCCAgugAwIBAgIJAMfISuBQ5m+5MA0GCSqGSIb3DQEBBQUAMBUxEzARBgNV\nBAMTCnVuaXQtdGVzdHMwHhcNMTExMjA2MTYyNjAyWhcNMjExMjAzMTYyNjAyWjAV\nMRMwEQYDVQQDEwp1bml0LXRlc3RzMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB\nCgKCAQEA4ej0p7bQ7L/r4rVGUz9RN4VQWoej1Bg1mYWIDYslvKrk1gpj7wZgkdmM\n7oVK2OfgrSj/FCTkInKPqaCR0gD7K80q+mLBrN3PUkDrJQZpvRZIff3/xmVU1Wer\nuQLFJjnFb2dqu0s/FY/2kWiJtBCakXvXEOb7zfbINuayL+MSsCGSdVYsSliS5qQp\ngyDap+8b5fpXZVJkq92hrcNtbkg7hCYUJczt8n9hcCTJCfUpApvaFQ18pe+zpyl4\n+WzkP66I28hniMQyUlA1hBiskT7qiouq0m8IOodhv2fagSZKjOTTU2xkSBc//fy3\nZpsL7WqgsZS7Q+0VRK8gKfqkxg5OYQIDAQABo3YwdDAdBgNVHQ4EFgQU2RQ8yO+O\ngN8oVW2SW7RLrfYd9jEwRQYDVR0jBD4wPIAU2RQ8yO+OgN8oVW2SW7RLrfYd9jGh\nGaQXMBUxEzARBgNVBAMTCnVuaXQtdGVzdHOCCQDHyErgUOZvuTAMBgNVHRMEBTAD\nAQH/MA0GCSqGSIb3DQEBBQUAA4IBAQBRv+M/6+FiVu7KXNjFI5pSN17OcW5QUtPr\nodJMlWrJBtynn/TA1oJlYu3yV5clc/71Vr/AxuX5xGP+IXL32YDF9lTUJXG/uUGk\n+JETpKmQviPbRsvzYhz4pf6ZIOZMc3/GIcNq92ECbseGO+yAgyWUVKMmZM0HqXC9\novNslqe0M8C1sLm1zAR5z/h/litE7/8O2ietija3Q/qtl2TOXJdCA6sgjJX2WUql\nybrC55ct18NKf3qhpcEkGQvFU40rVYApJpi98DiZPYFdx1oBDp/f4uZ3ojpxRVFT\ncDwcJLfNRCPUhormsY7fDS9xSyThiHsW9mjJYdcaKQkwYZ0F11yB\n-----END CERTIFICATE-----\n" +KEY_MOCK_VAL = b"-----BEGIN ENCRYPTED PRIVATE KEY-----\nMIHeMEkGCSqGSIb3DQEFDTA8MBsGCSqGSIb3DQEFDDAOBAj9XnJ2h78QVAICCAAw\nHQYJYIZIAWUDBAECBBBeiiOF2LnLzq/wjb/viwMwBIGQk28Zkfj2EIk42bgc7UzC\nSf98qssCVhsIYz0Xa3eSATg8Cpn83YieaBeyxdk/tXTnrOhxMV/vt7T98kWhaGbH\n5Z9CdGVLfes0UFvVJqrlk6vcf2sOnLCGbrn78HS+ayrGOCRSCd/7+dnEiB/7Um1B\nMk6BBJHsLEnZZSHyfrw8jvYgVmcSBy/WdY0pqldD/+4D\n-----END ENCRYPTED PRIVATE KEY-----\n" + + class TestRequestResponse(compliance.RequestResponseTests): def make_request(self): return google.auth.transport.requests.Request() @@ -537,6 +541,226 @@ def test_close_w_passed_in_auth_request(self): authed_session.close() # no raise + def test_cert_rotation_when_cert_mismatch_and_mtls_enabled(self): + credentials = mock.Mock(wraps=CredentialsStub()) + final_response = make_response(status=http_client.OK) + # First request will 401, second request will succeed. + adapter = AdapterStub( + [make_response(status=http_client.UNAUTHORIZED), final_response] + ) + + authed_session = google.auth.transport.requests.AuthorizedSession( + credentials, refresh_timeout=60 + ) + authed_session.mount(self.TEST_URL, adapter) + + old_cert = b"-----BEGIN CERTIFICATE-----\nMIIBdTCCARqgAwIBAgIJAOYVvu/axMxvMAoGCCqGSM49BAMCMCcxJTAjBgNVBAMM\nHEdvb2dsZSBFbmRwb2ludCBWZXJpZmljYXRpb24wHhcNMjUwNzMwMjMwNjA4WhcN\nMjYwNzMxMjMwNjA4WjAnMSUwIwYDVQQDDBxHb29nbGUgRW5kcG9pbnQgVmVyaWZp\nY2F0aW9uMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEbtr18gkEtwPow2oqyZsU\n4KLwFaLFlRlYv55UATS3QTDykDnIufC42TJCnqFRYhwicwpE2jnUV+l9g3Voias8\nraMvMC0wCQYDVR0TBAIwADALBgNVHQ8EBAMCB4AwEwYDVR0lBAwwCgYIKwYBBQUH\nAwIwCgYIKoZIzj0EAwIDSQAwRgIhAKcjW6dmF1YCksXPgDPlPu/nSnOjb3qCcivz\n/Jxq2zoeAiEA7/aNxcEoCGS3hwMIXoaaD/vPcZOOopKSyqXCvxRooKQ=\n-----END CERTIFICATE-----\n" + + # New certificate and key to simulate rotation. + new_cert = CERT_MOCK_VAL + new_key = KEY_MOCK_VAL + + # Set _cached_cert to a callable that returns the old certificate. + authed_session._cached_cert = old_cert + authed_session._is_mtls = True + + # Mock call_client_cert_callback to return the new certificate. + with mock.patch.object( + google.auth.transport._mtls_helper._agent_identity_utils, + "call_client_cert_callback", + return_value=(new_cert, new_key), + ) as mock_callback: + result = authed_session.request("GET", self.TEST_URL) + + # Asserts to verify the behavior. + assert mock_callback.called + assert credentials.refresh.called + assert credentials.refresh.call_count == 1 + assert result.status_code == final_response.status_code + + def test_no_cert_rotation_when_cert_match_and_mTLS_enabled(self): + credentials = mock.Mock(wraps=CredentialsStub()) + final_response = make_response(status=http_client.UNAUTHORIZED) + adapter = AdapterStub( + [ + make_response(status=http_client.UNAUTHORIZED), + make_response(status=http_client.UNAUTHORIZED), + make_response(status=http_client.UNAUTHORIZED), + ] + ) + authed_session = google.auth.transport.requests.AuthorizedSession( + credentials, refresh_timeout=60 + ) + authed_session.mount(self.TEST_URL, adapter) + authed_session._is_mtls = True + + old_cert = CERT_MOCK_VAL + + # New certificate and key to simulate rotation. + new_cert = old_cert + new_key = KEY_MOCK_VAL + + # Set _cached_cert to a callable that returns the old certificate. + authed_session._cached_cert = old_cert + + # Mock call_client_cert_callback to return the new certificate. + with mock.patch.object( + google.auth.transport._mtls_helper._agent_identity_utils, + "call_client_cert_callback", + return_value=(new_cert, new_key), + ): + result = authed_session.request("GET", self.TEST_URL) + + # Asserts to verify the behavior. + assert credentials.refresh.call_count == 2 + assert result.status_code == final_response.status_code + + def test_no_cert_match_check_when_mtls_disabled(self): + credentials = mock.Mock(wraps=CredentialsStub()) + final_response = make_response(status=http_client.UNAUTHORIZED) + adapter = AdapterStub( + [ + make_response(status=http_client.UNAUTHORIZED), + make_response(status=http_client.UNAUTHORIZED), + make_response(status=http_client.UNAUTHORIZED), + ] + ) + authed_session = google.auth.transport.requests.AuthorizedSession( + credentials, refresh_timeout=60 + ) + authed_session.mount(self.TEST_URL, adapter) + authed_session._is_mtls = False + + new_cert = CERT_MOCK_VAL + + # New certificate and key to simulate rotation. + new_key = KEY_MOCK_VAL + + # Mock call_client_cert_callback to return the new certificate. + with mock.patch.object( + google.auth.transport._mtls_helper._agent_identity_utils, + "call_client_cert_callback", + return_value=(new_cert, new_key), + ) as mock_callback: + result = authed_session.request("GET", self.TEST_URL) + + # Asserts to verify the behavior. + assert not mock_callback.called + assert result.status_code == final_response.status_code + + def test_no_cert_rotation_when_no_unauthorized_response(self): + credentials = mock.Mock(wraps=CredentialsStub()) + final_response = make_response(status=http_client.UPGRADE_REQUIRED) + + # Response is set to code other than 401(Unauthorized). + adapter = AdapterStub([make_response(status=http_client.UPGRADE_REQUIRED)]) + authed_session = google.auth.transport.requests.AuthorizedSession( + credentials, refresh_timeout=60 + ) + authed_session.mount(self.TEST_URL, adapter) + + authed_session._is_mtls = True + + result = authed_session.request("GET", self.TEST_URL) + assert result.status_code == final_response.status_code + + # Asserts to verify the behavior. + assert not credentials.refresh.called + assert credentials.refresh.call_count == 0 + + def test_cert_rotation_failure_raises_error(self): + credentials = mock.Mock(wraps=CredentialsStub()) + # First request will 401, second request will fail to reconfigure mTLS. + adapter = AdapterStub([make_response(status=http_client.UNAUTHORIZED)]) + + authed_session = google.auth.transport.requests.AuthorizedSession( + credentials, refresh_timeout=60 + ) + authed_session.mount(self.TEST_URL, adapter) + + old_cert = b"-----BEGIN CERTIFICATE-----\nMIIBdTCCARqgAwIBAgIJAOYVvu/axMxvMAoGCCqGSM49BAMCMCcxJTAjBgNVBAMM\nHEdvb2dsZSBFbmRwb2ludCBWZXJpZmljYXRpb24wHhcNMjUwNzMwMjMwNjA4WhcN\nMjYwNzMxMjMwNjA4WjAnMSUwIwYDVQQDDBxHb29nbGUgRW5kcG9pbnQgVmVyaWZp\nY2F0aW9uMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEbtr18gkEtwPow2oqyZsU\n4KLwFaLFlRlYv55UATS3QTDykDnIufC42TJCnqFRYhwicwpE2jnUV+l9g3Voias8\nraMvMC0wCQYDVR0TBAIwADALBgNVHQ8EBAMCB4AwEwYDVR0lBAwwCgYIKwYBBQUH\nAwIwCgYIKoZIzj0EAwIDSQAwRgIhAKcjW6dmF1YCksXPgDPlPu/nSnOjb3qCcivz\n/Jxq2zoeAiEA7/aNxcEoCGS3hwMIXoaaD/vPcZOOopKSyqXCvxRooKQ=\n-----END CERTIFICATE-----\n" + + # New certificate and key to simulate rotation. + new_cert = CERT_MOCK_VAL + new_key = KEY_MOCK_VAL + + authed_session._cached_cert = old_cert + authed_session._is_mtls = True + + with mock.patch.object( + google.auth.transport._mtls_helper._agent_identity_utils, + "call_client_cert_callback", + return_value=(new_cert, new_key), + ): + with mock.patch.object( + authed_session, + "configure_mtls_channel", + side_effect=Exception("Failed to reconfigure"), + ): + with pytest.raises(exceptions.MutualTLSChannelError): + authed_session.request("GET", self.TEST_URL) + + # Assert to verify behavior + credentials.refresh.assert_not_called() + + def test_cert_rotation_check_params_fails(self): + credentials = mock.Mock(wraps=CredentialsStub()) + adapter = AdapterStub([make_response(status=http_client.UNAUTHORIZED)]) + + authed_session = google.auth.transport.requests.AuthorizedSession( + credentials, refresh_timeout=60 + ) + authed_session.mount(self.TEST_URL, adapter) + authed_session._is_mtls = True + authed_session._cached_cert = b"cached_cert" + + with mock.patch( + "google.auth.transport.requests._mtls_helper.check_parameters_for_unauthorized_response", + side_effect=Exception("check_params failed"), + ) as mock_check_params: + with pytest.raises(Exception, match="check_params failed"): + authed_session.request("GET", self.TEST_URL) + + mock_check_params.assert_called_once() + credentials.refresh.assert_not_called() + + def test_cert_rotation_logic_skipped_on_other_refresh_status_codes(self): + """ + Tests that the code can handle a refresh triggered by a status code + other than 401 (UNAUTHORIZED). This covers the 'else' branch of the + 'if response.status_code == http_client.UNAUTHORIZED' check + """ + credentials = mock.Mock(wraps=CredentialsStub()) + # Configure the session to treat 503 (Service Unavailable) as a refreshable error + custom_refresh_codes = [http_client.SERVICE_UNAVAILABLE] + + # Return 503 first, then 200 + adapter = AdapterStub( + [ + make_response(status=http_client.SERVICE_UNAVAILABLE), + make_response(status=http_client.OK), + ] + ) + + authed_session = google.auth.transport.requests.AuthorizedSession( + credentials, refresh_status_codes=custom_refresh_codes + ) + authed_session.mount(self.TEST_URL, adapter) + + # Enable mTLS to prove it is skipped despite being enabled + authed_session._is_mtls = True + + with mock.patch( + "google.auth.transport.requests._mtls_helper", autospec=True + ) as mock_helper: + authed_session.request("GET", self.TEST_URL) + + # Assert refresh happened (Outer Check was True) + assert credentials.refresh.called + + # Assert mTLS check logic was SKIPPED (Inner Check was False) + assert not mock_helper.check_parameters_for_unauthorized_response.called + class TestMutualTlsOffloadAdapter(object): @mock.patch.object(requests.adapters.HTTPAdapter, "init_poolmanager") diff --git a/tests/transport/test_urllib3.py b/tests/transport/test_urllib3.py index e83230032..7872a7187 100644 --- a/tests/transport/test_urllib3.py +++ b/tests/transport/test_urllib3.py @@ -29,6 +29,9 @@ from google.oauth2 import service_account from tests.transport import compliance +CERT_MOCK_VAL = b"-----BEGIN CERTIFICATE-----\nMIIDIzCCAgugAwIBAgIJAMfISuBQ5m+5MA0GCSqGSIb3DQEBBQUAMBUxEzARBgNV\nBAMTCnVuaXQtdGVzdHMwHhcNMTExMjA2MTYyNjAyWhcNMjExMjAzMTYyNjAyWjAV\nMRMwEQYDVQQDEwp1bml0LXRlc3RzMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB\nCgKCAQEA4ej0p7bQ7L/r4rVGUz9RN4VQWoej1Bg1mYWIDYslvKrk1gpj7wZgkdmM\n7oVK2OfgrSj/FCTkInKPqaCR0gD7K80q+mLBrN3PUkDrJQZpvRZIff3/xmVU1Wer\nuQLFJjnFb2dqu0s/FY/2kWiJtBCakXvXEOb7zfbINuayL+MSsCGSdVYsSliS5qQp\ngyDap+8b5fpXZVJkq92hrcNtbkg7hCYUJczt8n9hcCTJCfUpApvaFQ18pe+zpyl4\n+WzkP66I28hniMQyUlA1hBiskT7qiouq0m8IOodhv2fagSZKjOTTU2xkSBc//fy3\nZpsL7WqgsZS7Q+0VRK8gKfqkxg5OYQIDAQABo3YwdDAdBgNVHQ4EFgQU2RQ8yO+O\ngN8oVW2SW7RLrfYd9jEwRQYDVR0jBD4wPIAU2RQ8yO+OgN8oVW2SW7RLrfYd9jGh\nGaQXMBUxEzARBgNVBAMTCnVuaXQtdGVzdHOCCQDHyErgUOZvuTAMBgNVHRMEBTAD\nAQH/MA0GCSqGSIb3DQEBBQUAA4IBAQBRv+M/6+FiVu7KXNjFI5pSN17OcW5QUtPr\nodJMlWrJBtynn/TA1oJlYu3yV5clc/71Vr/AxuX5xGP+IXL32YDF9lTUJXG/uUGk\n+JETpKmQviPbRsvzYhz4pf6ZIOZMc3/GIcNq92ECbseGO+yAgyWUVKMmZM0HqXC9\novNslqe0M8C1sLm1zAR5z/h/litE7/8O2ietija3Q/qtl2TOXJdCA6sgjJX2WUql\nybrC55ct18NKf3qhpcEkGQvFU40rVYApJpi98DiZPYFdx1oBDp/f4uZ3ojpxRVFT\ncDwcJLfNRCPUhormsY7fDS9xSyThiHsW9mjJYdcaKQkwYZ0F11yB\n-----END CERTIFICATE-----\n" +KEY_MOCK_VAL = b"-----BEGIN ENCRYPTED PRIVATE KEY-----\nMIHeMEkGCSqGSIb3DQEFDTA8MBsGCSqGSIb3DQEFDDAOBAj9XnJ2h78QVAICCAAw\nHQYJYIZIAWUDBAECBBBeiiOF2LnLzq/wjb/viwMwBIGQk28Zkfj2EIk42bgc7UzC\nSf98qssCVhsIYz0Xa3eSATg8Cpn83YieaBeyxdk/tXTnrOhxMV/vt7T98kWhaGbH\n5Z9CdGVLfes0UFvVJqrlk6vcf2sOnLCGbrn78HS+ayrGOCRSCd/7+dnEiB/7Um1B\nMk6BBJHsLEnZZSHyfrw8jvYgVmcSBy/WdY0pqldD/+4D\n-----END ENCRYPTED PRIVATE KEY-----\n" + class TestRequestResponse(compliance.RequestResponseTests): def make_request(self): @@ -320,3 +323,208 @@ def test_clear_pool_on_del(self): authed_http.http = None authed_http.__del__() # Expect it to not crash + + def test_cert_rotation_when_cert_mismatch_and_mtls_endpoint_used(self): + credentials = mock.Mock(wraps=CredentialsStub()) + final_response = ResponseStub(status=http_client.OK) + http = HttpStub([ResponseStub(status=http_client.UNAUTHORIZED), final_response]) + + authed_http = google.auth.transport.urllib3.AuthorizedHttp( + credentials, http=http + ) + + old_cert = b"-----BEGIN CERTIFICATE-----\nMIIBdTCCARqgAwIBAgIJAOYVvu/axMxvMAoGCCqGSM49BAMCMCcxJTAjBgNVBAMM\nHEdvb2dsZSBFbmRwb2ludCBWZXJpZmljYXRpb24wHhcNMjUwNzMwMjMwNjA4WhcN\nMjYwNzMxMjMwNjA4WjAnMSUwIwYDVQQDDBxHb29nbGUgRW5kcG9pbnQgVmVyaWZp\nY2F0aW9uMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEbtr18gkEtwPow2oqyZsU\n4KLwFaLFlRlYv55UATS3QTDykDnIufC42TJCnqFRYhwicwpE2jnUV+l9g3Voias8\nraMvMC0wCQYDVR0TBAIwADALBgNVHQ8EBAMCB4AwEwYDVR0lBAwwCgYIKwYBBQUH\nAwIwCgYIKoZIzj0EAwIDSQAwRgIhAKcjW6dmF1YCksXPgDPlPu/nSnOjb3qCcivz\n/Jxq2zoeAiEA7/aNxcEoCGS3hwMIXoaaD/vPcZOOopKSyqXCvxRooKQ=\n-----END CERTIFICATE-----\n" + + # New certificate and key to simulate rotation. + new_cert = CERT_MOCK_VAL + new_key = KEY_MOCK_VAL + # Set _cached_cert to a callable that returns the old certificate. + authed_http._cached_cert = old_cert + authed_http._is_mtls = True + # Mock call_client_cert_callback to return the new certificate. + with mock.patch.object( + google.auth._agent_identity_utils, + "call_client_cert_callback", + return_value=(new_cert, new_key), + ) as mock_callback: + # mTLS endpoint is used + result = authed_http.urlopen("GET", "http://example.mtls.googleapis.com") + + # Asserts to verify the behavior. + assert result == final_response + assert credentials.refresh.called + assert credentials.refresh.call_count == 1 + assert mock_callback.called + + def test_no_cert_rotation_when_cert_match_and_mtls_endpoint_used(self): + credentials = mock.Mock(wraps=CredentialsStub()) + final_response = ResponseStub(status=http_client.UNAUTHORIZED) + http = HttpStub( + [ + ResponseStub(status=http_client.UNAUTHORIZED), + ResponseStub(status=http_client.UNAUTHORIZED), + ResponseStub(status=http_client.UNAUTHORIZED), + ] + ) + authed_http = google.auth.transport.urllib3.AuthorizedHttp( + credentials, http=http + ) + old_cert = CERT_MOCK_VAL + + new_cert = old_cert + new_key = KEY_MOCK_VAL + # Set _cached_cert to a callable that returns the same certificate. + authed_http._cached_cert = old_cert + authed_http._is_mtls = True + # Mock call_client_cert_callback to return the certificate. + with mock.patch.object( + google.auth._agent_identity_utils, + "call_client_cert_callback", + return_value=(new_cert, new_key), + ): + # mTLS endpoint is used + result = authed_http.urlopen("GET", "http://example.mtls.googleapis.com") + + # Asserts to verify the behavior. + assert credentials.refresh.call_count == 2 + assert result.status == final_response.status + + def test_no_cert_match_check_when_mtls_endpoint_not_used(self): + credentials = mock.Mock(wraps=CredentialsStub()) + final_response = ResponseStub(status=http_client.UNAUTHORIZED) + http = HttpStub( + [ + ResponseStub(status=http_client.UNAUTHORIZED), + ResponseStub(status=http_client.UNAUTHORIZED), + ResponseStub(status=http_client.UNAUTHORIZED), + ] + ) + authed_http = google.auth.transport.urllib3.AuthorizedHttp( + credentials, http=http + ) + authed_http._is_mtls = False + new_cert = CERT_MOCK_VAL + new_key = KEY_MOCK_VAL + + # Mock call_client_cert_callback to return the certificate. + with mock.patch.object( + google.auth._agent_identity_utils, + "call_client_cert_callback", + return_value=(new_cert, new_key), + ) as mock_callback: + # non-mTLS endpoint is used + result = authed_http.urlopen("GET", "http://example.googleapis.com") + + # Asserts to verify the behavior. + assert not mock_callback.called + assert result.status == final_response.status + + def test_no_cert_rotation_when_no_unauthorized_response(self): + credentials = mock.Mock(wraps=CredentialsStub()) + final_response = ResponseStub(status=http_client.UPGRADE_REQUIRED) + + # Response is set to code other than 401(Unauthorized). + http = HttpStub([ResponseStub(status=http_client.UPGRADE_REQUIRED)]) + authed_http = google.auth.transport.urllib3.AuthorizedHttp( + credentials, http=http + ) + authed_http._is_mtls = True + with mock.patch.dict( + os.environ, {environment_vars.GOOGLE_API_USE_CLIENT_CERTIFICATE: "true"} + ): + # mTLS endpoint is used + result = authed_http.urlopen("GET", "http://example.mtls.googleapis.com") + assert result.status == final_response.status + assert not credentials.refresh.called + assert credentials.refresh.call_count == 0 + + def test_cert_rotation_failure_raises_error(self): + credentials = mock.Mock(wraps=CredentialsStub()) + http = HttpStub([ResponseStub(status=http_client.UNAUTHORIZED)]) + + authed_http = google.auth.transport.urllib3.AuthorizedHttp( + credentials, http=http + ) + + old_cert = b"-----BEGIN CERTIFICATE-----\nMIIBdTCCARqgAwIBAgIJAOYVvu/axMxvMAoGCCqGSM49BAMCMCcxJTAjBgNVBAMM\nHEdvb2dsZSBFbmRwb2ludCBWZXJpZmljYXRpb24wHhcNMjUwNzMwMjMwNjA4WhcN\nMjYwNzMxMjMwNjA4WjAnMSUwIwYDVQQDDBxHb29nbGUgRW5kcG9pbnQgVmVyaWZp\nY2F0aW9uMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEbtr18gkEtwPow2oqyZsU\n4KLwFaLFlRlYv55UATS3QTDykDnIufC42TJCnqFRYhwicwpE2jnUV+l9g3Voias8\nraMvMC0wCQYDVR0TBAIwADALBgNVHQ8EBAMCB4AwEwYDVR0lBAwwCgYIKwYBBQUH\nAwIwCgYIKoZIzj0EAwIDSQAwRgIhAKcjW6dmF1YCksXPgDPlPu/nSnOjb3qCcivz\n/Jxq2zoeAiEA7/aNxcEoCGS3hwMIXoaaD/vPcZOOopKSyqXCvxRooKQ=\n-----END CERTIFICATE-----\n" + + # New certificate and key to simulate rotation. + new_cert = CERT_MOCK_VAL + new_key = KEY_MOCK_VAL + authed_http._cached_cert = old_cert + authed_http._is_mtls = True + + # Mock call_client_cert_callback to return the new certificate. + with mock.patch.object( + google.auth.transport._mtls_helper, + "check_parameters_for_unauthorized_response", + return_value=(new_cert, new_key, "old_fingerprint", "new_fingerprint"), + ) as mock_check_params: + with mock.patch.object( + authed_http, + "configure_mtls_channel", + side_effect=Exception("Failed to reconfigure"), + ) as mock_reconfigure: + with pytest.raises(exceptions.MutualTLSChannelError): + authed_http.urlopen("GET", "https://example.mtls.googleapis.com") + + mock_check_params.assert_called_once() + mock_reconfigure.assert_called_once() + credentials.refresh.assert_not_called() + + def test_cert_rotation_check_params_fails(self): + credentials = mock.Mock(wraps=CredentialsStub()) + http = HttpStub([ResponseStub(status=http_client.UNAUTHORIZED)]) + + authed_http = google.auth.transport.urllib3.AuthorizedHttp( + credentials, http=http + ) + authed_http._is_mtls = True + authed_http._cached_cert = b"cached_cert" + + with mock.patch( + "google.auth.transport.urllib3._mtls_helper.check_parameters_for_unauthorized_response", + side_effect=Exception("check_params failed"), + ) as mock_check_params: + with pytest.raises(Exception, match="check_params failed"): + authed_http.urlopen("GET", "http://example.mtls.googleapis.com") + + mock_check_params.assert_called_once() + credentials.refresh.assert_not_called() + + def test_cert_rotation_logic_skipped_on_other_refresh_status_codes(self): + """ + Tests that the code can handle a refresh triggered by a status code + other than 401 (UNAUTHORIZED). This covers the 'else' branch of the + 'if response.status_code == http_client.UNAUTHORIZED' check + """ + credentials = mock.Mock(wraps=CredentialsStub()) + # Configure the session to treat 503 (Service Unavailable) as a refreshable error + custom_codes = [http_client.SERVICE_UNAVAILABLE] + + # Return 503 first, then 200 + http = HttpStub( + [ + ResponseStub(status=http_client.SERVICE_UNAVAILABLE), + ResponseStub(status=http_client.OK), + ] + ) + + authed_http = google.auth.transport.urllib3.AuthorizedHttp( + credentials, http=http, refresh_status_codes=custom_codes + ) + + # Enable mTLS to prove it is skipped despite being enabled + authed_http._is_mtls = True + mtls_url = "https://mtls.googleapis.com/test" + + with mock.patch( + "google.auth.transport.urllib3._mtls_helper", autospec=True + ) as mock_helper: + authed_http.urlopen("GET", mtls_url) + + # Assert refresh happened (Outer Check was True) + assert credentials.refresh.called + + # Assert mTLS check logic was SKIPPED (Inner Check was False) + assert not mock_helper.check_parameters_for_unauthorized_response.called