Skip to content

Commit 8a9cb7b

Browse files
agrawalradhika-cellnbayatiandyrzhao
authored
feat: Add retry logic when certificate mismatch for existing credentials & Agent Identity workloads (#1841)
feat: Add retry logic when certificate mismatch for existing credentials & Agent Identity workloads This change introduces retry support when requests are created for existing credentials and Agent Identities on GKE and Cloud Run Workloads. When 401(Unauthorized) error is created, due to certificate at time of configuration of mTLS channel being different from the current certificate, a retry is added to the request by configuring the mTLS channel with the current certificate. --------- Signed-off-by: Radhika Agrawal <[email protected]> Co-authored-by: nbayati <[email protected]> Co-authored-by: Andy Zhao <[email protected]>
1 parent cef6a2b commit 8a9cb7b

File tree

9 files changed

+667
-6
lines changed

9 files changed

+667
-6
lines changed

google/auth/_agent_identity_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@
2020
import os
2121
import re
2222
import time
23-
from urllib.parse import urlparse, quote
23+
from urllib.parse import quote, urlparse
2424

2525
from google.auth import environment_vars
2626
from google.auth import exceptions
27+
from google.auth.transport import _mtls_helper
2728

2829

2930
_LOGGER = logging.getLogger(__name__)
@@ -260,3 +261,21 @@ def should_request_bound_token(cert):
260261
== "true"
261262
)
262263
return is_agent_cert and is_opted_in
264+
265+
266+
def call_client_cert_callback():
267+
"""Calls the client cert callback and returns the certificate and key."""
268+
_, cert_bytes, key_bytes, passphrase = _mtls_helper.get_client_ssl_credentials(
269+
generate_encrypted_key=True
270+
)
271+
return cert_bytes, key_bytes
272+
273+
274+
def get_cached_cert_fingerprint(cached_cert):
275+
"""Returns the fingerprint of the cached certificate."""
276+
if cached_cert:
277+
cert_obj = parse_certificate(cached_cert)
278+
cached_cert_fingerprint = calculate_certificate_fingerprint(cert_obj)
279+
else:
280+
raise ValueError("mTLS connection is not configured.")
281+
return cached_cert_fingerprint

google/auth/transport/_mtls_helper.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import re
2121
import subprocess
2222

23+
from google.auth import _agent_identity_utils
2324
from google.auth import environment_vars
2425
from google.auth import exceptions
2526

@@ -474,3 +475,29 @@ def check_use_client_cert():
474475
) as e:
475476
_LOGGER.debug("error decoding certificate: %s", e)
476477
return False
478+
479+
480+
def check_parameters_for_unauthorized_response(cached_cert):
481+
"""Returns the cached and current cert fingerprint for reconfiguring mTLS.
482+
483+
Args:
484+
cached_cert(bytes): The cached client certificate.
485+
486+
Returns:
487+
bytes: The client callback cert bytes.
488+
bytes: The client callback key bytes.
489+
str: The base64-encoded SHA256 cached fingerprint.
490+
str: The base64-encoded SHA256 current cert fingerprint.
491+
"""
492+
call_cert_bytes, call_key_bytes = _agent_identity_utils.call_client_cert_callback()
493+
cert_obj = _agent_identity_utils.parse_certificate(call_cert_bytes)
494+
current_cert_fingerprint = _agent_identity_utils.calculate_certificate_fingerprint(
495+
cert_obj
496+
)
497+
if cached_cert:
498+
cached_fingerprint = _agent_identity_utils.get_cached_cert_fingerprint(
499+
cached_cert
500+
)
501+
else:
502+
cached_fingerprint = current_cert_fingerprint
503+
return call_cert_bytes, call_key_bytes, cached_fingerprint, current_cert_fingerprint

google/auth/transport/requests.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import absolute_import
1818

1919
import functools
20+
import http.client as http_client
2021
import logging
2122
import numbers
2223
import time
@@ -36,6 +37,7 @@
3637
from google.auth import _helpers
3738
from google.auth import exceptions
3839
from google.auth import transport
40+
from google.auth.transport import _mtls_helper
3941
import google.auth.transport._mtls_helper
4042
from google.oauth2 import service_account
4143

@@ -463,6 +465,7 @@ def configure_mtls_channel(self, client_cert_callback=None):
463465

464466
if self._is_mtls:
465467
mtls_adapter = _MutualTlsAdapter(cert, key)
468+
self._cached_cert = cert
466469
self.mount("https://", mtls_adapter)
467470
except (
468471
exceptions.ClientCertError,
@@ -502,6 +505,10 @@ def request(
502505
itself does not timeout, e.g. if a large file is being
503506
transmitted. The timout error will be raised after such
504507
request completes.
508+
Raises:
509+
google.auth.exceptions.MutualTLSChannelError: If mutual TLS
510+
channel creation fails for any reason.
511+
ValueError: If the client certificate is invalid.
505512
"""
506513
# pylint: disable=arguments-differ
507514
# Requests has a ton of arguments to request, but only two
@@ -551,7 +558,31 @@ def request(
551558
response.status_code in self._refresh_status_codes
552559
and _credential_refresh_attempt < self._max_refresh_attempts
553560
):
554-
561+
# Handle unauthorized permission error(401 status code)
562+
if response.status_code == http_client.UNAUTHORIZED:
563+
if self.is_mtls:
564+
call_cert_bytes, call_key_bytes, cached_fingerprint, current_cert_fingerprint = _mtls_helper.check_parameters_for_unauthorized_response(
565+
self._cached_cert
566+
)
567+
if cached_fingerprint != current_cert_fingerprint:
568+
try:
569+
_LOGGER.info(
570+
"Client certificate has changed, reconfiguring mTLS "
571+
"channel."
572+
)
573+
self.configure_mtls_channel(
574+
lambda: (call_cert_bytes, call_key_bytes)
575+
)
576+
except Exception as e:
577+
_LOGGER.error("Failed to reconfigure mTLS channel: %s", e)
578+
raise exceptions.MutualTLSChannelError(
579+
"Failed to reconfigure mTLS channel"
580+
) from e
581+
else:
582+
_LOGGER.info(
583+
"Skipping reconfiguration of mTLS channel because the client"
584+
" certificate has not changed."
585+
)
555586
_LOGGER.info(
556587
"Refreshing credentials due to a %s response. Attempt %s/%s.",
557588
response.status_code,

google/auth/transport/urllib3.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from __future__ import absolute_import
1818

19+
import http.client as http_client
1920
import logging
2021
import warnings
2122

@@ -52,6 +53,7 @@
5253
from google.auth import _helpers
5354
from google.auth import exceptions
5455
from google.auth import transport
56+
from google.auth.transport import _mtls_helper
5557
from google.oauth2 import service_account
5658

5759
if version.parse(urllib3.__version__) >= version.parse("2.0.0"): # pragma: NO COVER
@@ -299,6 +301,7 @@ def __init__(
299301
# Request instance used by internal methods (for example,
300302
# credentials.refresh).
301303
self._request = Request(self.http)
304+
self._is_mtls = False
302305

303306
# https://google.aip.dev/auth/4111
304307
# Attempt to use self-signed JWTs when a service account is used.
@@ -335,7 +338,10 @@ def configure_mtls_channel(self, client_cert_callback=None):
335338
"""
336339
use_client_cert = transport._mtls_helper.check_use_client_cert()
337340
if not use_client_cert:
341+
self._is_mtls = False
338342
return False
343+
else:
344+
self._is_mtls = True
339345
try:
340346
import OpenSSL
341347
except ImportError as caught_exc:
@@ -349,6 +355,7 @@ def configure_mtls_channel(self, client_cert_callback=None):
349355

350356
if found_cert_key:
351357
self.http = _make_mutual_tls_http(cert, key)
358+
self._cached_cert = cert
352359
else:
353360
self.http = _make_default_http()
354361
except (
@@ -381,6 +388,11 @@ def urlopen(self, method, url, body=None, headers=None, **kwargs):
381388
if headers is None:
382389
headers = self.headers
383390

391+
use_mtls = False
392+
if self._is_mtls:
393+
MTLS_URL_PREFIXES = ["mtls.googleapis.com", "mtls.sandbox.googleapis.com"]
394+
use_mtls = any([prefix in url for prefix in MTLS_URL_PREFIXES])
395+
384396
# Make a copy of the headers. They will be modified by the credentials
385397
# and we want to pass the original headers if we recurse.
386398
request_headers = headers.copy()
@@ -402,6 +414,34 @@ def urlopen(self, method, url, body=None, headers=None, **kwargs):
402414
response.status in self._refresh_status_codes
403415
and _credential_refresh_attempt < self._max_refresh_attempts
404416
):
417+
if response.status == http_client.UNAUTHORIZED:
418+
if use_mtls:
419+
call_cert_bytes, call_key_bytes, cached_fingerprint, current_cert_fingerprint = _mtls_helper.check_parameters_for_unauthorized_response(
420+
self._cached_cert
421+
)
422+
if cached_fingerprint != current_cert_fingerprint:
423+
try:
424+
_LOGGER.info(
425+
"Client certificate has changed, reconfiguring mTLS "
426+
"channel."
427+
)
428+
self.configure_mtls_channel(
429+
client_cert_callback=lambda: (
430+
call_cert_bytes,
431+
call_key_bytes,
432+
)
433+
)
434+
except Exception as e:
435+
_LOGGER.error("Failed to reconfigure mTLS channel: %s", e)
436+
raise exceptions.MutualTLSChannelError(
437+
"Failed to reconfigure mTLS channel"
438+
) from e
439+
440+
else:
441+
_LOGGER.info(
442+
"Skipping reconfiguration of mTLS channel because the "
443+
"client certificate has not changed."
444+
)
405445

406446
_LOGGER.info(
407447
"Refreshing credentials due to a %s response. Attempt %s/%s.",

noxfile.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def mypy(session):
105105
"types-requests",
106106
"types-setuptools",
107107
"types-mock",
108-
"pytest",
108+
"pytest<8.0.0",
109109
)
110110
session.run("mypy", "-p", "google", "-p", "tests", "-p", "tests_async")
111111

@@ -130,6 +130,7 @@ def unit(session):
130130

131131
@nox.session(python=DEFAULT_PYTHON_VERSION)
132132
def cover(session):
133+
session.env["PIP_EXTRA_INDEX_URL"] = "https://pypi.org/simple"
133134
session.install("-e", ".[testing]")
134135
session.run(
135136
"pytest",

tests/test_agent_identity_utils.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import base64
1616
import hashlib
1717
import json
18-
1918
import urllib.parse
2019

2120
from cryptography import x509
@@ -104,7 +103,9 @@ def test_calculate_certificate_fingerprint(self):
104103
mock_cert.public_bytes.return_value = b"der-bytes"
105104

106105
# Expected: base64 (standard), unpadded, then URL-encoded
107-
base64_fingerprint = base64.b64encode(hashlib.sha256(b"der-bytes").digest()).decode("utf-8")
106+
base64_fingerprint = base64.b64encode(
107+
hashlib.sha256(b"der-bytes").digest()
108+
).decode("utf-8")
108109
unpadded_base64_fingerprint = base64_fingerprint.rstrip("=")
109110
expected_fingerprint = urllib.parse.quote(unpadded_base64_fingerprint)
110111

@@ -260,6 +261,54 @@ def test_get_and_parse_agent_identity_certificate_success(
260261
mock_parse_certificate.assert_called_once_with(b"cert_bytes")
261262
assert result == mock_parse_certificate.return_value
262263

264+
@mock.patch("time.sleep", return_value=None)
265+
@mock.patch("google.auth._agent_identity_utils._is_certificate_file_ready")
266+
def test_get_agent_identity_certificate_path_fallback_to_well_known_path(
267+
self, mock_is_ready, mock_sleep, monkeypatch
268+
):
269+
# Set a dummy config path that won't be found.
270+
monkeypatch.setenv(
271+
environment_vars.GOOGLE_API_CERTIFICATE_CONFIG, "/dummy/config.json"
272+
)
273+
274+
# First, the primary path from the (mocked) config is not ready.
275+
# Then, the fallback well-known path is ready.
276+
mock_is_ready.side_effect = [False, True]
277+
278+
result = _agent_identity_utils.get_agent_identity_certificate_path()
279+
280+
assert result == _agent_identity_utils._WELL_KNOWN_CERT_PATH
281+
# The sleep should have been called once before the fallback is checked.
282+
mock_sleep.assert_called_once()
283+
assert mock_is_ready.call_count == 2
284+
285+
@mock.patch("google.auth.transport._mtls_helper.get_client_ssl_credentials")
286+
def test_call_client_cert_callback(self, mock_get_client_ssl_credentials):
287+
mock_get_client_ssl_credentials.return_value = (
288+
True,
289+
b"cert_bytes",
290+
b"key_bytes",
291+
b"passphrase",
292+
)
293+
294+
cert, key = _agent_identity_utils.call_client_cert_callback()
295+
296+
assert cert == b"cert_bytes"
297+
assert key == b"key_bytes"
298+
mock_get_client_ssl_credentials.assert_called_once_with(
299+
generate_encrypted_key=True
300+
)
301+
302+
def test_get_cached_cert_fingerprint_no_cert(self):
303+
with pytest.raises(ValueError, match="mTLS connection is not configured."):
304+
_agent_identity_utils.get_cached_cert_fingerprint(None)
305+
306+
def test_get_cached_cert_fingerprint_with_cert(self):
307+
fingerprint = _agent_identity_utils.get_cached_cert_fingerprint(
308+
NON_AGENT_IDENTITY_CERT_BYTES
309+
)
310+
assert isinstance(fingerprint, str)
311+
263312

264313
class TestAgentIdentityUtilsNoCryptography:
265314
@pytest.fixture(autouse=True)

tests/transport/test__mtls_helper.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@
2323
from google.auth import exceptions
2424
from google.auth.transport import _mtls_helper
2525

26+
CERT_MOCK_VAL = b"cert"
27+
KEY_MOCK_VAL = b"key"
2628
CONTEXT_AWARE_METADATA = {"cert_provider_command": ["some command"]}
27-
2829
ENCRYPTED_EC_PRIVATE_KEY = b"""-----BEGIN ENCRYPTED PRIVATE KEY-----
2930
MIHkME8GCSqGSIb3DQEFDTBCMCkGCSqGSIb3DQEFDDAcBAgl2/yVgs1h3QICCAAw
3031
DAYIKoZIhvcNAgkFADAVBgkrBgEEAZdVAQIECJk2GRrvxOaJBIGQXIBnMU4wmciT
@@ -813,3 +814,64 @@ def test_check_use_client_cert_when_file_does_not_exist(self, monkeypatch):
813814
monkeypatch.setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "")
814815
use_client_cert = _mtls_helper.check_use_client_cert()
815816
assert use_client_cert is False
817+
818+
819+
class TestMtlsHelper:
820+
@mock.patch("google.auth.transport._mtls_helper._agent_identity_utils")
821+
def test_check_parameters_for_unauthorized_response_with_cached_cert(
822+
self, mock_agent_identity_utils
823+
):
824+
mock_agent_identity_utils.call_client_cert_callback.return_value = (
825+
CERT_MOCK_VAL,
826+
KEY_MOCK_VAL,
827+
)
828+
mock_agent_identity_utils.get_cached_cert_fingerprint.return_value = (
829+
"cached_fingerprint"
830+
)
831+
mock_agent_identity_utils.calculate_certificate_fingerprint.return_value = (
832+
"current_fingerprint"
833+
)
834+
835+
(
836+
cert,
837+
key,
838+
cached_fingerprint,
839+
current_fingerprint,
840+
) = _mtls_helper.check_parameters_for_unauthorized_response(
841+
cached_cert=b"cached_cert_bytes"
842+
)
843+
844+
assert cert == CERT_MOCK_VAL
845+
assert key == KEY_MOCK_VAL
846+
assert cached_fingerprint == "cached_fingerprint"
847+
assert current_fingerprint == "current_fingerprint"
848+
mock_agent_identity_utils.call_client_cert_callback.assert_called_once()
849+
mock_agent_identity_utils.get_cached_cert_fingerprint.assert_called_once_with(
850+
b"cached_cert_bytes"
851+
)
852+
853+
@mock.patch("google.auth.transport._mtls_helper._agent_identity_utils")
854+
def test_check_parameters_for_unauthorized_response_without_cached_cert(
855+
self, mock_agent_identity_utils
856+
):
857+
mock_agent_identity_utils.call_client_cert_callback.return_value = (
858+
CERT_MOCK_VAL,
859+
KEY_MOCK_VAL,
860+
)
861+
mock_agent_identity_utils.calculate_certificate_fingerprint.return_value = (
862+
"current_fingerprint"
863+
)
864+
865+
(
866+
cert,
867+
key,
868+
cached_fingerprint,
869+
current_fingerprint,
870+
) = _mtls_helper.check_parameters_for_unauthorized_response(cached_cert=None)
871+
872+
assert cert == CERT_MOCK_VAL
873+
assert key == KEY_MOCK_VAL
874+
assert cached_fingerprint == "current_fingerprint"
875+
assert current_fingerprint == "current_fingerprint"
876+
mock_agent_identity_utils.call_client_cert_callback.assert_called_once()
877+
mock_agent_identity_utils.get_cached_cert_fingerprint.assert_not_called()

0 commit comments

Comments
 (0)