Skip to content

Commit 54152ca

Browse files
authored
Support querying multiple indices with the simplified linear retriever (#133720)
1 parent 4cb5ec3 commit 54152ca

File tree

6 files changed

+1006
-133
lines changed

6 files changed

+1006
-133
lines changed

docs/changelog/133720.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 133720
2+
summary: Support querying multiple indices with the simplified linear retriever
3+
area: Relevance
4+
type: enhancement
5+
issues: []

x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/MultiFieldsInnerRetrieverUtils.java

Lines changed: 181 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,13 @@
1313
import org.elasticsearch.common.regex.Regex;
1414
import org.elasticsearch.common.settings.Settings;
1515
import org.elasticsearch.core.Nullable;
16+
import org.elasticsearch.core.Tuple;
17+
import org.elasticsearch.index.mapper.IndexFieldMapper;
18+
import org.elasticsearch.index.query.BoolQueryBuilder;
1619
import org.elasticsearch.index.query.MatchQueryBuilder;
1720
import org.elasticsearch.index.query.MultiMatchQueryBuilder;
21+
import org.elasticsearch.index.query.QueryBuilder;
22+
import org.elasticsearch.index.query.TermsQueryBuilder;
1823
import org.elasticsearch.index.search.QueryParserHelper;
1924
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
2025
import org.elasticsearch.search.retriever.RetrieverBuilder;
@@ -114,20 +119,39 @@ public static ActionRequestValidationException validateParams(
114119
* Generate the inner retriever tree for the given fields, weights, and query. The tree follows this structure:
115120
*
116121
* <pre>
117-
* multi_match query on all lexical fields
122+
* standard retriever for querying lexical fields using multi_match.
118123
* normalizer retriever
119-
* match query on semantic_text field A
120-
* match query on semantic_text field B
124+
* match query on semantic_text field A with inference ID id1
125+
* match query on semantic_text field A with inference ID id2
126+
* match query on semantic_text field B with inference ID id1
121127
* ...
122-
* match query on semantic_text field Z
128+
* match query on semantic_text field Z with inference ID idN
123129
* </pre>
124130
*
125131
* <p>
126132
* Where the normalizer retriever is constructed by the {@code innerNormalizerGenerator} function.
127133
* </p>
134+
*
135+
* <p>
136+
* When the same lexical fields are queried for all indices, we use a single multi_match query to query them.
137+
* Otherwise, we create a boolean query with the following structure:
138+
* </p>
139+
*
140+
* <pre>
141+
* bool
142+
* should
143+
* bool
144+
* match query on lexical fields for index A
145+
* filter on indexA
146+
* bool
147+
* match query on lexical fields for index B
148+
* filter on indexB
149+
* ...
150+
* </pre>
151+
*
128152
* <p>
129-
* This tree structure is repeated for each index in {@code indicesMetadata}. That is to say, that for each index in
130-
* {@code indicesMetadata}, (up to) a pair of retrievers will be added to the returned {@code RetrieverBuilder} list.
153+
* The semantic_text fields are grouped by inference ID. For each (fieldName, inferenceID) pair we generate a match query.
154+
* Since we have no way to effectively filter on inference IDs, we filter on index names instead.
131155
* </p>
132156
*
133157
* @param fieldsAndWeights The fields to query and their respective weights, in "field^weight" format
@@ -150,32 +174,105 @@ public static List<RetrieverBuilder> generateInnerRetrievers(
150174
if (weightValidator != null) {
151175
parsedFieldsAndWeights.values().forEach(weightValidator);
152176
}
153-
154-
// We expect up to 2 inner retrievers to be generated for each index queried
155-
List<RetrieverBuilder> innerRetrievers = new ArrayList<>(indicesMetadata.size() * 2);
156-
for (IndexMetadata indexMetadata : indicesMetadata) {
157-
innerRetrievers.addAll(
158-
generateInnerRetrieversForIndex(parsedFieldsAndWeights, query, indexMetadata, innerNormalizerGenerator, weightValidator)
159-
);
177+
List<RetrieverBuilder> innerRetrievers = new ArrayList<>(2);
178+
// add lexical retriever
179+
RetrieverBuilder lexicalRetriever = generateLexicalRetriever(parsedFieldsAndWeights, indicesMetadata, query, weightValidator);
180+
if (lexicalRetriever != null) {
181+
innerRetrievers.add(lexicalRetriever);
182+
}
183+
// add semantic retriever
184+
RetrieverBuilder semanticRetriever = generateSemanticRetriever(
185+
parsedFieldsAndWeights,
186+
indicesMetadata,
187+
query,
188+
innerNormalizerGenerator,
189+
weightValidator
190+
);
191+
if (semanticRetriever != null) {
192+
innerRetrievers.add(semanticRetriever);
160193
}
194+
161195
return innerRetrievers;
162196
}
163197

164-
private static List<RetrieverBuilder> generateInnerRetrieversForIndex(
198+
private static RetrieverBuilder generateSemanticRetriever(
165199
Map<String, Float> parsedFieldsAndWeights,
200+
Collection<IndexMetadata> indicesMetadata,
166201
String query,
167-
IndexMetadata indexMetadata,
168202
Function<List<WeightedRetrieverSource>, CompoundRetrieverBuilder<?>> innerNormalizerGenerator,
169203
@Nullable Consumer<Float> weightValidator
204+
) {
205+
// Form groups of (fieldName, inferenceID) that need to be queried.
206+
// For each (fieldName, inferenceID) pair determine the weight that needs to be applied and the indices that need to be queried.
207+
Map<Tuple<String, String>, List<String>> groupedIndices = new HashMap<>();
208+
Map<Tuple<String, String>, Float> groupedWeights = new HashMap<>();
209+
for (IndexMetadata indexMetadata : indicesMetadata) {
210+
inferenceFieldsAndWeightsForIndex(parsedFieldsAndWeights, indexMetadata, weightValidator).forEach((fieldName, weight) -> {
211+
String indexName = indexMetadata.getIndex().getName();
212+
Tuple<String, String> fieldAndInferenceId = new Tuple<>(
213+
fieldName,
214+
indexMetadata.getInferenceFields().get(fieldName).getInferenceId()
215+
);
216+
217+
List<String> existingIndexNames = groupedIndices.get(fieldAndInferenceId);
218+
if (existingIndexNames != null && groupedWeights.get(fieldAndInferenceId).equals(weight) == false) {
219+
String conflictingIndexName = existingIndexNames.getFirst();
220+
throw new IllegalArgumentException(
221+
"field [" + fieldName + "] has different weights in indices [" + conflictingIndexName + "] and [" + indexName + "]"
222+
);
223+
}
224+
225+
groupedWeights.put(fieldAndInferenceId, weight);
226+
groupedIndices.computeIfAbsent(fieldAndInferenceId, k -> new ArrayList<>()).add(indexName);
227+
});
228+
}
229+
230+
// there are no semantic_text fields that need to be queried, no need to create a retriever
231+
if (groupedIndices.isEmpty()) {
232+
return null;
233+
}
234+
235+
// for each (fieldName, inferenceID) pair generate a standard retriever with a semantic query
236+
List<WeightedRetrieverSource> semanticRetrievers = new ArrayList<>(groupedIndices.size());
237+
groupedIndices.forEach((fieldAndInferenceId, indexNames) -> {
238+
String fieldName = fieldAndInferenceId.v1();
239+
Float weight = groupedWeights.get(fieldAndInferenceId);
240+
241+
QueryBuilder queryBuilder = new MatchQueryBuilder(fieldName, query);
242+
243+
// if indices does not contain all index names, we need to add a filter
244+
if (indicesMetadata.size() != indexNames.size()) {
245+
queryBuilder = new BoolQueryBuilder().must(queryBuilder).filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indexNames));
246+
}
247+
248+
RetrieverBuilder retrieverBuilder = new StandardRetrieverBuilder(queryBuilder);
249+
semanticRetrievers.add(new WeightedRetrieverSource(CompoundRetrieverBuilder.RetrieverSource.from(retrieverBuilder), weight));
250+
});
251+
252+
return innerNormalizerGenerator.apply(semanticRetrievers);
253+
}
254+
255+
private static Map<String, Float> defaultFieldsAndWeightsForIndex(
256+
IndexMetadata indexMetadata,
257+
@Nullable Consumer<Float> weightValidator
258+
) {
259+
Settings settings = indexMetadata.getSettings();
260+
List<String> defaultFields = settings.getAsList(DEFAULT_FIELD_SETTING.getKey(), DEFAULT_FIELD_SETTING.getDefault(settings));
261+
Map<String, Float> fieldsAndWeights = QueryParserHelper.parseFieldsAndWeights(defaultFields);
262+
if (weightValidator != null) {
263+
fieldsAndWeights.values().forEach(weightValidator);
264+
}
265+
return fieldsAndWeights;
266+
}
267+
268+
private static Map<String, Float> inferenceFieldsAndWeightsForIndex(
269+
Map<String, Float> parsedFieldsAndWeights,
270+
IndexMetadata indexMetadata,
271+
@Nullable Consumer<Float> weightValidator
170272
) {
171273
Map<String, Float> fieldsAndWeightsToQuery = parsedFieldsAndWeights;
172274
if (fieldsAndWeightsToQuery.isEmpty()) {
173-
Settings settings = indexMetadata.getSettings();
174-
List<String> defaultFields = settings.getAsList(DEFAULT_FIELD_SETTING.getKey(), DEFAULT_FIELD_SETTING.getDefault(settings));
175-
fieldsAndWeightsToQuery = QueryParserHelper.parseFieldsAndWeights(defaultFields);
176-
if (weightValidator != null) {
177-
fieldsAndWeightsToQuery.values().forEach(weightValidator);
178-
}
275+
fieldsAndWeightsToQuery = defaultFieldsAndWeightsForIndex(indexMetadata, weightValidator);
179276
}
180277

181278
Map<String, Float> inferenceFields = new HashMap<>();
@@ -198,30 +295,75 @@ private static List<RetrieverBuilder> generateInnerRetrieversForIndex(
198295
}
199296
}
200297
}
298+
return inferenceFields;
299+
}
201300

301+
private static Map<String, Float> nonInferenceFieldsAndWeightsForIndex(
302+
Map<String, Float> fieldsAndWeightsToQuery,
303+
IndexMetadata indexMetadata,
304+
@Nullable Consumer<Float> weightValidator
305+
) {
202306
Map<String, Float> nonInferenceFields = new HashMap<>(fieldsAndWeightsToQuery);
203-
nonInferenceFields.keySet().removeAll(inferenceFields.keySet()); // Remove all inference fields from non-inference fields map
204307

205-
// TODO: Set index pre-filters on returned retrievers when we want to implement multi-index support
206-
List<RetrieverBuilder> innerRetrievers = new ArrayList<>(2);
207-
if (nonInferenceFields.isEmpty() == false) {
208-
MultiMatchQueryBuilder nonInferenceFieldQueryBuilder = new MultiMatchQueryBuilder(query).type(
209-
MultiMatchQueryBuilder.Type.MOST_FIELDS
210-
).fields(nonInferenceFields);
211-
innerRetrievers.add(new StandardRetrieverBuilder(nonInferenceFieldQueryBuilder));
308+
if (nonInferenceFields.isEmpty()) {
309+
nonInferenceFields = defaultFieldsAndWeightsForIndex(indexMetadata, weightValidator);
212310
}
213-
if (inferenceFields.isEmpty() == false) {
214-
List<WeightedRetrieverSource> inferenceFieldRetrievers = new ArrayList<>(inferenceFields.size());
215-
inferenceFields.forEach((f, w) -> {
216-
RetrieverBuilder retrieverBuilder = new StandardRetrieverBuilder(new MatchQueryBuilder(f, query));
217-
inferenceFieldRetrievers.add(
218-
new WeightedRetrieverSource(CompoundRetrieverBuilder.RetrieverSource.from(retrieverBuilder), w)
219-
);
220-
});
221311

222-
innerRetrievers.add(innerNormalizerGenerator.apply(inferenceFieldRetrievers));
312+
nonInferenceFields.keySet().removeAll(indexMetadata.getInferenceFields().keySet());
313+
return nonInferenceFields;
314+
}
315+
316+
private static RetrieverBuilder generateLexicalRetriever(
317+
Map<String, Float> fieldsAndWeightsToQuery,
318+
Collection<IndexMetadata> indicesMetadata,
319+
String query,
320+
@Nullable Consumer<Float> weightValidator
321+
) {
322+
Map<Map<String, Float>, List<String>> groupedIndices = new HashMap<>();
323+
324+
for (IndexMetadata indexMetadata : indicesMetadata) {
325+
Map<String, Float> nonInferenceFieldsForIndex = nonInferenceFieldsAndWeightsForIndex(
326+
fieldsAndWeightsToQuery,
327+
indexMetadata,
328+
weightValidator
329+
);
330+
331+
if (nonInferenceFieldsForIndex.isEmpty()) {
332+
continue;
333+
}
334+
335+
groupedIndices.computeIfAbsent(nonInferenceFieldsForIndex, k -> new ArrayList<>()).add(indexMetadata.getIndex().getName());
223336
}
224-
return innerRetrievers;
337+
338+
// there are no lexical fields that need to be queried, no need to create a retriever
339+
if (groupedIndices.isEmpty()) {
340+
return null;
341+
}
342+
343+
List<QueryBuilder> lexicalQueryBuilders = new ArrayList<>();
344+
for (var entry : groupedIndices.entrySet()) {
345+
Map<String, Float> fieldsAndWeights = entry.getKey();
346+
List<String> indices = entry.getValue();
347+
348+
QueryBuilder queryBuilder = new MultiMatchQueryBuilder(query).type(MultiMatchQueryBuilder.Type.MOST_FIELDS)
349+
.fields(fieldsAndWeights);
350+
351+
// if indices does not contain all index names, we need to add a filter
352+
if (indices.size() != indicesMetadata.size()) {
353+
queryBuilder = new BoolQueryBuilder().must(queryBuilder).filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices));
354+
}
355+
356+
lexicalQueryBuilders.add(queryBuilder);
357+
}
358+
359+
// only a single lexical query, no need to wrap in a boolean query
360+
if (lexicalQueryBuilders.size() == 1) {
361+
return new StandardRetrieverBuilder(lexicalQueryBuilders.getFirst());
362+
}
363+
364+
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
365+
lexicalQueryBuilders.forEach(boolQueryBuilder::should);
366+
return new StandardRetrieverBuilder(boolQueryBuilder);
225367
}
226368

227369
private static void addToInferenceFieldsMap(Map<String, Float> inferenceFields, String field, Float weight) {

x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/RankRRFFeatures.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ public Set<NodeFeature> getTestFeatures() {
3939
LinearRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT,
4040
RRFRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT,
4141
RRFRetrieverBuilder.WEIGHTED_SUPPORT,
42-
LINEAR_RETRIEVER_TOP_LEVEL_NORMALIZER
42+
LINEAR_RETRIEVER_TOP_LEVEL_NORMALIZER,
43+
LinearRetrieverBuilder.MULTI_INDEX_SIMPLIFIED_FORMAT_SUPPORT
4344
);
4445
}
4546
}

x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ public final class LinearRetrieverBuilder extends CompoundRetrieverBuilder<Linea
5858
public static final NodeFeature MULTI_FIELDS_QUERY_FORMAT_SUPPORT = new NodeFeature(
5959
"linear_retriever.multi_fields_query_format_support"
6060
);
61+
public static final NodeFeature MULTI_INDEX_SIMPLIFIED_FORMAT_SUPPORT = new NodeFeature(
62+
"linear_retriever.multi_index_simplified_format_support"
63+
);
6164

6265
public static final NodeFeature LINEAR_RETRIEVER_MINSCORE_FIX = new NodeFeature("linear_retriever_minscore_fix");
6366
public static final String NAME = "linear";
@@ -318,11 +321,7 @@ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
318321
if (resolvedIndices != null && query != null) {
319322
// Using the multi-fields query format
320323
var localIndicesMetadata = resolvedIndices.getConcreteLocalIndicesMetadata();
321-
if (localIndicesMetadata.size() > 1) {
322-
throw new IllegalArgumentException(
323-
"[" + NAME + "] cannot specify [" + QUERY_FIELD.getPreferredName() + "] when querying multiple indices"
324-
);
325-
} else if (resolvedIndices.getRemoteClusterIndices().isEmpty() == false) {
324+
if (resolvedIndices.getRemoteClusterIndices().isEmpty() == false) {
326325
throw new IllegalArgumentException(
327326
"[" + NAME + "] cannot specify [" + QUERY_FIELD.getPreferredName() + "] when querying remote indices"
328327
);

0 commit comments

Comments
 (0)