Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from ._change_feed.feed_range_internal import FeedRangeInternalEpk
from ._constants import _Constants as Constants
from ._cosmos_http_logging_policy import CosmosHttpLoggingPolicy
from ._cosmos_responses import CosmosDict, CosmosList
from ._cosmos_responses import CosmosDict, CosmosList, CosmosItemPaged
from ._range_partition_resolver import RangePartitionResolver
from ._read_items_helper import ReadItemsHelperSync
from ._request_object import RequestObject
Expand Down Expand Up @@ -1138,7 +1138,7 @@ def QueryItems(
partition_key: Optional[PartitionKeyType] = None,
response_hook: Optional[Callable[[Mapping[str, Any], dict[str, Any]], None]] = None,
**kwargs: Any
) -> ItemPaged[dict[str, Any]]:
) -> CosmosItemPaged:
"""Queries documents in a collection.

:param str database_or_container_link:
Expand All @@ -1162,7 +1162,7 @@ def QueryItems(
options = {}

if base.IsDatabaseLink(database_or_container_link):
return ItemPaged(
return CosmosItemPaged(
self,
query,
options,
Expand All @@ -1186,7 +1186,7 @@ def fetch_fn(options: Mapping[str, Any]) -> Tuple[list[dict[str, Any]], CaseInse
response_hook=response_hook,
**kwargs)

return ItemPaged(
return CosmosItemPaged(
self,
query,
options,
Expand Down
105 changes: 104 additions & 1 deletion sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_responses.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,113 @@
# The MIT License (MIT)
# Copyright (c) 2024 Microsoft Corporation

from typing import Any, Iterable, Mapping, Optional
from typing import Any, AsyncIterator, Iterator, Iterable, List, Mapping, Optional

from azure.core.async_paging import AsyncItemPaged
from azure.core.paging import ItemPaged
from azure.core.utils import CaseInsensitiveDict


class CosmosItemPaged(ItemPaged[dict[str, Any]]):
"""A custom ItemPaged class that provides access to response headers from query operations.

This class wraps the standard ItemPaged and stores a reference to the underlying
QueryIterable to expose response headers captured during pagination.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._query_iterable: Optional[Any] = None

def by_page(self, continuation_token: Optional[str] = None) -> Iterator[Iterator[dict[str, Any]]]:
"""Get an iterator of pages of objects.

:param str continuation_token: An opaque continuation token.
:returns: An iterator of pages (themselves iterator of objects)
:rtype: iterator[iterator[dict[str, Any]]]
"""
# Call the parent's by_page to get the QueryIterable and store reference for header access
self._query_iterable = super().by_page(continuation_token)
return self._query_iterable

def get_response_headers(self) -> List[CaseInsensitiveDict]:
"""Returns a list of response headers captured from each page of results.

Each element in the list corresponds to the headers from one page fetch.
Headers are only available after iterating through the results.

:return: List of response headers from each page
:rtype: List[~azure.core.utils.CaseInsensitiveDict]
"""
if self._query_iterable is None:
return []
if hasattr(self._query_iterable, 'get_response_headers'):
return self._query_iterable.get_response_headers()
return []

def get_last_response_headers(self) -> Optional[CaseInsensitiveDict]:
"""Returns the response headers from the most recent page fetch.

:return: Response headers from the last page, or None if no pages have been fetched
:rtype: Optional[~azure.core.utils.CaseInsensitiveDict]
"""
if self._query_iterable is None:
return None
if hasattr(self._query_iterable, 'get_last_response_headers'):
return self._query_iterable.get_last_response_headers()
return None


class CosmosAsyncItemPaged(AsyncItemPaged[dict[str, Any]]):
"""A custom AsyncItemPaged class that provides access to response headers from async query operations.

This class wraps the standard AsyncItemPaged and stores a reference to the underlying
QueryIterable to expose response headers captured during pagination.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._query_iterable: Optional[Any] = None

def by_page(self, continuation_token: Optional[str] = None) -> AsyncIterator[AsyncIterator[dict[str, Any]]]:
"""Get an async iterator of pages of objects.

:param str continuation_token: An opaque continuation token.
:returns: An async iterator of pages (themselves async iterator of objects)
:rtype: AsyncIterator[AsyncIterator[dict[str, Any]]]
"""
# Call the parent's by_page to get the QueryIterable and store reference for header access
self._query_iterable = super().by_page(continuation_token)
return self._query_iterable

def get_response_headers(self) -> List[CaseInsensitiveDict]:
"""Returns a list of response headers captured from each page of results.

Each element in the list corresponds to the headers from one page fetch.
Headers are only available after iterating through the results.

:return: List of response headers from each page
:rtype: List[~azure.core.utils.CaseInsensitiveDict]
"""
if self._query_iterable is None:
return []
if hasattr(self._query_iterable, 'get_response_headers'):
return self._query_iterable.get_response_headers()
return []

def get_last_response_headers(self) -> Optional[CaseInsensitiveDict]:
"""Returns the response headers from the most recent page fetch.

:return: Response headers from the last page, or None if no pages have been fetched
:rtype: Optional[~azure.core.utils.CaseInsensitiveDict]
"""
if self._query_iterable is None:
return None
if hasattr(self._query_iterable, 'get_last_response_headers'):
return self._query_iterable.get_last_response_headers()
return None


class CosmosDict(dict[str, Any]):
def __init__(self, original_dict: Optional[Mapping[str, Any]], /, *, response_headers: CaseInsensitiveDict) -> None:
if original_dict is None:
Expand Down
50 changes: 49 additions & 1 deletion sdk/cosmos/azure-cosmos/azure/cosmos/_query_iterable.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,18 @@
"""Iterable query results in the Azure Cosmos database service.
"""
import time
from typing import List, Optional

from azure.core.paging import PageIterator # type: ignore
from azure.core.utils import CaseInsensitiveDict
from azure.cosmos._constants import _Constants, TimeoutScope
from azure.cosmos._execution_context import execution_dispatcher
from azure.cosmos import exceptions

# pylint: disable=protected-access


class QueryIterable(PageIterator):
class QueryIterable(PageIterator): # pylint: disable=too-many-instance-attributes
"""Represents an iterable object of the query results.

QueryIterable is a wrapper for query execution context.
Expand Down Expand Up @@ -81,6 +84,10 @@ def __init__(
self._ex_context = execution_dispatcher._ProxyQueryExecutionContext(
self._client, self._collection_link, self._query, self._options, self._fetch_function,
response_hook, raw_response_hook, resource_type)

# Response headers tracking for query operations
self._response_headers: List[CaseInsensitiveDict] = []

super(QueryIterable, self).__init__(self._fetch_next, self._unpack, continuation_token=continuation_token)

def _unpack(self, block):
Expand Down Expand Up @@ -114,6 +121,47 @@ def _fetch_next(self, *args): # pylint: disable=unused-argument
raise exceptions.CosmosClientTimeoutError()

block = self._ex_context.fetch_next_block()

# Capture response headers after each page fetch
self._capture_response_headers()

if not block:
raise StopIteration
return block

def _capture_response_headers(self) -> None:
"""Capture response headers from the last request."""
if self._client.last_response_headers:
headers = self._client.last_response_headers.copy()
self._response_headers.append(headers)

def get_response_headers(self) -> List[CaseInsensitiveDict]:
"""Get all response headers collected during query iteration.

Each entry in the list corresponds to one page/request made during
the query execution. Headers are captured as queries are iterated,
so this list grows as you consume more results. This method is typically accessed via the
:class:`~azure.cosmos.CosmosItemPaged` object returned from
:meth:`~azure.cosmos.ContainerProxy.query_items`.

:return: List of response headers from each page request.
:rtype: list[~azure.core.utils.CaseInsensitiveDict]

Example:
>>> items = container.query_items(query="SELECT * FROM c")
>>> for item in items:
... process(item)
>>> headers = items.get_response_headers()
>>> print(f"Total pages fetched: {len(headers)}")
"""
return [h.copy() for h in self._response_headers]

def get_last_response_headers(self) -> Optional[CaseInsensitiveDict]:
"""Get the response headers from the most recent page fetch.

:return: Response headers from the last page, or None if no pages fetched yet.
:rtype: ~azure.core.utils.CaseInsensitiveDict or None
"""
if self._response_headers:
return self._response_headers[-1].copy()
return None
18 changes: 9 additions & 9 deletions sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
build_options as _build_options, GenerateGuidId, validate_cache_staleness_value)
from .._change_feed.feed_range_internal import FeedRangeInternalEpk

from .._cosmos_responses import CosmosDict, CosmosList
from .._cosmos_responses import CosmosDict, CosmosList, CosmosAsyncItemPaged
from .._constants import _Constants as Constants, TimeoutScope
from .._routing.routing_range import Range
from .._session_token_helpers import get_latest_session_token
Expand Down Expand Up @@ -548,7 +548,7 @@ def query_items(
throughput_bucket: Optional[int] = None,
availability_strategy_config: Optional[dict[str, Any]] = _Unset,
**kwargs: Any
) -> AsyncItemPaged[dict[str, Any]]:
) -> CosmosAsyncItemPaged:
"""Return all results matching the given `query`.

You can use any value for the container name in the FROM clause, but
Expand Down Expand Up @@ -594,7 +594,7 @@ def query_items(
The threshold-based availability strategy to use for this request.
If not provided, the client's default strategy will be used.
:returns: An Iterable of items (dicts).
:rtype: AsyncItemPaged[dict[str, Any]]
:rtype: CosmosAsyncItemPaged

.. admonition:: Example:

Expand Down Expand Up @@ -634,7 +634,7 @@ def query_items(
throughput_bucket: Optional[int] = None,
availability_strategy_config: Optional[dict[str, Any]] = _Unset,
**kwargs: Any
) -> AsyncItemPaged[dict[str, Any]]:
) -> CosmosAsyncItemPaged:
"""Return all results matching the given `query`.

You can use any value for the container name in the FROM clause, but
Expand Down Expand Up @@ -677,7 +677,7 @@ def query_items(
The threshold-based availability strategy to use for this request.
If not provided, the client's default strategy will be used.
:returns: An Iterable of items (dicts).
:rtype: AsyncItemPaged[dict[str, Any]]
:rtype: CosmosAsyncItemPaged

.. admonition:: Example:

Expand Down Expand Up @@ -716,7 +716,7 @@ def query_items(
throughput_bucket: Optional[int] = None,
availability_strategy_config: Optional[dict[str, Any]] = _Unset,
**kwargs: Any
) -> AsyncItemPaged[dict[str, Any]]:
) -> CosmosAsyncItemPaged:
"""Return all results matching the given `query`.

You can use any value for the container name in the FROM clause, but
Expand Down Expand Up @@ -758,7 +758,7 @@ def query_items(
The threshold-based availability strategy to use for this request.
If not provided, the client's default strategy will be used.
:returns: An Iterable of items (dicts).
:rtype: AsyncItemPaged[Dict[str, Any]]
:rtype: CosmosAsyncItemPaged

.. admonition:: Example:

Expand All @@ -783,7 +783,7 @@ def query_items(
self,
*args: Any,
**kwargs: Any
) -> AsyncItemPaged[dict[str, Any]]:
) -> CosmosAsyncItemPaged:
"""Return all results matching the given `query`.

You can use any value for the container name in the FROM clause, but
Expand Down Expand Up @@ -833,7 +833,7 @@ def query_items(
The threshold-based availability strategy to use for this request.
If not provided, the client's default strategy will be used.
:returns: An Iterable of items (dicts).
:rtype: AsyncItemPaged[dict[str, Any]]
:rtype: CosmosAsyncItemPaged

.. admonition:: Example:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from .._routing import routing_range
from ..documents import ConnectionPolicy, DatabaseAccount
from .._constants import _Constants as Constants
from .._cosmos_responses import CosmosDict, CosmosList
from .._cosmos_responses import CosmosDict, CosmosList, CosmosAsyncItemPaged
from .. import http_constants, exceptions
from . import _query_iterable_async as query_iterable
from .. import _runtime_constants as runtime_constants
Expand Down Expand Up @@ -2353,7 +2353,7 @@ def QueryItems(
partition_key: Optional[PartitionKeyType] = None,
response_hook: Optional[Callable[[Mapping[str, Any], dict[str, Any]], None]] = None,
**kwargs: Any
) -> AsyncItemPaged[dict[str, Any]]:
) -> CosmosAsyncItemPaged:
"""Queries documents in a collection.

:param str database_or_container_link:
Expand All @@ -2375,7 +2375,7 @@ def QueryItems(
options = {}

if base.IsDatabaseLink(database_or_container_link):
return AsyncItemPaged(
return CosmosAsyncItemPaged(
self,
query,
options,
Expand Down Expand Up @@ -2406,7 +2406,7 @@ async def fetch_fn(options: Mapping[str, Any]) -> Tuple[list[dict[str, Any]], Ca
self.last_response_headers,
)

return AsyncItemPaged(
return CosmosAsyncItemPaged(
self,
query,
options,
Expand Down
Loading
Loading