@@ -79,7 +79,7 @@ def _drain_and_coalesce_results(document_producers_to_drain):
79
79
return all_results , is_singleton
80
80
81
81
82
- def _rewrite_query_infos (hybrid_search_query_info , global_statistics ):
82
+ def _rewrite_query_infos (hybrid_search_query_info , global_statistics , parameters = None ):
83
83
rewritten_query_infos = []
84
84
for query_info in hybrid_search_query_info ['componentQueryInfos' ]:
85
85
assert query_info ['orderBy' ]
@@ -90,6 +90,8 @@ def _rewrite_query_infos(hybrid_search_query_info, global_statistics):
90
90
_format_component_query_workaround (order_by_expression , global_statistics ,
91
91
len (hybrid_search_query_info [
92
92
'componentQueryInfos' ])))
93
+
94
+ query_info ['rewrittenQuery' ] = _attach_parameters (query_info ['rewrittenQuery' ], parameters )
93
95
rewritten_query = _format_component_query_workaround (query_info ['rewrittenQuery' ],
94
96
global_statistics ,
95
97
len (hybrid_search_query_info [
@@ -118,6 +120,10 @@ def _format_component_query(format_string, global_statistics):
118
120
119
121
def _format_component_query_workaround (format_string , global_statistics , component_count ):
120
122
# TODO: remove this method once the fix is live and switch back to one above
123
+ parameters = None
124
+ if isinstance (format_string , dict ):
125
+ parameters = format_string .get ('parameters' , None )
126
+ format_string = format_string ['query' ]
121
127
format_string = format_string .replace (_Placeholders .formattable_order_by , "true" )
122
128
query = format_string .replace (_Placeholders .total_document_count ,
123
129
str (global_statistics ['documentCount' ]))
@@ -137,7 +143,29 @@ def _format_component_query_workaround(format_string, global_statistics, compone
137
143
138
144
statistics_index += 1
139
145
140
- return query
146
+ return _attach_parameters (query , parameters )
147
+
148
+
149
+ def _attach_parameters (query , parameters = None ):
150
+ """Attach original query parameters (if any) without mutating the passed query object.
151
+
152
+ :param query: The original query text or a query payload dict which may already contain parameters.
153
+ :type query: str or dict
154
+ :param parameters: Optional sequence of parameter definitions to attach.
155
+ :type parameters: list or None
156
+ :returns: The original query if no parameters to attach or already present, otherwise a new dict containing the
157
+ query and parameters.
158
+ :rtype: str or dict
159
+ """
160
+ if not parameters :
161
+ return query
162
+ if isinstance (query , dict ):
163
+ if "parameters" not in query :
164
+ new_query = dict (query )
165
+ new_query ["parameters" ] = parameters
166
+ return new_query
167
+ return query
168
+ return {"query" : query , "parameters" : parameters }
141
169
142
170
143
171
class _HybridSearchContextAggregator (_QueryExecutionContextBase ): # pylint: disable=too-many-instance-attributes
@@ -159,9 +187,15 @@ def __init__(self, client, resource_link, options,
159
187
self ._client = client
160
188
self ._resource_link = resource_link
161
189
self ._partitioned_query_ex_info = partitioned_query_execution_info
190
+ self ._parameters = None
162
191
# If the query uses parameters, we must save them to add them back to the component queries
163
192
query_execution_info = getattr (self ._partitioned_query_ex_info , "_query_execution_info" , None )
164
- self ._parameters = getattr (query_execution_info , "parameters" , None ) if query_execution_info else None
193
+ if query_execution_info :
194
+ self ._parameters = (
195
+ query_execution_info .get ("parameters" )
196
+ if isinstance (query_execution_info , dict )
197
+ else getattr (query_execution_info , "parameters" , None )
198
+ )
165
199
self ._hybrid_search_query_info = hybrid_search_query_info
166
200
self ._final_results = []
167
201
self ._aggregated_global_statistics = None
@@ -174,13 +208,7 @@ def _run_hybrid_search(self): # pylint: disable=too-many-branches, too-many-sta
174
208
if self ._hybrid_search_query_info ['requiresGlobalStatistics' ]:
175
209
target_partition_key_ranges = self ._get_target_partition_key_range (target_all_ranges = True )
176
210
global_statistics_doc_producers = []
177
- global_statistics_query = self ._hybrid_search_query_info ['globalStatisticsQuery' ]
178
- # If query was given parameters we must add them back in
179
- if self ._parameters :
180
- global_statistics_query = {
181
- 'query' : global_statistics_query ,
182
- 'parameters' : self ._parameters
183
- }
211
+ global_statistics_query = self ._attach_parameters (self ._hybrid_search_query_info ['globalStatisticsQuery' ])
184
212
partitioned_query_execution_context_list = []
185
213
for partition_key_target_range in target_partition_key_ranges :
186
214
# create a document producer for each partition key range
@@ -218,7 +246,7 @@ def _run_hybrid_search(self): # pylint: disable=too-many-branches, too-many-sta
218
246
# re-write the component queries if needed
219
247
if self ._aggregated_global_statistics :
220
248
rewritten_query_infos = _rewrite_query_infos (self ._hybrid_search_query_info ,
221
- self ._aggregated_global_statistics )
249
+ self ._aggregated_global_statistics , self . _parameters )
222
250
else :
223
251
rewritten_query_infos = self ._hybrid_search_query_info ['componentQueryInfos' ]
224
252
@@ -229,10 +257,7 @@ def _run_hybrid_search(self): # pylint: disable=too-many-branches, too-many-sta
229
257
for pk_range in target_partition_key_ranges :
230
258
# If query was given parameters we must add them back in
231
259
if self ._parameters :
232
- rewritten_query ['rewrittenQuery' ] = {
233
- 'query' : rewritten_query ['rewrittenQuery' ],
234
- 'parameters' : self ._parameters
235
- }
260
+ rewritten_query ['rewrittenQuery' ] = self ._attach_parameters (rewritten_query ['rewrittenQuery' ])
236
261
component_query_execution_list .append (
237
262
document_producer ._DocumentProducer (
238
263
pk_range ,
@@ -302,6 +327,25 @@ def _run_hybrid_search(self): # pylint: disable=too-many-branches, too-many-sta
302
327
drained_results .sort (key = lambda x : x ['Score' ], reverse = True )
303
328
self ._format_final_results (drained_results )
304
329
330
+ def _attach_parameters (self , query ):
331
+ """Attach original query parameters (if any) without mutating the passed query object.
332
+
333
+ :param query: Query text or a query payload dict which may already contain parameters.
334
+ :type query: str or dict
335
+ :return: The original query if no parameters to attach or already present; otherwise a new dict containing the
336
+ query and parameters.
337
+ :rtype: str or dict
338
+ """
339
+ if not self ._parameters :
340
+ return query
341
+ if isinstance (query , dict ):
342
+ if "parameters" not in query :
343
+ new_query = dict (query )
344
+ new_query ["parameters" ] = self ._parameters
345
+ return new_query
346
+ return query
347
+ return {"query" : query , "parameters" : self ._parameters }
348
+
305
349
def _format_final_results (self , results ):
306
350
skip = self ._hybrid_search_query_info ['skip' ] or 0
307
351
take = self ._hybrid_search_query_info ['take' ]
@@ -320,7 +364,7 @@ def _rewrite_query_infos(self):
320
364
_format_component_query_workaround (order_by_expression , self ._aggregated_global_statistics ,
321
365
len (self ._hybrid_search_query_info [
322
366
'componentQueryInfos' ])))
323
-
367
+ query_info [ 'rewrittenQuery' ] = _attach_parameters ( query_info [ 'rewrittenQuery' ], self . _parameters )
324
368
rewritten_query = _format_component_query_workaround (query_info ['rewrittenQuery' ],
325
369
self ._aggregated_global_statistics ,
326
370
len (self ._hybrid_search_query_info [
0 commit comments