diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 0c7b619b0ae8..e28e85b85e3c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -65,7 +65,7 @@ from ._change_feed.feed_range_internal import FeedRangeInternalEpk from ._constants import _Constants as Constants from ._cosmos_http_logging_policy import CosmosHttpLoggingPolicy -from ._cosmos_responses import CosmosDict, CosmosList +from ._cosmos_responses import CosmosDict, CosmosList, CosmosItemPaged from ._range_partition_resolver import RangePartitionResolver from ._read_items_helper import ReadItemsHelperSync from ._request_object import RequestObject @@ -1138,7 +1138,7 @@ def QueryItems( partition_key: Optional[PartitionKeyType] = None, response_hook: Optional[Callable[[Mapping[str, Any], dict[str, Any]], None]] = None, **kwargs: Any - ) -> ItemPaged[dict[str, Any]]: + ) -> CosmosItemPaged: """Queries documents in a collection. :param str database_or_container_link: @@ -1162,7 +1162,7 @@ def QueryItems( options = {} if base.IsDatabaseLink(database_or_container_link): - return ItemPaged( + return CosmosItemPaged( self, query, options, @@ -1186,7 +1186,7 @@ def fetch_fn(options: Mapping[str, Any]) -> Tuple[list[dict[str, Any]], CaseInse response_hook=response_hook, **kwargs) - return ItemPaged( + return CosmosItemPaged( self, query, options, diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_responses.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_responses.py index 16792ec7641b..ea45448a1068 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_responses.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_responses.py @@ -1,10 +1,113 @@ # The MIT License (MIT) # Copyright (c) 2024 Microsoft Corporation -from typing import Any, Iterable, Mapping, Optional +from typing import Any, AsyncIterator, Iterator, Iterable, List, Mapping, Optional + +from azure.core.async_paging import AsyncItemPaged +from azure.core.paging import ItemPaged from azure.core.utils import CaseInsensitiveDict +class CosmosItemPaged(ItemPaged[dict[str, Any]]): + """A custom ItemPaged class that provides access to response headers from query operations. + + This class wraps the standard ItemPaged and stores a reference to the underlying + QueryIterable to expose response headers captured during pagination. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._query_iterable: Optional[Any] = None + + def by_page(self, continuation_token: Optional[str] = None) -> Iterator[Iterator[dict[str, Any]]]: + """Get an iterator of pages of objects. + + :param str continuation_token: An opaque continuation token. + :returns: An iterator of pages (themselves iterator of objects) + :rtype: iterator[iterator[dict[str, Any]]] + """ + # Call the parent's by_page to get the QueryIterable and store reference for header access + self._query_iterable = super().by_page(continuation_token) + return self._query_iterable + + def get_response_headers(self) -> List[CaseInsensitiveDict]: + """Returns a list of response headers captured from each page of results. + + Each element in the list corresponds to the headers from one page fetch. + Headers are only available after iterating through the results. + + :return: List of response headers from each page + :rtype: List[~azure.core.utils.CaseInsensitiveDict] + """ + if self._query_iterable is None: + return [] + if hasattr(self._query_iterable, 'get_response_headers'): + return self._query_iterable.get_response_headers() + return [] + + def get_last_response_headers(self) -> Optional[CaseInsensitiveDict]: + """Returns the response headers from the most recent page fetch. + + :return: Response headers from the last page, or None if no pages have been fetched + :rtype: Optional[~azure.core.utils.CaseInsensitiveDict] + """ + if self._query_iterable is None: + return None + if hasattr(self._query_iterable, 'get_last_response_headers'): + return self._query_iterable.get_last_response_headers() + return None + + +class CosmosAsyncItemPaged(AsyncItemPaged[dict[str, Any]]): + """A custom AsyncItemPaged class that provides access to response headers from async query operations. + + This class wraps the standard AsyncItemPaged and stores a reference to the underlying + QueryIterable to expose response headers captured during pagination. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._query_iterable: Optional[Any] = None + + def by_page(self, continuation_token: Optional[str] = None) -> AsyncIterator[AsyncIterator[dict[str, Any]]]: + """Get an async iterator of pages of objects. + + :param str continuation_token: An opaque continuation token. + :returns: An async iterator of pages (themselves async iterator of objects) + :rtype: AsyncIterator[AsyncIterator[dict[str, Any]]] + """ + # Call the parent's by_page to get the QueryIterable and store reference for header access + self._query_iterable = super().by_page(continuation_token) + return self._query_iterable + + def get_response_headers(self) -> List[CaseInsensitiveDict]: + """Returns a list of response headers captured from each page of results. + + Each element in the list corresponds to the headers from one page fetch. + Headers are only available after iterating through the results. + + :return: List of response headers from each page + :rtype: List[~azure.core.utils.CaseInsensitiveDict] + """ + if self._query_iterable is None: + return [] + if hasattr(self._query_iterable, 'get_response_headers'): + return self._query_iterable.get_response_headers() + return [] + + def get_last_response_headers(self) -> Optional[CaseInsensitiveDict]: + """Returns the response headers from the most recent page fetch. + + :return: Response headers from the last page, or None if no pages have been fetched + :rtype: Optional[~azure.core.utils.CaseInsensitiveDict] + """ + if self._query_iterable is None: + return None + if hasattr(self._query_iterable, 'get_last_response_headers'): + return self._query_iterable.get_last_response_headers() + return None + + class CosmosDict(dict[str, Any]): def __init__(self, original_dict: Optional[Mapping[str, Any]], /, *, response_headers: CaseInsensitiveDict) -> None: if original_dict is None: diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_query_iterable.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_query_iterable.py index 599d884c9797..bdbc14582d08 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_query_iterable.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_query_iterable.py @@ -22,7 +22,10 @@ """Iterable query results in the Azure Cosmos database service. """ import time +from typing import List, Optional + from azure.core.paging import PageIterator # type: ignore +from azure.core.utils import CaseInsensitiveDict from azure.cosmos._constants import _Constants, TimeoutScope from azure.cosmos._execution_context import execution_dispatcher from azure.cosmos import exceptions @@ -30,7 +33,7 @@ # pylint: disable=protected-access -class QueryIterable(PageIterator): +class QueryIterable(PageIterator): # pylint: disable=too-many-instance-attributes """Represents an iterable object of the query results. QueryIterable is a wrapper for query execution context. @@ -81,6 +84,10 @@ def __init__( self._ex_context = execution_dispatcher._ProxyQueryExecutionContext( self._client, self._collection_link, self._query, self._options, self._fetch_function, response_hook, raw_response_hook, resource_type) + + # Response headers tracking for query operations + self._response_headers: List[CaseInsensitiveDict] = [] + super(QueryIterable, self).__init__(self._fetch_next, self._unpack, continuation_token=continuation_token) def _unpack(self, block): @@ -114,6 +121,47 @@ def _fetch_next(self, *args): # pylint: disable=unused-argument raise exceptions.CosmosClientTimeoutError() block = self._ex_context.fetch_next_block() + + # Capture response headers after each page fetch + self._capture_response_headers() + if not block: raise StopIteration return block + + def _capture_response_headers(self) -> None: + """Capture response headers from the last request.""" + if self._client.last_response_headers: + headers = self._client.last_response_headers.copy() + self._response_headers.append(headers) + + def get_response_headers(self) -> List[CaseInsensitiveDict]: + """Get all response headers collected during query iteration. + + Each entry in the list corresponds to one page/request made during + the query execution. Headers are captured as queries are iterated, + so this list grows as you consume more results. This method is typically accessed via the + :class:`~azure.cosmos.CosmosItemPaged` object returned from + :meth:`~azure.cosmos.ContainerProxy.query_items`. + + :return: List of response headers from each page request. + :rtype: list[~azure.core.utils.CaseInsensitiveDict] + + Example: + >>> items = container.query_items(query="SELECT * FROM c") + >>> for item in items: + ... process(item) + >>> headers = items.get_response_headers() + >>> print(f"Total pages fetched: {len(headers)}") + """ + return [h.copy() for h in self._response_headers] + + def get_last_response_headers(self) -> Optional[CaseInsensitiveDict]: + """Get the response headers from the most recent page fetch. + + :return: Response headers from the last page, or None if no pages fetched yet. + :rtype: ~azure.core.utils.CaseInsensitiveDict or None + """ + if self._response_headers: + return self._response_headers[-1].copy() + return None diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py index 2ce07c61c9c9..95d56952729b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py @@ -42,7 +42,7 @@ build_options as _build_options, GenerateGuidId, validate_cache_staleness_value) from .._change_feed.feed_range_internal import FeedRangeInternalEpk -from .._cosmos_responses import CosmosDict, CosmosList +from .._cosmos_responses import CosmosDict, CosmosList, CosmosAsyncItemPaged from .._constants import _Constants as Constants, TimeoutScope from .._routing.routing_range import Range from .._session_token_helpers import get_latest_session_token @@ -548,7 +548,7 @@ def query_items( throughput_bucket: Optional[int] = None, availability_strategy_config: Optional[dict[str, Any]] = _Unset, **kwargs: Any - ) -> AsyncItemPaged[dict[str, Any]]: + ) -> CosmosAsyncItemPaged: """Return all results matching the given `query`. You can use any value for the container name in the FROM clause, but @@ -594,7 +594,7 @@ def query_items( The threshold-based availability strategy to use for this request. If not provided, the client's default strategy will be used. :returns: An Iterable of items (dicts). - :rtype: AsyncItemPaged[dict[str, Any]] + :rtype: CosmosAsyncItemPaged .. admonition:: Example: @@ -634,7 +634,7 @@ def query_items( throughput_bucket: Optional[int] = None, availability_strategy_config: Optional[dict[str, Any]] = _Unset, **kwargs: Any - ) -> AsyncItemPaged[dict[str, Any]]: + ) -> CosmosAsyncItemPaged: """Return all results matching the given `query`. You can use any value for the container name in the FROM clause, but @@ -677,7 +677,7 @@ def query_items( The threshold-based availability strategy to use for this request. If not provided, the client's default strategy will be used. :returns: An Iterable of items (dicts). - :rtype: AsyncItemPaged[dict[str, Any]] + :rtype: CosmosAsyncItemPaged .. admonition:: Example: @@ -716,7 +716,7 @@ def query_items( throughput_bucket: Optional[int] = None, availability_strategy_config: Optional[dict[str, Any]] = _Unset, **kwargs: Any - ) -> AsyncItemPaged[dict[str, Any]]: + ) -> CosmosAsyncItemPaged: """Return all results matching the given `query`. You can use any value for the container name in the FROM clause, but @@ -758,7 +758,7 @@ def query_items( The threshold-based availability strategy to use for this request. If not provided, the client's default strategy will be used. :returns: An Iterable of items (dicts). - :rtype: AsyncItemPaged[Dict[str, Any]] + :rtype: CosmosAsyncItemPaged .. admonition:: Example: @@ -783,7 +783,7 @@ def query_items( self, *args: Any, **kwargs: Any - ) -> AsyncItemPaged[dict[str, Any]]: + ) -> CosmosAsyncItemPaged: """Return all results matching the given `query`. You can use any value for the container name in the FROM clause, but @@ -833,7 +833,7 @@ def query_items( The threshold-based availability strategy to use for this request. If not provided, the client's default strategy will be used. :returns: An Iterable of items (dicts). - :rtype: AsyncItemPaged[dict[str, Any]] + :rtype: CosmosAsyncItemPaged .. admonition:: Example: diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index efa5188c70fe..60701fb6ab60 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -55,7 +55,7 @@ from .._routing import routing_range from ..documents import ConnectionPolicy, DatabaseAccount from .._constants import _Constants as Constants -from .._cosmos_responses import CosmosDict, CosmosList +from .._cosmos_responses import CosmosDict, CosmosList, CosmosAsyncItemPaged from .. import http_constants, exceptions from . import _query_iterable_async as query_iterable from .. import _runtime_constants as runtime_constants @@ -2353,7 +2353,7 @@ def QueryItems( partition_key: Optional[PartitionKeyType] = None, response_hook: Optional[Callable[[Mapping[str, Any], dict[str, Any]], None]] = None, **kwargs: Any - ) -> AsyncItemPaged[dict[str, Any]]: + ) -> CosmosAsyncItemPaged: """Queries documents in a collection. :param str database_or_container_link: @@ -2375,7 +2375,7 @@ def QueryItems( options = {} if base.IsDatabaseLink(database_or_container_link): - return AsyncItemPaged( + return CosmosAsyncItemPaged( self, query, options, @@ -2406,7 +2406,7 @@ async def fetch_fn(options: Mapping[str, Any]) -> Tuple[list[dict[str, Any]], Ca self.last_response_headers, ) - return AsyncItemPaged( + return CosmosAsyncItemPaged( self, query, options, diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_query_iterable_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_query_iterable_async.py index bd0f6537dced..b30e840c8530 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_query_iterable_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_query_iterable_async.py @@ -23,8 +23,10 @@ """ import asyncio # pylint: disable=do-not-import-asyncio import time +from typing import List, Optional from azure.core.async_paging import AsyncPageIterator +from azure.core.utils import CaseInsensitiveDict from azure.cosmos._constants import _Constants, TimeoutScope from azure.cosmos._execution_context.aio import execution_dispatcher @@ -33,7 +35,7 @@ # pylint: disable=protected-access -class QueryIterable(AsyncPageIterator): +class QueryIterable(AsyncPageIterator): # pylint: disable=too-many-instance-attributes """Represents an iterable object of the query results. QueryIterable is a wrapper for query execution context. @@ -84,6 +86,10 @@ def __init__( self._ex_context = execution_dispatcher._ProxyQueryExecutionContext( self._client, self._collection_link, self._query, self._options, self._fetch_function, response_hook, raw_response_hook, resource_type) + + # Response headers tracking for query operations + self._response_headers: List[CaseInsensitiveDict] = [] + super(QueryIterable, self).__init__(self._fetch_next, self._unpack, continuation_token=continuation_token) async def _unpack(self, block): @@ -122,6 +128,50 @@ async def _fetch_next(self, *args): # pylint: disable=unused-argument block = await self._ex_context.fetch_next_block() + # Capture response headers after each page fetch + self._capture_response_headers() + if not block: raise StopAsyncIteration return block + + def _capture_response_headers(self) -> None: + """Capture response headers from the last request.""" + if self._client.last_response_headers: + headers = self._client.last_response_headers.copy() + self._response_headers.append(headers) + + def get_response_headers(self) -> List[CaseInsensitiveDict]: + """Get all response headers collected during query iteration. + + Each entry in the list corresponds to one page/request made during + the query execution. Headers are captured as queries are iterated, + so this list grows as you consume more results. + + This method is typically accessed via the + :class:`~azure.cosmos.aio.CosmosAsyncItemPaged` object returned from + :meth:`~azure.cosmos.aio.ContainerProxy.query_items`. + + :return: List of response headers from each page request. + :rtype: list[~azure.core.utils.CaseInsensitiveDict] + + Example:: + + # container.query_items returns a CosmosAsyncItemPaged instance + >>> paged_items = container.query_items(query="SELECT * FROM c") + >>> async for item in paged_items: + ... process(item) + >>> headers = paged_items.get_response_headers() + >>> print(f"Total pages fetched: {len(headers)}") + """ + return [h.copy() for h in self._response_headers] + + def get_last_response_headers(self) -> Optional[CaseInsensitiveDict]: + """Get the response headers from the most recent page fetch. + + :return: Response headers from the last page, or None if no pages fetched yet. + :rtype: ~azure.core.utils.CaseInsensitiveDict or None + """ + if self._response_headers: + return self._response_headers[-1].copy() + return None diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py index 1bb4e95caa8a..39cf1bd0f9c8 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py @@ -39,7 +39,7 @@ from ._change_feed.feed_range_internal import FeedRangeInternalEpk from ._constants import _Constants as Constants, TimeoutScope from ._cosmos_client_connection import CosmosClientConnection -from ._cosmos_responses import CosmosDict, CosmosList +from ._cosmos_responses import CosmosDict, CosmosList, CosmosItemPaged from ._routing.routing_range import Range from ._session_token_helpers import get_latest_session_token from .exceptions import CosmosHttpResponseError @@ -736,7 +736,7 @@ def query_items( throughput_bucket: Optional[int] = None, availability_strategy_config: Optional[dict[str, Any]] = _Unset, **kwargs: Any - ) -> ItemPaged[dict[str, Any]]: + ) -> CosmosItemPaged: """Return all results matching the given `query`. You can use any value for the container name in the FROM clause, but @@ -785,7 +785,7 @@ def query_items( The threshold-based availability strategy to use for this request. If not provided, the client's default strategy will be used. :returns: An Iterable of items (dicts). - :rtype: ItemPaged[dict[str, Any]] + :rtype: CosmosItemPaged .. admonition:: Example: @@ -826,7 +826,7 @@ def query_items( throughput_bucket: Optional[int] = None, availability_strategy_config: Optional[dict[str, Any]] = _Unset, **kwargs: Any - ) -> ItemPaged[dict[str, Any]]: + ) -> CosmosItemPaged: """Return all results matching the given `query`. You can use any value for the container name in the FROM clause, but @@ -872,7 +872,7 @@ def query_items( The threshold-based availability strategy to use for this request. If not provided, the client's default strategy will be used. :returns: An Iterable of items (dicts). - :rtype: ItemPaged[dict[str, Any]] + :rtype: CosmosItemPaged .. admonition:: Example: @@ -897,7 +897,7 @@ def query_items( # pylint:disable=docstring-missing-param self, *args: Any, **kwargs: Any - ) -> ItemPaged[dict[str, Any]]: + ) -> CosmosItemPaged: """Return all results matching the given `query`. You can use any value for the container name in the FROM clause, but @@ -948,7 +948,7 @@ def query_items( # pylint:disable=docstring-missing-param :keyword str session_token: Token for use with Session consistency. :keyword int throughput_bucket: The desired throughput bucket for the client. :returns: An Iterable of items (dicts). - :rtype: ItemPaged[dict[str, Any]] + :rtype: CosmosItemPaged .. admonition:: Example: diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_response_headers.py b/sdk/cosmos/azure-cosmos/tests/test_query_response_headers.py new file mode 100644 index 000000000000..734ead97c091 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_query_response_headers.py @@ -0,0 +1,284 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +import os +import unittest +import uuid + +import pytest + +import azure.cosmos.cosmos_client as cosmos_client +import test_config +from azure.cosmos import DatabaseProxy +from azure.cosmos.partition_key import PartitionKey + + +@pytest.mark.cosmosEmulator +@pytest.mark.cosmosQuery +class TestQueryResponseHeaders(unittest.TestCase): + """Tests for query response headers functionality.""" + + created_db: DatabaseProxy = None + client: cosmos_client.CosmosClient = None + config = test_config.TestConfig + host = config.host + masterKey = config.masterKey + TEST_DATABASE_ID = config.TEST_DATABASE_ID + + @classmethod + def setUpClass(cls): + use_multiple_write_locations = False + if os.environ.get("AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER", "False") == "True": + use_multiple_write_locations = True + cls.client = cosmos_client.CosmosClient( + cls.host, cls.masterKey, multiple_write_locations=use_multiple_write_locations + ) + cls.created_db = cls.client.get_database_client(cls.TEST_DATABASE_ID) + + def test_query_response_headers_single_page(self): + """Test that response headers are captured for a single page query.""" + created_collection = self.created_db.create_container( + "test_headers_single_" + str(uuid.uuid4()), PartitionKey(path="/pk") + ) + try: + # Create a few items + for i in range(5): + created_collection.create_item(body={"pk": "test", "id": f"item_{i}", "value": i}) + + query = "SELECT * FROM c WHERE c.pk = @pk" + query_iterable = created_collection.query_items( + query=query, + parameters=[{"name": "@pk", "value": "test"}], + partition_key="test" + ) + + # Iterate through items using for loop (pagination) + items = [] + for item in query_iterable: + items.append(item) + + # Verify items were returned + self.assertEqual(len(items), 5) + + # Verify response headers were captured + response_headers = query_iterable.get_response_headers() + self.assertIsNotNone(response_headers) + self.assertGreater(len(response_headers), 0) + + # Verify headers contain expected fields + first_page_headers = response_headers[0] + self.assertIn("x-ms-request-charge", first_page_headers) + self.assertIn("x-ms-activity-id", first_page_headers) + + # Verify get_last_response_headers works + last_headers = query_iterable.get_last_response_headers() + self.assertIsNotNone(last_headers) + self.assertIn("x-ms-request-charge", last_headers) + + finally: + self.created_db.delete_container(created_collection.id) + + def test_query_response_headers_multiple_pages(self): + """Test that response headers are captured for each page in a paginated query.""" + created_collection = self.created_db.create_container( + "test_headers_multi_" + str(uuid.uuid4()), PartitionKey(path="/pk") + ) + try: + # Create enough items to span multiple pages + num_items = 15 + for i in range(num_items): + created_collection.create_item(body={"pk": "test", "id": f"item_{i}", "value": i}) + + query = "SELECT * FROM c WHERE c.pk = @pk" + # Use small page size to force multiple pages + query_iterable = created_collection.query_items( + query=query, + parameters=[{"name": "@pk", "value": "test"}], + partition_key="test", + max_item_count=5 # Force pagination with 5 items per page + ) + + # Iterate through items using for loop (pagination) + items = [] + for item in query_iterable: + items.append(item) + + # Verify all items were returned + self.assertEqual(len(items), num_items) + + # Verify response headers were captured for multiple pages + response_headers = query_iterable.get_response_headers() + self.assertIsNotNone(response_headers) + # With 15 items and max_item_count=5, we expect at least 3 pages + self.assertGreaterEqual(len(response_headers), 3) + + # Verify each page has headers + for i, headers in enumerate(response_headers): + self.assertIn("x-ms-request-charge", headers, f"Page {i} missing request charge header") + self.assertIn("x-ms-activity-id", headers, f"Page {i} missing activity id header") + + # Each page should have a different activity ID + activity_ids = [h.get("x-ms-activity-id") for h in response_headers] + # Note: Activity IDs might be the same for some edge cases, but generally should differ + self.assertEqual(len(activity_ids), len(response_headers)) + + finally: + self.created_db.delete_container(created_collection.id) + + def test_query_response_headers_empty_result(self): + """Test that response headers are captured even when query returns no results.""" + created_collection = self.created_db.create_container( + "test_headers_empty_" + str(uuid.uuid4()), PartitionKey(path="/pk") + ) + try: + # Create an item with different pk + created_collection.create_item(body={"pk": "other", "id": "item_1"}) + + query = "SELECT * FROM c WHERE c.pk = @pk" + query_iterable = created_collection.query_items( + query=query, + parameters=[{"name": "@pk", "value": "nonexistent"}], + partition_key="nonexistent" + ) + + # Iterate through items (should be empty) + items = [] + for item in query_iterable: + items.append(item) + + # Verify no items were returned + self.assertEqual(len(items), 0) + + # Response headers may or may not be captured depending on implementation + # The key is that the method doesn't throw an error + response_headers = query_iterable.get_response_headers() + self.assertIsNotNone(response_headers) + + # get_last_response_headers should return None or headers depending on if a request was made + last_headers = query_iterable.get_last_response_headers() + # This can be None if no request was made, or headers if at least one request was made + # Both are valid behaviors + + finally: + self.created_db.delete_container(created_collection.id) + + def test_query_response_headers_with_query_metrics(self): + """Test that query metrics are included in response headers when enabled.""" + created_collection = self.created_db.create_container( + "test_headers_metrics_" + str(uuid.uuid4()), PartitionKey(path="/pk") + ) + try: + # Create items + for i in range(5): + created_collection.create_item(body={"pk": "test", "id": f"item_{i}", "value": i}) + + query = "SELECT * FROM c WHERE c.pk = @pk" + query_iterable = created_collection.query_items( + query=query, + parameters=[{"name": "@pk", "value": "test"}], + partition_key="test", + populate_query_metrics=True + ) + + # Iterate through items + items = [] + for item in query_iterable: + items.append(item) + + self.assertEqual(len(items), 5) + + # Verify response headers contain query metrics + response_headers = query_iterable.get_response_headers() + self.assertGreater(len(response_headers), 0) + + # Check for query metrics header + metrics_header_name = "x-ms-documentdb-query-metrics" + first_page_headers = response_headers[0] + self.assertIn(metrics_header_name, first_page_headers) + + # Validate metrics header is well-formed + metrics_header = first_page_headers[metrics_header_name] + metrics = metrics_header.split(";") + self.assertGreater(len(metrics), 1) + self.assertTrue(all("=" in x for x in metrics)) + + finally: + self.created_db.delete_container(created_collection.id) + + def test_query_response_headers_by_page_iteration(self): + """Test response headers when using by_page() iteration.""" + created_collection = self.created_db.create_container( + "test_headers_by_page_" + str(uuid.uuid4()), PartitionKey(path="/pk") + ) + try: + # Create items + num_items = 10 + for i in range(num_items): + created_collection.create_item(body={"pk": "test", "id": f"item_{i}", "value": i}) + + query = "SELECT * FROM c WHERE c.pk = @pk" + query_iterable = created_collection.query_items( + query=query, + parameters=[{"name": "@pk", "value": "test"}], + partition_key="test", + max_item_count=3 # Force multiple pages + ) + + # Iterate by page + all_items = [] + page_count = 0 + for page in query_iterable.by_page(): + page_items = list(page) + all_items.extend(page_items) + page_count += 1 + + # After each page, we can check the last response headers + last_headers = query_iterable.get_last_response_headers() + self.assertIsNotNone(last_headers) + self.assertIn("x-ms-request-charge", last_headers) + + # Verify all items retrieved + self.assertEqual(len(all_items), num_items) + + # Verify we got headers for each page (at least as many as page_count) + response_headers = query_iterable.get_response_headers() + self.assertGreaterEqual(len(response_headers), page_count) + + finally: + self.created_db.delete_container(created_collection.id) + + def test_query_response_headers_returns_copies(self): + """Test that get_response_headers returns copies, not references.""" + created_collection = self.created_db.create_container( + "test_headers_copies_" + str(uuid.uuid4()), PartitionKey(path="/pk") + ) + try: + created_collection.create_item(body={"pk": "test", "id": "item_1"}) + + query = "SELECT * FROM c" + query_iterable = created_collection.query_items( + query=query, + partition_key="test" + ) + + # Iterate + for item in query_iterable: + pass + + # Get headers twice + headers1 = query_iterable.get_response_headers() + headers2 = query_iterable.get_response_headers() + + # They should be equal but not the same object + self.assertEqual(len(headers1), len(headers2)) + if len(headers1) > 0: + # Modifying one should not affect the other + headers1[0]["test-key"] = "test-value" + self.assertNotIn("test-key", headers2[0]) + + finally: + self.created_db.delete_container(created_collection.id) + + +if __name__ == "__main__": + unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_response_headers_async.py b/sdk/cosmos/azure-cosmos/tests/test_query_response_headers_async.py new file mode 100644 index 000000000000..5501cb6ac08c --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_query_response_headers_async.py @@ -0,0 +1,287 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +import os +import unittest +import uuid + +import pytest + +import test_config +from azure.cosmos.aio import CosmosClient, DatabaseProxy +from azure.cosmos.partition_key import PartitionKey + + +@pytest.mark.cosmosEmulator +@pytest.mark.cosmosQuery +class TestQueryResponseHeadersAsync(unittest.IsolatedAsyncioTestCase): + """Tests for async query response headers functionality.""" + + created_db: DatabaseProxy = None + client: CosmosClient = None + config = test_config.TestConfig + host = config.host + masterKey = config.masterKey + TEST_DATABASE_ID = config.TEST_DATABASE_ID + + @classmethod + def setUpClass(cls): + cls.use_multiple_write_locations = False + if os.environ.get("AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER", "False") == "True": + cls.use_multiple_write_locations = True + + async def asyncSetUp(self): + self.client = CosmosClient( + self.host, self.masterKey, multiple_write_locations=self.use_multiple_write_locations + ) + await self.client.__aenter__() + self.created_db = self.client.get_database_client(self.TEST_DATABASE_ID) + + async def asyncTearDown(self): + await self.client.close() + + async def test_query_response_headers_single_page_async(self): + """Test that response headers are captured for a single page query.""" + created_collection = await self.created_db.create_container( + "test_headers_single_async_" + str(uuid.uuid4()), PartitionKey(path="/pk") + ) + try: + # Create a few items + for i in range(5): + await created_collection.create_item(body={"pk": "test", "id": f"item_{i}", "value": i}) + + query = "SELECT * FROM c WHERE c.pk = @pk" + query_iterable = created_collection.query_items( + query=query, + parameters=[{"name": "@pk", "value": "test"}], + partition_key="test" + ) + + # Iterate through items using async for loop (pagination) + items = [] + async for item in query_iterable: + items.append(item) + + # Verify items were returned + assert len(items) == 5 + + # Verify response headers were captured + response_headers = query_iterable.get_response_headers() + assert response_headers is not None + assert len(response_headers) > 0 + + # Verify headers contain expected fields + first_page_headers = response_headers[0] + assert "x-ms-request-charge" in first_page_headers + assert "x-ms-activity-id" in first_page_headers + + # Verify get_last_response_headers works + last_headers = query_iterable.get_last_response_headers() + assert last_headers is not None + assert "x-ms-request-charge" in last_headers + + finally: + await self.created_db.delete_container(created_collection.id) + + async def test_query_response_headers_multiple_pages_async(self): + """Test that response headers are captured for each page in a paginated query.""" + created_collection = await self.created_db.create_container( + "test_headers_multi_async_" + str(uuid.uuid4()), PartitionKey(path="/pk") + ) + try: + # Create enough items to span multiple pages + num_items = 15 + for i in range(num_items): + await created_collection.create_item(body={"pk": "test", "id": f"item_{i}", "value": i}) + + query = "SELECT * FROM c WHERE c.pk = @pk" + # Use small page size to force multiple pages + query_iterable = created_collection.query_items( + query=query, + parameters=[{"name": "@pk", "value": "test"}], + partition_key="test", + max_item_count=5 # Force pagination with 5 items per page + ) + + # Iterate through items using async for loop (pagination) + items = [] + async for item in query_iterable: + items.append(item) + + # Verify all items were returned + assert len(items) == num_items + + # Verify response headers were captured for multiple pages + response_headers = query_iterable.get_response_headers() + assert response_headers is not None + # With 15 items and max_item_count=5, we expect at least 3 pages + assert len(response_headers) >= 3 + + # Verify each page has headers + for i, headers in enumerate(response_headers): + assert "x-ms-request-charge" in headers, f"Page {i} missing request charge header" + assert "x-ms-activity-id" in headers, f"Page {i} missing activity id header" + + # Each page should have activity IDs + activity_ids = [h.get("x-ms-activity-id") for h in response_headers] + assert len(activity_ids) == len(response_headers) + + finally: + await self.created_db.delete_container(created_collection.id) + + async def test_query_response_headers_empty_result_async(self): + """Test that response headers are captured even when query returns no results.""" + created_collection = await self.created_db.create_container( + "test_headers_empty_async_" + str(uuid.uuid4()), PartitionKey(path="/pk") + ) + try: + # Create an item with different pk + await created_collection.create_item(body={"pk": "other", "id": "item_1"}) + + query = "SELECT * FROM c WHERE c.pk = @pk" + query_iterable = created_collection.query_items( + query=query, + parameters=[{"name": "@pk", "value": "nonexistent"}], + partition_key="nonexistent" + ) + + # Iterate through items (should be empty) + items = [] + async for item in query_iterable: + items.append(item) + + # Verify no items were returned + assert len(items) == 0 + + # Response headers may or may not be captured depending on implementation + # The key is that the method doesn't throw an error + response_headers = query_iterable.get_response_headers() + assert response_headers is not None + + # get_last_response_headers should return None or headers depending on if a request was made + last_headers = query_iterable.get_last_response_headers() + # This can be None if no request was made, or headers if at least one request was made + + finally: + await self.created_db.delete_container(created_collection.id) + + async def test_query_response_headers_with_query_metrics_async(self): + """Test that query metrics are included in response headers when enabled.""" + created_collection = await self.created_db.create_container( + "test_headers_metrics_async_" + str(uuid.uuid4()), PartitionKey(path="/pk") + ) + try: + # Create items + for i in range(5): + await created_collection.create_item(body={"pk": "test", "id": f"item_{i}", "value": i}) + + query = "SELECT * FROM c WHERE c.pk = @pk" + query_iterable = created_collection.query_items( + query=query, + parameters=[{"name": "@pk", "value": "test"}], + partition_key="test", + populate_query_metrics=True + ) + + # Iterate through items + items = [] + async for item in query_iterable: + items.append(item) + + assert len(items) == 5 + + # Verify response headers contain query metrics + response_headers = query_iterable.get_response_headers() + assert len(response_headers) > 0 + + # Check for query metrics header + metrics_header_name = "x-ms-documentdb-query-metrics" + first_page_headers = response_headers[0] + assert metrics_header_name in first_page_headers + + # Validate metrics header is well-formed + metrics_header = first_page_headers[metrics_header_name] + metrics = metrics_header.split(";") + assert len(metrics) > 1 + assert all("=" in x for x in metrics) + + finally: + await self.created_db.delete_container(created_collection.id) + + async def test_query_response_headers_by_page_iteration_async(self): + """Test response headers when using by_page() iteration.""" + created_collection = await self.created_db.create_container( + "test_headers_by_page_async_" + str(uuid.uuid4()), PartitionKey(path="/pk") + ) + try: + # Create items + num_items = 10 + for i in range(num_items): + await created_collection.create_item(body={"pk": "test", "id": f"item_{i}", "value": i}) + + query = "SELECT * FROM c WHERE c.pk = @pk" + query_iterable = created_collection.query_items( + query=query, + parameters=[{"name": "@pk", "value": "test"}], + partition_key="test", + max_item_count=3 # Force multiple pages + ) + + # Iterate by page + all_items = [] + page_count = 0 + async for page in query_iterable.by_page(): + page_items = [item async for item in page] + all_items.extend(page_items) + page_count += 1 + + # After each page, we can check the last response headers + last_headers = query_iterable.get_last_response_headers() + assert last_headers is not None + assert "x-ms-request-charge" in last_headers + + # Verify all items retrieved + assert len(all_items) == num_items + + # Verify we got headers for each page (at least as many as page_count) + response_headers = query_iterable.get_response_headers() + assert len(response_headers) >= page_count + + finally: + await self.created_db.delete_container(created_collection.id) + + async def test_query_response_headers_returns_copies_async(self): + """Test that get_response_headers returns copies, not references.""" + created_collection = await self.created_db.create_container( + "test_headers_copies_async_" + str(uuid.uuid4()), PartitionKey(path="/pk") + ) + try: + await created_collection.create_item(body={"pk": "test", "id": "item_1"}) + + query = "SELECT * FROM c" + query_iterable = created_collection.query_items( + query=query, + partition_key="test" + ) + + # Iterate + async for item in query_iterable: + pass + + # Get headers twice + headers1 = query_iterable.get_response_headers() + headers2 = query_iterable.get_response_headers() + + # They should be equal but not the same object + assert len(headers1) == len(headers2) + if len(headers1) > 0: + # Modifying one should not affect the other + headers1[0]["test-key"] = "test-value" + assert "test-key" not in headers2[0] + + finally: + await self.created_db.delete_container(created_collection.id) + + +if __name__ == "__main__": + unittest.main()