99
1010import org .elasticsearch .action .ResolvedIndices ;
1111import org .elasticsearch .cluster .metadata .IndexMetadata ;
12+ import org .elasticsearch .cluster .metadata .InferenceFieldMetadata ;
13+ import org .elasticsearch .common .logging .HeaderWarning ;
1214import org .elasticsearch .index .query .BoolQueryBuilder ;
1315import org .elasticsearch .index .query .DisMaxQueryBuilder ;
1416import org .elasticsearch .index .query .MultiMatchQueryBuilder ;
1517import org .elasticsearch .index .query .QueryBuilder ;
1618import org .elasticsearch .index .query .QueryBuilders ;
1719import org .elasticsearch .index .query .QueryRewriteContext ;
20+ import org .elasticsearch .inference .MinimalServiceSettings ;
21+ import org .elasticsearch .inference .TaskType ;
1822import org .elasticsearch .plugins .internal .rewriter .QueryRewriteInterceptor ;
23+ import org .elasticsearch .xpack .inference .registry .ModelRegistry ;
1924
2025import java .io .IOException ;
2126import java .util .Collection ;
2227import java .util .HashMap ;
2328import java .util .Map ;
29+ import java .util .Objects ;
30+ import java .util .function .Supplier ;
2431
2532public class SemanticMultiMatchQueryRewriteInterceptor implements QueryRewriteInterceptor {
2633
34+ private static final String SCORE_MISMATCH_WARNING = "multi_match query is targeting a mixture of semantic_text fields with dense "
35+ + "and sparse models, or a mixture of semantic_text and non-inference fields. Score ranges will not be comparable." ;
36+
37+ private final Supplier <ModelRegistry > modelRegistrySupplier ;
38+
39+ public SemanticMultiMatchQueryRewriteInterceptor (Supplier <ModelRegistry > modelRegistrySupplier ) {
40+ this .modelRegistrySupplier = Objects .requireNonNull (modelRegistrySupplier );
41+ }
42+
2743 @ Override
2844 public QueryBuilder interceptAndRewrite (QueryRewriteContext context , QueryBuilder queryBuilder ) {
2945 if (queryBuilder instanceof MultiMatchQueryBuilder == false ) {
@@ -40,12 +56,29 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde
4056 Map <String , Float > otherFields = new HashMap <>();
4157 Collection <IndexMetadata > allIndicesMetadata = resolvedIndices .getConcreteLocalIndicesMetadata ().values ();
4258
59+ boolean hasDenseSemanticField = false ;
60+ boolean hasSparseSemanticField = false ;
61+
62+ ModelRegistry modelRegistry = modelRegistrySupplier .get ();
63+ if (modelRegistry == null ) {
64+ // Should not happen in a sane lifecycle, but protect against it
65+ return queryBuilder ;
66+ }
67+
4368 for (Map .Entry <String , Float > fieldEntry : multiMatchBuilder .fields ().entrySet ()) {
4469 String fieldName = fieldEntry .getKey ();
45- boolean isSemanticInAnyIndex = allIndicesMetadata . stream ()
46- . anyMatch ( indexMetadata -> indexMetadata . getInferenceFields (). containsKey ( fieldName ));
47- if (isSemanticInAnyIndex ) {
70+ InferenceFieldMetadata inferenceMetadata = findInferenceMetadata ( fieldName , allIndicesMetadata );
71+
72+ if (inferenceMetadata != null ) {
4873 semanticFields .put (fieldName , fieldEntry .getValue ());
74+ MinimalServiceSettings settings = modelRegistry .getMinimalServiceSettings (inferenceMetadata .getSearchInferenceId ());
75+ if (settings != null ) {
76+ if (settings .taskType () == TaskType .TEXT_EMBEDDING ) {
77+ hasDenseSemanticField = true ;
78+ } else if (settings .taskType () == TaskType .SPARSE_EMBEDDING ) {
79+ hasSparseSemanticField = true ;
80+ }
81+ }
4982 } else {
5083 otherFields .put (fieldName , fieldEntry .getValue ());
5184 }
@@ -55,6 +88,10 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde
5588 return queryBuilder ;
5689 }
5790
91+ if (hasDenseSemanticField && (hasSparseSemanticField || otherFields .isEmpty () == false )) {
92+ HeaderWarning .addWarning (SCORE_MISMATCH_WARNING );
93+ }
94+
5895 MultiMatchQueryBuilder .Type type = multiMatchBuilder .type ();
5996 if (type == MultiMatchQueryBuilder .Type .CROSS_FIELDS ||
6097 type == MultiMatchQueryBuilder .Type .PHRASE ||
@@ -97,6 +134,21 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde
97134 return rewrittenQuery ;
98135 }
99136
137+ @ Override
138+ public String getQueryName () {
139+ return MultiMatchQueryBuilder .NAME ;
140+ }
141+
142+ private InferenceFieldMetadata findInferenceMetadata (String fieldName , Collection <IndexMetadata > allIndicesMetadata ) {
143+ for (IndexMetadata indexMetadata : allIndicesMetadata ) {
144+ InferenceFieldMetadata inferenceMetadata = indexMetadata .getInferenceFields ().get (fieldName );
145+ if (inferenceMetadata != null ) {
146+ return inferenceMetadata ;
147+ }
148+ }
149+ return null ;
150+ }
151+
100152 private QueryBuilder createLexicalQuery (MultiMatchQueryBuilder original , Map <String , Float > lexicalFields ) {
101153 MultiMatchQueryBuilder lexicalPart = new MultiMatchQueryBuilder (original .value ());
102154 lexicalPart .fields (lexicalFields );
@@ -140,9 +192,4 @@ private QueryBuilder createSemanticQuery(String queryText, Map.Entry<String, Flo
140192 }
141193 return semanticQuery ;
142194 }
143-
144- @ Override
145- public String getQueryName () {
146- return MultiMatchQueryBuilder .NAME ;
147- }
148195}
0 commit comments