Skip to content

Commit 658f171

Browse files
refactor to push wildcard related work into multi_match
1 parent 8b45313 commit 658f171

File tree

2 files changed

+136
-47
lines changed

2 files changed

+136
-47
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMultiMatchQueryRewriteInterceptor.java

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,28 @@
77

88
package org.elasticsearch.xpack.inference.queries;
99

10+
import org.elasticsearch.action.ResolvedIndices;
11+
import org.elasticsearch.cluster.metadata.IndexMetadata;
12+
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
13+
import org.elasticsearch.common.regex.Regex;
14+
import org.elasticsearch.common.settings.Settings;
1015
import org.elasticsearch.features.NodeFeature;
1116
import org.elasticsearch.index.query.DisMaxQueryBuilder;
1217
import org.elasticsearch.index.query.MultiMatchQueryBuilder;
1318
import org.elasticsearch.index.query.QueryBuilder;
1419
import org.elasticsearch.index.query.QueryBuilders;
20+
import org.elasticsearch.index.search.QueryParserHelper;
1521

22+
import java.util.Collection;
1623
import java.util.HashMap;
24+
import java.util.HashSet;
1725
import java.util.List;
1826
import java.util.Map;
1927
import java.util.Objects;
2028
import java.util.Set;
29+
import java.util.stream.Collectors;
30+
31+
import static org.elasticsearch.index.IndexSettings.DEFAULT_FIELD_SETTING;
2132

2233
public class SemanticMultiMatchQueryRewriteInterceptor extends SemanticQueryRewriteInterceptor {
2334

@@ -108,6 +119,118 @@ protected boolean shouldUseDefaultFields() {
108119
return true;
109120
}
110121

122+
@Override
123+
protected InferenceIndexInformationForField resolveIndicesForFields(
124+
QueryBuilder queryBuilder,
125+
ResolvedIndices resolvedIndices,
126+
boolean resolveInferenceFieldWildcards
127+
) {
128+
assert (queryBuilder instanceof MultiMatchQueryBuilder);
129+
MultiMatchQueryBuilder multiMatchQuery = (MultiMatchQueryBuilder) queryBuilder;
130+
131+
// If wildcard resolution is disabled, use the simple parent class implementation
132+
if (resolveInferenceFieldWildcards == false || multiMatchQuery.resolveInferenceFieldWildcards() == false) {
133+
return super.resolveIndicesForFields(queryBuilder, resolvedIndices, resolveInferenceFieldWildcards);
134+
}
135+
136+
return resolveIndicesForFieldsWithWildcards(queryBuilder, resolvedIndices);
137+
}
138+
139+
private InferenceIndexInformationForField resolveIndicesForFieldsWithWildcards(
140+
QueryBuilder queryBuilder,
141+
ResolvedIndices resolvedIndices
142+
) {
143+
Map<String, Float> fieldsWithWeights = getFieldsWithWeights(queryBuilder);
144+
Collection<IndexMetadata> indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values();
145+
146+
// Global wildcard resolution for inference fields
147+
Map<String, Float> globalInferenceFieldBoosts = new HashMap<>();
148+
149+
// Get all unique inference fields across all indices
150+
Set<String> allInferenceFields = indexMetadataCollection.stream()
151+
.flatMap(idx -> idx.getInferenceFields().keySet().stream())
152+
.collect(Collectors.toSet());
153+
154+
// Calculate boost for each inference field based on matching patterns
155+
for (String inferenceField : allInferenceFields) {
156+
for (Map.Entry<String, Float> entry : fieldsWithWeights.entrySet()) {
157+
String pattern = entry.getKey();
158+
Float boost = entry.getValue();
159+
160+
if (Regex.isMatchAllPattern(pattern)
161+
|| (Regex.isSimpleMatchPattern(pattern) && Regex.simpleMatch(pattern, inferenceField))
162+
|| pattern.equals(inferenceField)) {
163+
addToFieldBoostsMap(globalInferenceFieldBoosts, inferenceField, boost);
164+
}
165+
}
166+
}
167+
168+
// Per-index processing using pre-calculated global boosts
169+
Map<String, Map<String, InferenceFieldMetadata>> inferenceFieldsPerIndex = new HashMap<>();
170+
Map<String, Set<String>> nonInferenceFieldsPerIndex = new HashMap<>();
171+
Map<String, Float> allFieldBoosts = new HashMap<>(globalInferenceFieldBoosts);
172+
173+
for (IndexMetadata indexMetadata : indexMetadataCollection) {
174+
String indexName = indexMetadata.getIndex().getName();
175+
Map<String, InferenceFieldMetadata> indexInferenceFields = new HashMap<>();
176+
Map<String, InferenceFieldMetadata> indexInferenceMetadata = indexMetadata.getInferenceFields();
177+
178+
// Handle default fields per index when no fields are specified
179+
Map<String, Float> fieldsToProcess = fieldsWithWeights;
180+
if (fieldsToProcess.isEmpty() && shouldUseDefaultFields()) {
181+
Settings settings = indexMetadata.getSettings();
182+
List<String> defaultFields = settings.getAsList(DEFAULT_FIELD_SETTING.getKey(), DEFAULT_FIELD_SETTING.getDefault(settings));
183+
fieldsToProcess = QueryParserHelper.parseFieldsAndWeights(defaultFields);
184+
}
185+
186+
// Collect resolved inference fields for this index
187+
Set<String> resolvedInferenceFields = new HashSet<>();
188+
189+
// Add wildcard-resolved inference fields that exist in this index
190+
for (String inferenceField : globalInferenceFieldBoosts.keySet()) {
191+
if (indexInferenceMetadata.containsKey(inferenceField)) {
192+
indexInferenceFields.put(inferenceField, indexInferenceMetadata.get(inferenceField));
193+
resolvedInferenceFields.add(inferenceField);
194+
}
195+
}
196+
197+
// Always handle explicit inference fields (both wildcard and non-wildcard cases)
198+
for (Map.Entry<String, Float> entry : fieldsToProcess.entrySet()) {
199+
String field = entry.getKey();
200+
Float boost = entry.getValue();
201+
202+
if (indexInferenceMetadata.containsKey(field)) {
203+
indexInferenceFields.put(field, indexInferenceMetadata.get(field));
204+
resolvedInferenceFields.add(field);
205+
addToFieldBoostsMap(allFieldBoosts, field, boost);
206+
}
207+
}
208+
209+
// Non-inference fields: all patterns minus resolved inference fields
210+
Set<String> indexNonInferenceFields = new HashSet<>(fieldsToProcess.keySet());
211+
indexNonInferenceFields.removeAll(resolvedInferenceFields);
212+
213+
// Store boosts for non-inference field patterns
214+
for (String nonInferenceField : indexNonInferenceFields) {
215+
addToFieldBoostsMap(allFieldBoosts, nonInferenceField, fieldsToProcess.get(nonInferenceField));
216+
}
217+
218+
if (indexInferenceFields.isEmpty() == false) {
219+
inferenceFieldsPerIndex.put(indexName, indexInferenceFields);
220+
}
221+
222+
if (indexNonInferenceFields.isEmpty() == false) {
223+
nonInferenceFieldsPerIndex.put(indexName, indexNonInferenceFields);
224+
}
225+
}
226+
227+
return new InferenceIndexInformationForField(inferenceFieldsPerIndex, nonInferenceFieldsPerIndex, allFieldBoosts);
228+
}
229+
230+
private static void addToFieldBoostsMap(Map<String, Float> fieldBoosts, String field, Float boost) {
231+
fieldBoosts.compute(field, (k, v) -> v == null ? boost : v * boost);
232+
}
233+
111234
private QueryBuilder buildMultiFieldSemanticQuery(
112235
MultiMatchQueryBuilder originalQuery,
113236
Set<String> inferenceFields,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryRewriteInterceptor.java

Lines changed: 13 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import org.elasticsearch.action.ResolvedIndices;
1111
import org.elasticsearch.cluster.metadata.IndexMetadata;
1212
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
13-
import org.elasticsearch.common.regex.Regex;
1413
import org.elasticsearch.common.settings.Settings;
1514
import org.elasticsearch.index.mapper.IndexFieldMapper;
1615
import org.elasticsearch.index.query.AbstractQueryBuilder;
@@ -136,41 +135,18 @@ private static void addToFieldBoostsMap(Map<String, Float> fieldBoosts, String f
136135
fieldBoosts.compute(field, (k, v) -> v == null ? boost : v * boost);
137136
}
138137

139-
private InferenceIndexInformationForField resolveIndicesForFields(
138+
protected InferenceIndexInformationForField resolveIndicesForFields(
140139
QueryBuilder queryBuilder,
141140
ResolvedIndices resolvedIndices,
142141
boolean resolveInferenceFieldWildcards
143142
) {
144143
Map<String, Float> fieldsWithWeights = getFieldsWithWeights(queryBuilder);
145144
Collection<IndexMetadata> indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values();
146145

147-
// STEP 1: Global wildcard resolution for inference fields
148-
Map<String, Float> globalInferenceFieldBoosts = new HashMap<>();
149-
if (resolveInferenceFieldWildcards) {
150-
// Get all unique inference fields across all indices
151-
Set<String> allInferenceFields = indexMetadataCollection.stream()
152-
.flatMap(idx -> idx.getInferenceFields().keySet().stream())
153-
.collect(Collectors.toSet());
154-
155-
// Calculate boost for each inference field based on matching patterns
156-
for (String inferenceField : allInferenceFields) {
157-
for (Map.Entry<String, Float> entry : fieldsWithWeights.entrySet()) {
158-
String pattern = entry.getKey();
159-
Float boost = entry.getValue();
160-
161-
if (Regex.isMatchAllPattern(pattern)
162-
|| (Regex.isSimpleMatchPattern(pattern) && Regex.simpleMatch(pattern, inferenceField))
163-
|| pattern.equals(inferenceField)) {
164-
addToFieldBoostsMap(globalInferenceFieldBoosts, inferenceField, boost);
165-
}
166-
}
167-
}
168-
}
169-
170-
// STEP 2: Per-index processing using pre-calculated global boosts
146+
// Simple implementation: only handle explicit inference fields (no wildcards)
171147
Map<String, Map<String, InferenceFieldMetadata>> inferenceFieldsPerIndex = new HashMap<>();
172148
Map<String, Set<String>> nonInferenceFieldsPerIndex = new HashMap<>();
173-
Map<String, Float> allFieldBoosts = new HashMap<>(globalInferenceFieldBoosts);
149+
Map<String, Float> allFieldBoosts = new HashMap<>();
174150

175151
for (IndexMetadata indexMetadata : indexMetadataCollection) {
176152
String indexName = indexMetadata.getIndex().getName();
@@ -188,29 +164,19 @@ private InferenceIndexInformationForField resolveIndicesForFields(
188164
// Collect resolved inference fields for this index
189165
Set<String> resolvedInferenceFields = new HashSet<>();
190166

191-
if (resolveInferenceFieldWildcards) {
192-
// Add inference fields that exist in this index (using pre-calculated boosts)
193-
for (String inferenceField : globalInferenceFieldBoosts.keySet()) {
194-
if (indexInferenceMetadata.containsKey(inferenceField)) {
195-
indexInferenceFields.put(inferenceField, indexInferenceMetadata.get(inferenceField));
196-
resolvedInferenceFields.add(inferenceField);
197-
}
198-
}
199-
} else {
200-
// Handle explicit inference fields (non-wildcard)
201-
for (Map.Entry<String, Float> entry : fieldsToProcess.entrySet()) {
202-
String field = entry.getKey();
203-
Float boost = entry.getValue();
204-
205-
if (indexInferenceMetadata.containsKey(field)) {
206-
indexInferenceFields.put(field, indexInferenceMetadata.get(field));
207-
resolvedInferenceFields.add(field);
208-
addToFieldBoostsMap(allFieldBoosts, field, boost);
209-
}
167+
// Handle explicit inference fields only
168+
for (Map.Entry<String, Float> entry : fieldsToProcess.entrySet()) {
169+
String field = entry.getKey();
170+
Float boost = entry.getValue();
171+
172+
if (indexInferenceMetadata.containsKey(field)) {
173+
indexInferenceFields.put(field, indexInferenceMetadata.get(field));
174+
resolvedInferenceFields.add(field);
175+
addToFieldBoostsMap(allFieldBoosts, field, boost);
210176
}
211177
}
212178

213-
// Non-inference fields: all patterns minus resolved inference fields (simple approach like MultiFieldsInnerRetrieverUtils)
179+
// Non-inference fields
214180
Set<String> indexNonInferenceFields = new HashSet<>(fieldsToProcess.keySet());
215181
indexNonInferenceFields.removeAll(resolvedInferenceFields);
216182

0 commit comments

Comments
 (0)