Skip to content

Commit 497f20c

Browse files
optimized
1 parent 9a03651 commit 497f20c

File tree

1 file changed

+71
-65
lines changed

1 file changed

+71
-65
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMultiMatchQueryRewriteInterceptor.java

Lines changed: 71 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,8 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde
4242

4343
for (Map.Entry<String, Float> fieldEntry : multiMatchBuilder.fields().entrySet()) {
4444
String fieldName = fieldEntry.getKey();
45-
boolean isSemanticInAnyIndex = false;
46-
for (IndexMetadata indexMetadata : allIndicesMetadata) {
47-
if (indexMetadata.getInferenceFields().containsKey(fieldName)) {
48-
isSemanticInAnyIndex = true;
49-
break;
50-
}
51-
}
45+
boolean isSemanticInAnyIndex = allIndicesMetadata.stream()
46+
.anyMatch(indexMetadata -> indexMetadata.getInferenceFields().containsKey(fieldName));
5247
if (isSemanticInAnyIndex) {
5348
semanticFields.put(fieldName, fieldEntry.getValue());
5449
} else {
@@ -67,72 +62,83 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde
6762
throw new IllegalArgumentException("Query type [" + type.parseField().getPreferredName() + "] is not supported with semantic_text fields");
6863
}
6964

70-
if (type == MultiMatchQueryBuilder.Type.BEST_FIELDS) {
71-
DisMaxQueryBuilder disMaxQuery = QueryBuilders.disMaxQuery();
72-
if (otherFields.isEmpty() == false) {
73-
MultiMatchQueryBuilder lexicalPart = new MultiMatchQueryBuilder(multiMatchBuilder.value());
74-
lexicalPart.fields(otherFields);
75-
lexicalPart.type(multiMatchBuilder.type());
76-
disMaxQuery.add(lexicalPart);
77-
}
78-
for (Map.Entry<String, Float> fieldEntry : semanticFields.entrySet()) {
79-
SemanticQueryBuilder semanticQuery = new SemanticQueryBuilder(fieldEntry.getKey(), multiMatchBuilder.value().toString(), true);
80-
if (fieldEntry.getValue() != 1.0f) {
81-
semanticQuery.boost(fieldEntry.getValue());
65+
QueryBuilder rewrittenQuery;
66+
switch (type) {
67+
case BEST_FIELDS:
68+
DisMaxQueryBuilder disMaxQuery = QueryBuilders.disMaxQuery();
69+
if (otherFields.isEmpty() == false) {
70+
disMaxQuery.add(createLexicalQuery(multiMatchBuilder, otherFields));
8271
}
83-
disMaxQuery.add(semanticQuery);
84-
}
85-
Float tieBreaker = multiMatchBuilder.tieBreaker();
86-
if (tieBreaker != null) {
87-
disMaxQuery.tieBreaker(tieBreaker);
88-
}
89-
disMaxQuery.boost(multiMatchBuilder.boost());
90-
disMaxQuery.queryName(multiMatchBuilder.queryName());
91-
return disMaxQuery;
72+
for (Map.Entry<String, Float> fieldEntry : semanticFields.entrySet()) {
73+
disMaxQuery.add(createSemanticQuery(multiMatchBuilder.value().toString(), fieldEntry));
74+
}
75+
Float tieBreaker = multiMatchBuilder.tieBreaker();
76+
if (tieBreaker != null) {
77+
disMaxQuery.tieBreaker(tieBreaker);
78+
}
79+
rewrittenQuery = disMaxQuery;
80+
break;
81+
case MOST_FIELDS:
82+
case BOOL_PREFIX:
83+
default:
84+
BoolQueryBuilder boolQuery = new BoolQueryBuilder();
85+
if (otherFields.isEmpty() == false) {
86+
boolQuery.should(createLexicalQuery(multiMatchBuilder, otherFields));
87+
}
88+
if (semanticFields.isEmpty() == false) {
89+
boolQuery.should(createSemanticQuery(multiMatchBuilder.value().toString(), semanticFields));
90+
}
91+
rewrittenQuery = boolQuery;
92+
break;
9293
}
9394

94-
// Fallback for other types like MOST_FIELDS and BOOL_PREFIX
95-
BoolQueryBuilder rewrittenQuery = new BoolQueryBuilder();
96-
if (otherFields.isEmpty() == false) {
97-
MultiMatchQueryBuilder lexicalPart = new MultiMatchQueryBuilder(multiMatchBuilder.value());
98-
lexicalPart.fields(otherFields);
99-
lexicalPart.type(multiMatchBuilder.type());
100-
lexicalPart.operator(multiMatchBuilder.operator());
101-
lexicalPart.analyzer(multiMatchBuilder.analyzer());
102-
lexicalPart.slop(multiMatchBuilder.slop());
103-
if (multiMatchBuilder.fuzziness() != null) {
104-
lexicalPart.fuzziness(multiMatchBuilder.fuzziness());
105-
}
106-
lexicalPart.prefixLength(multiMatchBuilder.prefixLength());
107-
lexicalPart.maxExpansions(multiMatchBuilder.maxExpansions());
108-
lexicalPart.minimumShouldMatch(multiMatchBuilder.minimumShouldMatch());
109-
lexicalPart.fuzzyRewrite(multiMatchBuilder.fuzzyRewrite());
110-
if (multiMatchBuilder.tieBreaker() != null) {
111-
lexicalPart.tieBreaker(multiMatchBuilder.tieBreaker());
112-
}
113-
lexicalPart.lenient(multiMatchBuilder.lenient());
114-
lexicalPart.zeroTermsQuery(multiMatchBuilder.zeroTermsQuery());
115-
lexicalPart.autoGenerateSynonymsPhraseQuery(multiMatchBuilder.autoGenerateSynonymsPhraseQuery());
116-
lexicalPart.fuzzyTranspositions(multiMatchBuilder.fuzzyTranspositions());
117-
rewrittenQuery.should(lexicalPart);
95+
rewrittenQuery.boost(multiMatchBuilder.boost());
96+
rewrittenQuery.queryName(multiMatchBuilder.queryName());
97+
return rewrittenQuery;
98+
}
99+
100+
private QueryBuilder createLexicalQuery(MultiMatchQueryBuilder original, Map<String, Float> lexicalFields) {
101+
MultiMatchQueryBuilder lexicalPart = new MultiMatchQueryBuilder(original.value());
102+
lexicalPart.fields(lexicalFields);
103+
lexicalPart.type(original.type());
104+
lexicalPart.operator(original.operator());
105+
lexicalPart.analyzer(original.analyzer());
106+
lexicalPart.slop(original.slop());
107+
if (original.fuzziness() != null) {
108+
lexicalPart.fuzziness(original.fuzziness());
109+
}
110+
lexicalPart.prefixLength(original.prefixLength());
111+
lexicalPart.maxExpansions(original.maxExpansions());
112+
lexicalPart.minimumShouldMatch(original.minimumShouldMatch());
113+
lexicalPart.fuzzyRewrite(original.fuzzyRewrite());
114+
if (original.tieBreaker() != null) {
115+
lexicalPart.tieBreaker(original.tieBreaker());
118116
}
117+
lexicalPart.lenient(original.lenient());
118+
lexicalPart.zeroTermsQuery(original.zeroTermsQuery());
119+
lexicalPart.autoGenerateSynonymsPhraseQuery(original.autoGenerateSynonymsPhraseQuery());
120+
lexicalPart.fuzzyTranspositions(original.fuzzyTranspositions());
121+
return lexicalPart;
122+
}
119123

120-
if (semanticFields.isEmpty() == false) {
121-
BoolQueryBuilder semanticPart = new BoolQueryBuilder();
122-
for (Map.Entry<String, Float> fieldEntry : semanticFields.entrySet()) {
123-
SemanticQueryBuilder semanticQueryBuilder = new SemanticQueryBuilder(fieldEntry.getKey(), multiMatchBuilder.value().toString(), true);
124-
if (fieldEntry.getValue() != 1.0f) {
125-
semanticQueryBuilder.boost(fieldEntry.getValue());
126-
}
127-
semanticPart.should(semanticQueryBuilder);
128-
}
129-
rewrittenQuery.should(semanticPart);
124+
private QueryBuilder createSemanticQuery(String queryText, Map<String, Float> semanticFields) {
125+
if (semanticFields.size() == 1) {
126+
return createSemanticQuery(queryText, semanticFields.entrySet().iterator().next());
130127
}
131128

132-
rewrittenQuery.boost(multiMatchBuilder.boost());
133-
rewrittenQuery.queryName(multiMatchBuilder.queryName());
129+
BoolQueryBuilder boolQuery = new BoolQueryBuilder();
130+
for (Map.Entry<String, Float> fieldEntry : semanticFields.entrySet()) {
131+
boolQuery.should(createSemanticQuery(queryText, fieldEntry));
132+
}
133+
return boolQuery;
134+
}
134135

135-
return rewrittenQuery;
136+
private QueryBuilder createSemanticQuery(String queryText, Map.Entry<String, Float> fieldEntry) {
137+
SemanticQueryBuilder semanticQuery = new SemanticQueryBuilder(fieldEntry.getKey(), queryText, true);
138+
if (fieldEntry.getValue() != 1.0f) {
139+
semanticQuery.boost(fieldEntry.getValue());
140+
}
141+
return semanticQuery;
136142
}
137143

138144
@Override

0 commit comments

Comments
 (0)