Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions datacommons_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
from datacommons_client.endpoints.node import NodeEndpoint
from datacommons_client.endpoints.observation import ObservationEndpoint
from datacommons_client.endpoints.resolve import ResolveEndpoint
from datacommons_client.utils.context import use_api_key

__all__ = [
"DataCommonsClient",
"API",
"NodeEndpoint",
"ObservationEndpoint",
"ResolveEndpoint",
"use_api_key",
]
9 changes: 8 additions & 1 deletion datacommons_client/endpoints/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re
from typing import Any, Dict, Optional

from datacommons_client.utils.context import _API_KEY_CONTEXT_VAR
from datacommons_client.utils.request_handling import check_instance_is_valid
from datacommons_client.utils.request_handling import post_request
from datacommons_client.utils.request_handling import resolve_instance_url
Expand Down Expand Up @@ -94,9 +95,15 @@ def post(self,

url = (self.base_url if endpoint is None else f"{self.base_url}/{endpoint}")

headers = self.headers
ctx_api_key = _API_KEY_CONTEXT_VAR.get()
if ctx_api_key:
headers = self.headers.copy()
headers["X-API-Key"] = ctx_api_key

return post_request(url=url,
payload=payload,
headers=self.headers,
headers=headers,
all_pages=all_pages,
next_token=next_token)

Expand Down
8 changes: 6 additions & 2 deletions datacommons_client/endpoints/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,10 +447,14 @@ def _fetch_place_relationships(
)

# Use a thread pool to fetch ancestry graphs in parallel for each input entity
import contextvars
ctx = contextvars.copy_context()
with ThreadPoolExecutor(max_workers=max_concurrent_requests) as executor:
futures = [
executor.submit(build_graph_map, root=dcid, fetch_fn=fetch_fn)
for dcid in place_dcids
executor.submit(ctx.run,
build_graph_map,
root=dcid,
fetch_fn=fetch_fn) for dcid in place_dcids
]
# Gather ancestry maps and postprocess into flat or nested form
for future in futures:
Expand Down
52 changes: 52 additions & 0 deletions datacommons_client/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pandas as pd
import pytest

from datacommons_client import use_api_key
from datacommons_client.client import DataCommonsClient
from datacommons_client.endpoints.base import API
from datacommons_client.endpoints.node import NodeEndpoint
Expand Down Expand Up @@ -419,3 +420,54 @@ def test_client_end_to_end_surface_header_propagation_observation(
assert headers is not None
assert headers.get("x-surface") == "datagemma"
assert headers.get("X-API-Key") == "test_key"


@patch("datacommons_client.endpoints.base.post_request")
def test_use_api_key_with_observation_fetch(mock_post_request):
"""Test use_api_key override for observation fetches (non-threaded)."""

# Setup client with default key
client = DataCommonsClient(api_key="default-key")

# Configure mock to return valid response structure
mock_post_request.return_value = {"byVariable": {}, "facets": {}}

# Default usage
client.observation.fetch(variable_dcids="sv1", entity_dcids=["geo1"])
mock_post_request.assert_called()
_, kwargs = mock_post_request.call_args
assert kwargs["headers"]["X-API-Key"] == "default-key"

# Context override
with use_api_key("context-key"):
client.observation.fetch(variable_dcids="sv1", entity_dcids=["geo1"])
_, kwargs = mock_post_request.call_args
assert kwargs["headers"]["X-API-Key"] == "context-key"

# Back to default
client.observation.fetch(variable_dcids="sv1", entity_dcids=["geo1"])
_, kwargs = mock_post_request.call_args
assert kwargs["headers"]["X-API-Key"] == "default-key"


@patch("datacommons_client.endpoints.base.post_request")
def test_use_api_key_with_node_fetch_place_ancestors(mock_post_request):
"""Test use_api_key propagation for node graph methods (threaded)."""

client = DataCommonsClient(api_key="default-key")

# Configure mock. fetch_place_ancestors expects a dict response or list of nodes.
# NodeResponse.data is a dict.
mock_post_request.return_value = {"data": {}}

# Default usage
client.node.fetch_place_ancestors(place_dcids=["geoId/06"])
_, kwargs = mock_post_request.call_args
assert kwargs["headers"]["X-API-Key"] == "default-key"

# Context override
with use_api_key("context-key"):
# Use a different DCID to avoid hitting fetch_relationship_lru cache
client.node.fetch_place_ancestors(place_dcids=["geoId/07"])
_, kwargs = mock_post_request.call_args
assert kwargs["headers"]["X-API-Key"] == "context-key"
45 changes: 45 additions & 0 deletions datacommons_client/tests/test_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from datacommons_client.utils.context import _API_KEY_CONTEXT_VAR
from datacommons_client.utils.context import use_api_key


class TestContext(unittest.TestCase):

def test_use_api_key_sets_var(self):
"""Test that use_api_key sets the context variable."""
self.assertIsNone(_API_KEY_CONTEXT_VAR.get())
with use_api_key("test-key"):
self.assertEqual(_API_KEY_CONTEXT_VAR.get(), "test-key")
self.assertIsNone(_API_KEY_CONTEXT_VAR.get())

def test_use_api_key_nested(self):
"""Test nested usage of use_api_key."""
with use_api_key("outer"):
self.assertEqual(_API_KEY_CONTEXT_VAR.get(), "outer")
with use_api_key("inner"):
self.assertEqual(_API_KEY_CONTEXT_VAR.get(), "inner")
self.assertEqual(_API_KEY_CONTEXT_VAR.get(), "outer")
self.assertIsNone(_API_KEY_CONTEXT_VAR.get())

def test_use_api_key_none(self):
"""Test that use_api_key with None/empty does not set the variable."""
self.assertIsNone(_API_KEY_CONTEXT_VAR.get())
with use_api_key(None):
self.assertIsNone(_API_KEY_CONTEXT_VAR.get())
with use_api_key(""):
self.assertIsNone(_API_KEY_CONTEXT_VAR.get())
54 changes: 54 additions & 0 deletions datacommons_client/utils/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import contextmanager
from contextvars import ContextVar
from typing import Generator

_API_KEY_CONTEXT_VAR: ContextVar[str | None] = ContextVar("api_key",
default=None)


@contextmanager
def use_api_key(api_key: str | None) -> Generator[None, None, None]:
"""Context manager to set the API key for the current execution context.
If api_key is None or empty, this context manager does nothing, allowing
the underlying client to use its default API key.
Args:
api_key: The API key to use. If None or empty, no change is made.
Example:
client = DataCommonsClient(api_key="default-key")
# Uses "default-key"
client.observation.fetch(...)
with use_api_key("temp-key"):
# Uses "temp-key"
client.observation.fetch(...)
# Back to "default-key"
client.observation.fetch(...)
"""
if not api_key:
yield
return

token = _API_KEY_CONTEXT_VAR.set(api_key)
try:
yield
finally:
_API_KEY_CONTEXT_VAR.reset(token)
4 changes: 3 additions & 1 deletion datacommons_client/utils/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from concurrent.futures import Future
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import wait
import contextvars
from functools import lru_cache
from typing import Callable, Literal, Optional, TypeAlias

Expand Down Expand Up @@ -108,6 +109,7 @@ def build_graph_map(

original_root = root

ctx = contextvars.copy_context()
with ThreadPoolExecutor(max_workers=max_workers) as executor:
queue = deque([root])

Expand All @@ -119,7 +121,7 @@ def build_graph_map(
# Check if the node has already been visited or is in progress
if dcid not in visited and dcid not in in_progress:
# Submit the fetch task
in_progress[dcid] = executor.submit(fetch_fn, dcid=dcid)
in_progress[dcid] = executor.submit(ctx.run, fetch_fn, dcid=dcid)

# Check if any futures are still in progress
if not in_progress:
Expand Down