Skip to content

Commit 53802bd

Browse files
committed
change the mds mtls implementation
1. now we do not create a new request. instead, create an mds mtls adapter and mount it on the request session. 2. added _validate_gce_mds_configured_environment, which ensures if we are using strict, that the host being contacted is default 3. fix unit tests and add new tests
1 parent 73ba95e commit 53802bd

File tree

4 files changed

+173
-83
lines changed

4 files changed

+173
-83
lines changed

google/auth/compute_engine/_metadata.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,38 +24,61 @@
2424
import os
2525
from urllib.parse import urljoin
2626

27+
import requests
28+
2729
from google.auth import _helpers
2830
from google.auth import environment_vars
2931
from google.auth import exceptions
3032
from google.auth import metrics
3133
from google.auth import transport
3234
from google.auth._exponential_backoff import ExponentialBackoff
3335
from google.auth.compute_engine import _mtls
34-
from google.auth.transport import requests
36+
3537

3638
_LOGGER = logging.getLogger(__name__)
3739

40+
_GCE_DEFAULT_MDS_IP = "169.254.169.254"
41+
_GCE_DEFAULT_HOST = "metadata.google.internal"
42+
_GCE_DEFAULT_MDS_HOSTS = [_GCE_DEFAULT_HOST, _GCE_DEFAULT_MDS_IP]
43+
3844
# Environment variable GCE_METADATA_HOST is originally named
3945
# GCE_METADATA_ROOT. For compatibility reasons, here it checks
4046
# the new variable first; if not set, the system falls back
4147
# to the old variable.
4248
_GCE_METADATA_HOST = os.getenv(environment_vars.GCE_METADATA_HOST, None)
4349
if not _GCE_METADATA_HOST:
4450
_GCE_METADATA_HOST = os.getenv(
45-
environment_vars.GCE_METADATA_ROOT, "metadata.google.internal"
51+
environment_vars.GCE_METADATA_ROOT, _GCE_DEFAULT_HOST
4652
)
4753

48-
_GCE_DEFAULT_MDS_IP = "169.254.169.254"
49-
_GCE_MDS_HOSTS = ["metadata.google.internal", _GCE_DEFAULT_MDS_IP]
5054

55+
def _validate_gce_mds_configured_environment():
56+
"""Validates the GCE metadata server environment configuration for mTLS.
5157
52-
def _get_metadata_root(use_mtls):
58+
Raises:
59+
google.auth.exceptions.MutualTLSChannelError: if the environment
60+
configuration is invalid for mTLS.
61+
"""
62+
mode = _mtls._parse_mds_mode()
63+
if mode == _mtls.MdsMtlsMode.STRICT:
64+
if _GCE_METADATA_HOST != _GCE_DEFAULT_HOST:
65+
# mTLS is only supported when connecting to the default metadata host.
66+
# Raise an exception if we are in strict mode (which requires mTLS)
67+
# but the metadata host has been overridden. (which means mTLS will fail)
68+
raise exceptions.MutualTLSChannelError(
69+
"Mutual TLS is required, but the metadata host has been overridden. "
70+
"mTLS is only supported when connecting to the default metadata host."
71+
)
72+
73+
74+
def _get_metadata_root(use_mtls: bool):
5375
"""Returns the metadata server root URL."""
76+
5477
scheme = "https" if use_mtls else "http"
5578
return "{}://{}/computeMetadata/v1/".format(scheme, _GCE_METADATA_HOST)
5679

5780

58-
def _get_metadata_ip_root(use_mtls):
81+
def _get_metadata_ip_root(use_mtls: bool):
5982
"""Returns the metadata server IP root URL."""
6083
scheme = "https" if use_mtls else "http"
6184
return "{}://{}".format(
@@ -131,8 +154,14 @@ def _prepare_request_for_mds(request, use_mtls=False):
131154
If mTLS is enabled, this will be a new request object with mTLS session configured.
132155
Otherwise, it will be the same as the input request.
133156
"""
134-
if use_mtls:
135-
request = requests.Request(_mtls.create_session())
157+
if not use_mtls:
158+
return request
159+
160+
adapter = _mtls.MdsMtlsAdapter()
161+
if not request.session:
162+
request.session = requests.Session()
163+
for host in _GCE_DEFAULT_MDS_HOSTS:
164+
request.session.mount(f"https://{host}/", adapter)
136165
return request
137166

138167

@@ -236,6 +265,7 @@ def get(
236265

237266
if root is None:
238267
root = _get_metadata_root(use_mtls)
268+
_validate_gce_mds_configured_environment()
239269

240270
base_url = urljoin(root, path)
241271
query_params = {} if params is None else params

google/auth/compute_engine/_mtls.py

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import os
2222
import ssl
2323

24-
import requests
2524
from requests.adapters import HTTPAdapter
2625

2726
from google.auth import environment_vars, exceptions
@@ -59,6 +58,13 @@ class MdsMtlsConfig:
5958
) # path to file containing client certificate and key
6059

6160

61+
def _certs_exist(mds_mtls_config: MdsMtlsConfig):
62+
"""Checks if the mTLS certificates exist."""
63+
return os.path.exists(mds_mtls_config.ca_cert_path) and os.path.exists(
64+
mds_mtls_config.client_combined_cert_path
65+
)
66+
67+
6268
class MdsMtlsMode(enum.Enum):
6369
"""MDS mTLS mode. Used to configure connection behavior when connecting to MDS.
6470
@@ -85,17 +91,27 @@ def _parse_mds_mode():
8591
)
8692

8793

88-
def _certs_exist(mds_mtls_config: MdsMtlsConfig):
89-
"""Checks if the mTLS certificates exist."""
90-
return os.path.exists(mds_mtls_config.ca_cert_path) and os.path.exists(
91-
mds_mtls_config.client_combined_cert_path
92-
)
94+
def should_use_mds_mtls(mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig()):
95+
"""Determines if mTLS should be used for the metadata server."""
96+
mode = _parse_mds_mode()
97+
if mode == MdsMtlsMode.STRICT:
98+
if not _certs_exist(mds_mtls_config):
99+
raise exceptions.MutualTLSChannelError(
100+
"mTLS certificates not found in strict mode."
101+
)
102+
return True
103+
elif mode == MdsMtlsMode.NONE:
104+
return False
105+
else: # Default mode
106+
return _certs_exist(mds_mtls_config)
93107

94108

95109
class MdsMtlsAdapter(HTTPAdapter):
96110
"""An HTTP adapter that uses mTLS for the metadata server."""
97111

98-
def __init__(self, mds_mtls_config: MdsMtlsConfig, *args, **kwargs):
112+
def __init__(
113+
self, mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig(), *args, **kwargs
114+
):
99115
self.ssl_context = ssl.create_default_context()
100116
self.ssl_context.load_verify_locations(cafile=mds_mtls_config.ca_cert_path)
101117
self.ssl_context.load_cert_chain(
@@ -110,26 +126,3 @@ def init_poolmanager(self, *args, **kwargs):
110126
def proxy_manager_for(self, *args, **kwargs):
111127
kwargs["ssl_context"] = self.ssl_context
112128
return super(MdsMtlsAdapter, self).proxy_manager_for(*args, **kwargs)
113-
114-
115-
def create_session(mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig()):
116-
"""Creates a requests.Session configured for mTLS."""
117-
session = requests.Session()
118-
adapter = MdsMtlsAdapter(mds_mtls_config)
119-
session.mount("https://", adapter)
120-
return session
121-
122-
123-
def should_use_mds_mtls(mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig()):
124-
"""Determines if mTLS should be used for the metadata server."""
125-
mode = _parse_mds_mode()
126-
if mode == MdsMtlsMode.STRICT:
127-
if not _certs_exist(mds_mtls_config):
128-
raise exceptions.MutualTLSChannelError(
129-
"mTLS certificates not found in strict mode."
130-
)
131-
return True
132-
elif mode == MdsMtlsMode.NONE:
133-
return False
134-
else: # Default mode
135-
return _certs_exist(mds_mtls_config)

tests/compute_engine/test__metadata.py

Lines changed: 83 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -712,12 +712,12 @@ def test__get_metadata_ip_root_no_mtls():
712712
assert _metadata._get_metadata_ip_root(use_mtls=False) == "http://169.254.169.254"
713713

714714

715-
@mock.patch("google.auth.compute_engine._mtls.create_session")
716-
def test__prepare_request_for_mds_mtls(mock_create_session):
717-
request = mock.Mock()
718-
new_request = _metadata._prepare_request_for_mds(request, use_mtls=True)
719-
mock_create_session.assert_called_once()
720-
assert isinstance(new_request, google_auth_requests.Request)
715+
@mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter")
716+
def test__prepare_request_for_mds_mtls(mock_mds_mtls_adapter):
717+
request = google_auth_requests.Request(mock.create_autospec(requests.Session))
718+
_metadata._prepare_request_for_mds(request, use_mtls=True)
719+
mock_mds_mtls_adapter.assert_called_once()
720+
assert request.session.mount.call_count == len(_metadata._GCE_DEFAULT_MDS_HOSTS)
721721

722722

723723
def test__prepare_request_for_mds_no_mtls():
@@ -726,53 +726,100 @@ def test__prepare_request_for_mds_no_mtls():
726726
assert new_request is request
727727

728728

729-
@mock.patch("google.auth.compute_engine._mtls.should_use_mds_mtls", return_value=True)
730-
@mock.patch("google.auth.compute_engine._mtls.create_session")
731729
@mock.patch("google.auth.metrics.mds_ping", return_value=MDS_PING_METRICS_HEADER_VALUE)
730+
@mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter")
731+
@mock.patch("google.auth.compute_engine._mtls.should_use_mds_mtls", return_value=True)
732+
@mock.patch("google.auth.transport.requests.Request")
732733
def test_ping_mtls(
733-
mock_metrics_header_value, mock_create_session, mock_should_use_mtls
734+
mock_request, mock_should_use_mtls, mock_mds_mtls_adapter, mock_metrics_header_value
734735
):
735-
response = mock.create_autospec(requests.Response, instance=True)
736-
response.status_code = http_client.OK
736+
response = mock.create_autospec(transport.Response, instance=True)
737+
response.status = http_client.OK
737738
response.headers = _metadata._METADATA_HEADERS
738-
mock_session = mock.Mock()
739-
mock_session.request.return_value = response
740-
mock_create_session.return_value = mock_session
739+
mock_request.return_value = response
741740

742-
initial_request = mock.Mock()
743-
assert _metadata.ping(initial_request)
741+
assert _metadata.ping(mock_request)
744742

745743
mock_should_use_mtls.assert_called_once()
746-
mock_create_session.assert_called_once()
747-
mock_session.request.assert_called_once_with(
748-
"GET",
749-
"https://169.254.169.254",
744+
mock_mds_mtls_adapter.assert_called_once()
745+
mock_request.assert_called_once_with(
746+
url="https://169.254.169.254",
747+
method="GET",
750748
headers=MDS_PING_REQUEST_HEADER,
751749
timeout=_metadata._METADATA_DEFAULT_TIMEOUT,
752-
data=None,
753750
)
754751

755752

753+
@mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter")
756754
@mock.patch("google.auth.compute_engine._mtls.should_use_mds_mtls", return_value=True)
757-
@mock.patch("google.auth.compute_engine._mtls.create_session")
758-
def test_get_mtls(mock_create_session, mock_should_use_mtls):
759-
response = mock.create_autospec(requests.Response, instance=True)
760-
response.status_code = http_client.OK
761-
response.content = _helpers.to_bytes("{}")
755+
@mock.patch("google.auth.transport.requests.Request")
756+
def test_get_mtls(mock_request, mock_should_use_mtls, mock_mds_mtls_adapter):
757+
response = mock.create_autospec(transport.Response, instance=True)
758+
response.status = http_client.OK
759+
response.data = _helpers.to_bytes("{}")
762760
response.headers = {"content-type": "application/json"}
763-
mock_session = mock.Mock()
764-
mock_session.request.return_value = response
765-
mock_create_session.return_value = mock_session
761+
mock_request.return_value = response
766762

767-
initial_request = mock.Mock()
768-
_metadata.get(initial_request, "some/path")
763+
_metadata.get(mock_request, "some/path")
769764

770765
mock_should_use_mtls.assert_called_once()
771-
mock_create_session.assert_called_once()
772-
mock_session.request.assert_called_once_with(
773-
"GET",
774-
"https://metadata.google.internal/computeMetadata/v1/some/path",
775-
data=None,
766+
mock_mds_mtls_adapter.assert_called_once()
767+
mock_request.assert_called_once_with(
768+
url="https://metadata.google.internal/computeMetadata/v1/some/path",
769+
method="GET",
776770
headers=_metadata._METADATA_HEADERS,
777771
timeout=_metadata._METADATA_DEFAULT_TIMEOUT,
778772
)
773+
774+
775+
@pytest.mark.parametrize(
776+
"mds_mode, metadata_host, expect_exception",
777+
[
778+
(_metadata._mtls.MdsMtlsMode.STRICT, _metadata._GCE_DEFAULT_HOST, False),
779+
(_metadata._mtls.MdsMtlsMode.STRICT, "custom.host", True),
780+
(_metadata._mtls.MdsMtlsMode.NONE, "custom.host", False),
781+
(_metadata._mtls.MdsMtlsMode.DEFAULT, _metadata._GCE_DEFAULT_HOST, False),
782+
],
783+
)
784+
@mock.patch("google.auth.compute_engine._mtls._parse_mds_mode")
785+
def test_validate_gce_mds_configured_environment(
786+
mock_parse_mds_mode, mds_mode, metadata_host, expect_exception
787+
):
788+
mock_parse_mds_mode.return_value = mds_mode
789+
with mock.patch(
790+
"google.auth.compute_engine._metadata._GCE_METADATA_HOST", new=metadata_host
791+
):
792+
if expect_exception:
793+
with pytest.raises(exceptions.MutualTLSChannelError):
794+
_metadata._validate_gce_mds_configured_environment()
795+
else:
796+
_metadata._validate_gce_mds_configured_environment()
797+
mock_parse_mds_mode.assert_called_once()
798+
799+
800+
@mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter")
801+
def test__prepare_request_for_mds_mtls_session_exists(mock_mds_mtls_adapter):
802+
mock_session = mock.create_autospec(requests.Session)
803+
request = google_auth_requests.Request(mock_session)
804+
new_request = _metadata._prepare_request_for_mds(request, use_mtls=True)
805+
806+
mock_mds_mtls_adapter.assert_called_once()
807+
assert mock_session.mount.call_count == len(_metadata._GCE_DEFAULT_MDS_HOSTS)
808+
assert new_request is request
809+
810+
811+
@mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter")
812+
def test__prepare_request_for_mds_mtls_no_session(mock_mds_mtls_adapter):
813+
request = google_auth_requests.Request(None)
814+
# Explicitly set session to None to avoid a session being created in the Request constructor.
815+
request.session = None
816+
817+
with mock.patch("requests.Session") as mock_session_class:
818+
new_request = _metadata._prepare_request_for_mds(request, use_mtls=True)
819+
820+
mock_session_class.assert_called_once()
821+
mock_mds_mtls_adapter.assert_called_once()
822+
assert new_request.session.mount.call_count == len(
823+
_metadata._GCE_DEFAULT_MDS_HOSTS
824+
)
825+
assert new_request is request

tests/compute_engine/test__mtls.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import mock
2121
import pytest # type: ignore
22+
import requests
2223

2324
from google.auth import environment_vars, exceptions
2425
from google.auth.compute_engine import _mtls
@@ -127,15 +128,14 @@ def test_mds_mtls_adapter_init(mock_ssl_context, mock_mds_mtls_config):
127128
)
128129

129130

130-
@mock.patch("requests.Session")
131-
@mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter")
132-
def test_create_session(mock_adapter, mock_session, mock_mds_mtls_config):
133-
session_instance = mock_session.return_value
134-
session = _mtls.create_session(mock_mds_mtls_config)
135-
assert session is session_instance
136-
mock_adapter.assert_called_once_with(mock_mds_mtls_config)
137-
session_instance.mount.assert_called_once_with(
138-
"https://", mock_adapter.return_value
131+
@mock.patch("ssl.create_default_context")
132+
@mock.patch("requests.adapters.HTTPAdapter.init_poolmanager")
133+
def test_mds_mtls_adapter_init_poolmanager(
134+
mock_init_poolmanager, mock_ssl_context, mock_mds_mtls_config
135+
):
136+
adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config)
137+
mock_init_poolmanager.assert_called_with(
138+
10, 10, block=False, ssl_context=adapter.ssl_context
139139
)
140140

141141

@@ -149,3 +149,23 @@ def test_mds_mtls_adapter_proxy_manager_for(
149149
mock_proxy_manager_for.assert_called_once_with(
150150
"test_proxy", ssl_context=adapter.ssl_context
151151
)
152+
153+
154+
@mock.patch("ssl.create_default_context")
155+
def test_mds_mtls_adapter_session_request(mock_ssl_context, mock_mds_mtls_config):
156+
adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config)
157+
session = requests.Session()
158+
session.mount("https://", adapter)
159+
160+
# Mock the adapter's send method to avoid actual network requests
161+
adapter.send = mock.Mock()
162+
response = requests.Response()
163+
response.status_code = 200
164+
adapter.send.return_value = response
165+
166+
# Make a request
167+
response = session.get("https://example.com")
168+
169+
# Assert that the request was successful
170+
assert response.status_code == 200
171+
adapter.send.assert_called_once()

0 commit comments

Comments
 (0)