Skip to content

Commit fa629aa

Browse files
aayush3011Copilot
andauthored
[Cosmos]: Adding Semantic Reranker API (#42991)
* Adding semantic reranker api * updating changelog * Update sdk/cosmos/azure-cosmos/azure/cosmos/container.py Co-authored-by: Copilot <[email protected]> * updating docstring * Adding async api changes * Resolving comments * Fixing build * Fixing build * Fixing build * Fixing build * Resolving comments * Updating changelog * Adding env variable for the inference service endpoint * Adding env variable for the inference service endpoint * Fixing build * Fixing build * Fixing build * Fixing build * Fixing build * Fixing build * Fixing build * Resolving comments * Resolving comments * Resolving comments --------- Co-authored-by: Copilot <[email protected]>
1 parent c0d0a7c commit fa629aa

12 files changed

+921
-47
lines changed

sdk/cosmos/azure-cosmos/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#### Features Added
66
* Added ability to return a tuple of a DatabaseProxy/ContainerProxy with the associated database/container properties when creating or reading databases/containers through `return_properties` parameter. See [PR 41742](https://github.com/Azure/azure-sdk-for-python/pull/41742)
7+
* Added a new API for Semantic Reranking. See [PR 42991](https://github.com/Azure/azure-sdk-for-python/pull/42991)
78
#### Breaking Changes
89

910
#### Bugs Fixed

sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from typing import Dict
2727
from typing_extensions import Literal
2828

29+
# cspell:ignore reranker
30+
2931

3032
class _Constants:
3133
"""Constants used in the azure-cosmos package"""
@@ -53,10 +55,12 @@ class _Constants:
5355
MAX_ITEM_BUFFER_VS_CONFIG_DEFAULT: int = 50000
5456
SESSION_TOKEN_FALSE_PROGRESS_MERGE_CONFIG: str = "AZURE_COSMOS_SESSION_TOKEN_FALSE_PROGRESS_MERGE"
5557
SESSION_TOKEN_FALSE_PROGRESS_MERGE_CONFIG_DEFAULT: str = "True"
56-
CIRCUIT_BREAKER_ENABLED_CONFIG: str = "AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"
58+
CIRCUIT_BREAKER_ENABLED_CONFIG: str = "AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"
5759
CIRCUIT_BREAKER_ENABLED_CONFIG_DEFAULT: str = "False"
5860
AAD_SCOPE_OVERRIDE: str = "AZURE_COSMOS_AAD_SCOPE_OVERRIDE"
5961
AAD_DEFAULT_SCOPE: str = "https://cosmos.azure.com/.default"
62+
INFERENCE_SERVICE_DEFAULT_SCOPE = "https://dbinference.azure.com/.default"
63+
SEMANTIC_RERANKER_INFERENCE_ENDPOINT: str = "AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT"
6064

6165
# Database Account Retry Policy constants
6266
AZURE_COSMOS_HEALTH_CHECK_MAX_RETRIES: str = "AZURE_COSMOS_HEALTH_CHECK_MAX_RETRIES"

sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
from ._request_object import RequestObject
7272
from ._retry_utility import ConnectionRetryPolicy
7373
from ._routing import routing_map_provider, routing_range
74+
from ._inference_service import _InferenceService
7475
from .documents import ConnectionPolicy, DatabaseAccount
7576
from .partition_key import (
7677
_Undefined,
@@ -236,6 +237,10 @@ def __init__( # pylint: disable=too-many-statements
236237
policies=policies
237238
)
238239

240+
self._inference_service: Optional[_InferenceService] = None
241+
if self.aad_credentials:
242+
self._inference_service = _InferenceService(self)
243+
239244
# Query compatibility mode.
240245
# Allows to specify compatibility mode used by client when making query requests. Should be removed when
241246
# application/sql is no longer supported.
@@ -302,6 +307,10 @@ def _set_client_consistency_level(
302307
else:
303308
self.session = None
304309

310+
def _get_inference_service(self) -> Optional[_InferenceService]:
311+
"""Get inference service instance"""
312+
return self._inference_service
313+
305314
@property
306315
def Session(self) -> Optional[_session.Session]:
307316
"""Gets the session object from the client.
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# The MIT License (MIT)
2+
# Copyright (c) 2014 Microsoft Corporation
3+
4+
# Permission is hereby granted, free of charge, to any person obtaining a copy
5+
# of this software and associated documentation files (the "Software"), to deal
6+
# in the Software without restriction, including without limitation the rights
7+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8+
# copies of the Software, and to permit persons to whom the Software is
9+
# furnished to do so, subject to the following conditions:
10+
11+
# The above copyright notice and this permission notice shall be included in all
12+
# copies or substantial portions of the Software.
13+
14+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
20+
# SOFTWARE.
21+
from typing import TypeVar, Any, MutableMapping, cast
22+
23+
from azure.core.pipeline import PipelineRequest
24+
from azure.core.pipeline.policies import BearerTokenCredentialPolicy
25+
from azure.core.pipeline.transport import HttpRequest as LegacyHttpRequest
26+
from azure.core.rest import HttpRequest
27+
from azure.core.credentials import AccessToken
28+
29+
HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest)
30+
31+
32+
class InferenceServiceBearerTokenPolicy(BearerTokenCredentialPolicy):
33+
"""Bearer token authentication policy for inference service.
34+
35+
This policy preserves the standard JWT Bearer token format required by
36+
external inference services, unlike CosmosBearerTokenCredentialPolicy which
37+
modifies tokens for Cosmos DB authentication.
38+
"""
39+
40+
@staticmethod
41+
def _update_headers(headers: MutableMapping[str, str], token: str) -> None:
42+
"""Updates the Authorization header with the standard-bearer token format.
43+
44+
:param MutableMapping[str, str] headers: The HTTP Request headers
45+
:param str token: The OAuth token.
46+
"""
47+
headers["Authorization"] = f"Bearer {token}"
48+
49+
def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
50+
"""Called before the policy sends a request.
51+
52+
The base implementation authorizes the request with a bearer token.
53+
54+
:param ~azure.core.pipeline.PipelineRequest request: the request
55+
"""
56+
super().on_request(request)
57+
# The None-check for self._token is done in the parent on_request
58+
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)
59+
60+
def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
61+
"""Acquire a token from the credential and authorize the request with it.
62+
63+
Keyword arguments are passed to the credential's get_token method. The token will be cached and used to
64+
authorize future requests.
65+
66+
:param ~azure.core.pipeline.PipelineRequest request: the request
67+
:param str scopes: required scopes of authentication
68+
"""
69+
super().authorize_request(request, *scopes, **kwargs)
70+
# The None-check for self._token is done in the parent authorize_request
71+
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
# The MIT License (MIT)
2+
# Copyright (c) 2014 Microsoft Corporation
3+
4+
# Permission is hereby granted, free of charge, to any person obtaining a copy
5+
# of this software and associated documentation files (the "Software"), to deal
6+
# in the Software without restriction, including without limitation the rights
7+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8+
# copies of the Software, and to permit persons to whom the Software is
9+
# furnished to do so, subject to the following conditions:
10+
11+
# The above copyright notice and this permission notice shall be included in all
12+
# copies or substantial portions of the Software.
13+
14+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
20+
# SOFTWARE.
21+
22+
import json
23+
import os
24+
import urllib
25+
from typing import Any, cast, Dict, List, Optional
26+
from urllib3.util.retry import Retry
27+
28+
from azure.core import PipelineClient
29+
from azure.core.exceptions import DecodeError
30+
from azure.core.pipeline.policies import (ContentDecodePolicy, CustomHookPolicy, DistributedTracingPolicy,
31+
HeadersPolicy, HTTPPolicy, NetworkTraceLoggingPolicy, ProxyPolicy,
32+
UserAgentPolicy)
33+
from azure.core.pipeline.transport import HttpRequest
34+
from azure.core.utils import CaseInsensitiveDict
35+
36+
from . import exceptions
37+
from ._constants import _Constants as Constants
38+
from ._cosmos_http_logging_policy import CosmosHttpLoggingPolicy
39+
from ._cosmos_responses import CosmosDict
40+
from ._inference_auth_policy import InferenceServiceBearerTokenPolicy
41+
from ._retry_utility import ConnectionRetryPolicy
42+
from .http_constants import HttpHeaders
43+
44+
45+
# cspell:ignore rerank reranker reranking
46+
# pylint: disable=protected-access,line-too-long
47+
48+
49+
class _InferenceService:
50+
"""Internal client for inference service."""
51+
52+
TOTAL_RETRIES = 3
53+
RETRY_BACKOFF_MAX = 120 # seconds
54+
RETRY_AFTER_STATUS_CODES = frozenset([429, 500])
55+
RETRY_BACKOFF_FACTOR = 0.8
56+
inference_service_default_scope = Constants.INFERENCE_SERVICE_DEFAULT_SCOPE
57+
semantic_reranking_inference_endpoint = os.environ.get(Constants.SEMANTIC_RERANKER_INFERENCE_ENDPOINT)
58+
59+
def __init__(self, cosmos_client_connection):
60+
"""Initialize inference service with credentials and endpoint information.
61+
62+
:param cosmos_client_connection: Optional reference to cosmos client connection for accessing settings
63+
:type cosmos_client_connection: Optional[CosmosClientConnection]
64+
"""
65+
self._client_connection = cosmos_client_connection
66+
self._aad_credentials = self._client_connection.aad_credentials
67+
self._token_scope = self.inference_service_default_scope
68+
69+
self._inference_endpoint = f"{self.semantic_reranking_inference_endpoint}/inference/semanticReranking"
70+
self._inference_pipeline_client = self._create_inference_pipeline_client()
71+
72+
def _create_inference_pipeline_client(self) -> PipelineClient:
73+
"""Create a pipeline for inference requests.
74+
75+
:returns: A PipelineClient configured for inference calls.
76+
:rtype: ~azure.core.PipelineClient
77+
"""
78+
access_token = self._aad_credentials
79+
auth_policy = InferenceServiceBearerTokenPolicy(access_token, self._token_scope)
80+
81+
connection_policy = self._client_connection.connection_policy
82+
retry_policy = None
83+
if isinstance(connection_policy.ConnectionRetryConfiguration, HTTPPolicy):
84+
retry_policy = ConnectionRetryPolicy(
85+
retry_total=getattr(connection_policy.ConnectionRetryConfiguration, 'retry_total',
86+
self.TOTAL_RETRIES),
87+
retry_connect=getattr(connection_policy.ConnectionRetryConfiguration, 'retry_connect', None),
88+
retry_read=getattr(connection_policy.ConnectionRetryConfiguration, 'retry_read', None),
89+
retry_status=getattr(connection_policy.ConnectionRetryConfiguration, 'retry_status', None),
90+
retry_backoff_max=getattr(connection_policy.ConnectionRetryConfiguration, 'retry_backoff_max',
91+
self.RETRY_BACKOFF_MAX),
92+
retry_on_status_codes=getattr(connection_policy.ConnectionRetryConfiguration, 'retry_on_status_codes',
93+
self.RETRY_AFTER_STATUS_CODES),
94+
retry_backoff_factor=getattr(connection_policy.ConnectionRetryConfiguration, 'retry_backoff_factor',
95+
self.RETRY_BACKOFF_FACTOR)
96+
)
97+
elif isinstance(connection_policy.ConnectionRetryConfiguration, int):
98+
retry_policy = ConnectionRetryPolicy(total=connection_policy.ConnectionRetryConfiguration)
99+
elif isinstance(connection_policy.ConnectionRetryConfiguration, Retry):
100+
# Convert a urllib3 retry policy to a Pipeline policy
101+
retry_policy = ConnectionRetryPolicy(
102+
retry_total=connection_policy.ConnectionRetryConfiguration.total,
103+
retry_connect=connection_policy.ConnectionRetryConfiguration.connect,
104+
retry_read=connection_policy.ConnectionRetryConfiguration.read,
105+
retry_status=connection_policy.ConnectionRetryConfiguration.status,
106+
retry_backoff_max=connection_policy.ConnectionRetryConfiguration.DEFAULT_BACKOFF_MAX,
107+
retry_on_status_codes=list(connection_policy.ConnectionRetryConfiguration.status_forcelist),
108+
retry_backoff_factor=connection_policy.ConnectionRetryConfiguration.backoff_factor
109+
)
110+
else:
111+
raise TypeError(
112+
"Unsupported retry policy. Must be an azure.cosmos.ConnectionRetryPolicy, int, or urllib3.Retry")
113+
114+
proxies = {}
115+
if connection_policy.ProxyConfiguration and connection_policy.ProxyConfiguration.Host:
116+
host = connection_policy.ProxyConfiguration.Host
117+
url = urllib.parse.urlparse(host)
118+
proxy = host if url.port else host + ":" + str(connection_policy.ProxyConfiguration.Port)
119+
proxies.update({url.scheme: proxy})
120+
121+
self._user_agent: str = self._client_connection._user_agent
122+
policies = [
123+
HeadersPolicy(),
124+
ProxyPolicy(proxies=proxies),
125+
UserAgentPolicy(base_user_agent=self._user_agent),
126+
ContentDecodePolicy(),
127+
retry_policy,
128+
auth_policy,
129+
CustomHookPolicy(),
130+
NetworkTraceLoggingPolicy(),
131+
DistributedTracingPolicy(),
132+
CosmosHttpLoggingPolicy(
133+
enable_diagnostics_logging=self._client_connection._enable_diagnostics_logging,
134+
),
135+
]
136+
137+
return PipelineClient(
138+
base_url=self._inference_endpoint,
139+
policies=policies
140+
)
141+
142+
def rerank(
143+
self,
144+
reranking_context: str,
145+
documents: List[str],
146+
semantic_reranking_options: Optional[Dict[str, Any]] = None,
147+
) -> CosmosDict:
148+
"""Rerank documents using the semantic reranking service.
149+
150+
:param reranking_context: Query / context string used to score documents.
151+
:type reranking_context: str
152+
:param documents: List of document strings to rerank.
153+
:type documents: List[str]
154+
:param semantic_reranking_options: Optional dictionary of tuning parameters. Supported keys:
155+
* return_documents (bool): Include original document text in results. Default True.
156+
* top_k (int): Limit number of scored documents returned.
157+
* batch_size (int): Batch size for internal scoring operations.
158+
* sort (bool): If True (default) results are ordered by descending score.
159+
:type semantic_reranking_options: Optional[Dict[str, Any]]
160+
:returns: Reranking result payload.
161+
:rtype: ~azure.cosmos.CosmosDict[str, Any]
162+
:raises ~azure.cosmos.exceptions.CosmosHttpResponseError: On HTTP or service error.
163+
"""
164+
try:
165+
body = {
166+
"query": reranking_context,
167+
"documents": documents,
168+
}
169+
170+
if semantic_reranking_options:
171+
if "return_documents" in semantic_reranking_options:
172+
body["return_documents"] = semantic_reranking_options["return_documents"]
173+
if "top_k" in semantic_reranking_options:
174+
body["top_k"] = semantic_reranking_options["top_k"]
175+
if "batch_size" in semantic_reranking_options:
176+
body["batch_size"] = semantic_reranking_options["batch_size"]
177+
if "sort" in semantic_reranking_options:
178+
body["sort"] = semantic_reranking_options["sort"]
179+
180+
headers = {
181+
HttpHeaders.ContentType: "application/json"
182+
}
183+
184+
request = HttpRequest(
185+
method="POST",
186+
url=self._inference_endpoint,
187+
headers=headers,
188+
data=json.dumps(body, separators=(",", ":"))
189+
)
190+
191+
pipeline_response = self._inference_pipeline_client._pipeline.run(request)
192+
response = pipeline_response.http_response
193+
response_headers = cast(CaseInsensitiveDict, response.headers)
194+
195+
data = response.body()
196+
if data:
197+
data = data.decode("utf-8")
198+
199+
if response.status_code >= 400:
200+
raise exceptions.CosmosHttpResponseError(message=data, response=response)
201+
202+
result = None
203+
if data:
204+
try:
205+
result = json.loads(data)
206+
except Exception as e:
207+
raise DecodeError(
208+
message="Failed to decode JSON data: {}".format(e),
209+
response=response,
210+
error=e) from e
211+
212+
return CosmosDict(result, response_headers=response_headers)
213+
214+
except Exception as e:
215+
if isinstance(e, (exceptions.CosmosHttpResponseError, exceptions.CosmosResourceNotFoundError)):
216+
raise
217+
raise exceptions.CosmosHttpResponseError(
218+
message=f"Semantic reranking failed: {str(e)}",
219+
response=None
220+
) from e

0 commit comments

Comments
 (0)