Skip to content

Commit 59d50f9

Browse files
committed
feat: Introduce a way to provide custom headers
1 parent c30a6a7 commit 59d50f9

File tree

2 files changed

+76
-0
lines changed

2 files changed

+76
-0
lines changed

google/auth/credentials.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def __init__(self):
6969

7070
self._use_non_blocking_refresh = False
7171
self._refresh_worker = RefreshThreadManager()
72+
self._custom_headers = {}
7273

7374
@property
7475
def expired(self):
@@ -185,6 +186,7 @@ def apply(self, headers, token=None):
185186
self._apply(headers, token)
186187
if self.quota_project_id:
187188
headers["x-goog-user-project"] = self.quota_project_id
189+
headers.update(self._custom_headers)
188190

189191
def _blocking_refresh(self, request):
190192
if not self.valid:
@@ -233,6 +235,38 @@ def before_request(self, request, method, url, headers):
233235
def with_non_blocking_refresh(self):
234236
self._use_non_blocking_refresh = True
235237

238+
def with_headers(self, headers):
239+
"""Returns a copy of these credentials with additional custom headers.
240+
241+
Args:
242+
headers (Mapping[str, str]): The custom headers to add.
243+
244+
Returns:
245+
google.auth.credentials.Credentials: A new credentials instance.
246+
247+
Raises:
248+
ValueError: If a protected header is included in the input headers.
249+
"""
250+
import copy
251+
252+
PROTECTED_HEADERS = {
253+
"authorization",
254+
"x-goog-user-project",
255+
"x-goog-api-client",
256+
"x-allowed-locations",
257+
}
258+
259+
for key in headers:
260+
if key.lower() in PROTECTED_HEADERS:
261+
raise ValueError(
262+
f"Header '{key}' is protected and cannot be set with with_headers. "
263+
"These headers are managed by the library."
264+
)
265+
266+
new_creds = copy.deepcopy(self)
267+
new_creds._custom_headers.update(headers)
268+
return new_creds
269+
236270

237271
class CredentialsWithQuotaProject(Credentials):
238272
"""Abstract base for credentials supporting ``with_quota_project`` factory"""

tests/test_credentials.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,48 @@ def test_with_non_blocking_refresh():
7171
assert c._use_non_blocking_refresh
7272

7373

74+
def test_with_headers():
75+
credentials = CredentialsImpl()
76+
request = mock.Mock()
77+
78+
# 1. Add a new custom header
79+
creds_with_header = credentials.with_headers({"X-Custom-Header": "value1"})
80+
headers = {}
81+
creds_with_header.before_request(request, "http://example.com", "GET", headers)
82+
assert headers["X-Custom-Header"] == "value1"
83+
assert "authorization" in headers # Ensure base apply logic ran
84+
assert creds_with_header is not credentials
85+
assert not hasattr(credentials, "_custom_headers") or not credentials._custom_headers
86+
87+
# 2. Update an existing custom header
88+
creds_updated = creds_with_header.with_headers({"X-Custom-Header": "value2"})
89+
headers = {}
90+
creds_updated.before_request(request, "http://example.com", "GET", headers)
91+
assert headers["X-Custom-Header"] == "value2"
92+
93+
# 3. Chaining with_headers calls
94+
creds_chained = credentials.with_headers({"X-Header-1": "v1"}).with_headers(
95+
{"X-Header-2": "v2"}
96+
)
97+
headers = {}
98+
creds_chained.before_request(request, "http://example.com", "GET", headers)
99+
assert headers["X-Header-1"] == "v1"
100+
assert headers["X-Header-2"] == "v2"
101+
102+
# 4. Ensure protected headers cannot be set
103+
with pytest.raises(ValueError):
104+
credentials.with_headers({"Authorization": "Bearer token"})
105+
with pytest.raises(ValueError):
106+
credentials.with_headers({"X-Goog-User-Project": "test"})
107+
with pytest.raises(ValueError):
108+
credentials.with_headers({"authorization": "Bearer token"}) # Case-insensitive
109+
110+
# 5. Check original credentials are not modified
111+
headers = {}
112+
credentials.before_request(request, "http://example.com", "GET", headers)
113+
assert "X-Custom-Header" not in headers
114+
115+
74116
def test_expired_and_valid():
75117
credentials = CredentialsImpl()
76118
credentials.token = "token"

0 commit comments

Comments
 (0)