Skip to content

Commit f488d01

Browse files
bambrizCopilot
andauthored
Fix hybrid text query with parameters (#42787)
* Fix hybrid text query with parameters * Update CHANGELOG.md * Update sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/hybrid_search_aggregator.py Co-authored-by: Copilot <[email protected]> * Update sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/hybrid_search_aggregator.py Co-authored-by: Copilot <[email protected]> * pylint fix and formatting fix --------- Co-authored-by: Copilot <[email protected]>
1 parent 4ca8bbb commit f488d01

File tree

7 files changed

+241
-4
lines changed

7 files changed

+241
-4
lines changed

sdk/cosmos/azure-cosmos/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
* Fixed bug where during health checks read regions were marked as unavailable for write operations. See [PR 42525](https://github.com/Azure/azure-sdk-for-python/pull/42525).
1414
* Fixed bug where containers named with spaces or special characters using session consistency would fall back to eventual consistency. See [PR 42608](https://github.com/Azure/azure-sdk-for-python/pull/42608)
1515
* Fixed bug where `excluded_locations` was not being honored for some metadata calls. See [PR 42266](https://github.com/Azure/azure-sdk-for-python/pull/42266).
16+
* Fixed bug where Hybrid Search queries using parameters were not working. See [PR 42787](https://github.com/Azure/azure-sdk-for-python/pull/42787)
1617
* Fixed partition scoping for per partition circuit breaker. See [PR 42751](https://github.com/Azure/azure-sdk-for-python/pull/42751)
1718

1819
#### Other Changes

sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ async def _create_execution_context_with_query_plan(self):
6868
query_to_use = self._query if self._query is not None else "Select * from root r"
6969
query_execution_info = _PartitionedQueryExecutionInfo(await self._client._GetQueryPlanThroughGateway
7070
(query_to_use, self._resource_link, self._options.get('excludedLocations')))
71+
qe_info = getattr(query_execution_info, "_query_execution_info", None)
72+
if isinstance(qe_info, dict) and isinstance(query_to_use, dict):
73+
params = query_to_use.get("parameters")
74+
if params is not None:
75+
query_execution_info._query_execution_info['parameters'] = params
7176
self._execution_context = await self._create_pipelined_execution_context(query_execution_info)
7277

7378
async def __anext__(self):

sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/hybrid_search_aggregator.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ async def _drain_and_coalesce_results(document_producers_to_drain):
3434
return all_results, is_singleton
3535

3636

37-
class _HybridSearchContextAggregator(_QueryExecutionContextBase):
37+
class _HybridSearchContextAggregator(_QueryExecutionContextBase): # pylint: disable=too-many-instance-attributes
3838
"""This class is a subclass of the query execution context base and serves for
3939
full text search and hybrid search queries. It is very similar to the existing MultiExecutionContextAggregator,
4040
but is needed since we have a lot more additional client-side logic to take care of.
@@ -53,6 +53,9 @@ def __init__(self, client, resource_link, options, partitioned_query_execution_i
5353
self._client = client
5454
self._resource_link = resource_link
5555
self._partitioned_query_ex_info = partitioned_query_execution_info
56+
# If the query uses parameters, we must save them to add them back to the component queries
57+
query_execution_info = getattr(self._partitioned_query_ex_info, "_query_execution_info", None)
58+
self._parameters = getattr(query_execution_info, "parameters", None) if query_execution_info else None
5659
self._hybrid_search_query_info = hybrid_search_query_info
5760
self._final_results = []
5861
self._aggregated_global_statistics = None
@@ -66,6 +69,12 @@ async def _run_hybrid_search(self): # pylint: disable=too-many-branches, too-ma
6669
target_partition_key_ranges = await self._get_target_partition_key_range(target_all_ranges=True)
6770
global_statistics_doc_producers = []
6871
global_statistics_query = self._hybrid_search_query_info['globalStatisticsQuery']
72+
# If query was given parameters we must add them back in
73+
if self._parameters:
74+
global_statistics_query = {
75+
'query': global_statistics_query,
76+
'parameters': self._parameters
77+
}
6978
partitioned_query_execution_context_list = []
7079
for partition_key_target_range in target_partition_key_ranges:
7180
# create a document producer for each partition key range
@@ -113,6 +122,11 @@ async def _run_hybrid_search(self): # pylint: disable=too-many-branches, too-ma
113122
target_partition_key_ranges = await self._get_target_partition_key_range(target_all_ranges=False)
114123
for rewritten_query in rewritten_query_infos:
115124
for pk_range in target_partition_key_ranges:
125+
if self._parameters:
126+
rewritten_query['rewrittenQuery'] = {
127+
'query': rewritten_query['rewrittenQuery'],
128+
'parameters': self._parameters
129+
}
116130
component_query_execution_list.append(
117131
document_producer._DocumentProducer(
118132
pk_range,

sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,12 @@ def _create_execution_context_with_query_plan(self):
9898
query_to_use = self._query if self._query is not None else "Select * from root r"
9999
query_execution_info = _PartitionedQueryExecutionInfo(self._client._GetQueryPlanThroughGateway
100100
(query_to_use, self._resource_link, self._options.get('excludedLocations')))
101+
102+
qe_info = getattr(query_execution_info, "_query_execution_info", None)
103+
if isinstance(qe_info, dict) and isinstance(query_to_use, dict):
104+
params = query_to_use.get("parameters")
105+
if params is not None:
106+
query_execution_info._query_execution_info['parameters'] = params
101107
self._execution_context = self._create_pipelined_execution_context(query_execution_info)
102108

103109
def __next__(self):

sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/hybrid_search_aggregator.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def _format_component_query_workaround(format_string, global_statistics, compone
140140
return query
141141

142142

143-
class _HybridSearchContextAggregator(_QueryExecutionContextBase):
143+
class _HybridSearchContextAggregator(_QueryExecutionContextBase): # pylint: disable=too-many-instance-attributes
144144
"""This class is a subclass of the query execution context base and serves for
145145
full text search and hybrid search queries. It is very similar to the existing MultiExecutionContextAggregator,
146146
but is needed since we have a lot more additional client-side logic to take care of.
@@ -159,6 +159,9 @@ def __init__(self, client, resource_link, options,
159159
self._client = client
160160
self._resource_link = resource_link
161161
self._partitioned_query_ex_info = partitioned_query_execution_info
162+
# If the query uses parameters, we must save them to add them back to the component queries
163+
query_execution_info = getattr(self._partitioned_query_ex_info, "_query_execution_info", None)
164+
self._parameters = getattr(query_execution_info, "parameters", None) if query_execution_info else None
162165
self._hybrid_search_query_info = hybrid_search_query_info
163166
self._final_results = []
164167
self._aggregated_global_statistics = None
@@ -172,6 +175,12 @@ def _run_hybrid_search(self): # pylint: disable=too-many-branches, too-many-sta
172175
target_partition_key_ranges = self._get_target_partition_key_range(target_all_ranges=True)
173176
global_statistics_doc_producers = []
174177
global_statistics_query = self._hybrid_search_query_info['globalStatisticsQuery']
178+
# If query was given parameters we must add them back in
179+
if self._parameters:
180+
global_statistics_query = {
181+
'query': global_statistics_query,
182+
'parameters': self._parameters
183+
}
175184
partitioned_query_execution_context_list = []
176185
for partition_key_target_range in target_partition_key_ranges:
177186
# create a document producer for each partition key range
@@ -218,6 +227,12 @@ def _run_hybrid_search(self): # pylint: disable=too-many-branches, too-many-sta
218227
target_partition_key_ranges = self._get_target_partition_key_range(target_all_ranges=False)
219228
for rewritten_query in rewritten_query_infos:
220229
for pk_range in target_partition_key_ranges:
230+
# If query was given parameters we must add them back in
231+
if self._parameters:
232+
rewritten_query['rewrittenQuery'] = {
233+
'query': rewritten_query['rewrittenQuery'],
234+
'parameters': self._parameters
235+
}
221236
component_query_execution_list.append(
222237
document_producer._DocumentProducer(
223238
pk_range,

sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,105 @@ def test_weighted_reciprocal_rank_fusion_with_response_hook(self):
365365
assert len(result_list) == 10
366366
assert response_hook.count > 0 # Ensure the response hook was called
367367

368+
def test_hybrid_search_query_with_params_equivalence(self):
369+
# Literal hybrid query
370+
literal_query = (
371+
"SELECT TOP 10 c.index, c.title FROM c "
372+
"WHERE FullTextContains(c.title, 'John') OR FullTextContains(c.text, 'John') "
373+
"ORDER BY RANK FullTextScore(c.title, 'John')"
374+
)
375+
literal_results = self.test_container.query_items(literal_query, enable_cross_partition_query=True)
376+
literal_indices = [res["index"] for res in literal_results]
377+
378+
# Parameterized hybrid query (same as above, but using @term)
379+
param_query = (
380+
"SELECT TOP 10 c.index, c.title FROM c "
381+
"WHERE FullTextContains(c.title, @term) OR FullTextContains(c.text, @term) "
382+
"ORDER BY RANK FullTextScore(c.title, @term)"
383+
)
384+
params = [{"name": "@term", "value": "John"}]
385+
param_results = self.test_container.query_items(
386+
param_query, parameters=params, enable_cross_partition_query=True
387+
)
388+
param_indices = [res["index"] for res in param_results]
389+
390+
# Checks: both forms produce the same results and match known expectation
391+
assert len(literal_indices) == len(param_indices) == 3
392+
assert set(literal_indices) == set(param_indices) == {2, 85, 57}
393+
394+
def test_weighted_rrf_hybrid_search_with_params_and_response_hook(self):
395+
# Literal weighted RRF hybrid query
396+
literal_query = (
397+
"SELECT TOP 10 c.index, c.title FROM c "
398+
"ORDER BY RANK RRF(FullTextScore(c.title, 'John'), FullTextScore(c.text, 'United States'), [1, 0.5])"
399+
)
400+
literal_results = self.test_container.query_items(literal_query, enable_cross_partition_query=True)
401+
literal_indices = [res["index"] for res in literal_results]
402+
403+
# Parameterized weighted RRF hybrid query (+ response hook)
404+
response_hook = test_config.ResponseHookCaller()
405+
param_query = (
406+
"SELECT TOP 10 c.index, c.title FROM c "
407+
"ORDER BY RANK RRF(FullTextScore(c.title, @titleTerm), FullTextScore(c.text, @textTerm), @weights)"
408+
)
409+
params = [
410+
{"name": "@titleTerm", "value": "John"},
411+
{"name": "@textTerm", "value": "United States"},
412+
{"name": "@weights", "value": [1, 0.5]},
413+
]
414+
param_results = self.test_container.query_items(
415+
param_query, parameters=params, enable_cross_partition_query=True, response_hook=response_hook
416+
)
417+
param_indices = [res["index"] for res in param_results]
418+
419+
# Checks: number of results, equality against literal, and hook invoked
420+
assert len(literal_indices) == len(param_indices) == 10
421+
assert set(literal_indices) == set(param_indices)
422+
assert response_hook.count > 0
423+
424+
def test_hybrid_and_non_hybrid_param_queries_equivalence(self):
425+
# Hybrid query with vector distance (literal vs param) and compare equality
426+
item_vector = self.test_container.read_item("50", "1")["vector"]
427+
literal_hybrid = (
428+
"SELECT c.index, c.title FROM c "
429+
"ORDER BY RANK RRF(FullTextScore(c.text, 'United States'), VectorDistance(c.vector, {})) "
430+
"OFFSET 0 LIMIT 10"
431+
).format(item_vector)
432+
literal_hybrid_results = self.test_container.query_items(literal_hybrid, enable_cross_partition_query=True)
433+
literal_hybrid_indices = [res["index"] for res in literal_hybrid_results]
434+
435+
param_hybrid = (
436+
"SELECT c.index, c.title FROM c "
437+
"ORDER BY RANK RRF(FullTextScore(c.text, @country), VectorDistance(c.vector, @vec)) "
438+
"OFFSET 0 LIMIT 10"
439+
)
440+
params_hybrid = [
441+
{"name": "@country", "value": "United States"},
442+
{"name": "@vec", "value": item_vector},
443+
]
444+
param_hybrid_results = self.test_container.query_items(
445+
param_hybrid, parameters=params_hybrid, enable_cross_partition_query=True
446+
)
447+
param_hybrid_indices = [res["index"] for res in param_hybrid_results]
448+
449+
assert len(literal_hybrid_indices) == len(param_hybrid_indices) == 10
450+
# Compare ordered lists to ensure identical ranking
451+
assert literal_hybrid_indices == param_hybrid_indices
452+
453+
# Non-hybrid parameterized query equivalence on same container
454+
literal_simple = "SELECT TOP 5 c.index FROM c WHERE c.pk = '1' ORDER BY c.index"
455+
literal_simple_results = self.test_container.query_items(literal_simple, enable_cross_partition_query=True)
456+
literal_simple_indices = [res["index"] for res in literal_simple_results]
457+
458+
param_simple = "SELECT TOP 5 c.index FROM c WHERE c.pk = @pk ORDER BY c.index"
459+
params_simple = [{"name": "@pk", "value": "1"}]
460+
param_simple_results = self.test_container.query_items(
461+
param_simple, parameters=params_simple, enable_cross_partition_query=True
462+
)
463+
param_simple_indices = [res["index"] for res in param_simple_results]
464+
465+
assert len(literal_simple_indices) == len(param_simple_indices) == 5
466+
assert literal_simple_indices == param_simple_indices
368467

369468

370469
if __name__ == "__main__":

sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search_async.py

Lines changed: 99 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,8 +368,105 @@ async def test_weighted_reciprocal_rank_fusion_with_response_hook_async(self):
368368
assert len(result_list) == 10
369369
assert response_hook.count > 0 # Ensure the response hook was called
370370

371-
372-
371+
async def test_hybrid_search_query_with_params_equivalence_async(self):
372+
# Literal hybrid query
373+
literal_query = (
374+
"SELECT TOP 10 c.index, c.title FROM c "
375+
"WHERE FullTextContains(c.title, 'John') OR FullTextContains(c.text, 'John') "
376+
"ORDER BY RANK FullTextScore(c.title, 'John')"
377+
)
378+
literal_results = self.test_container.query_items(literal_query, enable_cross_partition_query=True)
379+
literal_indices = [res["index"] async for res in literal_results]
380+
381+
# Parameterized hybrid query (same as above, but using @term)
382+
param_query = (
383+
"SELECT TOP 10 c.index, c.title FROM c "
384+
"WHERE FullTextContains(c.title, @term) OR FullTextContains(c.text, @term) "
385+
"ORDER BY RANK FullTextScore(c.title, @term)"
386+
)
387+
params = [{"name": "@term", "value": "John"}]
388+
param_results = self.test_container.query_items(
389+
param_query, parameters=params, enable_cross_partition_query=True
390+
)
391+
param_indices = [res["index"] async for res in param_results]
392+
393+
# Checks: both forms produce the same results and match known expectation
394+
assert len(literal_indices) == len(param_indices) == 3
395+
assert set(literal_indices) == set(param_indices) == {2, 85, 57}
396+
397+
async def test_weighted_rrf_hybrid_search_with_params_and_response_hook_async(self):
398+
# Literal weighted RRF hybrid query
399+
literal_query = (
400+
"SELECT TOP 10 c.index, c.title FROM c "
401+
"ORDER BY RANK RRF(FullTextScore(c.title, 'John'), FullTextScore(c.text, 'United States'), [1, 0.5])"
402+
)
403+
literal_results = self.test_container.query_items(literal_query, enable_cross_partition_query=True)
404+
literal_indices = [res["index"] async for res in literal_results]
405+
406+
# Parameterized weighted RRF hybrid query (+ response hook)
407+
response_hook = test_config.ResponseHookCaller()
408+
param_query = (
409+
"SELECT TOP 10 c.index, c.title FROM c "
410+
"ORDER BY RANK RRF(FullTextScore(c.title, @titleTerm), FullTextScore(c.text, @textTerm), @weights)"
411+
)
412+
params = [
413+
{"name": "@titleTerm", "value": "John"},
414+
{"name": "@textTerm", "value": "United States"},
415+
{"name": "@weights", "value": [1, 0.5]},
416+
]
417+
param_results = self.test_container.query_items(
418+
param_query, parameters=params, enable_cross_partition_query=True, response_hook=response_hook
419+
)
420+
param_indices = [res["index"] async for res in param_results]
421+
422+
# Checks: number of results, equality against literal, and hook invoked
423+
assert len(literal_indices) == len(param_indices) == 10
424+
assert set(literal_indices) == set(param_indices)
425+
assert response_hook.count > 0
426+
427+
async def test_hybrid_and_non_hybrid_param_queries_equivalence_async(self):
428+
# Hybrid query with vector distance (literal vs param) and compare equality
429+
item_vector = (await self.test_container.read_item("50", "1"))["vector"]
430+
literal_hybrid = (
431+
"SELECT c.index, c.title FROM c "
432+
"ORDER BY RANK RRF(FullTextScore(c.text, 'United States'), VectorDistance(c.vector, {})) "
433+
"OFFSET 0 LIMIT 10"
434+
).format(item_vector)
435+
literal_hybrid_results = self.test_container.query_items(literal_hybrid, enable_cross_partition_query=True)
436+
literal_hybrid_indices = [res["index"] async for res in literal_hybrid_results]
437+
438+
param_hybrid = (
439+
"SELECT c.index, c.title FROM c "
440+
"ORDER BY RANK RRF(FullTextScore(c.text, @country), VectorDistance(c.vector, @vec)) "
441+
"OFFSET 0 LIMIT 10"
442+
)
443+
params_hybrid = [
444+
{"name": "@country", "value": "United States"},
445+
{"name": "@vec", "value": item_vector},
446+
]
447+
param_hybrid_results = self.test_container.query_items(
448+
param_hybrid, parameters=params_hybrid, enable_cross_partition_query=True
449+
)
450+
param_hybrid_indices = [res["index"] async for res in param_hybrid_results]
451+
452+
assert len(literal_hybrid_indices) == len(param_hybrid_indices) == 10
453+
# Compare ordered lists to ensure identical ranking
454+
assert literal_hybrid_indices == param_hybrid_indices
455+
456+
# Non-hybrid parameterized query equivalence on same container
457+
literal_simple = "SELECT TOP 5 c.index FROM c WHERE c.pk = '1' ORDER BY c.index"
458+
literal_simple_results = self.test_container.query_items(literal_simple, enable_cross_partition_query=True)
459+
literal_simple_indices = [res["index"] async for res in literal_simple_results]
460+
461+
param_simple = "SELECT TOP 5 c.index FROM c WHERE c.pk = @pk ORDER BY c.index"
462+
params_simple = [{"name": "@pk", "value": "1"}]
463+
param_simple_results = self.test_container.query_items(
464+
param_simple, parameters=params_simple, enable_cross_partition_query=True
465+
)
466+
param_simple_indices = [res["index"] async for res in param_simple_results]
467+
468+
assert len(literal_simple_indices) == len(param_simple_indices) == 5
469+
assert literal_simple_indices == param_simple_indices
373470

374471

375472
if __name__ == "__main__":

0 commit comments

Comments
 (0)