Skip to content

Commit 7ab0fce

Browse files
feat: add universe domain support for VM cred (#1409)
* feat: add universe domain support for VM cred * chore: refresh sys test cred
1 parent 8eaa878 commit 7ab0fce

File tree

5 files changed

+132
-9
lines changed

5 files changed

+132
-9
lines changed

google/auth/compute_engine/_metadata.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def get(
156156
recursive=False,
157157
retry_count=5,
158158
headers=None,
159+
return_none_for_not_found_error=False,
159160
):
160161
"""Fetch a resource from the metadata server.
161162
@@ -173,6 +174,8 @@ def get(
173174
retry_count (int): How many times to attempt connecting to metadata
174175
server using above timeout.
175176
headers (Optional[Mapping[str, str]]): Headers for the request.
177+
return_none_for_not_found_error (Optional[bool]): If True, returns None
178+
for 404 error instead of throwing an exception.
176179
177180
Returns:
178181
Union[Mapping, str]: If the metadata server returns JSON, a mapping of
@@ -216,8 +219,17 @@ def get(
216219
"metadata service. Compute Engine Metadata server unavailable".format(url)
217220
)
218221

222+
content = _helpers.from_bytes(response.data)
223+
224+
if response.status == http_client.NOT_FOUND and return_none_for_not_found_error:
225+
_LOGGER.info(
226+
"Compute Engine Metadata server call to %s returned 404, reason: %s",
227+
path,
228+
content,
229+
)
230+
return None
231+
219232
if response.status == http_client.OK:
220-
content = _helpers.from_bytes(response.data)
221233
if (
222234
_helpers.parse_content_type(response.headers["content-type"])
223235
== "application/json"
@@ -232,14 +244,14 @@ def get(
232244
raise new_exc from caught_exc
233245
else:
234246
return content
235-
else:
236-
raise exceptions.TransportError(
237-
"Failed to retrieve {} from the Google Compute Engine "
238-
"metadata service. Status: {} Response:\n{}".format(
239-
url, response.status, response.data
240-
),
241-
response,
242-
)
247+
248+
raise exceptions.TransportError(
249+
"Failed to retrieve {} from the Google Compute Engine "
250+
"metadata service. Status: {} Response:\n{}".format(
251+
url, response.status, response.data
252+
),
253+
response,
254+
)
243255

244256

245257
def get_project_id(request):
@@ -259,6 +271,29 @@ def get_project_id(request):
259271
return get(request, "project/project-id")
260272

261273

274+
def get_universe_domain(request):
275+
"""Get the universe domain value from the metadata server.
276+
277+
Args:
278+
request (google.auth.transport.Request): A callable used to make
279+
HTTP requests.
280+
281+
Returns:
282+
str: The universe domain value. If the universe domain endpoint is not
283+
not found, return the default value, which is googleapis.com
284+
285+
Raises:
286+
google.auth.exceptions.TransportError: if an error other than
287+
404 occurs while retrieving metadata.
288+
"""
289+
universe_domain = get(
290+
request, "universe/universe_domain", return_none_for_not_found_error=True
291+
)
292+
if not universe_domain:
293+
return "googleapis.com"
294+
return universe_domain
295+
296+
262297
def get_service_account_info(request, service_account="default"):
263298
"""Get information about a service account from the metadata server.
264299

google/auth/compute_engine/credentials.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def __init__(
7373
self._quota_project_id = quota_project_id
7474
self._scopes = scopes
7575
self._default_scopes = default_scopes
76+
self._universe_domain_cached = False
7677

7778
def _retrieve_info(self, request):
7879
"""Retrieve information about the service account.
@@ -131,6 +132,14 @@ def service_account_email(self):
131132
def requires_scopes(self):
132133
return not self._scopes
133134

135+
@property
136+
def universe_domain(self):
137+
if self._universe_domain_cached:
138+
return self._universe_domain
139+
self._universe_domain = _metadata.get_universe_domain()
140+
self._universe_domain_cached = True
141+
return self._universe_domain
142+
134143
@_helpers.copy_docstring(credentials.CredentialsWithQuotaProject)
135144
def with_quota_project(self, quota_project_id):
136145
return self.__class__(

system_tests/secrets.tar.enc

0 Bytes
Binary file not shown.

tests/compute_engine/test__metadata.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,18 @@ def test_get_failure():
325325
)
326326

327327

328+
def test_get_return_none_for_not_found_error():
329+
request = make_request("Metadata error", status=http_client.NOT_FOUND)
330+
331+
assert _metadata.get(request, PATH, return_none_for_not_found_error=True) is None
332+
333+
request.assert_called_once_with(
334+
method="GET",
335+
url=_metadata._METADATA_ROOT + PATH,
336+
headers=_metadata._METADATA_HEADERS,
337+
)
338+
339+
328340
def test_get_failure_connection_failed():
329341
request = make_request("")
330342
request.side_effect = exceptions.TransportError()
@@ -371,6 +383,53 @@ def test_get_project_id():
371383
assert project_id == project
372384

373385

386+
def test_get_universe_domain_success():
387+
request = make_request(
388+
"fake_universe_domain", headers={"content-type": "text/plain"}
389+
)
390+
391+
universe_domain = _metadata.get_universe_domain(request)
392+
393+
request.assert_called_once_with(
394+
method="GET",
395+
url=_metadata._METADATA_ROOT + "universe/universe_domain",
396+
headers=_metadata._METADATA_HEADERS,
397+
)
398+
assert universe_domain == "fake_universe_domain"
399+
400+
401+
def test_get_universe_domain_not_found():
402+
# Test that if the universe domain endpoint returns 404 error, we should
403+
# use googleapis.com as the universe domain
404+
request = make_request("not found", status=http_client.NOT_FOUND)
405+
406+
universe_domain = _metadata.get_universe_domain(request)
407+
408+
request.assert_called_once_with(
409+
method="GET",
410+
url=_metadata._METADATA_ROOT + "universe/universe_domain",
411+
headers=_metadata._METADATA_HEADERS,
412+
)
413+
assert universe_domain == "googleapis.com"
414+
415+
416+
def test_get_universe_domain_other_error():
417+
# Test that if the universe domain endpoint returns an error other than 404
418+
# we should throw the error
419+
request = make_request("unauthorized", status=http_client.UNAUTHORIZED)
420+
421+
with pytest.raises(exceptions.TransportError) as excinfo:
422+
_metadata.get_universe_domain(request)
423+
424+
assert excinfo.match(r"unauthorized")
425+
426+
request.assert_called_once_with(
427+
method="GET",
428+
url=_metadata._METADATA_ROOT + "universe/universe_domain",
429+
headers=_metadata._METADATA_HEADERS,
430+
)
431+
432+
374433
@mock.patch(
375434
"google.auth.metrics.token_request_access_token_mds",
376435
return_value=ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE,

tests/compute_engine/test_credentials.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,26 @@ def test_token_usage_metrics(self):
208208
assert headers["authorization"] == "Bearer token"
209209
assert headers["x-goog-api-client"] == "cred-type/mds"
210210

211+
@mock.patch(
212+
"google.auth.compute_engine._metadata.get_universe_domain",
213+
return_value="fake_universe_domain",
214+
)
215+
def test_universe_domain(self, get_universe_domain):
216+
self.credentials._universe_domain_cached = False
217+
self.credentials._universe_domain = "googleapis.com"
218+
219+
# calling the universe_domain property should trigger a call to
220+
# get_universe_domain to fetch the value. The value should be cached.
221+
assert self.credentials.universe_domain == "fake_universe_domain"
222+
assert self.credentials._universe_domain == "fake_universe_domain"
223+
assert self.credentials._universe_domain_cached
224+
get_universe_domain.assert_called_once()
225+
226+
# calling the universe_domain property the second time should use the
227+
# cached value instead of calling get_universe_domain
228+
assert self.credentials.universe_domain == "fake_universe_domain"
229+
get_universe_domain.assert_called_once()
230+
211231

212232
class TestIDTokenCredentials(object):
213233
credentials = None

0 commit comments

Comments
 (0)