Skip to content

Commit e39f887

Browse files
authored
Support overriding API key at request-time (#283)
1 parent 6099e6f commit e39f887

File tree

7 files changed

+171
-4
lines changed

7 files changed

+171
-4
lines changed

datacommons_client/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
from datacommons_client.endpoints.node import NodeEndpoint
1111
from datacommons_client.endpoints.observation import ObservationEndpoint
1212
from datacommons_client.endpoints.resolve import ResolveEndpoint
13+
from datacommons_client.utils.context import use_api_key
1314

1415
__all__ = [
1516
"DataCommonsClient",
1617
"API",
1718
"NodeEndpoint",
1819
"ObservationEndpoint",
1920
"ResolveEndpoint",
21+
"use_api_key",
2022
]

datacommons_client/endpoints/base.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import re
22
from typing import Any, Dict, Optional
33

4+
from datacommons_client.utils.context import _API_KEY_CONTEXT_VAR
45
from datacommons_client.utils.request_handling import check_instance_is_valid
56
from datacommons_client.utils.request_handling import post_request
67
from datacommons_client.utils.request_handling import resolve_instance_url
@@ -94,9 +95,16 @@ def post(self,
9495

9596
url = (self.base_url if endpoint is None else f"{self.base_url}/{endpoint}")
9697

98+
headers = self.headers
99+
ctx_api_key = _API_KEY_CONTEXT_VAR.get()
100+
if ctx_api_key:
101+
# Copy headers to avoid mutating the shared client state
102+
headers = self.headers.copy()
103+
headers["X-API-Key"] = ctx_api_key
104+
97105
return post_request(url=url,
98106
payload=payload,
99-
headers=self.headers,
107+
headers=headers,
100108
all_pages=all_pages,
101109
next_token=next_token)
102110

datacommons_client/endpoints/node.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from concurrent.futures import ThreadPoolExecutor
2+
import contextvars
23
from functools import partial
34
from functools import wraps
45
from typing import Literal, Optional
@@ -447,10 +448,13 @@ def _fetch_place_relationships(
447448
)
448449

449450
# Use a thread pool to fetch ancestry graphs in parallel for each input entity
451+
ctx = contextvars.copy_context()
450452
with ThreadPoolExecutor(max_workers=max_concurrent_requests) as executor:
451453
futures = [
452-
executor.submit(build_graph_map, root=dcid, fetch_fn=fetch_fn)
453-
for dcid in place_dcids
454+
executor.submit(ctx.run,
455+
build_graph_map,
456+
root=dcid,
457+
fetch_fn=fetch_fn) for dcid in place_dcids
454458
]
455459
# Gather ancestry maps and postprocess into flat or nested form
456460
for future in futures:

datacommons_client/tests/test_client.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pandas as pd
55
import pytest
66

7+
from datacommons_client import use_api_key
78
from datacommons_client.client import DataCommonsClient
89
from datacommons_client.endpoints.base import API
910
from datacommons_client.endpoints.node import NodeEndpoint
@@ -419,3 +420,54 @@ def test_client_end_to_end_surface_header_propagation_observation(
419420
assert headers is not None
420421
assert headers.get("x-surface") == "datagemma"
421422
assert headers.get("X-API-Key") == "test_key"
423+
424+
425+
@patch("datacommons_client.endpoints.base.post_request")
426+
def test_use_api_key_with_observation_fetch(mock_post_request):
427+
"""Test use_api_key override for observation fetches (non-threaded)."""
428+
429+
# Setup client with default key
430+
client = DataCommonsClient(api_key="default-key")
431+
432+
# Configure mock to return valid response structure
433+
mock_post_request.return_value = {"byVariable": {}, "facets": {}}
434+
435+
# Default usage
436+
client.observation.fetch(variable_dcids="sv1", entity_dcids=["geo1"])
437+
mock_post_request.assert_called()
438+
_, kwargs = mock_post_request.call_args
439+
assert kwargs["headers"]["X-API-Key"] == "default-key"
440+
441+
# Context override
442+
with use_api_key("context-key"):
443+
client.observation.fetch(variable_dcids="sv1", entity_dcids=["geo1"])
444+
_, kwargs = mock_post_request.call_args
445+
assert kwargs["headers"]["X-API-Key"] == "context-key"
446+
447+
# Back to default
448+
client.observation.fetch(variable_dcids="sv1", entity_dcids=["geo1"])
449+
_, kwargs = mock_post_request.call_args
450+
assert kwargs["headers"]["X-API-Key"] == "default-key"
451+
452+
453+
@patch("datacommons_client.endpoints.base.post_request")
454+
def test_use_api_key_with_node_fetch_place_ancestors(mock_post_request):
455+
"""Test use_api_key propagation for node graph methods (threaded)."""
456+
457+
client = DataCommonsClient(api_key="default-key")
458+
459+
# Configure mock. fetch_place_ancestors expects a dict response or list of nodes.
460+
# NodeResponse.data is a dict.
461+
mock_post_request.return_value = {"data": {}}
462+
463+
# Default usage
464+
client.node.fetch_place_ancestors(place_dcids=["geoId/06"])
465+
_, kwargs = mock_post_request.call_args
466+
assert kwargs["headers"]["X-API-Key"] == "default-key"
467+
468+
# Context override
469+
with use_api_key("context-key"):
470+
# Use a different DCID to avoid hitting fetch_relationship_lru cache
471+
client.node.fetch_place_ancestors(place_dcids=["geoId/07"])
472+
_, kwargs = mock_post_request.call_args
473+
assert kwargs["headers"]["X-API-Key"] == "context-key"
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from datacommons_client.utils.context import _API_KEY_CONTEXT_VAR
16+
from datacommons_client.utils.context import use_api_key
17+
18+
19+
def test_use_api_key_sets_var():
20+
"""Test that use_api_key sets the context variable."""
21+
assert _API_KEY_CONTEXT_VAR.get() is None
22+
with use_api_key("test-key"):
23+
assert _API_KEY_CONTEXT_VAR.get() == "test-key"
24+
assert _API_KEY_CONTEXT_VAR.get() is None
25+
26+
27+
def test_use_api_key_nested():
28+
"""Test nested usage of use_api_key."""
29+
with use_api_key("outer"):
30+
assert _API_KEY_CONTEXT_VAR.get() == "outer"
31+
with use_api_key("inner"):
32+
assert _API_KEY_CONTEXT_VAR.get() == "inner"
33+
assert _API_KEY_CONTEXT_VAR.get() == "outer"
34+
assert _API_KEY_CONTEXT_VAR.get() is None
35+
36+
37+
def test_use_api_key_none():
38+
"""Test that use_api_key with None/empty does not set the variable."""
39+
assert _API_KEY_CONTEXT_VAR.get() is None
40+
with use_api_key(None):
41+
assert _API_KEY_CONTEXT_VAR.get() is None
42+
with use_api_key(""):
43+
assert _API_KEY_CONTEXT_VAR.get() is None
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from contextlib import contextmanager
16+
from contextvars import ContextVar
17+
from typing import Generator
18+
19+
_API_KEY_CONTEXT_VAR: ContextVar[str | None] = ContextVar("api_key",
20+
default=None)
21+
22+
23+
@contextmanager
24+
def use_api_key(api_key: str | None) -> Generator[None, None, None]:
25+
"""Context manager to set the API key for the current execution context.
26+
27+
If api_key is None or empty, this context manager does nothing, allowing
28+
the underlying client to use its default API key.
29+
30+
Args:
31+
api_key: The API key to use. If None or empty, no change is made.
32+
33+
Example:
34+
from datacommons_client import use_api_key
35+
# ...
36+
client = DataCommonsClient(api_key="default-key")
37+
38+
# Uses "default-key"
39+
client.observation.fetch(...)
40+
41+
with use_api_key("temp-key"):
42+
# Uses "temp-key"
43+
client.observation.fetch(...)
44+
45+
# Back to "default-key"
46+
client.observation.fetch(...)
47+
"""
48+
if not api_key:
49+
yield
50+
return
51+
52+
token = _API_KEY_CONTEXT_VAR.set(api_key)
53+
try:
54+
yield
55+
finally:
56+
_API_KEY_CONTEXT_VAR.reset(token)

datacommons_client/utils/graph.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from concurrent.futures import Future
44
from concurrent.futures import ThreadPoolExecutor
55
from concurrent.futures import wait
6+
import contextvars
67
from functools import lru_cache
78
from typing import Callable, Literal, Optional, TypeAlias
89

@@ -108,6 +109,7 @@ def build_graph_map(
108109

109110
original_root = root
110111

112+
ctx = contextvars.copy_context()
111113
with ThreadPoolExecutor(max_workers=max_workers) as executor:
112114
queue = deque([root])
113115

@@ -119,7 +121,7 @@ def build_graph_map(
119121
# Check if the node has already been visited or is in progress
120122
if dcid not in visited and dcid not in in_progress:
121123
# Submit the fetch task
122-
in_progress[dcid] = executor.submit(fetch_fn, dcid=dcid)
124+
in_progress[dcid] = executor.submit(ctx.run, fetch_fn, dcid=dcid)
123125

124126
# Check if any futures are still in progress
125127
if not in_progress:

0 commit comments

Comments
 (0)