Skip to content

Commit f525c46

Browse files
bambrizCopilot
andauthored
Fulltext query parameters async test fix (#42871)
* Fix hybrid text query with parameters * Update CHANGELOG.md * Update sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/hybrid_search_aggregator.py Co-authored-by: Copilot <[email protected]> * Update sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/hybrid_search_aggregator.py Co-authored-by: Copilot <[email protected]> * pylint fix and formatting fix * Update test_query_hybrid_search_async.py Removed enable_cross_partition_query param from async tests as async cosmos db queries does not use this parameter. * Update how we grab parameters for hybrid search queries. Fixed issue that was causing original fix to using parameters with hybrid text queries to not work. * additional test fixes * test string change * update edge case of parameterized hybrid search queries * Fix pylint issues * Update test_query_hybrid_search_async.py --------- Co-authored-by: Copilot <[email protected]>
1 parent 1c0344e commit f525c46

File tree

4 files changed

+131
-53
lines changed

4 files changed

+131
-53
lines changed

sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/hybrid_search_aggregator.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from azure.cosmos._execution_context.aio.base_execution_context import _QueryExecutionContextBase
88
from azure.cosmos._execution_context.aio import document_producer
99
from azure.cosmos._execution_context.hybrid_search_aggregator import _retrieve_component_scores, _rewrite_query_infos, \
10-
_compute_rrf_scores, _compute_ranks, _coalesce_duplicate_rids
10+
_compute_rrf_scores, _compute_ranks, _coalesce_duplicate_rids, _attach_parameters
1111
from azure.cosmos._routing import routing_range
1212
from azure.cosmos import exceptions
1313

@@ -53,9 +53,15 @@ def __init__(self, client, resource_link, options, partitioned_query_execution_i
5353
self._client = client
5454
self._resource_link = resource_link
5555
self._partitioned_query_ex_info = partitioned_query_execution_info
56+
self._parameters = None
5657
# If the query uses parameters, we must save them to add them back to the component queries
5758
query_execution_info = getattr(self._partitioned_query_ex_info, "_query_execution_info", None)
58-
self._parameters = getattr(query_execution_info, "parameters", None) if query_execution_info else None
59+
if query_execution_info:
60+
self._parameters = (
61+
query_execution_info.get("parameters")
62+
if isinstance(query_execution_info, dict)
63+
else getattr(query_execution_info, "parameters", None)
64+
)
5965
self._hybrid_search_query_info = hybrid_search_query_info
6066
self._final_results = []
6167
self._aggregated_global_statistics = None
@@ -68,13 +74,8 @@ async def _run_hybrid_search(self): # pylint: disable=too-many-branches, too-ma
6874
if self._hybrid_search_query_info['requiresGlobalStatistics']:
6975
target_partition_key_ranges = await self._get_target_partition_key_range(target_all_ranges=True)
7076
global_statistics_doc_producers = []
71-
global_statistics_query = self._hybrid_search_query_info['globalStatisticsQuery']
72-
# If query was given parameters we must add them back in
73-
if self._parameters:
74-
global_statistics_query = {
75-
'query': global_statistics_query,
76-
'parameters': self._parameters
77-
}
77+
global_statistics_query = self._attach_parameters(self._hybrid_search_query_info['globalStatisticsQuery'])
78+
7879
partitioned_query_execution_context_list = []
7980
for partition_key_target_range in target_partition_key_ranges:
8081
# create a document producer for each partition key range
@@ -113,7 +114,7 @@ async def _run_hybrid_search(self): # pylint: disable=too-many-branches, too-ma
113114
component_query_infos = self._hybrid_search_query_info['componentQueryInfos']
114115
if self._aggregated_global_statistics:
115116
rewritten_query_infos = _rewrite_query_infos(self._hybrid_search_query_info,
116-
self._aggregated_global_statistics)
117+
self._aggregated_global_statistics, self._parameters)
117118
else:
118119
rewritten_query_infos = component_query_infos
119120

@@ -123,10 +124,8 @@ async def _run_hybrid_search(self): # pylint: disable=too-many-branches, too-ma
123124
for rewritten_query in rewritten_query_infos:
124125
for pk_range in target_partition_key_ranges:
125126
if self._parameters:
126-
rewritten_query['rewrittenQuery'] = {
127-
'query': rewritten_query['rewrittenQuery'],
128-
'parameters': self._parameters
129-
}
127+
rewritten_query['rewrittenQuery'] = _attach_parameters(rewritten_query['rewrittenQuery'],
128+
self._parameters)
130129
component_query_execution_list.append(
131130
document_producer._DocumentProducer(
132131
pk_range,
@@ -195,6 +194,26 @@ async def _run_hybrid_search(self): # pylint: disable=too-many-branches, too-ma
195194
drained_results.sort(key=lambda x: x['Score'], reverse=True)
196195
self._format_final_results(drained_results)
197196

197+
def _attach_parameters(self, query):
198+
"""Attach original query parameters (if any) without mutating the passed query object.
199+
200+
:param query: The original query (string or dict) to which saved parameters should be attached.
201+
:type query: str or dict
202+
:return: The query with parameters attached. Returns the original object if no parameters are stored.
203+
If the input was a string and parameters exist, a new dict is returned. If the input was a
204+
dict without "parameters", a shallow copied dict with "parameters" added is returned.
205+
:rtype: str or dict
206+
"""
207+
if not self._parameters:
208+
return query
209+
if isinstance(query, dict):
210+
if "parameters" not in query:
211+
new_query = dict(query)
212+
new_query["parameters"] = self._parameters
213+
return new_query
214+
return query
215+
return {"query": query, "parameters": self._parameters}
216+
198217
def _format_final_results(self, results):
199218
skip = self._hybrid_search_query_info['skip'] or 0
200219
take = self._hybrid_search_query_info['take']

sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/hybrid_search_aggregator.py

Lines changed: 60 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def _drain_and_coalesce_results(document_producers_to_drain):
7979
return all_results, is_singleton
8080

8181

82-
def _rewrite_query_infos(hybrid_search_query_info, global_statistics):
82+
def _rewrite_query_infos(hybrid_search_query_info, global_statistics, parameters=None):
8383
rewritten_query_infos = []
8484
for query_info in hybrid_search_query_info['componentQueryInfos']:
8585
assert query_info['orderBy']
@@ -90,6 +90,8 @@ def _rewrite_query_infos(hybrid_search_query_info, global_statistics):
9090
_format_component_query_workaround(order_by_expression, global_statistics,
9191
len(hybrid_search_query_info[
9292
'componentQueryInfos'])))
93+
94+
query_info['rewrittenQuery'] = _attach_parameters(query_info['rewrittenQuery'], parameters)
9395
rewritten_query = _format_component_query_workaround(query_info['rewrittenQuery'],
9496
global_statistics,
9597
len(hybrid_search_query_info[
@@ -118,6 +120,10 @@ def _format_component_query(format_string, global_statistics):
118120

119121
def _format_component_query_workaround(format_string, global_statistics, component_count):
120122
# TODO: remove this method once the fix is live and switch back to one above
123+
parameters = None
124+
if isinstance(format_string, dict):
125+
parameters = format_string.get('parameters', None)
126+
format_string = format_string['query']
121127
format_string = format_string.replace(_Placeholders.formattable_order_by, "true")
122128
query = format_string.replace(_Placeholders.total_document_count,
123129
str(global_statistics['documentCount']))
@@ -137,7 +143,29 @@ def _format_component_query_workaround(format_string, global_statistics, compone
137143

138144
statistics_index += 1
139145

140-
return query
146+
return _attach_parameters(query, parameters)
147+
148+
149+
def _attach_parameters(query, parameters=None):
150+
"""Attach original query parameters (if any) without mutating the passed query object.
151+
152+
:param query: The original query text or a query payload dict which may already contain parameters.
153+
:type query: str or dict
154+
:param parameters: Optional sequence of parameter definitions to attach.
155+
:type parameters: list or None
156+
:returns: The original query if no parameters to attach or already present, otherwise a new dict containing the
157+
query and parameters.
158+
:rtype: str or dict
159+
"""
160+
if not parameters:
161+
return query
162+
if isinstance(query, dict):
163+
if "parameters" not in query:
164+
new_query = dict(query)
165+
new_query["parameters"] = parameters
166+
return new_query
167+
return query
168+
return {"query": query, "parameters": parameters}
141169

142170

143171
class _HybridSearchContextAggregator(_QueryExecutionContextBase): # pylint: disable=too-many-instance-attributes
@@ -159,9 +187,15 @@ def __init__(self, client, resource_link, options,
159187
self._client = client
160188
self._resource_link = resource_link
161189
self._partitioned_query_ex_info = partitioned_query_execution_info
190+
self._parameters = None
162191
# If the query uses parameters, we must save them to add them back to the component queries
163192
query_execution_info = getattr(self._partitioned_query_ex_info, "_query_execution_info", None)
164-
self._parameters = getattr(query_execution_info, "parameters", None) if query_execution_info else None
193+
if query_execution_info:
194+
self._parameters = (
195+
query_execution_info.get("parameters")
196+
if isinstance(query_execution_info, dict)
197+
else getattr(query_execution_info, "parameters", None)
198+
)
165199
self._hybrid_search_query_info = hybrid_search_query_info
166200
self._final_results = []
167201
self._aggregated_global_statistics = None
@@ -174,13 +208,7 @@ def _run_hybrid_search(self): # pylint: disable=too-many-branches, too-many-sta
174208
if self._hybrid_search_query_info['requiresGlobalStatistics']:
175209
target_partition_key_ranges = self._get_target_partition_key_range(target_all_ranges=True)
176210
global_statistics_doc_producers = []
177-
global_statistics_query = self._hybrid_search_query_info['globalStatisticsQuery']
178-
# If query was given parameters we must add them back in
179-
if self._parameters:
180-
global_statistics_query = {
181-
'query': global_statistics_query,
182-
'parameters': self._parameters
183-
}
211+
global_statistics_query = self._attach_parameters(self._hybrid_search_query_info['globalStatisticsQuery'])
184212
partitioned_query_execution_context_list = []
185213
for partition_key_target_range in target_partition_key_ranges:
186214
# create a document producer for each partition key range
@@ -218,7 +246,7 @@ def _run_hybrid_search(self): # pylint: disable=too-many-branches, too-many-sta
218246
# re-write the component queries if needed
219247
if self._aggregated_global_statistics:
220248
rewritten_query_infos = _rewrite_query_infos(self._hybrid_search_query_info,
221-
self._aggregated_global_statistics)
249+
self._aggregated_global_statistics, self._parameters)
222250
else:
223251
rewritten_query_infos = self._hybrid_search_query_info['componentQueryInfos']
224252

@@ -229,10 +257,7 @@ def _run_hybrid_search(self): # pylint: disable=too-many-branches, too-many-sta
229257
for pk_range in target_partition_key_ranges:
230258
# If query was given parameters we must add them back in
231259
if self._parameters:
232-
rewritten_query['rewrittenQuery'] = {
233-
'query': rewritten_query['rewrittenQuery'],
234-
'parameters': self._parameters
235-
}
260+
rewritten_query['rewrittenQuery'] = self._attach_parameters(rewritten_query['rewrittenQuery'])
236261
component_query_execution_list.append(
237262
document_producer._DocumentProducer(
238263
pk_range,
@@ -302,6 +327,25 @@ def _run_hybrid_search(self): # pylint: disable=too-many-branches, too-many-sta
302327
drained_results.sort(key=lambda x: x['Score'], reverse=True)
303328
self._format_final_results(drained_results)
304329

330+
def _attach_parameters(self, query):
331+
"""Attach original query parameters (if any) without mutating the passed query object.
332+
333+
:param query: Query text or a query payload dict which may already contain parameters.
334+
:type query: str or dict
335+
:return: The original query if no parameters to attach or already present; otherwise a new dict containing the
336+
query and parameters.
337+
:rtype: str or dict
338+
"""
339+
if not self._parameters:
340+
return query
341+
if isinstance(query, dict):
342+
if "parameters" not in query:
343+
new_query = dict(query)
344+
new_query["parameters"] = self._parameters
345+
return new_query
346+
return query
347+
return {"query": query, "parameters": self._parameters}
348+
305349
def _format_final_results(self, results):
306350
skip = self._hybrid_search_query_info['skip'] or 0
307351
take = self._hybrid_search_query_info['take']
@@ -320,7 +364,7 @@ def _rewrite_query_infos(self):
320364
_format_component_query_workaround(order_by_expression, self._aggregated_global_statistics,
321365
len(self._hybrid_search_query_info[
322366
'componentQueryInfos'])))
323-
367+
query_info['rewrittenQuery'] = _attach_parameters(query_info['rewrittenQuery'], self._parameters)
324368
rewritten_query = _format_component_query_workaround(query_info['rewrittenQuery'],
325369
self._aggregated_global_statistics,
326370
len(self._hybrid_search_query_info[

sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ def test_hybrid_search_query_with_params_equivalence(self):
373373
"ORDER BY RANK FullTextScore(c.title, 'John')"
374374
)
375375
literal_results = self.test_container.query_items(literal_query, enable_cross_partition_query=True)
376+
literal_results = list(literal_results)
376377
literal_indices = [res["index"] for res in literal_results]
377378

378379
# Parameterized hybrid query (same as above, but using @term)
@@ -385,6 +386,7 @@ def test_hybrid_search_query_with_params_equivalence(self):
385386
param_results = self.test_container.query_items(
386387
param_query, parameters=params, enable_cross_partition_query=True
387388
)
389+
param_results = list(param_results)
388390
param_indices = [res["index"] for res in param_results]
389391

390392
# Checks: both forms produce the same results and match known expectation
@@ -398,6 +400,7 @@ def test_weighted_rrf_hybrid_search_with_params_and_response_hook(self):
398400
"ORDER BY RANK RRF(FullTextScore(c.title, 'John'), FullTextScore(c.text, 'United States'), [1, 0.5])"
399401
)
400402
literal_results = self.test_container.query_items(literal_query, enable_cross_partition_query=True)
403+
literal_results = list(literal_results)
401404
literal_indices = [res["index"] for res in literal_results]
402405

403406
# Parameterized weighted RRF hybrid query (+ response hook)
@@ -414,6 +417,7 @@ def test_weighted_rrf_hybrid_search_with_params_and_response_hook(self):
414417
param_results = self.test_container.query_items(
415418
param_query, parameters=params, enable_cross_partition_query=True, response_hook=response_hook
416419
)
420+
param_results = list(param_results)
417421
param_indices = [res["index"] for res in param_results]
418422

419423
# Checks: number of results, equality against literal, and hook invoked
@@ -423,13 +427,13 @@ def test_weighted_rrf_hybrid_search_with_params_and_response_hook(self):
423427

424428
def test_hybrid_and_non_hybrid_param_queries_equivalence(self):
425429
# Hybrid query with vector distance (literal vs param) and compare equality
426-
item_vector = self.test_container.read_item("50", "1")["vector"]
427-
literal_hybrid = (
428-
"SELECT c.index, c.title FROM c "
429-
"ORDER BY RANK RRF(FullTextScore(c.text, 'United States'), VectorDistance(c.vector, {})) "
430-
"OFFSET 0 LIMIT 10"
431-
).format(item_vector)
430+
item = self.test_container.read_item('50', '1')
431+
item_vector = item['vector']
432+
literal_hybrid = "SELECT c.index, c.title FROM c " \
433+
"ORDER BY RANK RRF(FullTextScore(c.text, 'United States'), VectorDistance(c.vector, {})) " \
434+
"OFFSET 0 LIMIT 10".format(item_vector)
432435
literal_hybrid_results = self.test_container.query_items(literal_hybrid, enable_cross_partition_query=True)
436+
literal_hybrid_results = list(literal_hybrid_results)
433437
literal_hybrid_indices = [res["index"] for res in literal_hybrid_results]
434438

435439
param_hybrid = (
@@ -444,6 +448,7 @@ def test_hybrid_and_non_hybrid_param_queries_equivalence(self):
444448
param_hybrid_results = self.test_container.query_items(
445449
param_hybrid, parameters=params_hybrid, enable_cross_partition_query=True
446450
)
451+
param_hybrid_results = list(param_hybrid_results)
447452
param_hybrid_indices = [res["index"] for res in param_hybrid_results]
448453

449454
assert len(literal_hybrid_indices) == len(param_hybrid_indices) == 10
@@ -453,6 +458,7 @@ def test_hybrid_and_non_hybrid_param_queries_equivalence(self):
453458
# Non-hybrid parameterized query equivalence on same container
454459
literal_simple = "SELECT TOP 5 c.index FROM c WHERE c.pk = '1' ORDER BY c.index"
455460
literal_simple_results = self.test_container.query_items(literal_simple, enable_cross_partition_query=True)
461+
literal_simple_results = list(literal_simple_results)
456462
literal_simple_indices = [res["index"] for res in literal_simple_results]
457463

458464
param_simple = "SELECT TOP 5 c.index FROM c WHERE c.pk = @pk ORDER BY c.index"

0 commit comments

Comments
 (0)