Skip to content

Commit 791552b

Browse files
simplified query generation
1 parent 344e2de commit 791552b

File tree

5 files changed

+58
-224
lines changed

5 files changed

+58
-224
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticKnnVectorQueryRewriteInterceptor.java

Lines changed: 13 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public class SemanticKnnVectorQueryRewriteInterceptor extends SemanticQueryRewri
3333
public SemanticKnnVectorQueryRewriteInterceptor() {}
3434

3535
@Override
36-
protected Map<String, Float> getFieldNamesWithWeights(QueryBuilder queryBuilder) {
36+
protected Map<String, Float> getFieldNamesWithBoosts(QueryBuilder queryBuilder) {
3737
assert (queryBuilder instanceof KnnVectorQueryBuilder);
3838
KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) queryBuilder;
3939
return Map.of(knnVectorQueryBuilder.getFieldName(), 1.0f);
@@ -47,7 +47,12 @@ protected String getQuery(QueryBuilder queryBuilder) {
4747
return queryVectorBuilder != null ? queryVectorBuilder.getModelText() : null;
4848
}
4949

50-
private QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) {
50+
@Override
51+
protected QueryBuilder buildInferenceQuery(
52+
QueryBuilder queryBuilder,
53+
InferenceIndexInformationForField indexInformation,
54+
Float fieldBoost
55+
) {
5156
assert (queryBuilder instanceof KnnVectorQueryBuilder);
5257
KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) queryBuilder;
5358
Map<String, List<String>> inferenceIdsIndices = indexInformation.getInferenceIdsIndices();
@@ -62,26 +67,11 @@ private QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceInd
6267
// Multiple inference IDs, construct a boolean query
6368
finalQueryBuilder = buildInferenceQueryWithMultipleInferenceIds(knnVectorQueryBuilder, inferenceIdsIndices);
6469
}
65-
finalQueryBuilder.boost(queryBuilder.boost());
70+
finalQueryBuilder.boost(queryBuilder.boost() * fieldBoost);
6671
finalQueryBuilder.queryName(queryBuilder.queryName());
6772
return finalQueryBuilder;
6873
}
6974

70-
@Override
71-
protected QueryBuilder buildInferenceQuery(
72-
QueryBuilder queryBuilder,
73-
InferenceIndexInformationForField indexInformation,
74-
Float fieldWeight
75-
) {
76-
QueryBuilder inferenceQuery = buildInferenceQuery(queryBuilder, indexInformation);
77-
78-
if (fieldWeight != null && fieldWeight.equals(1.0f) == false) {
79-
inferenceQuery.boost(fieldWeight);
80-
}
81-
82-
return inferenceQuery;
83-
}
84-
8575
private QueryBuilder buildInferenceQueryWithMultipleInferenceIds(
8676
KnnVectorQueryBuilder queryBuilder,
8777
Map<String, List<String>> inferenceIdsIndices
@@ -98,9 +88,11 @@ private QueryBuilder buildInferenceQueryWithMultipleInferenceIds(
9888
return boolQueryBuilder;
9989
}
10090

101-
private QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
91+
@Override
92+
protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
10293
QueryBuilder queryBuilder,
103-
InferenceIndexInformationForField indexInformation
94+
InferenceIndexInformationForField indexInformation,
95+
Float fieldBoost
10496
) {
10597
assert (queryBuilder instanceof KnnVectorQueryBuilder);
10698
KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) queryBuilder;
@@ -119,26 +111,11 @@ private QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
119111
)
120112
);
121113
}
122-
boolQueryBuilder.boost(queryBuilder.boost());
114+
boolQueryBuilder.boost(queryBuilder.boost() * fieldBoost);
123115
boolQueryBuilder.queryName(queryBuilder.queryName());
124116
return boolQueryBuilder;
125117
}
126118

127-
@Override
128-
protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
129-
QueryBuilder queryBuilder,
130-
InferenceIndexInformationForField indexInformation,
131-
Float fieldWeight
132-
) {
133-
QueryBuilder inferenceQuery = buildCombinedInferenceAndNonInferenceQuery(queryBuilder, indexInformation);
134-
135-
if (fieldWeight != null && fieldWeight.equals(1.0f) == false) {
136-
inferenceQuery.boost(fieldWeight);
137-
}
138-
139-
return inferenceQuery;
140-
}
141-
142119
private QueryBuilder buildNestedQueryFromKnnVectorQuery(
143120
KnnVectorQueryBuilder knnVectorQueryBuilder,
144121
List<String> indices,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java

Lines changed: 10 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ public class SemanticMatchQueryRewriteInterceptor extends SemanticQueryRewriteIn
2323
public SemanticMatchQueryRewriteInterceptor() {}
2424

2525
@Override
26-
protected Map<String, Float> getFieldNamesWithWeights(QueryBuilder queryBuilder) {
26+
protected Map<String, Float> getFieldNamesWithBoosts(QueryBuilder queryBuilder) {
2727
assert (queryBuilder instanceof MatchQueryBuilder);
2828
MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder;
2929
return Map.of(matchQueryBuilder.fieldName(), 1.0f);
@@ -36,31 +36,23 @@ protected String getQuery(QueryBuilder queryBuilder) {
3636
return (String) matchQueryBuilder.value();
3737
}
3838

39-
private QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) {
40-
SemanticQueryBuilder semanticQueryBuilder = new SemanticQueryBuilder(indexInformation.fieldName(), getQuery(queryBuilder), false);
41-
semanticQueryBuilder.boost(queryBuilder.boost());
42-
semanticQueryBuilder.queryName(queryBuilder.queryName());
43-
return semanticQueryBuilder;
44-
}
45-
4639
@Override
4740
protected QueryBuilder buildInferenceQuery(
4841
QueryBuilder queryBuilder,
4942
InferenceIndexInformationForField indexInformation,
50-
Float fieldWeight
43+
Float fieldBoost
5144
) {
52-
QueryBuilder inferenceQuery = buildInferenceQuery(queryBuilder, indexInformation);
53-
54-
if (fieldWeight != null && fieldWeight.equals(1.0f) == false) {
55-
inferenceQuery.boost(fieldWeight);
56-
}
57-
58-
return inferenceQuery;
45+
SemanticQueryBuilder semanticQueryBuilder = new SemanticQueryBuilder(indexInformation.fieldName(), getQuery(queryBuilder), false);
46+
semanticQueryBuilder.boost(queryBuilder.boost() * fieldBoost);
47+
semanticQueryBuilder.queryName(queryBuilder.queryName());
48+
return semanticQueryBuilder;
5949
}
6050

51+
@Override
6152
protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
6253
QueryBuilder queryBuilder,
63-
InferenceIndexInformationForField indexInformation
54+
InferenceIndexInformationForField indexInformation,
55+
Float fieldBoost
6456
) {
6557
assert (queryBuilder instanceof MatchQueryBuilder);
6658
MatchQueryBuilder originalMatchQueryBuilder = (MatchQueryBuilder) queryBuilder;
@@ -76,26 +68,11 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
7668
)
7769
);
7870
boolQueryBuilder.should(createSubQueryForIndices(indexInformation.nonInferenceIndices(), matchQueryBuilder));
79-
boolQueryBuilder.boost(queryBuilder.boost());
71+
boolQueryBuilder.boost(queryBuilder.boost() * fieldBoost);
8072
boolQueryBuilder.queryName(queryBuilder.queryName());
8173
return boolQueryBuilder;
8274
}
8375

84-
@Override
85-
protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
86-
QueryBuilder queryBuilder,
87-
InferenceIndexInformationForField indexInformation,
88-
Float fieldWeight
89-
) {
90-
QueryBuilder inferenceQuery = buildCombinedInferenceAndNonInferenceQuery(queryBuilder, indexInformation);
91-
92-
if (fieldWeight != null && fieldWeight.equals(1.0f) == false) {
93-
inferenceQuery.boost(fieldWeight);
94-
}
95-
96-
return inferenceQuery;
97-
}
98-
9976
@Override
10077
public String getQueryName() {
10178
return MatchQueryBuilder.NAME;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMultiMatchQueryRewriteInterceptor.java

Lines changed: 8 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,14 @@
88
package org.elasticsearch.xpack.inference.queries;
99

1010
import org.elasticsearch.index.query.BoolQueryBuilder;
11-
import org.elasticsearch.index.query.MatchQueryBuilder;
1211
import org.elasticsearch.index.query.MultiMatchQueryBuilder;
1312
import org.elasticsearch.index.query.QueryBuilder;
1413

1514
import java.util.Map;
1615

1716
public class SemanticMultiMatchQueryRewriteInterceptor extends SemanticQueryRewriteInterceptor {
1817
@Override
19-
protected Map<String, Float> getFieldNamesWithWeights(QueryBuilder queryBuilder) {
18+
protected Map<String, Float> getFieldNamesWithBoosts(QueryBuilder queryBuilder) {
2019
assert (queryBuilder instanceof MultiMatchQueryBuilder);
2120
MultiMatchQueryBuilder multiMatchQueryBuilder = (MultiMatchQueryBuilder) queryBuilder;
2221
return multiMatchQueryBuilder.fields();
@@ -29,61 +28,23 @@ protected String getQuery(QueryBuilder queryBuilder) {
2928
return (String) multiMatchQueryBuilder.value();
3029
}
3130

32-
private QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) {
33-
SemanticQueryBuilder semanticQueryBuilder = new SemanticQueryBuilder(indexInformation.fieldName(), getQuery(queryBuilder), false);
34-
semanticQueryBuilder.boost(queryBuilder.boost());
35-
semanticQueryBuilder.queryName(queryBuilder.queryName());
36-
return semanticQueryBuilder;
37-
}
38-
3931
@Override
4032
protected QueryBuilder buildInferenceQuery(
4133
QueryBuilder queryBuilder,
4234
InferenceIndexInformationForField indexInformation,
43-
Float fieldWeight
35+
Float fieldWBoost
4436
) {
45-
QueryBuilder inferenceQuery = buildInferenceQuery(queryBuilder, indexInformation);
46-
47-
if (fieldWeight != null && fieldWeight.equals(1.0f) == false) {
48-
inferenceQuery.boost(fieldWeight);
49-
}
50-
51-
return inferenceQuery;
52-
}
53-
54-
private QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
55-
QueryBuilder queryBuilder,
56-
InferenceIndexInformationForField indexInformation
57-
) {
58-
assert (queryBuilder instanceof MultiMatchQueryBuilder);
59-
MultiMatchQueryBuilder originalMultiMatchQueryBuilder = (MultiMatchQueryBuilder) queryBuilder;
60-
61-
// Create a copy for non-inference fields with only this specific field
62-
MultiMatchQueryBuilder multiMatchQueryBuilder = createSingleFieldMultiMatch(
63-
originalMultiMatchQueryBuilder,
64-
indexInformation.fieldName()
65-
);
66-
67-
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
68-
69-
// Add semantic query for inference indices
70-
boolQueryBuilder.should(
71-
createSemanticSubQuery(indexInformation.getInferenceIndices(), indexInformation.fieldName(), getQuery(queryBuilder))
72-
);
73-
74-
// Add regular query for non-inference indices
75-
boolQueryBuilder.should(createSubQueryForIndices(indexInformation.nonInferenceIndices(), multiMatchQueryBuilder));
76-
77-
// TODO:: add boost
78-
boolQueryBuilder.queryName(queryBuilder.queryName());
79-
return boolQueryBuilder;
37+
SemanticQueryBuilder semanticQueryBuilder = new SemanticQueryBuilder(indexInformation.fieldName(), getQuery(queryBuilder), false);
38+
semanticQueryBuilder.boost(queryBuilder.boost() * fieldWBoost);
39+
semanticQueryBuilder.queryName(queryBuilder.queryName());
40+
return semanticQueryBuilder;
8041
}
8142

8243
@Override
8344
protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
8445
QueryBuilder queryBuilder,
8546
InferenceIndexInformationForField indexInformation,
86-
Float fieldWeight
47+
Float fieldBoost
8748
) {
8849
assert (queryBuilder instanceof MultiMatchQueryBuilder);
8950
MultiMatchQueryBuilder multiMatchQueryBuilder = (MultiMatchQueryBuilder) queryBuilder;
@@ -101,20 +62,7 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
10162
)
10263
);
10364

104-
105-
QueryBuilder nonSemanticFieldQuery = buildNonSemanticFieldQuery(
106-
queryBuilder,
107-
indexInformation.fieldName(),
108-
fieldWeight
109-
);
110-
// boolQueryBuilder.should(
111-
// createSubQueryForIndices(indexInformation.nonInferenceIndices(), nonSemanticFieldQuery)
112-
// );
113-
// boolQueryBuilder.should(createSubQueryForIndices(indexInformation.nonInferenceIndices(), multiMatchQueryBuilder));
114-
115-
if (fieldWeight != null && fieldWeight.equals(1.0f) == false) {
116-
boolQueryBuilder.boost(fieldWeight);
117-
}
65+
boolQueryBuilder.boost(queryBuilder.boost() * fieldBoost);
11866
boolQueryBuilder.queryName(queryBuilder.queryName());
11967
return boolQueryBuilder;
12068
}
@@ -123,31 +71,4 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
12371
public String getQueryName() {
12472
return MultiMatchQueryBuilder.NAME;
12573
}
126-
127-
/**
128-
* Create a MultiMatchQueryBuilder with only a single field for non-inference indices
129-
*/
130-
private MultiMatchQueryBuilder createSingleFieldMultiMatch(MultiMatchQueryBuilder original, String fieldName) {
131-
MultiMatchQueryBuilder singleFieldQuery = new MultiMatchQueryBuilder(original.value());
132-
133-
// Copy all properties from original query
134-
singleFieldQuery.type(original.type());
135-
singleFieldQuery.operator(original.operator());
136-
singleFieldQuery.analyzer(original.analyzer());
137-
singleFieldQuery.fuzziness(original.fuzziness());
138-
singleFieldQuery.prefixLength(original.prefixLength());
139-
singleFieldQuery.maxExpansions(original.maxExpansions());
140-
singleFieldQuery.minimumShouldMatch(original.minimumShouldMatch());
141-
singleFieldQuery.fuzzyRewrite(original.fuzzyRewrite());
142-
singleFieldQuery.tieBreaker(original.tieBreaker());
143-
singleFieldQuery.lenient(original.lenient());
144-
singleFieldQuery.zeroTermsQuery(original.zeroTermsQuery());
145-
singleFieldQuery.autoGenerateSynonymsPhraseQuery(original.autoGenerateSynonymsPhraseQuery());
146-
singleFieldQuery.fuzzyTranspositions(original.fuzzyTranspositions());
147-
148-
// Add only the specific field (without boost for now)
149-
singleFieldQuery.field(fieldName);
150-
151-
return singleFieldQuery;
152-
}
15374
}

0 commit comments

Comments
 (0)