13
13
import org .elasticsearch .common .regex .Regex ;
14
14
import org .elasticsearch .common .settings .Settings ;
15
15
import org .elasticsearch .core .Nullable ;
16
+ import org .elasticsearch .core .Tuple ;
17
+ import org .elasticsearch .index .mapper .IndexFieldMapper ;
18
+ import org .elasticsearch .index .query .BoolQueryBuilder ;
16
19
import org .elasticsearch .index .query .MatchQueryBuilder ;
17
20
import org .elasticsearch .index .query .MultiMatchQueryBuilder ;
21
+ import org .elasticsearch .index .query .QueryBuilder ;
22
+ import org .elasticsearch .index .query .TermsQueryBuilder ;
18
23
import org .elasticsearch .index .search .QueryParserHelper ;
19
24
import org .elasticsearch .search .retriever .CompoundRetrieverBuilder ;
20
25
import org .elasticsearch .search .retriever .RetrieverBuilder ;
@@ -114,20 +119,39 @@ public static ActionRequestValidationException validateParams(
114
119
* Generate the inner retriever tree for the given fields, weights, and query. The tree follows this structure:
115
120
*
116
121
* <pre>
117
- * multi_match query on all lexical fields
122
+ * standard retriever for querying lexical fields using multi_match.
118
123
* normalizer retriever
119
- * match query on semantic_text field A
120
- * match query on semantic_text field B
124
+ * match query on semantic_text field A with inference ID id1
125
+ * match query on semantic_text field A with inference ID id2
126
+ * match query on semantic_text field B with inference ID id1
121
127
* ...
122
- * match query on semantic_text field Z
128
+ * match query on semantic_text field Z with inference ID idN
123
129
* </pre>
124
130
*
125
131
* <p>
126
132
* Where the normalizer retriever is constructed by the {@code innerNormalizerGenerator} function.
127
133
* </p>
134
+ *
135
+ * <p>
136
+ * When the same lexical fields are queried for all indices, we use a single multi_match query to query them.
137
+ * Otherwise, we create a boolean query with the following structure:
138
+ * </p>
139
+ *
140
+ * <pre>
141
+ * bool
142
+ * should
143
+ * bool
144
+ * match query on lexical fields for index A
145
+ * filter on indexA
146
+ * bool
147
+ * match query on lexical fields for index B
148
+ * filter on indexB
149
+ * ...
150
+ * </pre>
151
+ *
128
152
* <p>
129
- * This tree structure is repeated for each index in {@code indicesMetadata}. That is to say, that for each index in
130
- * {@code indicesMetadata}, (up to) a pair of retrievers will be added to the returned {@code RetrieverBuilder} list .
153
+ * The semantic_text fields are grouped by inference ID. For each (fieldName, inferenceID) pair we generate a match query.
154
+ * Since we have no way to effectively filter on inference IDs, we filter on index names instead .
131
155
* </p>
132
156
*
133
157
* @param fieldsAndWeights The fields to query and their respective weights, in "field^weight" format
@@ -150,32 +174,105 @@ public static List<RetrieverBuilder> generateInnerRetrievers(
150
174
if (weightValidator != null ) {
151
175
parsedFieldsAndWeights .values ().forEach (weightValidator );
152
176
}
153
-
154
- // We expect up to 2 inner retrievers to be generated for each index queried
155
- List <RetrieverBuilder > innerRetrievers = new ArrayList <>(indicesMetadata .size () * 2 );
156
- for (IndexMetadata indexMetadata : indicesMetadata ) {
157
- innerRetrievers .addAll (
158
- generateInnerRetrieversForIndex (parsedFieldsAndWeights , query , indexMetadata , innerNormalizerGenerator , weightValidator )
159
- );
177
+ List <RetrieverBuilder > innerRetrievers = new ArrayList <>(2 );
178
+ // add lexical retriever
179
+ RetrieverBuilder lexicalRetriever = generateLexicalRetriever (parsedFieldsAndWeights , indicesMetadata , query , weightValidator );
180
+ if (lexicalRetriever != null ) {
181
+ innerRetrievers .add (lexicalRetriever );
182
+ }
183
+ // add semantic retriever
184
+ RetrieverBuilder semanticRetriever = generateSemanticRetriever (
185
+ parsedFieldsAndWeights ,
186
+ indicesMetadata ,
187
+ query ,
188
+ innerNormalizerGenerator ,
189
+ weightValidator
190
+ );
191
+ if (semanticRetriever != null ) {
192
+ innerRetrievers .add (semanticRetriever );
160
193
}
194
+
161
195
return innerRetrievers ;
162
196
}
163
197
164
- private static List < RetrieverBuilder > generateInnerRetrieversForIndex (
198
+ private static RetrieverBuilder generateSemanticRetriever (
165
199
Map <String , Float > parsedFieldsAndWeights ,
200
+ Collection <IndexMetadata > indicesMetadata ,
166
201
String query ,
167
- IndexMetadata indexMetadata ,
168
202
Function <List <WeightedRetrieverSource >, CompoundRetrieverBuilder <?>> innerNormalizerGenerator ,
169
203
@ Nullable Consumer <Float > weightValidator
204
+ ) {
205
+ // Form groups of (fieldName, inferenceID) that need to be queried.
206
+ // For each (fieldName, inferenceID) pair determine the weight that needs to be applied and the indices that need to be queried.
207
+ Map <Tuple <String , String >, List <String >> groupedIndices = new HashMap <>();
208
+ Map <Tuple <String , String >, Float > groupedWeights = new HashMap <>();
209
+ for (IndexMetadata indexMetadata : indicesMetadata ) {
210
+ inferenceFieldsAndWeightsForIndex (parsedFieldsAndWeights , indexMetadata , weightValidator ).forEach ((fieldName , weight ) -> {
211
+ String indexName = indexMetadata .getIndex ().getName ();
212
+ Tuple <String , String > fieldAndInferenceId = new Tuple <>(
213
+ fieldName ,
214
+ indexMetadata .getInferenceFields ().get (fieldName ).getInferenceId ()
215
+ );
216
+
217
+ List <String > existingIndexNames = groupedIndices .get (fieldAndInferenceId );
218
+ if (existingIndexNames != null && groupedWeights .get (fieldAndInferenceId ).equals (weight ) == false ) {
219
+ String conflictingIndexName = existingIndexNames .getFirst ();
220
+ throw new IllegalArgumentException (
221
+ "field [" + fieldName + "] has different weights in indices [" + conflictingIndexName + "] and [" + indexName + "]"
222
+ );
223
+ }
224
+
225
+ groupedWeights .put (fieldAndInferenceId , weight );
226
+ groupedIndices .computeIfAbsent (fieldAndInferenceId , k -> new ArrayList <>()).add (indexName );
227
+ });
228
+ }
229
+
230
+ // there are no semantic_text fields that need to be queried, no need to create a retriever
231
+ if (groupedIndices .isEmpty ()) {
232
+ return null ;
233
+ }
234
+
235
+ // for each (fieldName, inferenceID) pair generate a standard retriever with a semantic query
236
+ List <WeightedRetrieverSource > semanticRetrievers = new ArrayList <>(groupedIndices .size ());
237
+ groupedIndices .forEach ((fieldAndInferenceId , indexNames ) -> {
238
+ String fieldName = fieldAndInferenceId .v1 ();
239
+ Float weight = groupedWeights .get (fieldAndInferenceId );
240
+
241
+ QueryBuilder queryBuilder = new MatchQueryBuilder (fieldName , query );
242
+
243
+ // if indices does not contain all index names, we need to add a filter
244
+ if (indicesMetadata .size () != indexNames .size ()) {
245
+ queryBuilder = new BoolQueryBuilder ().must (queryBuilder ).filter (new TermsQueryBuilder (IndexFieldMapper .NAME , indexNames ));
246
+ }
247
+
248
+ RetrieverBuilder retrieverBuilder = new StandardRetrieverBuilder (queryBuilder );
249
+ semanticRetrievers .add (new WeightedRetrieverSource (CompoundRetrieverBuilder .RetrieverSource .from (retrieverBuilder ), weight ));
250
+ });
251
+
252
+ return innerNormalizerGenerator .apply (semanticRetrievers );
253
+ }
254
+
255
+ private static Map <String , Float > defaultFieldsAndWeightsForIndex (
256
+ IndexMetadata indexMetadata ,
257
+ @ Nullable Consumer <Float > weightValidator
258
+ ) {
259
+ Settings settings = indexMetadata .getSettings ();
260
+ List <String > defaultFields = settings .getAsList (DEFAULT_FIELD_SETTING .getKey (), DEFAULT_FIELD_SETTING .getDefault (settings ));
261
+ Map <String , Float > fieldsAndWeights = QueryParserHelper .parseFieldsAndWeights (defaultFields );
262
+ if (weightValidator != null ) {
263
+ fieldsAndWeights .values ().forEach (weightValidator );
264
+ }
265
+ return fieldsAndWeights ;
266
+ }
267
+
268
+ private static Map <String , Float > inferenceFieldsAndWeightsForIndex (
269
+ Map <String , Float > parsedFieldsAndWeights ,
270
+ IndexMetadata indexMetadata ,
271
+ @ Nullable Consumer <Float > weightValidator
170
272
) {
171
273
Map <String , Float > fieldsAndWeightsToQuery = parsedFieldsAndWeights ;
172
274
if (fieldsAndWeightsToQuery .isEmpty ()) {
173
- Settings settings = indexMetadata .getSettings ();
174
- List <String > defaultFields = settings .getAsList (DEFAULT_FIELD_SETTING .getKey (), DEFAULT_FIELD_SETTING .getDefault (settings ));
175
- fieldsAndWeightsToQuery = QueryParserHelper .parseFieldsAndWeights (defaultFields );
176
- if (weightValidator != null ) {
177
- fieldsAndWeightsToQuery .values ().forEach (weightValidator );
178
- }
275
+ fieldsAndWeightsToQuery = defaultFieldsAndWeightsForIndex (indexMetadata , weightValidator );
179
276
}
180
277
181
278
Map <String , Float > inferenceFields = new HashMap <>();
@@ -198,30 +295,75 @@ private static List<RetrieverBuilder> generateInnerRetrieversForIndex(
198
295
}
199
296
}
200
297
}
298
+ return inferenceFields ;
299
+ }
201
300
301
+ private static Map <String , Float > nonInferenceFieldsAndWeightsForIndex (
302
+ Map <String , Float > fieldsAndWeightsToQuery ,
303
+ IndexMetadata indexMetadata ,
304
+ @ Nullable Consumer <Float > weightValidator
305
+ ) {
202
306
Map <String , Float > nonInferenceFields = new HashMap <>(fieldsAndWeightsToQuery );
203
- nonInferenceFields .keySet ().removeAll (inferenceFields .keySet ()); // Remove all inference fields from non-inference fields map
204
307
205
- // TODO: Set index pre-filters on returned retrievers when we want to implement multi-index support
206
- List <RetrieverBuilder > innerRetrievers = new ArrayList <>(2 );
207
- if (nonInferenceFields .isEmpty () == false ) {
208
- MultiMatchQueryBuilder nonInferenceFieldQueryBuilder = new MultiMatchQueryBuilder (query ).type (
209
- MultiMatchQueryBuilder .Type .MOST_FIELDS
210
- ).fields (nonInferenceFields );
211
- innerRetrievers .add (new StandardRetrieverBuilder (nonInferenceFieldQueryBuilder ));
308
+ if (nonInferenceFields .isEmpty ()) {
309
+ nonInferenceFields = defaultFieldsAndWeightsForIndex (indexMetadata , weightValidator );
212
310
}
213
- if (inferenceFields .isEmpty () == false ) {
214
- List <WeightedRetrieverSource > inferenceFieldRetrievers = new ArrayList <>(inferenceFields .size ());
215
- inferenceFields .forEach ((f , w ) -> {
216
- RetrieverBuilder retrieverBuilder = new StandardRetrieverBuilder (new MatchQueryBuilder (f , query ));
217
- inferenceFieldRetrievers .add (
218
- new WeightedRetrieverSource (CompoundRetrieverBuilder .RetrieverSource .from (retrieverBuilder ), w )
219
- );
220
- });
221
311
222
- innerRetrievers .add (innerNormalizerGenerator .apply (inferenceFieldRetrievers ));
312
+ nonInferenceFields .keySet ().removeAll (indexMetadata .getInferenceFields ().keySet ());
313
+ return nonInferenceFields ;
314
+ }
315
+
316
+ private static RetrieverBuilder generateLexicalRetriever (
317
+ Map <String , Float > fieldsAndWeightsToQuery ,
318
+ Collection <IndexMetadata > indicesMetadata ,
319
+ String query ,
320
+ @ Nullable Consumer <Float > weightValidator
321
+ ) {
322
+ Map <Map <String , Float >, List <String >> groupedIndices = new HashMap <>();
323
+
324
+ for (IndexMetadata indexMetadata : indicesMetadata ) {
325
+ Map <String , Float > nonInferenceFieldsForIndex = nonInferenceFieldsAndWeightsForIndex (
326
+ fieldsAndWeightsToQuery ,
327
+ indexMetadata ,
328
+ weightValidator
329
+ );
330
+
331
+ if (nonInferenceFieldsForIndex .isEmpty ()) {
332
+ continue ;
333
+ }
334
+
335
+ groupedIndices .computeIfAbsent (nonInferenceFieldsForIndex , k -> new ArrayList <>()).add (indexMetadata .getIndex ().getName ());
223
336
}
224
- return innerRetrievers ;
337
+
338
+ // there are no lexical fields that need to be queried, no need to create a retriever
339
+ if (groupedIndices .isEmpty ()) {
340
+ return null ;
341
+ }
342
+
343
+ List <QueryBuilder > lexicalQueryBuilders = new ArrayList <>();
344
+ for (var entry : groupedIndices .entrySet ()) {
345
+ Map <String , Float > fieldsAndWeights = entry .getKey ();
346
+ List <String > indices = entry .getValue ();
347
+
348
+ QueryBuilder queryBuilder = new MultiMatchQueryBuilder (query ).type (MultiMatchQueryBuilder .Type .MOST_FIELDS )
349
+ .fields (fieldsAndWeights );
350
+
351
+ // if indices does not contain all index names, we need to add a filter
352
+ if (indices .size () != indicesMetadata .size ()) {
353
+ queryBuilder = new BoolQueryBuilder ().must (queryBuilder ).filter (new TermsQueryBuilder (IndexFieldMapper .NAME , indices ));
354
+ }
355
+
356
+ lexicalQueryBuilders .add (queryBuilder );
357
+ }
358
+
359
+ // only a single lexical query, no need to wrap in a boolean query
360
+ if (lexicalQueryBuilders .size () == 1 ) {
361
+ return new StandardRetrieverBuilder (lexicalQueryBuilders .getFirst ());
362
+ }
363
+
364
+ BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder ();
365
+ lexicalQueryBuilders .forEach (boolQueryBuilder ::should );
366
+ return new StandardRetrieverBuilder (boolQueryBuilder );
225
367
}
226
368
227
369
private static void addToInferenceFieldsMap (Map <String , Float > inferenceFields , String field , Float weight ) {
0 commit comments