diff --git a/google/auth/_agent_identity_utils.py b/google/auth/_agent_identity_utils.py new file mode 100644 index 000000000..54fcc60fd --- /dev/null +++ b/google/auth/_agent_identity_utils.py @@ -0,0 +1,281 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for Agent Identity credentials.""" + +import base64 +import hashlib +import logging +import os +import re +import time +from urllib.parse import quote, urlparse + +from google.auth import environment_vars +from google.auth import exceptions +from google.auth.transport import _mtls_helper + + +_LOGGER = logging.getLogger(__name__) + +CRYPTOGRAPHY_NOT_FOUND_ERROR = ( + "The cryptography library is required for certificate-based authentication." + "Please install it with `pip install google-auth[cryptography]`." +) + +# SPIFFE trust domain patterns for Agent Identities. +_AGENT_IDENTITY_SPIFFE_TRUST_DOMAIN_PATTERNS = [ + r"^agents\.global\.org-\d+\.system\.id\.goog$", + r"^agents\.global\.proj-\d+\.system\.id\.goog$", +] + +_WELL_KNOWN_CERT_PATH = "/var/run/secrets/workload-spiffe-credentials/certificates.pem" + +# Constants for polling the certificate file. +_FAST_POLL_CYCLES = 50 +_FAST_POLL_INTERVAL = 0.1 # 100ms +_SLOW_POLL_INTERVAL = 0.5 # 500ms +_TOTAL_TIMEOUT = 30 # seconds + +# Calculate the number of slow poll cycles based on the total timeout. +_SLOW_POLL_CYCLES = int( + (_TOTAL_TIMEOUT - (_FAST_POLL_CYCLES * _FAST_POLL_INTERVAL)) / _SLOW_POLL_INTERVAL +) + +_POLLING_INTERVALS = ([_FAST_POLL_INTERVAL] * _FAST_POLL_CYCLES) + ( + [_SLOW_POLL_INTERVAL] * _SLOW_POLL_CYCLES +) + + +def _is_certificate_file_ready(path): + """Checks if a file exists and is not empty.""" + return path and os.path.exists(path) and os.path.getsize(path) > 0 + + +def get_agent_identity_certificate_path(): + """Gets the certificate path from the certificate config file. + + The path to the certificate config file is read from the + GOOGLE_API_CERTIFICATE_CONFIG environment variable. This function + implements a retry mechanism to handle cases where the environment + variable is set before the files are available on the filesystem. + + Returns: + str: The path to the leaf certificate file. + + Raises: + google.auth.exceptions.RefreshError: If the certificate config file + or the certificate file cannot be found after retries. + """ + import json + + cert_config_path = os.environ.get(environment_vars.GOOGLE_API_CERTIFICATE_CONFIG) + if not cert_config_path: + return None + + has_logged_warning = False + + for interval in _POLLING_INTERVALS: + try: + with open(cert_config_path, "r") as f: + cert_config = json.load(f) + cert_path = ( + cert_config.get("cert_configs", {}) + .get("workload", {}) + .get("cert_path") + ) + if _is_certificate_file_ready(cert_path): + return cert_path + except (IOError, ValueError, KeyError): + if not has_logged_warning: + _LOGGER.warning( + "Certificate config file not found at %s (from %s environment " + "variable). Retrying for up to %s seconds.", + cert_config_path, + environment_vars.GOOGLE_API_CERTIFICATE_CONFIG, + _TOTAL_TIMEOUT, + ) + has_logged_warning = True + pass + + # As a fallback, check the well-known certificate path. + if _is_certificate_file_ready(_WELL_KNOWN_CERT_PATH): + return _WELL_KNOWN_CERT_PATH + + # A sleep is required in two cases: + # 1. The config file is not found (the except block). + # 2. The config file is found, but the certificate is not yet available. + # In both cases, we need to poll, so we sleep on every iteration + # that doesn't return a certificate. + time.sleep(interval) + + raise exceptions.RefreshError( + "Certificate config or certificate file not found after multiple retries. " + f"Token binding protection is failing. You can turn off this protection by setting " + f"{environment_vars.GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES} to false " + "to fall back to unbound tokens." + ) + + +def get_and_parse_agent_identity_certificate(): + """Gets and parses the agent identity certificate if not opted out. + + Checks if the user has opted out of certificate-bound tokens. If not, + it gets the certificate path, reads the file, and parses it. + + Returns: + The parsed certificate object if found and not opted out, otherwise None. + """ + # If the user has opted out of cert bound tokens, there is no need to + # look up the certificate. + is_opted_out = ( + os.environ.get( + environment_vars.GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES, + "true", + ).lower() + == "false" + ) + if is_opted_out: + return None + + cert_path = get_agent_identity_certificate_path() + if not cert_path: + return None + + with open(cert_path, "rb") as cert_file: + cert_bytes = cert_file.read() + + return parse_certificate(cert_bytes) + + +def parse_certificate(cert_bytes): + """Parses a PEM-encoded certificate. + + Args: + cert_bytes (bytes): The PEM-encoded certificate bytes. + + Returns: + cryptography.x509.Certificate: The parsed certificate object. + """ + try: + from cryptography import x509 + + return x509.load_pem_x509_certificate(cert_bytes) + except ImportError as e: + raise ImportError(CRYPTOGRAPHY_NOT_FOUND_ERROR) from e + + +def _is_agent_identity_certificate(cert): + """Checks if a certificate is an Agent Identity certificate. + + This is determined by checking the Subject Alternative Name (SAN) for a + SPIFFE ID with a trust domain matching Agent Identity patterns. + + Args: + cert (cryptography.x509.Certificate): The parsed certificate object. + + Returns: + bool: True if the certificate is an Agent Identity certificate, + False otherwise. + """ + try: + from cryptography import x509 + from cryptography.x509.oid import ExtensionOID + + try: + ext = cert.extensions.get_extension_for_oid( + ExtensionOID.SUBJECT_ALTERNATIVE_NAME + ) + except x509.ExtensionNotFound: + return False + uris = ext.value.get_values_for_type(x509.UniformResourceIdentifier) + + for uri in uris: + parsed_uri = urlparse(uri) + if parsed_uri.scheme == "spiffe": + trust_domain = parsed_uri.netloc + for pattern in _AGENT_IDENTITY_SPIFFE_TRUST_DOMAIN_PATTERNS: + if re.match(pattern, trust_domain): + return True + return False + except ImportError as e: + raise ImportError(CRYPTOGRAPHY_NOT_FOUND_ERROR) from e + + +def calculate_certificate_fingerprint(cert): + """Calculates the URL-encoded, unpadded, base64-encoded SHA256 hash of a + DER-encoded certificate. + + Args: + cert (cryptography.x509.Certificate): The parsed certificate object. + + Returns: + str: The URL-encoded, unpadded, base64-encoded SHA256 fingerprint. + """ + try: + from cryptography.hazmat.primitives import serialization + + der_cert = cert.public_bytes(serialization.Encoding.DER) + fingerprint = hashlib.sha256(der_cert).digest() + # The certificate fingerprint is generated in two steps to align with GFE's + # expectations and ensure proper URL transmission: + # 1. Standard base64 encoding is applied, and padding ('=') is removed. + # 2. The resulting string is then URL-encoded to handle special characters + # ('+', '/') that would otherwise be misinterpreted in URL parameters. + base64_fingerprint = base64.b64encode(fingerprint).decode("utf-8") + unpadded_base64_fingerprint = base64_fingerprint.rstrip("=") + return quote(unpadded_base64_fingerprint) + except ImportError as e: + raise ImportError(CRYPTOGRAPHY_NOT_FOUND_ERROR) from e + + +def should_request_bound_token(cert): + """Determines if a bound token should be requested. + + This is based on the GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES + environment variable and whether the certificate is an agent identity cert. + + Args: + cert (cryptography.x509.Certificate): The parsed certificate object. + + Returns: + bool: True if a bound token should be requested, False otherwise. + """ + is_agent_cert = _is_agent_identity_certificate(cert) + is_opted_in = ( + os.environ.get( + environment_vars.GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES, + "true", + ).lower() + == "true" + ) + return is_agent_cert and is_opted_in + + +def call_client_cert_callback(): + """Calls the client cert callback and returns the certificate and key.""" + _, cert_bytes, key_bytes, passphrase = _mtls_helper.get_client_ssl_credentials( + generate_encrypted_key=True + ) + return cert_bytes, key_bytes + + +def get_cached_cert_fingerprint(cached_cert): + """Returns the fingerprint of the cached certificate.""" + if cached_cert: + cert_obj = parse_certificate(cached_cert) + cached_cert_fingerprint = calculate_certificate_fingerprint(cert_obj) + else: + raise ValueError("mTLS connection is not configured.") + return cached_cert_fingerprint diff --git a/google/auth/_oauth2client.py b/google/auth/_oauth2client.py index 8b83ff23c..8032b26ad 100644 --- a/google/auth/_oauth2client.py +++ b/google/auth/_oauth2client.py @@ -127,7 +127,7 @@ def _convert_appengine_app_assertion_credentials(credentials): oauth2client.contrib.gce.AppAssertionCredentials: _convert_gce_app_assertion_credentials, } -if _HAS_APPENGINE: +if _HAS_APPENGINE: # pragma: no cover _CLASS_CONVERSION_MAP[ oauth2client.contrib.appengine.AppAssertionCredentials ] = _convert_appengine_app_assertion_credentials diff --git a/google/auth/compute_engine/_metadata.py b/google/auth/compute_engine/_metadata.py index 96f1ff526..35b6c4495 100644 --- a/google/auth/compute_engine/_metadata.py +++ b/google/auth/compute_engine/_metadata.py @@ -451,12 +451,19 @@ def get_service_account_token(request, service_account="default", scopes=None): google.auth.exceptions.TransportError: if an error occurred while retrieving metadata. """ + from google.auth import _agent_identity_utils + + params = {} if scopes: if not isinstance(scopes, str): scopes = ",".join(scopes) - params = {"scopes": scopes} - else: - params = None + params["scopes"] = scopes + + cert = _agent_identity_utils.get_and_parse_agent_identity_certificate() + if cert: + if _agent_identity_utils.should_request_bound_token(cert): + fingerprint = _agent_identity_utils.calculate_certificate_fingerprint(cert) + params["bindCertificateFingerprint"] = fingerprint metrics_header = { metrics.API_CLIENT_HEADER: metrics.token_request_access_token_mds() diff --git a/google/auth/compute_engine/credentials.py b/google/auth/compute_engine/credentials.py index 0f518166a..554547619 100644 --- a/google/auth/compute_engine/credentials.py +++ b/google/auth/compute_engine/credentials.py @@ -135,9 +135,9 @@ def _refresh_token(self, request): service can't be reached if if the instance has not credentials. """ - scopes = self._scopes if self._scopes is not None else self._default_scopes try: self._retrieve_info(request) + scopes = self._scopes if self._scopes is not None else self._default_scopes # Always fetch token with default service account email. self.token, self.expiry = _metadata.get_service_account_token( request, service_account="default", scopes=scopes diff --git a/google/auth/environment_vars.py b/google/auth/environment_vars.py index 5da3a7382..1e6557272 100644 --- a/google/auth/environment_vars.py +++ b/google/auth/environment_vars.py @@ -92,3 +92,12 @@ GOOGLE_AUTH_TRUST_BOUNDARY_ENABLED = "GOOGLE_AUTH_TRUST_BOUNDARY_ENABLED" """Environment variable controlling whether to enable trust boundary feature. The default value is false. Users have to explicitly set this value to true.""" + +GOOGLE_API_CERTIFICATE_CONFIG = "GOOGLE_API_CERTIFICATE_CONFIG" +"""Environment variable defining the location of Google API certificate config +file.""" + +GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES = ( + "GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES" +) +"""Environment variable to prevent agent token sharing for GCP services.""" diff --git a/google/auth/external_account.py b/google/auth/external_account.py index 8eba0d249..8cf928778 100644 --- a/google/auth/external_account.py +++ b/google/auth/external_account.py @@ -420,6 +420,9 @@ def refresh(self, request): credentials, it will refresh the access token and the trust boundary. """ self._refresh_token(request) + self._handle_trust_boundary(request) + + def _handle_trust_boundary(self, request): # If we are impersonating, the trust boundary is handled by the # impersonated credentials object. We need to get it from there. if self._service_account_impersonation_url: @@ -428,7 +431,7 @@ def refresh(self, request): # Otherwise, refresh the trust boundary for the external account. self._refresh_trust_boundary(request) - def _refresh_token(self, request): + def _refresh_token(self, request, cert_fingerprint=None): scopes = self._scopes if self._scopes is not None else self._default_scopes # Inject client certificate into request. @@ -446,11 +449,15 @@ def _refresh_token(self, request): self.expiry = self._impersonated_credentials.expiry else: now = _helpers.utcnow() - additional_options = None + additional_options = {} # Do not pass workforce_pool_user_project when client authentication # is used. The client ID is sufficient for determining the user project. if self._workforce_pool_user_project and not self._client_id: - additional_options = {"userProject": self._workforce_pool_user_project} + additional_options["userProject"] = self._workforce_pool_user_project + + if cert_fingerprint: + additional_options["bindCertFingerprint"] = cert_fingerprint + additional_headers = { metrics.API_CLIENT_HEADER: metrics.byoid_metrics_header( self._metrics_options @@ -464,7 +471,7 @@ def _refresh_token(self, request): audience=self._audience, scopes=scopes, requested_token_type=_STS_REQUESTED_TOKEN_TYPE, - additional_options=additional_options, + additional_options=additional_options if additional_options else None, additional_headers=additional_headers, ) self.token = response_data.get("access_token") diff --git a/google/auth/identity_pool.py b/google/auth/identity_pool.py index 79b7de920..d2ed8c85a 100644 --- a/google/auth/identity_pool.py +++ b/google/auth/identity_pool.py @@ -550,3 +550,25 @@ def from_file(cls, filename, **kwargs): credentials. """ return super(Credentials, cls).from_file(filename, **kwargs) + + def refresh(self, request): + """Refreshes the access token. + + Args: + request (google.auth.transport.Request): The object used to make + HTTP requests. + """ + from google.auth import _agent_identity_utils + + cert_fingerprint = None + # Check if the credential is X.509 based. + if self._credential_source_certificate is not None: + cert_bytes = self._get_cert_bytes() + cert = _agent_identity_utils.parse_certificate(cert_bytes) + if _agent_identity_utils.should_request_bound_token(cert): + cert_fingerprint = _agent_identity_utils.calculate_certificate_fingerprint( + cert + ) + + self._refresh_token(request, cert_fingerprint=cert_fingerprint) + self._handle_trust_boundary(request) diff --git a/google/auth/transport/_mtls_helper.py b/google/auth/transport/_mtls_helper.py index f5d6b6724..5b56b60b1 100644 --- a/google/auth/transport/_mtls_helper.py +++ b/google/auth/transport/_mtls_helper.py @@ -20,11 +20,12 @@ import re import subprocess +from google.auth import _agent_identity_utils +from google.auth import environment_vars from google.auth import exceptions CONTEXT_AWARE_METADATA_PATH = "~/.secureConnect/context_aware_metadata.json" CERTIFICATE_CONFIGURATION_DEFAULT_PATH = "~/.config/gcloud/certificate_config.json" -_CERTIFICATE_CONFIGURATION_ENV = "GOOGLE_API_CERTIFICATE_CONFIG" _CERT_PROVIDER_COMMAND = "cert_provider_command" _CERT_REGEX = re.compile( b"-----BEGIN CERTIFICATE-----.+-----END CERTIFICATE-----\r?\n?", re.DOTALL @@ -146,7 +147,7 @@ def _get_cert_config_path(certificate_config_path=None): """ if certificate_config_path is None: - env_path = environ.get(_CERTIFICATE_CONFIGURATION_ENV, None) + env_path = environ.get(environment_vars.GOOGLE_API_CERTIFICATE_CONFIG, None) if env_path is not None and env_path != "": certificate_config_path = env_path else: @@ -474,3 +475,29 @@ def check_use_client_cert(): ) as e: _LOGGER.debug("error decoding certificate: %s", e) return False + + +def check_parameters_for_unauthorized_response(cached_cert): + """Returns the cached and current cert fingerprint for reconfiguring mTLS. + + Args: + cached_cert(bytes): The cached client certificate. + + Returns: + bytes: The client callback cert bytes. + bytes: The client callback key bytes. + str: The base64-encoded SHA256 cached fingerprint. + str: The base64-encoded SHA256 current cert fingerprint. + """ + call_cert_bytes, call_key_bytes = _agent_identity_utils.call_client_cert_callback() + cert_obj = _agent_identity_utils.parse_certificate(call_cert_bytes) + current_cert_fingerprint = _agent_identity_utils.calculate_certificate_fingerprint( + cert_obj + ) + if cached_cert: + cached_fingerprint = _agent_identity_utils.get_cached_cert_fingerprint( + cached_cert + ) + else: + cached_fingerprint = current_cert_fingerprint + return call_cert_bytes, call_key_bytes, cached_fingerprint, current_cert_fingerprint diff --git a/google/auth/transport/requests.py b/google/auth/transport/requests.py index d1ff8f368..b750e2b3d 100644 --- a/google/auth/transport/requests.py +++ b/google/auth/transport/requests.py @@ -17,6 +17,7 @@ from __future__ import absolute_import import functools +import http.client as http_client import logging import numbers import time @@ -36,6 +37,7 @@ from google.auth import _helpers from google.auth import exceptions from google.auth import transport +from google.auth.transport import _mtls_helper import google.auth.transport._mtls_helper from google.oauth2 import service_account @@ -463,6 +465,7 @@ def configure_mtls_channel(self, client_cert_callback=None): if self._is_mtls: mtls_adapter = _MutualTlsAdapter(cert, key) + self._cached_cert = cert self.mount("https://", mtls_adapter) except ( exceptions.ClientCertError, @@ -502,6 +505,10 @@ def request( itself does not timeout, e.g. if a large file is being transmitted. The timout error will be raised after such request completes. + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS + channel creation fails for any reason. + ValueError: If the client certificate is invalid. """ # pylint: disable=arguments-differ # Requests has a ton of arguments to request, but only two @@ -551,7 +558,31 @@ def request( response.status_code in self._refresh_status_codes and _credential_refresh_attempt < self._max_refresh_attempts ): - + # Handle unauthorized permission error(401 status code) + if response.status_code == http_client.UNAUTHORIZED: + if self.is_mtls: + call_cert_bytes, call_key_bytes, cached_fingerprint, current_cert_fingerprint = _mtls_helper.check_parameters_for_unauthorized_response( + self._cached_cert + ) + if cached_fingerprint != current_cert_fingerprint: + try: + _LOGGER.info( + "Client certificate has changed, reconfiguring mTLS " + "channel." + ) + self.configure_mtls_channel( + lambda: (call_cert_bytes, call_key_bytes) + ) + except Exception as e: + _LOGGER.error("Failed to reconfigure mTLS channel: %s", e) + raise exceptions.MutualTLSChannelError( + "Failed to reconfigure mTLS channel" + ) from e + else: + _LOGGER.info( + "Skipping reconfiguration of mTLS channel because the client" + " certificate has not changed." + ) _LOGGER.info( "Refreshing credentials due to a %s response. Attempt %s/%s.", response.status_code, diff --git a/google/auth/transport/urllib3.py b/google/auth/transport/urllib3.py index 353cb8e08..3c16831f2 100644 --- a/google/auth/transport/urllib3.py +++ b/google/auth/transport/urllib3.py @@ -16,6 +16,7 @@ from __future__ import absolute_import +import http.client as http_client import logging import warnings @@ -52,6 +53,7 @@ from google.auth import _helpers from google.auth import exceptions from google.auth import transport +from google.auth.transport import _mtls_helper from google.oauth2 import service_account if version.parse(urllib3.__version__) >= version.parse("2.0.0"): # pragma: NO COVER @@ -299,6 +301,7 @@ def __init__( # Request instance used by internal methods (for example, # credentials.refresh). self._request = Request(self.http) + self._is_mtls = False # https://google.aip.dev/auth/4111 # Attempt to use self-signed JWTs when a service account is used. @@ -335,7 +338,10 @@ def configure_mtls_channel(self, client_cert_callback=None): """ use_client_cert = transport._mtls_helper.check_use_client_cert() if not use_client_cert: + self._is_mtls = False return False + else: + self._is_mtls = True try: import OpenSSL except ImportError as caught_exc: @@ -349,6 +355,7 @@ def configure_mtls_channel(self, client_cert_callback=None): if found_cert_key: self.http = _make_mutual_tls_http(cert, key) + self._cached_cert = cert else: self.http = _make_default_http() except ( @@ -381,6 +388,11 @@ def urlopen(self, method, url, body=None, headers=None, **kwargs): if headers is None: headers = self.headers + use_mtls = False + if self._is_mtls: + MTLS_URL_PREFIXES = ["mtls.googleapis.com", "mtls.sandbox.googleapis.com"] + use_mtls = any([prefix in url for prefix in MTLS_URL_PREFIXES]) + # Make a copy of the headers. They will be modified by the credentials # and we want to pass the original headers if we recurse. request_headers = headers.copy() @@ -402,6 +414,34 @@ def urlopen(self, method, url, body=None, headers=None, **kwargs): response.status in self._refresh_status_codes and _credential_refresh_attempt < self._max_refresh_attempts ): + if response.status == http_client.UNAUTHORIZED: + if use_mtls: + call_cert_bytes, call_key_bytes, cached_fingerprint, current_cert_fingerprint = _mtls_helper.check_parameters_for_unauthorized_response( + self._cached_cert + ) + if cached_fingerprint != current_cert_fingerprint: + try: + _LOGGER.info( + "Client certificate has changed, reconfiguring mTLS " + "channel." + ) + self.configure_mtls_channel( + client_cert_callback=lambda: ( + call_cert_bytes, + call_key_bytes, + ) + ) + except Exception as e: + _LOGGER.error("Failed to reconfigure mTLS channel: %s", e) + raise exceptions.MutualTLSChannelError( + "Failed to reconfigure mTLS channel" + ) from e + + else: + _LOGGER.info( + "Skipping reconfiguration of mTLS channel because the " + "client certificate has not changed." + ) _LOGGER.info( "Refreshing credentials due to a %s response. Attempt %s/%s.", diff --git a/mypy.ini b/mypy.ini index 574c5aed3..c129006db 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,3 +1,3 @@ [mypy] -python_version = 3.7 +python_version = 3.9 namespace_packages = True diff --git a/noxfile.py b/noxfile.py index 11f677a3b..e91b7d5d6 100644 --- a/noxfile.py +++ b/noxfile.py @@ -106,6 +106,7 @@ def mypy(session): "types-requests", "types-setuptools", "types-mock", + "pytest<8.0.0", ) session.run("mypy", "-p", "google", "-p", "tests", "-p", "tests_async") @@ -130,6 +131,7 @@ def unit(session): @nox.session(python=DEFAULT_PYTHON_VERSION) def cover(session): + session.env["PIP_EXTRA_INDEX_URL"] = "https://pypi.org/simple" session.install("-e", ".[testing]") session.run( "pytest", diff --git a/setup.py b/setup.py index 014b32a95..51d946e5c 100644 --- a/setup.py +++ b/setup.py @@ -81,6 +81,7 @@ ] extras = { + "cryptography": cryptography_base_require, "aiohttp": aiohttp_extra_require, "enterprise_cert": enterprise_cert_extra_require, "pyopenssl": pyopenssl_extra_require, diff --git a/tests/compute_engine/test__metadata.py b/tests/compute_engine/test__metadata.py index adb63f667..5bb85c264 100644 --- a/tests/compute_engine/test__metadata.py +++ b/tests/compute_engine/test__metadata.py @@ -40,6 +40,29 @@ DATA_DIR, "smbios_product_name_non_google" ) +# A mock PEM-encoded certificate without an Agent Identity SPIFFE ID. +NON_AGENT_IDENTITY_CERT_BYTES = ( + b"-----BEGIN CERTIFICATE-----\n" + b"MIIDIzCCAgugAwIBAgIJAMfISuBQ5m+5MA0GCSqGSIb3DQEBBQUAMBUxEzARBgNV\n" + b"BAMTCnVuaXQtdGVzdHMwHhcNMTExMjA2MTYyNjAyWhcNMjExMjAzMTYyNjAyWjAV\n" + b"MRMwEQYDVQQDEwp1bml0LXRlc3RzMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB\n" + b"CgKCAQEA4ej0p7bQ7L/r4rVGUz9RN4VQWoej1Bg1mYWIDYslvKrk1gpj7wZgkdmM\n" + b"7oVK2OfgrSj/FCTkInKPqaCR0gD7K80q+mLBrN3PUkDrJQZpvRZIff3/xmVU1Wer\n" + b"uQLFJjnFb2dqu0s/FY/2kWiJtBCakXvXEOb7zfbINuayL+MSsCGSdVYsSliS5qQp\n" + b"gyDap+8b5fpXZVJkq92hrcNtbkg7hCYUJczt8n9hcCTJCfUpApvaFQ18pe+zpyl4\n" + b"+WzkP66I28hniMQyUlA1hBiskT7qiouq0m8IOodhv2fagSZKjOTTU2xkSBc//fy3\n" + b"ZpsL7WqgsZS7Q+0VRK8gKfqkxg5OYQIDAQABo3YwdDAdBgNVHQ4EFgQU2RQ8yO+O\n" + b"gN8oVW2SW7RLrfYd9jEwRQYDVR0jBD4wPIAU2RQ8yO+OgN8oVW2SW7RLrfYd9jGh\n" + b"GaQXMBUxEzARBgNVBAMTCnVuaXQtdGVzdHOCCQDHyErgUOZvuTAMBgNVHRMEBTAD\n" + b"AQH/MA0GCSqGSIb3DQEBBQUAA4IBAQBRv+M/6+FiVu7KXNjFI5pSN17OcW5QUtPr\n" + b"odJMlWrJBtynn/TA1oJlYu3yV5clc/71Vr/AxuX5xGP+IXL32YDF9lTUJXG/uUGk\n" + b"+JETpKmQviPbRsvzYhz4pf6ZIOZMc3/GIcNq92ECbseGO+yAgyWUVKMmZM0HqXC9\n" + b"ovNslqe0M8C1sLm1zAR5z/h/litE7/8O2ietija3Q/qtl2TOXJdCA6sgjJX2WUql\n" + b"ybrC55ct18NKf3qhpcEkGQvFU40rVYApJpi98DiZPYFdx1oBDp/f4uZ3ojpxRVFT\n" + b"cDwcJLfNRCPUhormsY7fDS9xSyThiHsW9mjJYdcaKQkwYZ0F11yB\n" + b"-----END CERTIFICATE-----\n" +) + ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE = ( "gl-python/3.7 auth/1.1 auth-request-type/at cred-type/mds" ) @@ -670,6 +693,83 @@ def test_get_service_account_token_with_scopes_string( assert expiry == utcnow() + datetime.timedelta(seconds=ttl) +@mock.patch("google.auth._agent_identity_utils.calculate_certificate_fingerprint") +@mock.patch("google.auth._agent_identity_utils.should_request_bound_token") +@mock.patch( + "google.auth._agent_identity_utils.get_and_parse_agent_identity_certificate" +) +@mock.patch( + "google.auth.metrics.token_request_access_token_mds", + return_value=ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE, +) +@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) +def test_get_service_account_token_with_bound_token( + utcnow, + mock_metrics_header_value, + mock_get_and_parse, + mock_should_request, + mock_calculate_fingerprint, +): + # Test the successful path where a certificate is found and a bound token + # is requested. + mock_cert = mock.sentinel.cert + mock_get_and_parse.return_value = mock_cert + mock_should_request.return_value = True + mock_calculate_fingerprint.return_value = "fake_fingerprint" + + token_response = json.dumps({"access_token": "token", "expires_in": 3600}) + request = make_request(token_response, headers={"content-type": "application/json"}) + + _metadata.get_service_account_token(request) + + mock_get_and_parse.assert_called_once() + mock_should_request.assert_called_once_with(mock_cert) + mock_calculate_fingerprint.assert_called_once_with(mock_cert) + + request.assert_called_once() + _, kwargs = request.call_args + url = kwargs["url"] + assert "bindCertificateFingerprint=fake_fingerprint" in url + + +@mock.patch( + "google.auth._agent_identity_utils.get_and_parse_agent_identity_certificate" +) +def test_get_service_account_token_no_cert(mock_get_and_parse): + # Test that no fingerprint is added when no certificate is found. + mock_get_and_parse.return_value = None + token_response = json.dumps({"access_token": "token", "expires_in": 3600}) + request = make_request(token_response, headers={"content-type": "application/json"}) + + _metadata.get_service_account_token(request) + + request.assert_called_once() + _, kwargs = request.call_args + url = kwargs["url"] + assert "bindCertificateFingerprint" not in url + + +@mock.patch("google.auth._agent_identity_utils.should_request_bound_token") +@mock.patch( + "google.auth._agent_identity_utils.get_and_parse_agent_identity_certificate" +) +def test_get_service_account_token_should_not_bind( + mock_get_and_parse, mock_should_request +): + # Test that no fingerprint is added when a cert is found but should not be used. + mock_get_and_parse.return_value = mock.sentinel.cert + mock_should_request.return_value = False + token_response = json.dumps({"access_token": "token", "expires_in": 3600}) + request = make_request(token_response, headers={"content-type": "application/json"}) + + _metadata.get_service_account_token(request) + + request.assert_called_once() + _, kwargs = request.call_args + url = kwargs["url"] + assert "bindCertificateFingerprint" not in url + + def test_get_service_account_info(): key, value = "foo", "bar" request = make_request( diff --git a/tests/compute_engine/test_credentials.py b/tests/compute_engine/test_credentials.py index 1c7706993..9ef671425 100644 --- a/tests/compute_engine/test_credentials.py +++ b/tests/compute_engine/test_credentials.py @@ -658,6 +658,78 @@ def test_build_trust_boundary_lookup_url_no_email( assert excinfo.match(r"missing 'email' field") + @mock.patch("google.auth.compute_engine._metadata.get") + @mock.patch("google.auth._agent_identity_utils.get_agent_identity_certificate_path") + @mock.patch("google.auth._agent_identity_utils.parse_certificate") + @mock.patch( + "google.auth._agent_identity_utils.should_request_bound_token", + return_value=True, + ) + @mock.patch( + "google.auth._agent_identity_utils.calculate_certificate_fingerprint", + return_value="fingerprint", + ) + def test_refresh_with_agent_identity( + self, + mock_calculate_fingerprint, + mock_should_request, + mock_parse_certificate, + mock_get_path, + mock_metadata_get, + tmpdir, + ): + cert_path = tmpdir.join("cert.pem") + cert_path.write(b"cert_content") + mock_get_path.return_value = str(cert_path) + + mock_metadata_get.side_effect = [ + {"email": "service-account@example.com", "scopes": ["one", "two"]}, + {"access_token": "token", "expires_in": 500}, + ] + + self.credentials.refresh(None) + + assert self.credentials.token == "token" + mock_parse_certificate.assert_called_once_with(b"cert_content") + mock_should_request.assert_called_once_with(mock_parse_certificate.return_value) + kwargs = mock_metadata_get.call_args[1] + assert kwargs["params"] == { + "scopes": "one,two", + "bindCertificateFingerprint": "fingerprint", + } + + @mock.patch("google.auth.compute_engine._metadata.get") + @mock.patch("google.auth._agent_identity_utils.get_agent_identity_certificate_path") + @mock.patch("google.auth._agent_identity_utils.parse_certificate") + @mock.patch( + "google.auth._agent_identity_utils.should_request_bound_token", + return_value=False, + ) + def test_refresh_with_agent_identity_opt_out_or_not_agent( + self, + mock_should_request, + mock_parse_certificate, + mock_get_path, + mock_metadata_get, + tmpdir, + ): + cert_path = tmpdir.join("cert.pem") + cert_path.write(b"cert_content") + mock_get_path.return_value = str(cert_path) + + mock_metadata_get.side_effect = [ + {"email": "service-account@example.com", "scopes": ["one", "two"]}, + {"access_token": "token", "expires_in": 500}, + ] + + self.credentials.refresh(None) + + assert self.credentials.token == "token" + mock_parse_certificate.assert_called_once_with(b"cert_content") + mock_should_request.assert_called_once_with(mock_parse_certificate.return_value) + kwargs = mock_metadata_get.call_args[1] + assert "bindCertificateFingerprint" not in kwargs.get("params", {}) + class TestIDTokenCredentials(object): credentials = None diff --git a/tests/test_agent_identity_utils.py b/tests/test_agent_identity_utils.py new file mode 100644 index 000000000..86a63e82e --- /dev/null +++ b/tests/test_agent_identity_utils.py @@ -0,0 +1,337 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import hashlib +import json +import urllib.parse + +from cryptography import x509 +import mock +import pytest + +from google.auth import _agent_identity_utils +from google.auth import environment_vars +from google.auth import exceptions + +# A mock PEM-encoded certificate without an Agent Identity SPIFFE ID. +NON_AGENT_IDENTITY_CERT_BYTES = ( + b"-----BEGIN CERTIFICATE-----\n" + b"MIIDIzCCAgugAwIBAgIJAMfISuBQ5m+5MA0GCSqGSIb3DQEBBQUAMBUxEzARBgNV\n" + b"BAMTCnVuaXQtdGVzdHMwHhcNMTExMjA2MTYyNjAyWhcNMjExMjAzMTYyNjAyWjAV\n" + b"MRMwEQYDVQQDEwp1bml0LXRlc3RzMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB\n" + b"CgKCAQEA4ej0p7bQ7L/r4rVGUz9RN4VQWoej1Bg1mYWIDYslvKrk1gpj7wZgkdmM\n" + b"7oVK2OfgrSj/FCTkInKPqaCR0gD7K80q+mLBrN3PUkDrJQZpvRZIff3/xmVU1Wer\n" + b"uQLFJjnFb2dqu0s/FY/2kWiJtBCakXvXEOb7zfbINuayL+MSsCGSdVYsSliS5qQp\n" + b"gyDap+8b5fpXZVJkq92hrcNtbkg7hCYUJczt8n9hcCTJCfUpApvaFQ18pe+zpyl4\n" + b"+WzkP66I28hniMQyUlA1hBiskT7qiouq0m8IOodhv2fagSZKjOTTU2xkSBc//fy3\n" + b"ZpsL7WqgsZS7Q+0VRK8gKfqkxg5OYQIDAQABo3YwdDAdBgNVHQ4EFgQU2RQ8yO+O\n" + b"gN8oVW2SW7RLrfYd9jEwRQYDVR0jBD4wPIAU2RQ8yO+OgN8oVW2SW7RLrfYd9jGh\n" + b"GaQXMBUxEzARBgNVBAMTCnVuaXQtdGVzdHOCCQDHyErgUOZvuTAMBgNVHRMEBTAD\n" + b"AQH/MA0GCSqGSIb3DQEBBQUAA4IBAQBRv+M/6+FiVu7KXNjFI5pSN17OcW5QUtPr\n" + b"odJMlWrJBtynn/TA1oJlYu3yV5clc/71Vr/AxuX5xGP+IXL32YDF9lTUJXG/uUGk\n" + b"+JETpKmQviPbRsvzYhz4pf6ZIOZMc3/GIcNq92ECbseGO+yAgyWUVKMmZM0HqXC9\n" + b"ovNslqe0M8C1sLm1zAR5z/h/litE7/8O2ietija3Q/qtl2TOXJdCA6sgjJX2WUql\n" + b"ybrC55ct18NKf3qhpcEkGQvFU40rVYApJpi98DiZPYFdx1oBDp/f4uZ3ojpxRVFT\n" + b"cDwcJLfNRCPUhormsY7fDS9xSyThiHsW9mjJYdcaKQkwYZ0F11yB\n" + b"-----END CERTIFICATE-----\n" +) + + +class TestAgentIdentityUtils: + @mock.patch("cryptography.x509.load_pem_x509_certificate") + def test_parse_certificate(self, mock_load_cert): + result = _agent_identity_utils.parse_certificate(b"cert_bytes") + mock_load_cert.assert_called_once_with(b"cert_bytes") + assert result == mock_load_cert.return_value + + def test__is_agent_identity_certificate_invalid(self): + cert = _agent_identity_utils.parse_certificate(NON_AGENT_IDENTITY_CERT_BYTES) + assert not _agent_identity_utils._is_agent_identity_certificate(cert) + + def test__is_agent_identity_certificate_valid_spiffe(self): + mock_cert = mock.MagicMock() + mock_ext = mock.MagicMock() + mock_san_value = mock.MagicMock() + mock_cert.extensions.get_extension_for_oid.return_value = mock_ext + mock_ext.value = mock_san_value + mock_san_value.get_values_for_type.return_value = [ + "spiffe://agents.global.proj-12345.system.id.goog/workload" + ] + assert _agent_identity_utils._is_agent_identity_certificate(mock_cert) + + def test__is_agent_identity_certificate_non_matching_spiffe(self): + mock_cert = mock.MagicMock() + mock_ext = mock.MagicMock() + mock_san_value = mock.MagicMock() + mock_cert.extensions.get_extension_for_oid.return_value = mock_ext + mock_ext.value = mock_san_value + mock_san_value.get_values_for_type.return_value = [ + "spiffe://other.domain.com/workload" + ] + assert not _agent_identity_utils._is_agent_identity_certificate(mock_cert) + + def test__is_agent_identity_certificate_no_san(self): + mock_cert = mock.MagicMock() + mock_cert.extensions.get_extension_for_oid.side_effect = x509.ExtensionNotFound( + "Test extension not found", None + ) + assert not _agent_identity_utils._is_agent_identity_certificate(mock_cert) + + def test__is_agent_identity_certificate_not_spiffe_uri(self): + mock_cert = mock.MagicMock() + mock_ext = mock.MagicMock() + mock_san_value = mock.MagicMock() + mock_cert.extensions.get_extension_for_oid.return_value = mock_ext + mock_ext.value = mock_san_value + mock_san_value.get_values_for_type.return_value = ["https://example.com"] + assert not _agent_identity_utils._is_agent_identity_certificate(mock_cert) + + def test_calculate_certificate_fingerprint(self): + mock_cert = mock.MagicMock() + mock_cert.public_bytes.return_value = b"der-bytes" + + # Expected: base64 (standard), unpadded, then URL-encoded + base64_fingerprint = base64.b64encode( + hashlib.sha256(b"der-bytes").digest() + ).decode("utf-8") + unpadded_base64_fingerprint = base64_fingerprint.rstrip("=") + expected_fingerprint = urllib.parse.quote(unpadded_base64_fingerprint) + + fingerprint = _agent_identity_utils.calculate_certificate_fingerprint(mock_cert) + + assert fingerprint == expected_fingerprint + + @mock.patch("google.auth._agent_identity_utils._is_agent_identity_certificate") + def test_should_request_bound_token(self, mock_is_agent, monkeypatch): + # Agent cert, default env var (opt-in) + mock_is_agent.return_value = True + monkeypatch.delenv( + environment_vars.GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES, + raising=False, + ) + assert _agent_identity_utils.should_request_bound_token(mock.sentinel.cert) + + # Agent cert, explicit opt-in + monkeypatch.setenv( + environment_vars.GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES, + "true", + ) + assert _agent_identity_utils.should_request_bound_token(mock.sentinel.cert) + + # Agent cert, explicit opt-out + monkeypatch.setenv( + environment_vars.GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES, + "false", + ) + assert not _agent_identity_utils.should_request_bound_token(mock.sentinel.cert) + + # Non-agent cert, opt-in + mock_is_agent.return_value = False + monkeypatch.setenv( + environment_vars.GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES, + "true", + ) + assert not _agent_identity_utils.should_request_bound_token(mock.sentinel.cert) + + def test_get_agent_identity_certificate_path_success(self, tmpdir, monkeypatch): + cert_path = tmpdir.join("cert.pem") + cert_path.write("cert_content") + config_path = tmpdir.join("config.json") + config_path.write( + json.dumps({"cert_configs": {"workload": {"cert_path": str(cert_path)}}}) + ) + monkeypatch.setenv( + environment_vars.GOOGLE_API_CERTIFICATE_CONFIG, str(config_path) + ) + + result = _agent_identity_utils.get_agent_identity_certificate_path() + assert result == str(cert_path) + + @mock.patch("time.sleep") + def test_get_agent_identity_certificate_path_retry( + self, mock_sleep, tmpdir, monkeypatch + ): + config_path = tmpdir.join("config.json") + monkeypatch.setenv( + environment_vars.GOOGLE_API_CERTIFICATE_CONFIG, str(config_path) + ) + + # File doesn't exist initially + with pytest.raises(exceptions.RefreshError): + _agent_identity_utils.get_agent_identity_certificate_path() + + assert mock_sleep.call_count == 100 + + @mock.patch("time.sleep") + def test_get_agent_identity_certificate_path_failure( + self, mock_sleep, tmpdir, monkeypatch + ): + config_path = tmpdir.join("non_existent_config.json") + monkeypatch.setenv( + environment_vars.GOOGLE_API_CERTIFICATE_CONFIG, str(config_path) + ) + + with pytest.raises(exceptions.RefreshError) as excinfo: + _agent_identity_utils.get_agent_identity_certificate_path() + + assert "not found after multiple retries" in str(excinfo.value) + assert ( + environment_vars.GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES + in str(excinfo.value) + ) + assert mock_sleep.call_count == 100 + + @mock.patch("time.sleep") + @mock.patch("os.path.exists") + def test_get_agent_identity_certificate_path_cert_not_found( + self, mock_exists, mock_sleep, tmpdir, monkeypatch + ): + cert_path_str = str(tmpdir.join("cert.pem")) + config_path = tmpdir.join("config.json") + config_path.write( + json.dumps({"cert_configs": {"workload": {"cert_path": cert_path_str}}}) + ) + monkeypatch.setenv( + environment_vars.GOOGLE_API_CERTIFICATE_CONFIG, str(config_path) + ) + + def exists_side_effect(path): + return path == str(config_path) + + mock_exists.side_effect = exists_side_effect + + with pytest.raises(exceptions.RefreshError): + _agent_identity_utils.get_agent_identity_certificate_path() + + assert mock_sleep.call_count == 100 + + @mock.patch("google.auth._agent_identity_utils.get_agent_identity_certificate_path") + def test_get_and_parse_agent_identity_certificate_opted_out( + self, mock_get_path, monkeypatch + ): + monkeypatch.setenv( + environment_vars.GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES, + "false", + ) + result = _agent_identity_utils.get_and_parse_agent_identity_certificate() + assert result is None + mock_get_path.assert_not_called() + + @mock.patch("google.auth._agent_identity_utils.get_agent_identity_certificate_path") + def test_get_and_parse_agent_identity_certificate_no_path( + self, mock_get_path, monkeypatch + ): + monkeypatch.setenv( + environment_vars.GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES, + "true", + ) + mock_get_path.return_value = None + result = _agent_identity_utils.get_and_parse_agent_identity_certificate() + assert result is None + mock_get_path.assert_called_once() + + @mock.patch("google.auth._agent_identity_utils.parse_certificate") + @mock.patch("google.auth._agent_identity_utils.get_agent_identity_certificate_path") + def test_get_and_parse_agent_identity_certificate_success( + self, mock_get_path, mock_parse_certificate, monkeypatch + ): + monkeypatch.setenv( + environment_vars.GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES, + "true", + ) + mock_get_path.return_value = "/fake/cert.pem" + mock_open = mock.mock_open(read_data=b"cert_bytes") + + with mock.patch("builtins.open", mock_open): + result = _agent_identity_utils.get_and_parse_agent_identity_certificate() + + mock_open.assert_called_once_with("/fake/cert.pem", "rb") + mock_parse_certificate.assert_called_once_with(b"cert_bytes") + assert result == mock_parse_certificate.return_value + + @mock.patch("time.sleep", return_value=None) + @mock.patch("google.auth._agent_identity_utils._is_certificate_file_ready") + def test_get_agent_identity_certificate_path_fallback_to_well_known_path( + self, mock_is_ready, mock_sleep, monkeypatch + ): + # Set a dummy config path that won't be found. + monkeypatch.setenv( + environment_vars.GOOGLE_API_CERTIFICATE_CONFIG, "/dummy/config.json" + ) + + # First, the primary path from the (mocked) config is not ready. + # Then, the fallback well-known path is ready. + mock_is_ready.side_effect = [False, True] + + result = _agent_identity_utils.get_agent_identity_certificate_path() + + assert result == _agent_identity_utils._WELL_KNOWN_CERT_PATH + # The sleep should have been called once before the fallback is checked. + mock_sleep.assert_called_once() + assert mock_is_ready.call_count == 2 + + @mock.patch("google.auth.transport._mtls_helper.get_client_ssl_credentials") + def test_call_client_cert_callback(self, mock_get_client_ssl_credentials): + mock_get_client_ssl_credentials.return_value = ( + True, + b"cert_bytes", + b"key_bytes", + b"passphrase", + ) + + cert, key = _agent_identity_utils.call_client_cert_callback() + + assert cert == b"cert_bytes" + assert key == b"key_bytes" + mock_get_client_ssl_credentials.assert_called_once_with( + generate_encrypted_key=True + ) + + def test_get_cached_cert_fingerprint_no_cert(self): + with pytest.raises(ValueError, match="mTLS connection is not configured."): + _agent_identity_utils.get_cached_cert_fingerprint(None) + + def test_get_cached_cert_fingerprint_with_cert(self): + fingerprint = _agent_identity_utils.get_cached_cert_fingerprint( + NON_AGENT_IDENTITY_CERT_BYTES + ) + assert isinstance(fingerprint, str) + + +class TestAgentIdentityUtilsNoCryptography: + @pytest.fixture(autouse=True) + def mock_cryptography_import(self): + with mock.patch.dict( + "sys.modules", + { + "cryptography": None, + "cryptography.hazmat": None, + "cryptography.hazmat.primitives": None, + "cryptography.hazmat.primitives.serialization": None, + }, + ): + yield + + def test_parse_certificate_raises_import_error(self): + with pytest.raises(ImportError, match="The cryptography library is required"): + _agent_identity_utils.parse_certificate(b"cert_bytes") + + def test_is_agent_identity_certificate_raises_import_error(self): + with pytest.raises(ImportError, match="The cryptography library is required"): + _agent_identity_utils._is_agent_identity_certificate(mock.sentinel.cert) + + def test_calculate_certificate_fingerprint_raises_import_error(self): + with pytest.raises(ImportError, match="The cryptography library is required"): + _agent_identity_utils.calculate_certificate_fingerprint(mock.sentinel.cert) diff --git a/tests/test_external_account.py b/tests/test_external_account.py index 2fa64361d..a56b54a43 100644 --- a/tests/test_external_account.py +++ b/tests/test_external_account.py @@ -737,6 +737,24 @@ def test_refresh_skips_trust_boundary_lookup_when_disabled( credentials.apply(headers_applied) assert "x-allowed-locations" not in headers_applied + def test_refresh_token_with_cert_fingerprint(self): + credentials = self.make_credentials() + credentials._sts_client = mock.MagicMock() + credentials._sts_client.exchange_token.return_value = { + "access_token": "token", + "expires_in": 3600, + } + credentials.retrieve_subject_token = mock.MagicMock( + return_value="subject_token" + ) + + credentials._refresh_token( + request=mock.sentinel.request, cert_fingerprint="my-fingerprint" + ) + + _, kwargs = credentials._sts_client.exchange_token.call_args + assert kwargs["additional_options"]["bindCertFingerprint"] == "my-fingerprint" + def test_refresh_skips_sending_allowed_locations_header_with_trust_boundary(self): # This test verifies that the x-allowed-locations header is not sent with # the STS request even if a trust boundary is cached. diff --git a/tests/test_identity_pool.py b/tests/test_identity_pool.py index dbbdbf53a..529d83d65 100644 --- a/tests/test_identity_pool.py +++ b/tests/test_identity_pool.py @@ -1772,3 +1772,59 @@ def test_get_mtls_certs_invalid(self): assert excinfo.match( 'The credential is not configured to use mtls requests. The credential should include a "certificate" section in the credential source.' ) + + @mock.patch("google.auth._agent_identity_utils.parse_certificate") + @mock.patch( + "google.auth._agent_identity_utils.should_request_bound_token", + return_value=True, + ) + @mock.patch( + "google.auth._agent_identity_utils.calculate_certificate_fingerprint", + return_value="fingerprint", + ) + @mock.patch.object( + identity_pool.Credentials, "_get_cert_bytes", return_value=b"cert" + ) + @mock.patch.object(external_account.Credentials, "_refresh_token") + def test_refresh_with_agent_identity( + self, + mock_refresh_token, + mock_get_cert_bytes, + mock_calculate_fingerprint, + mock_should_request, + mock_parse_certificate, + ): + mock_parse_certificate.return_value = mock.sentinel.cert + credentials = self.make_credentials( + credential_source=self.CREDENTIAL_SOURCE_CERTIFICATE.copy() + ) + credentials.refresh(None) + mock_parse_certificate.assert_called_once_with(b"cert") + mock_should_request.assert_called_once_with(mock.sentinel.cert) + mock_calculate_fingerprint.assert_called_once_with(mock.sentinel.cert) + mock_refresh_token.assert_called_once_with(None, cert_fingerprint="fingerprint") + + @mock.patch("google.auth._agent_identity_utils.parse_certificate") + @mock.patch( + "google.auth._agent_identity_utils.should_request_bound_token", + return_value=False, + ) + @mock.patch.object( + identity_pool.Credentials, "_get_cert_bytes", return_value=b"cert" + ) + @mock.patch.object(external_account.Credentials, "_refresh_token") + def test_refresh_with_agent_identity_opt_out_or_not_agent( + self, + mock_refresh_token, + mock_get_cert_bytes, + mock_should_request, + mock_parse_certificate, + ): + mock_parse_certificate.return_value = mock.sentinel.cert + credentials = self.make_credentials( + credential_source=self.CREDENTIAL_SOURCE_CERTIFICATE.copy() + ) + credentials.refresh(None) + mock_parse_certificate.assert_called_once_with(b"cert") + mock_should_request.assert_called_once_with(mock.sentinel.cert) + mock_refresh_token.assert_called_once_with(None, cert_fingerprint=None) diff --git a/tests/transport/test__mtls_helper.py b/tests/transport/test__mtls_helper.py index 2a7a524b1..7b49215cc 100644 --- a/tests/transport/test__mtls_helper.py +++ b/tests/transport/test__mtls_helper.py @@ -23,8 +23,9 @@ from google.auth import exceptions from google.auth.transport import _mtls_helper +CERT_MOCK_VAL = b"cert" +KEY_MOCK_VAL = b"key" CONTEXT_AWARE_METADATA = {"cert_provider_command": ["some command"]} - ENCRYPTED_EC_PRIVATE_KEY = b"""-----BEGIN ENCRYPTED PRIVATE KEY----- MIHkME8GCSqGSIb3DQEFDTBCMCkGCSqGSIb3DQEFDDAcBAgl2/yVgs1h3QICCAAw DAYIKoZIhvcNAgkFADAVBgkrBgEEAZdVAQIECJk2GRrvxOaJBIGQXIBnMU4wmciT @@ -813,3 +814,64 @@ def test_check_use_client_cert_when_file_does_not_exist(self, monkeypatch): monkeypatch.setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "") use_client_cert = _mtls_helper.check_use_client_cert() assert use_client_cert is False + + +class TestMtlsHelper: + @mock.patch("google.auth.transport._mtls_helper._agent_identity_utils") + def test_check_parameters_for_unauthorized_response_with_cached_cert( + self, mock_agent_identity_utils + ): + mock_agent_identity_utils.call_client_cert_callback.return_value = ( + CERT_MOCK_VAL, + KEY_MOCK_VAL, + ) + mock_agent_identity_utils.get_cached_cert_fingerprint.return_value = ( + "cached_fingerprint" + ) + mock_agent_identity_utils.calculate_certificate_fingerprint.return_value = ( + "current_fingerprint" + ) + + ( + cert, + key, + cached_fingerprint, + current_fingerprint, + ) = _mtls_helper.check_parameters_for_unauthorized_response( + cached_cert=b"cached_cert_bytes" + ) + + assert cert == CERT_MOCK_VAL + assert key == KEY_MOCK_VAL + assert cached_fingerprint == "cached_fingerprint" + assert current_fingerprint == "current_fingerprint" + mock_agent_identity_utils.call_client_cert_callback.assert_called_once() + mock_agent_identity_utils.get_cached_cert_fingerprint.assert_called_once_with( + b"cached_cert_bytes" + ) + + @mock.patch("google.auth.transport._mtls_helper._agent_identity_utils") + def test_check_parameters_for_unauthorized_response_without_cached_cert( + self, mock_agent_identity_utils + ): + mock_agent_identity_utils.call_client_cert_callback.return_value = ( + CERT_MOCK_VAL, + KEY_MOCK_VAL, + ) + mock_agent_identity_utils.calculate_certificate_fingerprint.return_value = ( + "current_fingerprint" + ) + + ( + cert, + key, + cached_fingerprint, + current_fingerprint, + ) = _mtls_helper.check_parameters_for_unauthorized_response(cached_cert=None) + + assert cert == CERT_MOCK_VAL + assert key == KEY_MOCK_VAL + assert cached_fingerprint == "current_fingerprint" + assert current_fingerprint == "current_fingerprint" + mock_agent_identity_utils.call_client_cert_callback.assert_called_once() + mock_agent_identity_utils.get_cached_cert_fingerprint.assert_not_called() diff --git a/tests/transport/test_requests.py b/tests/transport/test_requests.py index 0da3e36d9..ccc937527 100644 --- a/tests/transport/test_requests.py +++ b/tests/transport/test_requests.py @@ -41,6 +41,10 @@ def frozen_time(): yield frozen +CERT_MOCK_VAL = b"-----BEGIN CERTIFICATE-----\nMIIDIzCCAgugAwIBAgIJAMfISuBQ5m+5MA0GCSqGSIb3DQEBBQUAMBUxEzARBgNV\nBAMTCnVuaXQtdGVzdHMwHhcNMTExMjA2MTYyNjAyWhcNMjExMjAzMTYyNjAyWjAV\nMRMwEQYDVQQDEwp1bml0LXRlc3RzMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB\nCgKCAQEA4ej0p7bQ7L/r4rVGUz9RN4VQWoej1Bg1mYWIDYslvKrk1gpj7wZgkdmM\n7oVK2OfgrSj/FCTkInKPqaCR0gD7K80q+mLBrN3PUkDrJQZpvRZIff3/xmVU1Wer\nuQLFJjnFb2dqu0s/FY/2kWiJtBCakXvXEOb7zfbINuayL+MSsCGSdVYsSliS5qQp\ngyDap+8b5fpXZVJkq92hrcNtbkg7hCYUJczt8n9hcCTJCfUpApvaFQ18pe+zpyl4\n+WzkP66I28hniMQyUlA1hBiskT7qiouq0m8IOodhv2fagSZKjOTTU2xkSBc//fy3\nZpsL7WqgsZS7Q+0VRK8gKfqkxg5OYQIDAQABo3YwdDAdBgNVHQ4EFgQU2RQ8yO+O\ngN8oVW2SW7RLrfYd9jEwRQYDVR0jBD4wPIAU2RQ8yO+OgN8oVW2SW7RLrfYd9jGh\nGaQXMBUxEzARBgNVBAMTCnVuaXQtdGVzdHOCCQDHyErgUOZvuTAMBgNVHRMEBTAD\nAQH/MA0GCSqGSIb3DQEBBQUAA4IBAQBRv+M/6+FiVu7KXNjFI5pSN17OcW5QUtPr\nodJMlWrJBtynn/TA1oJlYu3yV5clc/71Vr/AxuX5xGP+IXL32YDF9lTUJXG/uUGk\n+JETpKmQviPbRsvzYhz4pf6ZIOZMc3/GIcNq92ECbseGO+yAgyWUVKMmZM0HqXC9\novNslqe0M8C1sLm1zAR5z/h/litE7/8O2ietija3Q/qtl2TOXJdCA6sgjJX2WUql\nybrC55ct18NKf3qhpcEkGQvFU40rVYApJpi98DiZPYFdx1oBDp/f4uZ3ojpxRVFT\ncDwcJLfNRCPUhormsY7fDS9xSyThiHsW9mjJYdcaKQkwYZ0F11yB\n-----END CERTIFICATE-----\n" +KEY_MOCK_VAL = b"-----BEGIN ENCRYPTED PRIVATE KEY-----\nMIHeMEkGCSqGSIb3DQEFDTA8MBsGCSqGSIb3DQEFDDAOBAj9XnJ2h78QVAICCAAw\nHQYJYIZIAWUDBAECBBBeiiOF2LnLzq/wjb/viwMwBIGQk28Zkfj2EIk42bgc7UzC\nSf98qssCVhsIYz0Xa3eSATg8Cpn83YieaBeyxdk/tXTnrOhxMV/vt7T98kWhaGbH\n5Z9CdGVLfes0UFvVJqrlk6vcf2sOnLCGbrn78HS+ayrGOCRSCd/7+dnEiB/7Um1B\nMk6BBJHsLEnZZSHyfrw8jvYgVmcSBy/WdY0pqldD/+4D\n-----END ENCRYPTED PRIVATE KEY-----\n" + + class TestRequestResponse(compliance.RequestResponseTests): def make_request(self): return google.auth.transport.requests.Request() @@ -537,6 +541,226 @@ def test_close_w_passed_in_auth_request(self): authed_session.close() # no raise + def test_cert_rotation_when_cert_mismatch_and_mtls_enabled(self): + credentials = mock.Mock(wraps=CredentialsStub()) + final_response = make_response(status=http_client.OK) + # First request will 401, second request will succeed. + adapter = AdapterStub( + [make_response(status=http_client.UNAUTHORIZED), final_response] + ) + + authed_session = google.auth.transport.requests.AuthorizedSession( + credentials, refresh_timeout=60 + ) + authed_session.mount(self.TEST_URL, adapter) + + old_cert = b"-----BEGIN CERTIFICATE-----\nMIIBdTCCARqgAwIBAgIJAOYVvu/axMxvMAoGCCqGSM49BAMCMCcxJTAjBgNVBAMM\nHEdvb2dsZSBFbmRwb2ludCBWZXJpZmljYXRpb24wHhcNMjUwNzMwMjMwNjA4WhcN\nMjYwNzMxMjMwNjA4WjAnMSUwIwYDVQQDDBxHb29nbGUgRW5kcG9pbnQgVmVyaWZp\nY2F0aW9uMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEbtr18gkEtwPow2oqyZsU\n4KLwFaLFlRlYv55UATS3QTDykDnIufC42TJCnqFRYhwicwpE2jnUV+l9g3Voias8\nraMvMC0wCQYDVR0TBAIwADALBgNVHQ8EBAMCB4AwEwYDVR0lBAwwCgYIKwYBBQUH\nAwIwCgYIKoZIzj0EAwIDSQAwRgIhAKcjW6dmF1YCksXPgDPlPu/nSnOjb3qCcivz\n/Jxq2zoeAiEA7/aNxcEoCGS3hwMIXoaaD/vPcZOOopKSyqXCvxRooKQ=\n-----END CERTIFICATE-----\n" + + # New certificate and key to simulate rotation. + new_cert = CERT_MOCK_VAL + new_key = KEY_MOCK_VAL + + # Set _cached_cert to a callable that returns the old certificate. + authed_session._cached_cert = old_cert + authed_session._is_mtls = True + + # Mock call_client_cert_callback to return the new certificate. + with mock.patch.object( + google.auth.transport._mtls_helper._agent_identity_utils, + "call_client_cert_callback", + return_value=(new_cert, new_key), + ) as mock_callback: + result = authed_session.request("GET", self.TEST_URL) + + # Asserts to verify the behavior. + assert mock_callback.called + assert credentials.refresh.called + assert credentials.refresh.call_count == 1 + assert result.status_code == final_response.status_code + + def test_no_cert_rotation_when_cert_match_and_mTLS_enabled(self): + credentials = mock.Mock(wraps=CredentialsStub()) + final_response = make_response(status=http_client.UNAUTHORIZED) + adapter = AdapterStub( + [ + make_response(status=http_client.UNAUTHORIZED), + make_response(status=http_client.UNAUTHORIZED), + make_response(status=http_client.UNAUTHORIZED), + ] + ) + authed_session = google.auth.transport.requests.AuthorizedSession( + credentials, refresh_timeout=60 + ) + authed_session.mount(self.TEST_URL, adapter) + authed_session._is_mtls = True + + old_cert = CERT_MOCK_VAL + + # New certificate and key to simulate rotation. + new_cert = old_cert + new_key = KEY_MOCK_VAL + + # Set _cached_cert to a callable that returns the old certificate. + authed_session._cached_cert = old_cert + + # Mock call_client_cert_callback to return the new certificate. + with mock.patch.object( + google.auth.transport._mtls_helper._agent_identity_utils, + "call_client_cert_callback", + return_value=(new_cert, new_key), + ): + result = authed_session.request("GET", self.TEST_URL) + + # Asserts to verify the behavior. + assert credentials.refresh.call_count == 2 + assert result.status_code == final_response.status_code + + def test_no_cert_match_check_when_mtls_disabled(self): + credentials = mock.Mock(wraps=CredentialsStub()) + final_response = make_response(status=http_client.UNAUTHORIZED) + adapter = AdapterStub( + [ + make_response(status=http_client.UNAUTHORIZED), + make_response(status=http_client.UNAUTHORIZED), + make_response(status=http_client.UNAUTHORIZED), + ] + ) + authed_session = google.auth.transport.requests.AuthorizedSession( + credentials, refresh_timeout=60 + ) + authed_session.mount(self.TEST_URL, adapter) + authed_session._is_mtls = False + + new_cert = CERT_MOCK_VAL + + # New certificate and key to simulate rotation. + new_key = KEY_MOCK_VAL + + # Mock call_client_cert_callback to return the new certificate. + with mock.patch.object( + google.auth.transport._mtls_helper._agent_identity_utils, + "call_client_cert_callback", + return_value=(new_cert, new_key), + ) as mock_callback: + result = authed_session.request("GET", self.TEST_URL) + + # Asserts to verify the behavior. + assert not mock_callback.called + assert result.status_code == final_response.status_code + + def test_no_cert_rotation_when_no_unauthorized_response(self): + credentials = mock.Mock(wraps=CredentialsStub()) + final_response = make_response(status=http_client.UPGRADE_REQUIRED) + + # Response is set to code other than 401(Unauthorized). + adapter = AdapterStub([make_response(status=http_client.UPGRADE_REQUIRED)]) + authed_session = google.auth.transport.requests.AuthorizedSession( + credentials, refresh_timeout=60 + ) + authed_session.mount(self.TEST_URL, adapter) + + authed_session._is_mtls = True + + result = authed_session.request("GET", self.TEST_URL) + assert result.status_code == final_response.status_code + + # Asserts to verify the behavior. + assert not credentials.refresh.called + assert credentials.refresh.call_count == 0 + + def test_cert_rotation_failure_raises_error(self): + credentials = mock.Mock(wraps=CredentialsStub()) + # First request will 401, second request will fail to reconfigure mTLS. + adapter = AdapterStub([make_response(status=http_client.UNAUTHORIZED)]) + + authed_session = google.auth.transport.requests.AuthorizedSession( + credentials, refresh_timeout=60 + ) + authed_session.mount(self.TEST_URL, adapter) + + old_cert = b"-----BEGIN CERTIFICATE-----\nMIIBdTCCARqgAwIBAgIJAOYVvu/axMxvMAoGCCqGSM49BAMCMCcxJTAjBgNVBAMM\nHEdvb2dsZSBFbmRwb2ludCBWZXJpZmljYXRpb24wHhcNMjUwNzMwMjMwNjA4WhcN\nMjYwNzMxMjMwNjA4WjAnMSUwIwYDVQQDDBxHb29nbGUgRW5kcG9pbnQgVmVyaWZp\nY2F0aW9uMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEbtr18gkEtwPow2oqyZsU\n4KLwFaLFlRlYv55UATS3QTDykDnIufC42TJCnqFRYhwicwpE2jnUV+l9g3Voias8\nraMvMC0wCQYDVR0TBAIwADALBgNVHQ8EBAMCB4AwEwYDVR0lBAwwCgYIKwYBBQUH\nAwIwCgYIKoZIzj0EAwIDSQAwRgIhAKcjW6dmF1YCksXPgDPlPu/nSnOjb3qCcivz\n/Jxq2zoeAiEA7/aNxcEoCGS3hwMIXoaaD/vPcZOOopKSyqXCvxRooKQ=\n-----END CERTIFICATE-----\n" + + # New certificate and key to simulate rotation. + new_cert = CERT_MOCK_VAL + new_key = KEY_MOCK_VAL + + authed_session._cached_cert = old_cert + authed_session._is_mtls = True + + with mock.patch.object( + google.auth.transport._mtls_helper._agent_identity_utils, + "call_client_cert_callback", + return_value=(new_cert, new_key), + ): + with mock.patch.object( + authed_session, + "configure_mtls_channel", + side_effect=Exception("Failed to reconfigure"), + ): + with pytest.raises(exceptions.MutualTLSChannelError): + authed_session.request("GET", self.TEST_URL) + + # Assert to verify behavior + credentials.refresh.assert_not_called() + + def test_cert_rotation_check_params_fails(self): + credentials = mock.Mock(wraps=CredentialsStub()) + adapter = AdapterStub([make_response(status=http_client.UNAUTHORIZED)]) + + authed_session = google.auth.transport.requests.AuthorizedSession( + credentials, refresh_timeout=60 + ) + authed_session.mount(self.TEST_URL, adapter) + authed_session._is_mtls = True + authed_session._cached_cert = b"cached_cert" + + with mock.patch( + "google.auth.transport.requests._mtls_helper.check_parameters_for_unauthorized_response", + side_effect=Exception("check_params failed"), + ) as mock_check_params: + with pytest.raises(Exception, match="check_params failed"): + authed_session.request("GET", self.TEST_URL) + + mock_check_params.assert_called_once() + credentials.refresh.assert_not_called() + + def test_cert_rotation_logic_skipped_on_other_refresh_status_codes(self): + """ + Tests that the code can handle a refresh triggered by a status code + other than 401 (UNAUTHORIZED). This covers the 'else' branch of the + 'if response.status_code == http_client.UNAUTHORIZED' check + """ + credentials = mock.Mock(wraps=CredentialsStub()) + # Configure the session to treat 503 (Service Unavailable) as a refreshable error + custom_refresh_codes = [http_client.SERVICE_UNAVAILABLE] + + # Return 503 first, then 200 + adapter = AdapterStub( + [ + make_response(status=http_client.SERVICE_UNAVAILABLE), + make_response(status=http_client.OK), + ] + ) + + authed_session = google.auth.transport.requests.AuthorizedSession( + credentials, refresh_status_codes=custom_refresh_codes + ) + authed_session.mount(self.TEST_URL, adapter) + + # Enable mTLS to prove it is skipped despite being enabled + authed_session._is_mtls = True + + with mock.patch( + "google.auth.transport.requests._mtls_helper", autospec=True + ) as mock_helper: + authed_session.request("GET", self.TEST_URL) + + # Assert refresh happened (Outer Check was True) + assert credentials.refresh.called + + # Assert mTLS check logic was SKIPPED (Inner Check was False) + assert not mock_helper.check_parameters_for_unauthorized_response.called + class TestMutualTlsOffloadAdapter(object): @mock.patch.object(requests.adapters.HTTPAdapter, "init_poolmanager") diff --git a/tests/transport/test_urllib3.py b/tests/transport/test_urllib3.py index e83230032..7872a7187 100644 --- a/tests/transport/test_urllib3.py +++ b/tests/transport/test_urllib3.py @@ -29,6 +29,9 @@ from google.oauth2 import service_account from tests.transport import compliance +CERT_MOCK_VAL = b"-----BEGIN CERTIFICATE-----\nMIIDIzCCAgugAwIBAgIJAMfISuBQ5m+5MA0GCSqGSIb3DQEBBQUAMBUxEzARBgNV\nBAMTCnVuaXQtdGVzdHMwHhcNMTExMjA2MTYyNjAyWhcNMjExMjAzMTYyNjAyWjAV\nMRMwEQYDVQQDEwp1bml0LXRlc3RzMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB\nCgKCAQEA4ej0p7bQ7L/r4rVGUz9RN4VQWoej1Bg1mYWIDYslvKrk1gpj7wZgkdmM\n7oVK2OfgrSj/FCTkInKPqaCR0gD7K80q+mLBrN3PUkDrJQZpvRZIff3/xmVU1Wer\nuQLFJjnFb2dqu0s/FY/2kWiJtBCakXvXEOb7zfbINuayL+MSsCGSdVYsSliS5qQp\ngyDap+8b5fpXZVJkq92hrcNtbkg7hCYUJczt8n9hcCTJCfUpApvaFQ18pe+zpyl4\n+WzkP66I28hniMQyUlA1hBiskT7qiouq0m8IOodhv2fagSZKjOTTU2xkSBc//fy3\nZpsL7WqgsZS7Q+0VRK8gKfqkxg5OYQIDAQABo3YwdDAdBgNVHQ4EFgQU2RQ8yO+O\ngN8oVW2SW7RLrfYd9jEwRQYDVR0jBD4wPIAU2RQ8yO+OgN8oVW2SW7RLrfYd9jGh\nGaQXMBUxEzARBgNVBAMTCnVuaXQtdGVzdHOCCQDHyErgUOZvuTAMBgNVHRMEBTAD\nAQH/MA0GCSqGSIb3DQEBBQUAA4IBAQBRv+M/6+FiVu7KXNjFI5pSN17OcW5QUtPr\nodJMlWrJBtynn/TA1oJlYu3yV5clc/71Vr/AxuX5xGP+IXL32YDF9lTUJXG/uUGk\n+JETpKmQviPbRsvzYhz4pf6ZIOZMc3/GIcNq92ECbseGO+yAgyWUVKMmZM0HqXC9\novNslqe0M8C1sLm1zAR5z/h/litE7/8O2ietija3Q/qtl2TOXJdCA6sgjJX2WUql\nybrC55ct18NKf3qhpcEkGQvFU40rVYApJpi98DiZPYFdx1oBDp/f4uZ3ojpxRVFT\ncDwcJLfNRCPUhormsY7fDS9xSyThiHsW9mjJYdcaKQkwYZ0F11yB\n-----END CERTIFICATE-----\n" +KEY_MOCK_VAL = b"-----BEGIN ENCRYPTED PRIVATE KEY-----\nMIHeMEkGCSqGSIb3DQEFDTA8MBsGCSqGSIb3DQEFDDAOBAj9XnJ2h78QVAICCAAw\nHQYJYIZIAWUDBAECBBBeiiOF2LnLzq/wjb/viwMwBIGQk28Zkfj2EIk42bgc7UzC\nSf98qssCVhsIYz0Xa3eSATg8Cpn83YieaBeyxdk/tXTnrOhxMV/vt7T98kWhaGbH\n5Z9CdGVLfes0UFvVJqrlk6vcf2sOnLCGbrn78HS+ayrGOCRSCd/7+dnEiB/7Um1B\nMk6BBJHsLEnZZSHyfrw8jvYgVmcSBy/WdY0pqldD/+4D\n-----END ENCRYPTED PRIVATE KEY-----\n" + class TestRequestResponse(compliance.RequestResponseTests): def make_request(self): @@ -320,3 +323,208 @@ def test_clear_pool_on_del(self): authed_http.http = None authed_http.__del__() # Expect it to not crash + + def test_cert_rotation_when_cert_mismatch_and_mtls_endpoint_used(self): + credentials = mock.Mock(wraps=CredentialsStub()) + final_response = ResponseStub(status=http_client.OK) + http = HttpStub([ResponseStub(status=http_client.UNAUTHORIZED), final_response]) + + authed_http = google.auth.transport.urllib3.AuthorizedHttp( + credentials, http=http + ) + + old_cert = b"-----BEGIN CERTIFICATE-----\nMIIBdTCCARqgAwIBAgIJAOYVvu/axMxvMAoGCCqGSM49BAMCMCcxJTAjBgNVBAMM\nHEdvb2dsZSBFbmRwb2ludCBWZXJpZmljYXRpb24wHhcNMjUwNzMwMjMwNjA4WhcN\nMjYwNzMxMjMwNjA4WjAnMSUwIwYDVQQDDBxHb29nbGUgRW5kcG9pbnQgVmVyaWZp\nY2F0aW9uMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEbtr18gkEtwPow2oqyZsU\n4KLwFaLFlRlYv55UATS3QTDykDnIufC42TJCnqFRYhwicwpE2jnUV+l9g3Voias8\nraMvMC0wCQYDVR0TBAIwADALBgNVHQ8EBAMCB4AwEwYDVR0lBAwwCgYIKwYBBQUH\nAwIwCgYIKoZIzj0EAwIDSQAwRgIhAKcjW6dmF1YCksXPgDPlPu/nSnOjb3qCcivz\n/Jxq2zoeAiEA7/aNxcEoCGS3hwMIXoaaD/vPcZOOopKSyqXCvxRooKQ=\n-----END CERTIFICATE-----\n" + + # New certificate and key to simulate rotation. + new_cert = CERT_MOCK_VAL + new_key = KEY_MOCK_VAL + # Set _cached_cert to a callable that returns the old certificate. + authed_http._cached_cert = old_cert + authed_http._is_mtls = True + # Mock call_client_cert_callback to return the new certificate. + with mock.patch.object( + google.auth._agent_identity_utils, + "call_client_cert_callback", + return_value=(new_cert, new_key), + ) as mock_callback: + # mTLS endpoint is used + result = authed_http.urlopen("GET", "http://example.mtls.googleapis.com") + + # Asserts to verify the behavior. + assert result == final_response + assert credentials.refresh.called + assert credentials.refresh.call_count == 1 + assert mock_callback.called + + def test_no_cert_rotation_when_cert_match_and_mtls_endpoint_used(self): + credentials = mock.Mock(wraps=CredentialsStub()) + final_response = ResponseStub(status=http_client.UNAUTHORIZED) + http = HttpStub( + [ + ResponseStub(status=http_client.UNAUTHORIZED), + ResponseStub(status=http_client.UNAUTHORIZED), + ResponseStub(status=http_client.UNAUTHORIZED), + ] + ) + authed_http = google.auth.transport.urllib3.AuthorizedHttp( + credentials, http=http + ) + old_cert = CERT_MOCK_VAL + + new_cert = old_cert + new_key = KEY_MOCK_VAL + # Set _cached_cert to a callable that returns the same certificate. + authed_http._cached_cert = old_cert + authed_http._is_mtls = True + # Mock call_client_cert_callback to return the certificate. + with mock.patch.object( + google.auth._agent_identity_utils, + "call_client_cert_callback", + return_value=(new_cert, new_key), + ): + # mTLS endpoint is used + result = authed_http.urlopen("GET", "http://example.mtls.googleapis.com") + + # Asserts to verify the behavior. + assert credentials.refresh.call_count == 2 + assert result.status == final_response.status + + def test_no_cert_match_check_when_mtls_endpoint_not_used(self): + credentials = mock.Mock(wraps=CredentialsStub()) + final_response = ResponseStub(status=http_client.UNAUTHORIZED) + http = HttpStub( + [ + ResponseStub(status=http_client.UNAUTHORIZED), + ResponseStub(status=http_client.UNAUTHORIZED), + ResponseStub(status=http_client.UNAUTHORIZED), + ] + ) + authed_http = google.auth.transport.urllib3.AuthorizedHttp( + credentials, http=http + ) + authed_http._is_mtls = False + new_cert = CERT_MOCK_VAL + new_key = KEY_MOCK_VAL + + # Mock call_client_cert_callback to return the certificate. + with mock.patch.object( + google.auth._agent_identity_utils, + "call_client_cert_callback", + return_value=(new_cert, new_key), + ) as mock_callback: + # non-mTLS endpoint is used + result = authed_http.urlopen("GET", "http://example.googleapis.com") + + # Asserts to verify the behavior. + assert not mock_callback.called + assert result.status == final_response.status + + def test_no_cert_rotation_when_no_unauthorized_response(self): + credentials = mock.Mock(wraps=CredentialsStub()) + final_response = ResponseStub(status=http_client.UPGRADE_REQUIRED) + + # Response is set to code other than 401(Unauthorized). + http = HttpStub([ResponseStub(status=http_client.UPGRADE_REQUIRED)]) + authed_http = google.auth.transport.urllib3.AuthorizedHttp( + credentials, http=http + ) + authed_http._is_mtls = True + with mock.patch.dict( + os.environ, {environment_vars.GOOGLE_API_USE_CLIENT_CERTIFICATE: "true"} + ): + # mTLS endpoint is used + result = authed_http.urlopen("GET", "http://example.mtls.googleapis.com") + assert result.status == final_response.status + assert not credentials.refresh.called + assert credentials.refresh.call_count == 0 + + def test_cert_rotation_failure_raises_error(self): + credentials = mock.Mock(wraps=CredentialsStub()) + http = HttpStub([ResponseStub(status=http_client.UNAUTHORIZED)]) + + authed_http = google.auth.transport.urllib3.AuthorizedHttp( + credentials, http=http + ) + + old_cert = b"-----BEGIN CERTIFICATE-----\nMIIBdTCCARqgAwIBAgIJAOYVvu/axMxvMAoGCCqGSM49BAMCMCcxJTAjBgNVBAMM\nHEdvb2dsZSBFbmRwb2ludCBWZXJpZmljYXRpb24wHhcNMjUwNzMwMjMwNjA4WhcN\nMjYwNzMxMjMwNjA4WjAnMSUwIwYDVQQDDBxHb29nbGUgRW5kcG9pbnQgVmVyaWZp\nY2F0aW9uMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEbtr18gkEtwPow2oqyZsU\n4KLwFaLFlRlYv55UATS3QTDykDnIufC42TJCnqFRYhwicwpE2jnUV+l9g3Voias8\nraMvMC0wCQYDVR0TBAIwADALBgNVHQ8EBAMCB4AwEwYDVR0lBAwwCgYIKwYBBQUH\nAwIwCgYIKoZIzj0EAwIDSQAwRgIhAKcjW6dmF1YCksXPgDPlPu/nSnOjb3qCcivz\n/Jxq2zoeAiEA7/aNxcEoCGS3hwMIXoaaD/vPcZOOopKSyqXCvxRooKQ=\n-----END CERTIFICATE-----\n" + + # New certificate and key to simulate rotation. + new_cert = CERT_MOCK_VAL + new_key = KEY_MOCK_VAL + authed_http._cached_cert = old_cert + authed_http._is_mtls = True + + # Mock call_client_cert_callback to return the new certificate. + with mock.patch.object( + google.auth.transport._mtls_helper, + "check_parameters_for_unauthorized_response", + return_value=(new_cert, new_key, "old_fingerprint", "new_fingerprint"), + ) as mock_check_params: + with mock.patch.object( + authed_http, + "configure_mtls_channel", + side_effect=Exception("Failed to reconfigure"), + ) as mock_reconfigure: + with pytest.raises(exceptions.MutualTLSChannelError): + authed_http.urlopen("GET", "https://example.mtls.googleapis.com") + + mock_check_params.assert_called_once() + mock_reconfigure.assert_called_once() + credentials.refresh.assert_not_called() + + def test_cert_rotation_check_params_fails(self): + credentials = mock.Mock(wraps=CredentialsStub()) + http = HttpStub([ResponseStub(status=http_client.UNAUTHORIZED)]) + + authed_http = google.auth.transport.urllib3.AuthorizedHttp( + credentials, http=http + ) + authed_http._is_mtls = True + authed_http._cached_cert = b"cached_cert" + + with mock.patch( + "google.auth.transport.urllib3._mtls_helper.check_parameters_for_unauthorized_response", + side_effect=Exception("check_params failed"), + ) as mock_check_params: + with pytest.raises(Exception, match="check_params failed"): + authed_http.urlopen("GET", "http://example.mtls.googleapis.com") + + mock_check_params.assert_called_once() + credentials.refresh.assert_not_called() + + def test_cert_rotation_logic_skipped_on_other_refresh_status_codes(self): + """ + Tests that the code can handle a refresh triggered by a status code + other than 401 (UNAUTHORIZED). This covers the 'else' branch of the + 'if response.status_code == http_client.UNAUTHORIZED' check + """ + credentials = mock.Mock(wraps=CredentialsStub()) + # Configure the session to treat 503 (Service Unavailable) as a refreshable error + custom_codes = [http_client.SERVICE_UNAVAILABLE] + + # Return 503 first, then 200 + http = HttpStub( + [ + ResponseStub(status=http_client.SERVICE_UNAVAILABLE), + ResponseStub(status=http_client.OK), + ] + ) + + authed_http = google.auth.transport.urllib3.AuthorizedHttp( + credentials, http=http, refresh_status_codes=custom_codes + ) + + # Enable mTLS to prove it is skipped despite being enabled + authed_http._is_mtls = True + mtls_url = "https://mtls.googleapis.com/test" + + with mock.patch( + "google.auth.transport.urllib3._mtls_helper", autospec=True + ) as mock_helper: + authed_http.urlopen("GET", mtls_url) + + # Assert refresh happened (Outer Check was True) + assert credentials.refresh.called + + # Assert mTLS check logic was SKIPPED (Inner Check was False) + assert not mock_helper.check_parameters_for_unauthorized_response.called