Skip to content

Commit 28ecc72

Browse files
authored
[Cosmos]: Adding new options for semantic reranking, and adding more test cases (#43275)
* Updating options for semantic reranking, and adding more test cases * Resolving comments
1 parent d849b25 commit 28ecc72

File tree

6 files changed

+180
-75
lines changed

6 files changed

+180
-75
lines changed

sdk/cosmos/azure-cosmos/azure/cosmos/_inference_service.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import json
2323
import os
2424
import urllib
25-
from typing import Any, cast, Dict, List, Optional
25+
from typing import Any, cast, Optional
2626
from urllib3.util.retry import Retry
2727

2828
from azure.core import PipelineClient
@@ -142,24 +142,28 @@ def _create_inference_pipeline_client(self) -> PipelineClient:
142142
def rerank(
143143
self,
144144
reranking_context: str,
145-
documents: List[str],
146-
semantic_reranking_options: Optional[Dict[str, Any]] = None,
145+
documents: list[str],
146+
semantic_reranking_options: Optional[dict[str, Any]] = None,
147147
) -> CosmosDict:
148148
"""Rerank documents using the semantic reranking service.
149149
150-
:param reranking_context: Query / context string used to score documents.
151-
:type reranking_context: str
152-
:param documents: List of document strings to rerank.
153-
:type documents: List[str]
154-
:param semantic_reranking_options: Optional dictionary of tuning parameters. Supported keys:
155-
* return_documents (bool): Include original document text in results. Default True.
156-
* top_k (int): Limit number of scored documents returned.
157-
* batch_size (int): Batch size for internal scoring operations.
158-
* sort (bool): If True (default) results are ordered by descending score.
159-
:type semantic_reranking_options: Optional[Dict[str, Any]]
160-
:returns: Reranking result payload.
150+
:param str reranking_context: The context or query string to use for reranking the documents.
151+
:param list[str] documents: A list of documents (as strings) to be reranked.
152+
:param dict[str, Any] semantic_reranking_options: Optional dictionary of additional options to customize the semantic reranking process.
153+
154+
Supported options:
155+
156+
* **return_documents** (bool): Whether to return the document text in the response. If False, only scores and indices are returned. Default is True.
157+
* **top_k** (int): Maximum number of documents to return in the reranked results. If not specified, all documents are returned.
158+
* **batch_size** (int): Number of documents to process in each batch. Used for optimizing performance with large document sets.
159+
* **sort** (bool): Whether to sort the results by relevance score in descending order. Default is True.
160+
* **document_type** (str): Type of documents being reranked. Supported values are "string" and "json".
161+
* **target_paths** (str): If document_type is "json", the list of JSON paths to extract text from for reranking. Comma-separated string.
162+
163+
:type semantic_reranking_options: Optional[dict[str, Any]]
164+
:returns: A CosmosDict containing the reranking results. The structure typically includes results list with reranked documents and their relevance scores. Each result contains index, relevance_score, and optionally document.
161165
:rtype: ~azure.cosmos.CosmosDict[str, Any]
162-
:raises ~azure.cosmos.exceptions.CosmosHttpResponseError: On HTTP or service error.
166+
:raises ~azure.cosmos.exceptions.CosmosHttpResponseError: If the semantic reranking operation fails.
163167
"""
164168
try:
165169
body = {
@@ -168,14 +172,7 @@ def rerank(
168172
}
169173

170174
if semantic_reranking_options:
171-
if "return_documents" in semantic_reranking_options:
172-
body["return_documents"] = semantic_reranking_options["return_documents"]
173-
if "top_k" in semantic_reranking_options:
174-
body["top_k"] = semantic_reranking_options["top_k"]
175-
if "batch_size" in semantic_reranking_options:
176-
body["batch_size"] = semantic_reranking_options["batch_size"]
177-
if "sort" in semantic_reranking_options:
178-
body["sort"] = semantic_reranking_options["sort"]
175+
body.update(semantic_reranking_options)
179176

180177
headers = {
181178
HttpHeaders.ContentType: "application/json"

sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,8 +1183,8 @@ async def upsert_item(
11831183
async def semantic_rerank(
11841184
self,
11851185
reranking_context: str,
1186-
documents: List[str],
1187-
semantic_reranking_options: Optional[Dict[str, Any]] = None
1186+
documents: list[str],
1187+
semantic_reranking_options: Optional[dict[str, Any]] = None
11881188
) -> CosmosDict:
11891189
"""Rerank a list of documents using semantic reranking.
11901190
@@ -1193,7 +1193,7 @@ async def semantic_rerank(
11931193
11941194
:param str reranking_context: The context or query string to use for reranking the documents.
11951195
:param list[str] documents: A list of documents (as strings) to be reranked.
1196-
:param semantic_reranking_options: Optional dictionary of additional options to customize the semantic reranking process.
1196+
:param dict[str, Any] semantic_reranking_options: Optional dictionary of additional options to customize the semantic reranking process.
11971197
11981198
Supported options:
11991199
@@ -1202,9 +1202,9 @@ async def semantic_rerank(
12021202
* **batch_size** (int): Number of documents to process in each batch. Used for optimizing performance with large document sets.
12031203
* **sort** (bool): Whether to sort the results by relevance score in descending order. Default is True.
12041204
* **document_type** (str): Type of documents being reranked. Supported values are "string" and "json".
1205-
* **target_paths** (list[str]): If document_type is "json", the list of JSON paths to extract text from for reranking.
1205+
* **target_paths** (str): If document_type is "json", the list of JSON paths to extract text from for reranking. Comma-separated string.
12061206
1207-
:type semantic_reranking_options: Optional[Dict[str, Any]]
1207+
:type semantic_reranking_options: Optional[dict[str, Any]]
12081208
:returns: A CosmosDict containing the reranking results. The structure typically includes results list with reranked documents and their relevance scores. Each result contains index, relevance_score, and optionally document.
12091209
:rtype: ~azure.cosmos.CosmosDict[str, Any]
12101210
:raises ~azure.cosmos.exceptions.CosmosHttpResponseError: If the semantic reranking operation fails.

sdk/cosmos/azure-cosmos/azure/cosmos/aio/_inference_service_async.py

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import json
2323
import os
2424
import urllib
25-
from typing import Any, cast, Dict, List, Optional
25+
from typing import Any, cast, Optional
2626
from urllib.parse import urlparse
2727
from urllib3.util.retry import Retry
2828

@@ -171,41 +171,37 @@ def _get_ssl_verification_setting(self) -> bool:
171171
async def rerank(
172172
self,
173173
reranking_context: str,
174-
documents: List[str],
175-
semantic_reranking_options: Optional[Dict[str, Any]] = None,
174+
documents: list[str],
175+
semantic_reranking_options: Optional[dict[str, Any]] = None,
176176
) -> CosmosDict:
177177
"""Rerank documents using the semantic reranking service (async).
178178
179-
:param reranking_context: Query / context string used to score documents.
180-
:type reranking_context: str
181-
:param documents: List of document strings to rerank.
182-
:type documents: List[str]
183-
:param semantic_reranking_options: Optional dictionary of tuning parameters. Supported keys:
184-
* return_documents (bool): Include original document text in results. Default True.
185-
* top_k (int): Limit number of scored documents returned.
186-
* batch_size (int): Batch size for internal scoring operations.
187-
* sort (bool): If True (default) results are ordered by descending score.
188-
:type semantic_reranking_options: Optional[Dict[str, Any]]
189-
:returns: Reranking result payload.
179+
:param str reranking_context: The context or query string to use for reranking the documents.
180+
:param list[str] documents: A list of documents (as strings) to be reranked.
181+
:param dict[str, Any] semantic_reranking_options: Optional dictionary of additional options to customize the semantic reranking process.
182+
183+
Supported options:
184+
185+
* **return_documents** (bool): Whether to return the document text in the response. If False, only scores and indices are returned. Default is True.
186+
* **top_k** (int): Maximum number of documents to return in the reranked results. If not specified, all documents are returned.
187+
* **batch_size** (int): Number of documents to process in each batch. Used for optimizing performance with large document sets.
188+
* **sort** (bool): Whether to sort the results by relevance score in descending order. Default is True.
189+
* **document_type** (str): Type of documents being reranked. Supported values are "string" and "json".
190+
* **target_paths** (str): If document_type is "json", the list of JSON paths to extract text from for reranking. Comma-separated string.
191+
192+
:type semantic_reranking_options: Optional[dict[str, Any]]
193+
:returns: A CosmosDict containing the reranking results. The structure typically includes results list with reranked documents and their relevance scores. Each result contains index, relevance_score, and optionally document.
190194
:rtype: ~azure.cosmos.CosmosDict[str, Any]
191-
:raises ~azure.cosmos.exceptions.CosmosHttpResponseError: On HTTP or service error.
195+
:raises ~azure.cosmos.exceptions.CosmosHttpResponseError: If the semantic reranking operation fails.
192196
"""
193197
try:
194198
body = {
195199
"query": reranking_context,
196200
"documents": documents,
197201
}
198202

199-
# Add optional parameters if provided
200203
if semantic_reranking_options:
201-
if "return_documents" in semantic_reranking_options:
202-
body["return_documents"] = semantic_reranking_options["return_documents"]
203-
if "top_k" in semantic_reranking_options:
204-
body["top_k"] = semantic_reranking_options["top_k"]
205-
if "batch_size" in semantic_reranking_options:
206-
body["batch_size"] = semantic_reranking_options["batch_size"]
207-
if "sort" in semantic_reranking_options:
208-
body["sort"] = semantic_reranking_options["sort"]
204+
body.update(semantic_reranking_options)
209205

210206
headers = {
211207
HttpHeaders.ContentType: "application/json"

sdk/cosmos/azure-cosmos/azure/cosmos/container.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,8 +1058,8 @@ def query_items( # pylint:disable=docstring-missing-param
10581058
def semantic_rerank(
10591059
self,
10601060
reranking_context: str,
1061-
documents: List[str],
1062-
semantic_reranking_options: Optional[Dict[str, Any]] = None
1061+
documents: list[str],
1062+
semantic_reranking_options: Optional[dict[str, Any]] = None
10631063
) -> CosmosDict:
10641064
"""Rerank a list of documents using semantic reranking.
10651065
@@ -1068,7 +1068,7 @@ def semantic_rerank(
10681068
10691069
:param str reranking_context: The context or query string to use for reranking the documents.
10701070
:param list[str] documents: A list of documents (as strings) to be reranked.
1071-
:param semantic_reranking_options: Optional dictionary of additional options to customize the semantic reranking process.
1071+
:param dict[str, Any] semantic_reranking_options: Optional dictionary of additional options to customize the semantic reranking process.
10721072
10731073
Supported options:
10741074
@@ -1077,9 +1077,9 @@ def semantic_rerank(
10771077
* **batch_size** (int): Number of documents to process in each batch. Used for optimizing performance with large document sets.
10781078
* **sort** (bool): Whether to sort the results by relevance score in descending order. Default is True.
10791079
* **document_type** (str): Type of documents being reranked. Supported values are "string" and "json".
1080-
* **target_paths** (list[str]): If document_type is "json", the list of JSON paths to extract text from for reranking.
1080+
* **target_paths** (str): If document_type is "json", the list of JSON paths to extract text from for reranking. Comma-separated string.
10811081
1082-
:type semantic_reranking_options: Optional[Dict[str, Any]]
1082+
:type semantic_reranking_options: Optional[dict[str, Any]]
10831083
:returns: A CosmosDict containing the reranking results. The structure typically includes results list with reranked documents and their relevance scores. Each result contains index, relevance_score, and optionally document.
10841084
:rtype: ~azure.cosmos.CosmosDict[str, Any]
10851085
:raises ~azure.cosmos.exceptions.CosmosHttpResponseError: If the semantic reranking operation fails.

sdk/cosmos/azure-cosmos/tests/test_semantic_reranker.py

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# The MIT License (MIT)
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# cspell:ignore rerank reranker reranking
4+
import json
45
import unittest
56

67
import azure.cosmos.cosmos_client as cosmos_client
@@ -44,7 +45,7 @@ def tearDownClass(cls):
4445
pass
4546

4647
def test_semantic_reranker(self):
47-
documents = self._get_documents()
48+
documents = self._get_documents(document_type="string")
4849
results = self.test_container.semantic_rerank(
4950
reranking_context="What is the capital of France?",
5051
documents=documents,
@@ -59,9 +60,62 @@ def test_semantic_reranker(self):
5960
assert len(results["Scores"]) == len(documents)
6061
assert results["Scores"][0]["document"] == "Paris is the capital of France."
6162

62-
def _get_documents(self):
63-
return [
64-
"Berlin is the capital of Germany.",
65-
"Paris is the capital of France.",
66-
"Madrid is the capital of Spain."
67-
]
63+
def test_semantic_reranker_json_documents(self):
64+
documents = self._get_documents(document_type="json")
65+
results = self.test_container.semantic_rerank(
66+
reranking_context="What is the capital of France?",
67+
documents=[json.dumps(item) for item in documents],
68+
semantic_reranking_options={
69+
"return_documents": True,
70+
"top_k": 10,
71+
"batch_size": 32,
72+
"sort": True,
73+
"document_type": "json",
74+
"target_paths": "text",
75+
}
76+
)
77+
78+
assert len(results["Scores"]) == len(documents)
79+
returned_document = json.loads(results["Scores"][0]["document"])
80+
assert returned_document["text"] == "Paris is the capital of France."
81+
82+
def test_semantic_reranker_nested_json_documents(self):
83+
documents = self._get_documents(document_type="nested_json")
84+
results = self.test_container.semantic_rerank(
85+
reranking_context="What is the capital of France?",
86+
documents=[json.dumps(item) for item in documents],
87+
semantic_reranking_options={
88+
"return_documents": True,
89+
"top_k": 10,
90+
"batch_size": 32,
91+
"sort": True,
92+
"document_type": "json",
93+
"target_paths": "info.text",
94+
}
95+
)
96+
97+
assert len(results["Scores"]) == len(documents)
98+
returned_document = json.loads(results["Scores"][0]["document"])
99+
assert returned_document["info"]["text"] == "Paris is the capital of France."
100+
101+
def _get_documents(self, document_type: str):
102+
if document_type == "string":
103+
return [
104+
"Berlin is the capital of Germany.",
105+
"Paris is the capital of France.",
106+
"Madrid is the capital of Spain."
107+
]
108+
elif document_type == "json":
109+
return [
110+
{"id": "1", "text": "Berlin is the capital of Germany."},
111+
{"id": "2", "text": "Paris is the capital of France."},
112+
{"id": "3", "text": "Madrid is the capital of Spain."}
113+
]
114+
elif document_type == "nested_json":
115+
return [
116+
{"id": "1", "info": {"text": "Berlin is the capital of Germany."}},
117+
{"id": "2", "info": {"text": "Paris is the capital of France."}},
118+
{"id": "3", "info": {"text": "Madrid is the capital of Spain."}}
119+
]
120+
else:
121+
raise ValueError("Unsupported document type")

0 commit comments

Comments
 (0)