diff --git a/datacommons_client/__init__.py b/datacommons_client/__init__.py index 95b203a1..f7d02654 100644 --- a/datacommons_client/__init__.py +++ b/datacommons_client/__init__.py @@ -10,6 +10,7 @@ from datacommons_client.endpoints.node import NodeEndpoint from datacommons_client.endpoints.observation import ObservationEndpoint from datacommons_client.endpoints.resolve import ResolveEndpoint +from datacommons_client.utils.context import use_api_key __all__ = [ "DataCommonsClient", @@ -17,4 +18,5 @@ "NodeEndpoint", "ObservationEndpoint", "ResolveEndpoint", + "use_api_key", ] diff --git a/datacommons_client/endpoints/base.py b/datacommons_client/endpoints/base.py index 1452ea87..ff4adfdc 100644 --- a/datacommons_client/endpoints/base.py +++ b/datacommons_client/endpoints/base.py @@ -1,6 +1,7 @@ import re from typing import Any, Dict, Optional +from datacommons_client.utils.context import _API_KEY_CONTEXT_VAR from datacommons_client.utils.request_handling import check_instance_is_valid from datacommons_client.utils.request_handling import post_request from datacommons_client.utils.request_handling import resolve_instance_url @@ -94,9 +95,16 @@ def post(self, url = (self.base_url if endpoint is None else f"{self.base_url}/{endpoint}") + headers = self.headers + ctx_api_key = _API_KEY_CONTEXT_VAR.get() + if ctx_api_key: + # Copy headers to avoid mutating the shared client state + headers = self.headers.copy() + headers["X-API-Key"] = ctx_api_key + return post_request(url=url, payload=payload, - headers=self.headers, + headers=headers, all_pages=all_pages, next_token=next_token) diff --git a/datacommons_client/endpoints/node.py b/datacommons_client/endpoints/node.py index 69bf3e9d..808c3495 100644 --- a/datacommons_client/endpoints/node.py +++ b/datacommons_client/endpoints/node.py @@ -1,4 +1,5 @@ from concurrent.futures import ThreadPoolExecutor +import contextvars from functools import partial from functools import wraps from typing import Literal, Optional @@ -447,10 +448,13 @@ def _fetch_place_relationships( ) # Use a thread pool to fetch ancestry graphs in parallel for each input entity + ctx = contextvars.copy_context() with ThreadPoolExecutor(max_workers=max_concurrent_requests) as executor: futures = [ - executor.submit(build_graph_map, root=dcid, fetch_fn=fetch_fn) - for dcid in place_dcids + executor.submit(ctx.run, + build_graph_map, + root=dcid, + fetch_fn=fetch_fn) for dcid in place_dcids ] # Gather ancestry maps and postprocess into flat or nested form for future in futures: diff --git a/datacommons_client/tests/test_client.py b/datacommons_client/tests/test_client.py index a17a2d9c..221befff 100644 --- a/datacommons_client/tests/test_client.py +++ b/datacommons_client/tests/test_client.py @@ -4,6 +4,7 @@ import pandas as pd import pytest +from datacommons_client import use_api_key from datacommons_client.client import DataCommonsClient from datacommons_client.endpoints.base import API from datacommons_client.endpoints.node import NodeEndpoint @@ -419,3 +420,54 @@ def test_client_end_to_end_surface_header_propagation_observation( assert headers is not None assert headers.get("x-surface") == "datagemma" assert headers.get("X-API-Key") == "test_key" + + +@patch("datacommons_client.endpoints.base.post_request") +def test_use_api_key_with_observation_fetch(mock_post_request): + """Test use_api_key override for observation fetches (non-threaded).""" + + # Setup client with default key + client = DataCommonsClient(api_key="default-key") + + # Configure mock to return valid response structure + mock_post_request.return_value = {"byVariable": {}, "facets": {}} + + # Default usage + client.observation.fetch(variable_dcids="sv1", entity_dcids=["geo1"]) + mock_post_request.assert_called() + _, kwargs = mock_post_request.call_args + assert kwargs["headers"]["X-API-Key"] == "default-key" + + # Context override + with use_api_key("context-key"): + client.observation.fetch(variable_dcids="sv1", entity_dcids=["geo1"]) + _, kwargs = mock_post_request.call_args + assert kwargs["headers"]["X-API-Key"] == "context-key" + + # Back to default + client.observation.fetch(variable_dcids="sv1", entity_dcids=["geo1"]) + _, kwargs = mock_post_request.call_args + assert kwargs["headers"]["X-API-Key"] == "default-key" + + +@patch("datacommons_client.endpoints.base.post_request") +def test_use_api_key_with_node_fetch_place_ancestors(mock_post_request): + """Test use_api_key propagation for node graph methods (threaded).""" + + client = DataCommonsClient(api_key="default-key") + + # Configure mock. fetch_place_ancestors expects a dict response or list of nodes. + # NodeResponse.data is a dict. + mock_post_request.return_value = {"data": {}} + + # Default usage + client.node.fetch_place_ancestors(place_dcids=["geoId/06"]) + _, kwargs = mock_post_request.call_args + assert kwargs["headers"]["X-API-Key"] == "default-key" + + # Context override + with use_api_key("context-key"): + # Use a different DCID to avoid hitting fetch_relationship_lru cache + client.node.fetch_place_ancestors(place_dcids=["geoId/07"]) + _, kwargs = mock_post_request.call_args + assert kwargs["headers"]["X-API-Key"] == "context-key" diff --git a/datacommons_client/tests/test_context.py b/datacommons_client/tests/test_context.py new file mode 100644 index 00000000..ef7117cb --- /dev/null +++ b/datacommons_client/tests/test_context.py @@ -0,0 +1,43 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datacommons_client.utils.context import _API_KEY_CONTEXT_VAR +from datacommons_client.utils.context import use_api_key + + +def test_use_api_key_sets_var(): + """Test that use_api_key sets the context variable.""" + assert _API_KEY_CONTEXT_VAR.get() is None + with use_api_key("test-key"): + assert _API_KEY_CONTEXT_VAR.get() == "test-key" + assert _API_KEY_CONTEXT_VAR.get() is None + + +def test_use_api_key_nested(): + """Test nested usage of use_api_key.""" + with use_api_key("outer"): + assert _API_KEY_CONTEXT_VAR.get() == "outer" + with use_api_key("inner"): + assert _API_KEY_CONTEXT_VAR.get() == "inner" + assert _API_KEY_CONTEXT_VAR.get() == "outer" + assert _API_KEY_CONTEXT_VAR.get() is None + + +def test_use_api_key_none(): + """Test that use_api_key with None/empty does not set the variable.""" + assert _API_KEY_CONTEXT_VAR.get() is None + with use_api_key(None): + assert _API_KEY_CONTEXT_VAR.get() is None + with use_api_key(""): + assert _API_KEY_CONTEXT_VAR.get() is None diff --git a/datacommons_client/utils/context.py b/datacommons_client/utils/context.py new file mode 100644 index 00000000..c76944c5 --- /dev/null +++ b/datacommons_client/utils/context.py @@ -0,0 +1,56 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import contextmanager +from contextvars import ContextVar +from typing import Generator + +_API_KEY_CONTEXT_VAR: ContextVar[str | None] = ContextVar("api_key", + default=None) + + +@contextmanager +def use_api_key(api_key: str | None) -> Generator[None, None, None]: + """Context manager to set the API key for the current execution context. + + If api_key is None or empty, this context manager does nothing, allowing + the underlying client to use its default API key. + + Args: + api_key: The API key to use. If None or empty, no change is made. + + Example: + from datacommons_client import use_api_key + # ... + client = DataCommonsClient(api_key="default-key") + + # Uses "default-key" + client.observation.fetch(...) + + with use_api_key("temp-key"): + # Uses "temp-key" + client.observation.fetch(...) + + # Back to "default-key" + client.observation.fetch(...) + """ + if not api_key: + yield + return + + token = _API_KEY_CONTEXT_VAR.set(api_key) + try: + yield + finally: + _API_KEY_CONTEXT_VAR.reset(token) diff --git a/datacommons_client/utils/graph.py b/datacommons_client/utils/graph.py index b9d14495..db636b5d 100644 --- a/datacommons_client/utils/graph.py +++ b/datacommons_client/utils/graph.py @@ -3,6 +3,7 @@ from concurrent.futures import Future from concurrent.futures import ThreadPoolExecutor from concurrent.futures import wait +import contextvars from functools import lru_cache from typing import Callable, Literal, Optional, TypeAlias @@ -108,6 +109,7 @@ def build_graph_map( original_root = root + ctx = contextvars.copy_context() with ThreadPoolExecutor(max_workers=max_workers) as executor: queue = deque([root]) @@ -119,7 +121,7 @@ def build_graph_map( # Check if the node has already been visited or is in progress if dcid not in visited and dcid not in in_progress: # Submit the fetch task - in_progress[dcid] = executor.submit(fetch_fn, dcid=dcid) + in_progress[dcid] = executor.submit(ctx.run, fetch_fn, dcid=dcid) # Check if any futures are still in progress if not in_progress: