Skip to content

Commit 9d9bdf6

Browse files
initial implementation for multi_match
1 parent fe2f1fe commit 9d9bdf6

File tree

2 files changed

+111
-1
lines changed

2 files changed

+111
-1
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
9696
import org.elasticsearch.xpack.inference.queries.SemanticKnnVectorQueryRewriteInterceptor;
9797
import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor;
98+
import org.elasticsearch.xpack.inference.queries.SemanticMultiMatchQueryRewriteInterceptor;
9899
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
99100
import org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor;
100101
import org.elasticsearch.xpack.inference.rank.random.RandomRankBuilder;
@@ -571,7 +572,8 @@ public List<QueryRewriteInterceptor> getQueryRewriteInterceptors() {
571572
return List.of(
572573
new SemanticKnnVectorQueryRewriteInterceptor(),
573574
new SemanticMatchQueryRewriteInterceptor(),
574-
new SemanticSparseVectorQueryRewriteInterceptor()
575+
new SemanticSparseVectorQueryRewriteInterceptor(),
576+
new SemanticMultiMatchQueryRewriteInterceptor()
575577
);
576578
}
577579

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.queries;
9+
10+
import org.elasticsearch.action.ResolvedIndices;
11+
import org.elasticsearch.cluster.metadata.IndexMetadata;
12+
import org.elasticsearch.index.query.BoolQueryBuilder;
13+
import org.elasticsearch.index.query.MultiMatchQueryBuilder;
14+
import org.elasticsearch.index.query.QueryBuilder;
15+
import org.elasticsearch.index.query.QueryRewriteContext;
16+
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
17+
18+
import java.io.IOException;
19+
import java.util.Collection;
20+
import java.util.HashMap;
21+
import java.util.Map;
22+
23+
public class SemanticMultiMatchQueryRewriteInterceptor implements QueryRewriteInterceptor {
24+
25+
@Override
26+
public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) {
27+
if (queryBuilder instanceof MultiMatchQueryBuilder == false) {
28+
return queryBuilder;
29+
}
30+
31+
MultiMatchQueryBuilder multiMatchBuilder = (MultiMatchQueryBuilder) queryBuilder;
32+
ResolvedIndices resolvedIndices = context.getResolvedIndices();
33+
if (resolvedIndices == null) {
34+
return queryBuilder;
35+
}
36+
37+
Map<String, Float> semanticFields = new HashMap<>();
38+
Map<String, Float> otherFields = new HashMap<>();
39+
Collection<IndexMetadata> allIndicesMetadata = resolvedIndices.getConcreteLocalIndicesMetadata().values();
40+
41+
for (Map.Entry<String, Float> fieldEntry : multiMatchBuilder.fields().entrySet()) {
42+
String fieldName = fieldEntry.getKey();
43+
boolean isSemanticInAnyIndex = false;
44+
for (IndexMetadata indexMetadata : allIndicesMetadata) {
45+
if (indexMetadata.getInferenceFields().containsKey(fieldName)) {
46+
isSemanticInAnyIndex = true;
47+
break;
48+
}
49+
}
50+
if (isSemanticInAnyIndex) {
51+
semanticFields.put(fieldName, fieldEntry.getValue());
52+
} else {
53+
otherFields.put(fieldName, fieldEntry.getValue());
54+
}
55+
}
56+
57+
if (semanticFields.isEmpty()) {
58+
return queryBuilder;
59+
}
60+
61+
BoolQueryBuilder rewrittenQuery = new BoolQueryBuilder();
62+
if (otherFields.isEmpty() == false) {
63+
MultiMatchQueryBuilder lexicalPart = new MultiMatchQueryBuilder(multiMatchBuilder.value());
64+
lexicalPart.fields(otherFields);
65+
lexicalPart.type(multiMatchBuilder.type());
66+
lexicalPart.operator(multiMatchBuilder.operator());
67+
lexicalPart.analyzer(multiMatchBuilder.analyzer());
68+
lexicalPart.slop(multiMatchBuilder.slop());
69+
if (multiMatchBuilder.fuzziness() != null) {
70+
lexicalPart.fuzziness(multiMatchBuilder.fuzziness());
71+
}
72+
lexicalPart.prefixLength(multiMatchBuilder.prefixLength());
73+
lexicalPart.maxExpansions(multiMatchBuilder.maxExpansions());
74+
lexicalPart.minimumShouldMatch(multiMatchBuilder.minimumShouldMatch());
75+
lexicalPart.fuzzyRewrite(multiMatchBuilder.fuzzyRewrite());
76+
if (multiMatchBuilder.tieBreaker() != null) {
77+
lexicalPart.tieBreaker(multiMatchBuilder.tieBreaker());
78+
}
79+
lexicalPart.lenient(multiMatchBuilder.lenient());
80+
lexicalPart.zeroTermsQuery(multiMatchBuilder.zeroTermsQuery());
81+
lexicalPart.autoGenerateSynonymsPhraseQuery(multiMatchBuilder.autoGenerateSynonymsPhraseQuery());
82+
lexicalPart.fuzzyTranspositions(multiMatchBuilder.fuzzyTranspositions());
83+
rewrittenQuery.should(lexicalPart);
84+
}
85+
86+
if (semanticFields.isEmpty() == false) {
87+
BoolQueryBuilder semanticPart = new BoolQueryBuilder();
88+
for (Map.Entry<String, Float> fieldEntry : semanticFields.entrySet()) {
89+
SemanticQueryBuilder semanticQueryBuilder = new SemanticQueryBuilder(fieldEntry.getKey(), multiMatchBuilder.value().toString(), true);
90+
if (fieldEntry.getValue() != 1.0f) {
91+
semanticQueryBuilder.boost(fieldEntry.getValue());
92+
}
93+
semanticPart.should(semanticQueryBuilder);
94+
}
95+
rewrittenQuery.should(semanticPart);
96+
}
97+
98+
rewrittenQuery.boost(multiMatchBuilder.boost());
99+
rewrittenQuery.queryName(multiMatchBuilder.queryName());
100+
101+
return rewrittenQuery;
102+
}
103+
104+
@Override
105+
public String getQueryName() {
106+
return MultiMatchQueryBuilder.NAME;
107+
}
108+
}

0 commit comments

Comments
 (0)