1313import org .elasticsearch .common .regex .Regex ;
1414import org .elasticsearch .common .settings .Settings ;
1515import org .elasticsearch .core .Nullable ;
16+ import org .elasticsearch .core .Tuple ;
17+ import org .elasticsearch .index .mapper .IndexFieldMapper ;
18+ import org .elasticsearch .index .query .BoolQueryBuilder ;
1619import org .elasticsearch .index .query .MatchQueryBuilder ;
1720import org .elasticsearch .index .query .MultiMatchQueryBuilder ;
21+ import org .elasticsearch .index .query .QueryBuilder ;
22+ import org .elasticsearch .index .query .TermsQueryBuilder ;
1823import org .elasticsearch .index .search .QueryParserHelper ;
1924import org .elasticsearch .search .retriever .CompoundRetrieverBuilder ;
2025import org .elasticsearch .search .retriever .RetrieverBuilder ;
@@ -114,20 +119,39 @@ public static ActionRequestValidationException validateParams(
114119 * Generate the inner retriever tree for the given fields, weights, and query. The tree follows this structure:
115120 *
116121 * <pre>
117- * multi_match query on all lexical fields
122+ * standard retriever for querying lexical fields using multi_match.
118123 * normalizer retriever
119- * match query on semantic_text field A
120- * match query on semantic_text field B
124+ * match query on semantic_text field A with inference ID id1
125+ * match query on semantic_text field A with inference ID id2
126+ * match query on semantic_text field B with inference ID id1
121127 * ...
122- * match query on semantic_text field Z
128+ * match query on semantic_text field Z with inference ID idN
123129 * </pre>
124130 *
125131 * <p>
126132 * Where the normalizer retriever is constructed by the {@code innerNormalizerGenerator} function.
127133 * </p>
134+ *
135+ * <p>
136+ * When the same lexical fields are queried for all indices, we use a single multi_match query to query them.
137+ * Otherwise, we create a boolean query with the following structure:
138+ * </p>
139+ *
140+ * <pre>
141+ * bool
142+ * should
143+ * bool
144+ * match query on lexical fields for index A
145+ * filter on indexA
146+ * bool
147+ * match query on lexical fields for index B
148+ * filter on indexB
149+ * ...
150+ * </pre>
151+ *
128152 * <p>
129- * This tree structure is repeated for each index in {@code indicesMetadata}. That is to say, that for each index in
130- * {@code indicesMetadata}, (up to) a pair of retrievers will be added to the returned {@code RetrieverBuilder} list .
153+ * The semantic_text fields are grouped by inference ID. For each (fieldName, inferenceID) pair we generate a match query.
154+ * Since we have no way to effectively filter on inference IDs, we filter on index names instead .
131155 * </p>
132156 *
133157 * @param fieldsAndWeights The fields to query and their respective weights, in "field^weight" format
@@ -150,32 +174,105 @@ public static List<RetrieverBuilder> generateInnerRetrievers(
150174 if (weightValidator != null ) {
151175 parsedFieldsAndWeights .values ().forEach (weightValidator );
152176 }
153-
154- // We expect up to 2 inner retrievers to be generated for each index queried
155- List <RetrieverBuilder > innerRetrievers = new ArrayList <>(indicesMetadata .size () * 2 );
156- for (IndexMetadata indexMetadata : indicesMetadata ) {
157- innerRetrievers .addAll (
158- generateInnerRetrieversForIndex (parsedFieldsAndWeights , query , indexMetadata , innerNormalizerGenerator , weightValidator )
159- );
177+ List <RetrieverBuilder > innerRetrievers = new ArrayList <>(2 );
178+ // add lexical retriever
179+ RetrieverBuilder lexicalRetriever = generateLexicalRetriever (parsedFieldsAndWeights , indicesMetadata , query , weightValidator );
180+ if (lexicalRetriever != null ) {
181+ innerRetrievers .add (lexicalRetriever );
182+ }
183+ // add semantic retriever
184+ RetrieverBuilder semanticRetriever = generateSemanticRetriever (
185+ parsedFieldsAndWeights ,
186+ indicesMetadata ,
187+ query ,
188+ innerNormalizerGenerator ,
189+ weightValidator
190+ );
191+ if (semanticRetriever != null ) {
192+ innerRetrievers .add (semanticRetriever );
160193 }
194+
161195 return innerRetrievers ;
162196 }
163197
164- private static List < RetrieverBuilder > generateInnerRetrieversForIndex (
198+ private static RetrieverBuilder generateSemanticRetriever (
165199 Map <String , Float > parsedFieldsAndWeights ,
200+ Collection <IndexMetadata > indicesMetadata ,
166201 String query ,
167- IndexMetadata indexMetadata ,
168202 Function <List <WeightedRetrieverSource >, CompoundRetrieverBuilder <?>> innerNormalizerGenerator ,
169203 @ Nullable Consumer <Float > weightValidator
204+ ) {
205+ // Form groups of (fieldName, inferenceID) that need to be queried.
206+ // For each (fieldName, inferenceID) pair determine the weight that needs to be applied and the indices that need to be queried.
207+ Map <Tuple <String , String >, List <String >> groupedIndices = new HashMap <>();
208+ Map <Tuple <String , String >, Float > groupedWeights = new HashMap <>();
209+ for (IndexMetadata indexMetadata : indicesMetadata ) {
210+ inferenceFieldsAndWeightsForIndex (parsedFieldsAndWeights , indexMetadata , weightValidator ).forEach ((fieldName , weight ) -> {
211+ String indexName = indexMetadata .getIndex ().getName ();
212+ Tuple <String , String > fieldAndInferenceId = new Tuple <>(
213+ fieldName ,
214+ indexMetadata .getInferenceFields ().get (fieldName ).getInferenceId ()
215+ );
216+
217+ List <String > existingIndexNames = groupedIndices .get (fieldAndInferenceId );
218+ if (existingIndexNames != null && groupedWeights .get (fieldAndInferenceId ).equals (weight ) == false ) {
219+ String conflictingIndexName = existingIndexNames .getFirst ();
220+ throw new IllegalArgumentException (
221+ "field [" + fieldName + "] has different weights in indices [" + conflictingIndexName + "] and [" + indexName + "]"
222+ );
223+ }
224+
225+ groupedWeights .put (fieldAndInferenceId , weight );
226+ groupedIndices .computeIfAbsent (fieldAndInferenceId , k -> new ArrayList <>()).add (indexName );
227+ });
228+ }
229+
230+ // there are no semantic_text fields that need to be queried, no need to create a retriever
231+ if (groupedIndices .isEmpty ()) {
232+ return null ;
233+ }
234+
235+ // for each (fieldName, inferenceID) pair generate a standard retriever with a semantic query
236+ List <WeightedRetrieverSource > semanticRetrievers = new ArrayList <>(groupedIndices .size ());
237+ groupedIndices .forEach ((fieldAndInferenceId , indexNames ) -> {
238+ String fieldName = fieldAndInferenceId .v1 ();
239+ Float weight = groupedWeights .get (fieldAndInferenceId );
240+
241+ QueryBuilder queryBuilder = new MatchQueryBuilder (fieldName , query );
242+
243+ // if indices does not contain all index names, we need to add a filter
244+ if (indicesMetadata .size () != indexNames .size ()) {
245+ queryBuilder = new BoolQueryBuilder ().must (queryBuilder ).filter (new TermsQueryBuilder (IndexFieldMapper .NAME , indexNames ));
246+ }
247+
248+ RetrieverBuilder retrieverBuilder = new StandardRetrieverBuilder (queryBuilder );
249+ semanticRetrievers .add (new WeightedRetrieverSource (CompoundRetrieverBuilder .RetrieverSource .from (retrieverBuilder ), weight ));
250+ });
251+
252+ return innerNormalizerGenerator .apply (semanticRetrievers );
253+ }
254+
255+ private static Map <String , Float > defaultFieldsAndWeightsForIndex (
256+ IndexMetadata indexMetadata ,
257+ @ Nullable Consumer <Float > weightValidator
258+ ) {
259+ Settings settings = indexMetadata .getSettings ();
260+ List <String > defaultFields = settings .getAsList (DEFAULT_FIELD_SETTING .getKey (), DEFAULT_FIELD_SETTING .getDefault (settings ));
261+ Map <String , Float > fieldsAndWeights = QueryParserHelper .parseFieldsAndWeights (defaultFields );
262+ if (weightValidator != null ) {
263+ fieldsAndWeights .values ().forEach (weightValidator );
264+ }
265+ return fieldsAndWeights ;
266+ }
267+
268+ private static Map <String , Float > inferenceFieldsAndWeightsForIndex (
269+ Map <String , Float > parsedFieldsAndWeights ,
270+ IndexMetadata indexMetadata ,
271+ @ Nullable Consumer <Float > weightValidator
170272 ) {
171273 Map <String , Float > fieldsAndWeightsToQuery = parsedFieldsAndWeights ;
172274 if (fieldsAndWeightsToQuery .isEmpty ()) {
173- Settings settings = indexMetadata .getSettings ();
174- List <String > defaultFields = settings .getAsList (DEFAULT_FIELD_SETTING .getKey (), DEFAULT_FIELD_SETTING .getDefault (settings ));
175- fieldsAndWeightsToQuery = QueryParserHelper .parseFieldsAndWeights (defaultFields );
176- if (weightValidator != null ) {
177- fieldsAndWeightsToQuery .values ().forEach (weightValidator );
178- }
275+ fieldsAndWeightsToQuery = defaultFieldsAndWeightsForIndex (indexMetadata , weightValidator );
179276 }
180277
181278 Map <String , Float > inferenceFields = new HashMap <>();
@@ -198,30 +295,75 @@ private static List<RetrieverBuilder> generateInnerRetrieversForIndex(
198295 }
199296 }
200297 }
298+ return inferenceFields ;
299+ }
201300
301+ private static Map <String , Float > nonInferenceFieldsAndWeightsForIndex (
302+ Map <String , Float > fieldsAndWeightsToQuery ,
303+ IndexMetadata indexMetadata ,
304+ @ Nullable Consumer <Float > weightValidator
305+ ) {
202306 Map <String , Float > nonInferenceFields = new HashMap <>(fieldsAndWeightsToQuery );
203- nonInferenceFields .keySet ().removeAll (inferenceFields .keySet ()); // Remove all inference fields from non-inference fields map
204307
205- // TODO: Set index pre-filters on returned retrievers when we want to implement multi-index support
206- List <RetrieverBuilder > innerRetrievers = new ArrayList <>(2 );
207- if (nonInferenceFields .isEmpty () == false ) {
208- MultiMatchQueryBuilder nonInferenceFieldQueryBuilder = new MultiMatchQueryBuilder (query ).type (
209- MultiMatchQueryBuilder .Type .MOST_FIELDS
210- ).fields (nonInferenceFields );
211- innerRetrievers .add (new StandardRetrieverBuilder (nonInferenceFieldQueryBuilder ));
308+ if (nonInferenceFields .isEmpty ()) {
309+ nonInferenceFields = defaultFieldsAndWeightsForIndex (indexMetadata , weightValidator );
212310 }
213- if (inferenceFields .isEmpty () == false ) {
214- List <WeightedRetrieverSource > inferenceFieldRetrievers = new ArrayList <>(inferenceFields .size ());
215- inferenceFields .forEach ((f , w ) -> {
216- RetrieverBuilder retrieverBuilder = new StandardRetrieverBuilder (new MatchQueryBuilder (f , query ));
217- inferenceFieldRetrievers .add (
218- new WeightedRetrieverSource (CompoundRetrieverBuilder .RetrieverSource .from (retrieverBuilder ), w )
219- );
220- });
221311
222- innerRetrievers .add (innerNormalizerGenerator .apply (inferenceFieldRetrievers ));
312+ nonInferenceFields .keySet ().removeAll (indexMetadata .getInferenceFields ().keySet ());
313+ return nonInferenceFields ;
314+ }
315+
316+ private static RetrieverBuilder generateLexicalRetriever (
317+ Map <String , Float > fieldsAndWeightsToQuery ,
318+ Collection <IndexMetadata > indicesMetadata ,
319+ String query ,
320+ @ Nullable Consumer <Float > weightValidator
321+ ) {
322+ Map <Map <String , Float >, List <String >> groupedIndices = new HashMap <>();
323+
324+ for (IndexMetadata indexMetadata : indicesMetadata ) {
325+ Map <String , Float > nonInferenceFieldsForIndex = nonInferenceFieldsAndWeightsForIndex (
326+ fieldsAndWeightsToQuery ,
327+ indexMetadata ,
328+ weightValidator
329+ );
330+
331+ if (nonInferenceFieldsForIndex .isEmpty ()) {
332+ continue ;
333+ }
334+
335+ groupedIndices .computeIfAbsent (nonInferenceFieldsForIndex , k -> new ArrayList <>()).add (indexMetadata .getIndex ().getName ());
223336 }
224- return innerRetrievers ;
337+
338+ // there are no lexical fields that need to be queried, no need to create a retriever
339+ if (groupedIndices .isEmpty ()) {
340+ return null ;
341+ }
342+
343+ List <QueryBuilder > lexicalQueryBuilders = new ArrayList <>();
344+ for (var entry : groupedIndices .entrySet ()) {
345+ Map <String , Float > fieldsAndWeights = entry .getKey ();
346+ List <String > indices = entry .getValue ();
347+
348+ QueryBuilder queryBuilder = new MultiMatchQueryBuilder (query ).type (MultiMatchQueryBuilder .Type .MOST_FIELDS )
349+ .fields (fieldsAndWeights );
350+
351+ // if indices does not contain all index names, we need to add a filter
352+ if (indices .size () != indicesMetadata .size ()) {
353+ queryBuilder = new BoolQueryBuilder ().must (queryBuilder ).filter (new TermsQueryBuilder (IndexFieldMapper .NAME , indices ));
354+ }
355+
356+ lexicalQueryBuilders .add (queryBuilder );
357+ }
358+
359+ // only a single lexical query, no need to wrap in a boolean query
360+ if (lexicalQueryBuilders .size () == 1 ) {
361+ return new StandardRetrieverBuilder (lexicalQueryBuilders .getFirst ());
362+ }
363+
364+ BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder ();
365+ lexicalQueryBuilders .forEach (boolQueryBuilder ::should );
366+ return new StandardRetrieverBuilder (boolQueryBuilder );
225367 }
226368
227369 private static void addToInferenceFieldsMap (Map <String , Float > inferenceFields , String field , Float weight ) {
0 commit comments