Skip to content

Commit 5ed5d6c

Browse files
supply modelRegistry and add warning
1 parent 497f20c commit 5ed5d6c

File tree

2 files changed

+56
-9
lines changed

2 files changed

+56
-9
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,7 @@ public List<QueryRewriteInterceptor> getQueryRewriteInterceptors() {
573573
new SemanticKnnVectorQueryRewriteInterceptor(),
574574
new SemanticMatchQueryRewriteInterceptor(),
575575
new SemanticSparseVectorQueryRewriteInterceptor(),
576-
new SemanticMultiMatchQueryRewriteInterceptor()
576+
new SemanticMultiMatchQueryRewriteInterceptor(getModelRegistry())
577577
);
578578
}
579579

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMultiMatchQueryRewriteInterceptor.java

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,37 @@
99

1010
import org.elasticsearch.action.ResolvedIndices;
1111
import org.elasticsearch.cluster.metadata.IndexMetadata;
12+
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
13+
import org.elasticsearch.common.logging.HeaderWarning;
1214
import org.elasticsearch.index.query.BoolQueryBuilder;
1315
import org.elasticsearch.index.query.DisMaxQueryBuilder;
1416
import org.elasticsearch.index.query.MultiMatchQueryBuilder;
1517
import org.elasticsearch.index.query.QueryBuilder;
1618
import org.elasticsearch.index.query.QueryBuilders;
1719
import org.elasticsearch.index.query.QueryRewriteContext;
20+
import org.elasticsearch.inference.MinimalServiceSettings;
21+
import org.elasticsearch.inference.TaskType;
1822
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
23+
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
1924

2025
import java.io.IOException;
2126
import java.util.Collection;
2227
import java.util.HashMap;
2328
import java.util.Map;
29+
import java.util.Objects;
30+
import java.util.function.Supplier;
2431

2532
public class SemanticMultiMatchQueryRewriteInterceptor implements QueryRewriteInterceptor {
2633

34+
private static final String SCORE_MISMATCH_WARNING = "multi_match query is targeting a mixture of semantic_text fields with dense "
35+
+ "and sparse models, or a mixture of semantic_text and non-inference fields. Score ranges will not be comparable.";
36+
37+
private final Supplier<ModelRegistry> modelRegistrySupplier;
38+
39+
public SemanticMultiMatchQueryRewriteInterceptor(Supplier<ModelRegistry> modelRegistrySupplier) {
40+
this.modelRegistrySupplier = Objects.requireNonNull(modelRegistrySupplier);
41+
}
42+
2743
@Override
2844
public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) {
2945
if (queryBuilder instanceof MultiMatchQueryBuilder == false) {
@@ -40,12 +56,29 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde
4056
Map<String, Float> otherFields = new HashMap<>();
4157
Collection<IndexMetadata> allIndicesMetadata = resolvedIndices.getConcreteLocalIndicesMetadata().values();
4258

59+
boolean hasDenseSemanticField = false;
60+
boolean hasSparseSemanticField = false;
61+
62+
ModelRegistry modelRegistry = modelRegistrySupplier.get();
63+
if (modelRegistry == null) {
64+
// Should not happen in a sane lifecycle, but protect against it
65+
return queryBuilder;
66+
}
67+
4368
for (Map.Entry<String, Float> fieldEntry : multiMatchBuilder.fields().entrySet()) {
4469
String fieldName = fieldEntry.getKey();
45-
boolean isSemanticInAnyIndex = allIndicesMetadata.stream()
46-
.anyMatch(indexMetadata -> indexMetadata.getInferenceFields().containsKey(fieldName));
47-
if (isSemanticInAnyIndex) {
70+
InferenceFieldMetadata inferenceMetadata = findInferenceMetadata(fieldName, allIndicesMetadata);
71+
72+
if (inferenceMetadata != null) {
4873
semanticFields.put(fieldName, fieldEntry.getValue());
74+
MinimalServiceSettings settings = modelRegistry.getMinimalServiceSettings(inferenceMetadata.getSearchInferenceId());
75+
if (settings != null) {
76+
if (settings.taskType() == TaskType.TEXT_EMBEDDING) {
77+
hasDenseSemanticField = true;
78+
} else if (settings.taskType() == TaskType.SPARSE_EMBEDDING) {
79+
hasSparseSemanticField = true;
80+
}
81+
}
4982
} else {
5083
otherFields.put(fieldName, fieldEntry.getValue());
5184
}
@@ -55,6 +88,10 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde
5588
return queryBuilder;
5689
}
5790

91+
if (hasDenseSemanticField && (hasSparseSemanticField || otherFields.isEmpty() == false)) {
92+
HeaderWarning.addWarning(SCORE_MISMATCH_WARNING);
93+
}
94+
5895
MultiMatchQueryBuilder.Type type = multiMatchBuilder.type();
5996
if (type == MultiMatchQueryBuilder.Type.CROSS_FIELDS ||
6097
type == MultiMatchQueryBuilder.Type.PHRASE ||
@@ -97,6 +134,21 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde
97134
return rewrittenQuery;
98135
}
99136

137+
@Override
138+
public String getQueryName() {
139+
return MultiMatchQueryBuilder.NAME;
140+
}
141+
142+
private InferenceFieldMetadata findInferenceMetadata(String fieldName, Collection<IndexMetadata> allIndicesMetadata) {
143+
for (IndexMetadata indexMetadata : allIndicesMetadata) {
144+
InferenceFieldMetadata inferenceMetadata = indexMetadata.getInferenceFields().get(fieldName);
145+
if (inferenceMetadata != null) {
146+
return inferenceMetadata;
147+
}
148+
}
149+
return null;
150+
}
151+
100152
private QueryBuilder createLexicalQuery(MultiMatchQueryBuilder original, Map<String, Float> lexicalFields) {
101153
MultiMatchQueryBuilder lexicalPart = new MultiMatchQueryBuilder(original.value());
102154
lexicalPart.fields(lexicalFields);
@@ -140,9 +192,4 @@ private QueryBuilder createSemanticQuery(String queryText, Map.Entry<String, Flo
140192
}
141193
return semanticQuery;
142194
}
143-
144-
@Override
145-
public String getQueryName() {
146-
return MultiMatchQueryBuilder.NAME;
147-
}
148195
}

0 commit comments

Comments
 (0)