Skip to content

Commit 3cc9109

Browse files
committed
add fallback to mds mtls
1 parent 53802bd commit 3cc9109

File tree

3 files changed

+120
-26
lines changed

3 files changed

+120
-26
lines changed

google/auth/compute_engine/_metadata.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def detect_gce_residency_linux():
142142
def _prepare_request_for_mds(request, use_mtls=False):
143143
"""Prepares a request for the metadata server.
144144
145-
This will check if mTLS should be used and return a new request object if so.
145+
This will check if mTLS should be used and mount the mTLS adapter if needed.
146146
147147
Args:
148148
request (google.auth.transport.Request): A callable used to make
@@ -151,8 +151,8 @@ def _prepare_request_for_mds(request, use_mtls=False):
151151
152152
Returns:
153153
google.auth.transport.Request: A request object to use.
154-
If mTLS is enabled, this will be a new request object with mTLS session configured.
155-
Otherwise, it will be the same as the input request.
154+
If mTLS is enabled, the request will have the mTLS adapter mounted.
155+
Otherwise, the original request will be returned unchanged.
156156
"""
157157
if not use_mtls:
158158
return request

google/auth/compute_engine/_mtls.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,13 @@
1818

1919
from dataclasses import dataclass, field
2020
import enum
21+
import logging
2122
import os
23+
from pathlib import Path
2224
import ssl
25+
from urllib.parse import urlparse, urlunparse
2326

27+
import requests
2428
from requests.adapters import HTTPAdapter
2529

2630
from google.auth import environment_vars, exceptions
@@ -30,22 +34,26 @@
3034
# https://cloud.google.com/compute/docs/metadata/overview#https-mds-certificates
3135

3236

37+
_WINDOWS_OS_NAME = "nt"
38+
# _WINDOWS_MTLS_COMPONENTS_BASE_PATH = os.path.join("C:\\", "ProgramData", "Google", "ComputeEngine")
39+
# _MTLS_COMPONENTS_BASE_PATH = os.path.join("/", "run", "google-mds-mtls")
40+
41+
_WINDOWS_MTLS_COMPONENTS_BASE_PATH = Path("C:/ProgramData/Google/ComputeEngine")
42+
_MTLS_COMPONENTS_BASE_PATH = Path("/run/google-mds-mtls")
43+
44+
3345
def _get_mds_root_crt_path():
34-
if os.name == "nt":
35-
return os.path.join(
36-
"C:\\", "ProgramData", "Google", "ComputeEngine", "mds-mtls-root.crt"
37-
)
46+
if os.name == _WINDOWS_OS_NAME:
47+
return _WINDOWS_MTLS_COMPONENTS_BASE_PATH / "mds-mtls-root.crt"
3848
else:
39-
return os.path.join("/", "run", "google-mds-mtls", "root.crt")
49+
return _MTLS_COMPONENTS_BASE_PATH / "root.crt"
4050

4151

4252
def _get_mds_client_combined_cert_path():
43-
if os.name == "nt":
44-
return os.path.join(
45-
"C:\\", "ProgramData", "Google", "ComputeEngine", "mds-mtls-client.key"
46-
)
53+
if os.name == _WINDOWS_OS_NAME:
54+
return _WINDOWS_MTLS_COMPONENTS_BASE_PATH / "mds-mtls-client.key"
4755
else:
48-
return os.path.join("/", "run", "google-mds-mtls", "client.key")
56+
return _MTLS_COMPONENTS_BASE_PATH / "client.key"
4957

5058

5159
@dataclass
@@ -106,6 +114,9 @@ def should_use_mds_mtls(mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig()):
106114
return _certs_exist(mds_mtls_config)
107115

108116

117+
_LOGGER = logging.getLogger(__name__)
118+
119+
109120
class MdsMtlsAdapter(HTTPAdapter):
110121
"""An HTTP adapter that uses mTLS for the metadata server."""
111122

@@ -126,3 +137,26 @@ def init_poolmanager(self, *args, **kwargs):
126137
def proxy_manager_for(self, *args, **kwargs):
127138
kwargs["ssl_context"] = self.ssl_context
128139
return super(MdsMtlsAdapter, self).proxy_manager_for(*args, **kwargs)
140+
141+
def send(self, request, **kwargs):
142+
# If we are in strict mode, always use mTLS (no HTTP fallback)
143+
if _parse_mds_mode() == MdsMtlsMode.STRICT:
144+
return super(MdsMtlsAdapter, self).send(request, **kwargs)
145+
146+
# In default mode, attempt mTLS first, then fallback to HTTP on failure
147+
try:
148+
return super(MdsMtlsAdapter, self).send(request, **kwargs)
149+
except (ssl.SSLError, requests.exceptions.SSLError) as e:
150+
_LOGGER.warning(
151+
"mTLS connection to Compute Engine Metadata server failed. "
152+
"Falling back to standard HTTP. Reason: %s",
153+
e,
154+
)
155+
# Fallback to standard HTTP
156+
parsed_original_url = urlparse(request.url)
157+
http_fallback_url = urlunparse(parsed_original_url._replace(scheme="http"))
158+
request.url = http_fallback_url
159+
160+
# Use a standard HTTPAdapter for the fallback
161+
http_adapter = HTTPAdapter()
162+
return http_adapter.send(request, **kwargs)

tests/compute_engine/test__mtls.py

Lines changed: 73 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
# limitations under the License.
1616
#
1717

18-
import os
19-
2018
import mock
2119
import pytest # type: ignore
2220
import requests
@@ -35,23 +33,21 @@ def mock_mds_mtls_config():
3533
@mock.patch("os.name", "nt")
3634
def test__MdsMtlsConfig_windows_defaults():
3735
config = _mtls.MdsMtlsConfig()
38-
assert config.ca_cert_path == os.path.join(
39-
"C:\\", "ProgramData", "Google", "ComputeEngine", "mds-mtls-root.crt"
36+
assert (
37+
str(config.ca_cert_path)
38+
== "C:/ProgramData/Google/ComputeEngine/mds-mtls-root.crt"
4039
)
41-
assert config.client_combined_cert_path == os.path.join(
42-
"C:\\", "ProgramData", "Google", "ComputeEngine", "mds-mtls-client.key"
40+
assert (
41+
str(config.client_combined_cert_path)
42+
== "C:/ProgramData/Google/ComputeEngine/mds-mtls-client.key"
4343
)
4444

4545

4646
@mock.patch("os.name", "posix")
4747
def test__MdsMtlsConfig_non_windows_defaults():
4848
config = _mtls.MdsMtlsConfig()
49-
assert config.ca_cert_path == os.path.join(
50-
"/", "run", "google-mds-mtls", "root.crt"
51-
)
52-
assert config.client_combined_cert_path == os.path.join(
53-
"/", "run", "google-mds-mtls", "client.key"
54-
)
49+
assert str(config.ca_cert_path) == "/run/google-mds-mtls/root.crt"
50+
assert str(config.client_combined_cert_path) == "/run/google-mds-mtls/client.key"
5551

5652

5753
def test__parse_mds_mode_default(monkeypatch):
@@ -157,7 +153,7 @@ def test_mds_mtls_adapter_session_request(mock_ssl_context, mock_mds_mtls_config
157153
session = requests.Session()
158154
session.mount("https://", adapter)
159155

160-
# Mock the adapter's send method to avoid actual network requests
156+
# Mock the adapter\'s send method to avoid actual network requests
161157
adapter.send = mock.Mock()
162158
response = requests.Response()
163159
response.status_code = 200
@@ -169,3 +165,67 @@ def test_mds_mtls_adapter_session_request(mock_ssl_context, mock_mds_mtls_config
169165
# Assert that the request was successful
170166
assert response.status_code == 200
171167
adapter.send.assert_called_once()
168+
169+
170+
@mock.patch("google.auth.compute_engine._mtls.HTTPAdapter")
171+
@mock.patch("google.auth.compute_engine._mtls._parse_mds_mode")
172+
@mock.patch("ssl.create_default_context")
173+
def test_mds_mtls_adapter_send_fallback_default_mode(
174+
mock_ssl_context, mock_parse_mds_mode, mock_http_adapter_class, mock_mds_mtls_config
175+
):
176+
mock_parse_mds_mode.return_value = _mtls.MdsMtlsMode.DEFAULT
177+
adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config)
178+
179+
mock_fallback_send = mock.Mock()
180+
mock_http_adapter_class.return_value.send = mock_fallback_send
181+
182+
# Simulate SSLError on the super().send() call
183+
with mock.patch(
184+
"requests.adapters.HTTPAdapter.send", side_effect=requests.exceptions.SSLError
185+
):
186+
request = requests.Request(method="GET", url="https://example.com").prepare()
187+
adapter.send(request)
188+
189+
# Check that fallback to HTTPAdapter.send occurred
190+
mock_http_adapter_class.assert_called_once()
191+
mock_fallback_send.assert_called_once()
192+
fallback_request = mock_fallback_send.call_args[0][0]
193+
assert fallback_request.url == "http://example.com/"
194+
195+
196+
@mock.patch("google.auth.compute_engine._mtls._parse_mds_mode")
197+
@mock.patch("ssl.create_default_context")
198+
def test_mds_mtls_adapter_send_no_fallback_strict_mode(
199+
mock_ssl_context, mock_parse_mds_mode, mock_mds_mtls_config
200+
):
201+
mock_parse_mds_mode.return_value = _mtls.MdsMtlsMode.STRICT
202+
adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config)
203+
204+
# Simulate SSLError on the super().send() call
205+
with mock.patch(
206+
"requests.adapters.HTTPAdapter.send", side_effect=requests.exceptions.SSLError
207+
):
208+
request = requests.Request(method="GET", url="https://example.com").prepare()
209+
with pytest.raises(requests.exceptions.SSLError):
210+
adapter.send(request)
211+
212+
213+
@mock.patch("requests.adapters.HTTPAdapter.send")
214+
@mock.patch("google.auth.compute_engine._mtls._parse_mds_mode")
215+
@mock.patch("ssl.create_default_context")
216+
def test_mds_mtls_adapter_send_no_fallback_other_exception(
217+
mock_ssl_context, mock_parse_mds_mode, mock_http_adapter_send, mock_mds_mtls_config
218+
):
219+
mock_parse_mds_mode.return_value = _mtls.MdsMtlsMode.DEFAULT
220+
adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config)
221+
222+
# Simulate a different exception
223+
with mock.patch(
224+
"requests.adapters.HTTPAdapter.send",
225+
side_effect=requests.exceptions.ConnectionError,
226+
):
227+
request = requests.Request(method="GET", url="https://example.com").prepare()
228+
with pytest.raises(requests.exceptions.ConnectionError):
229+
adapter.send(request)
230+
231+
mock_http_adapter_send.assert_not_called()

0 commit comments

Comments
 (0)