Skip to content

Commit daf2cb4

Browse files
refactoring sparse query to adjust boost and queryname and move to copy constructor
1 parent 6db0abf commit daf2cb4

File tree

3 files changed

+14
-8
lines changed

3 files changed

+14
-8
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,8 @@ public static SparseVectorQueryBuilder from(
170170
shouldPruneTokens,
171171
tokenPruningConfig
172172
);
173-
sparseVectorQueryBuilder.boost(queryBuilder.boost());
174-
sparseVectorQueryBuilder.queryName(queryBuilder.queryName());
173+
// sparseVectorQueryBuilder.boost(queryBuilder.boost());
174+
// sparseVectorQueryBuilder.queryName(queryBuilder.queryName());
175175
return sparseVectorQueryBuilder;
176176
}
177177

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,18 @@ protected String getQuery(QueryBuilder queryBuilder) {
4343
@Override
4444
protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) {
4545
Map<String, List<String>> inferenceIdsIndices = indexInformation.getInferenceIdsIndices();
46+
QueryBuilder finalQueryBuilder;
4647
if (inferenceIdsIndices.size() == 1) {
4748
// Simple case, everything uses the same inference ID
4849
String searchInferenceId = inferenceIdsIndices.keySet().iterator().next();
49-
return buildNestedQueryFromSparseVectorQuery(queryBuilder, searchInferenceId);
50+
finalQueryBuilder = buildNestedQueryFromSparseVectorQuery(queryBuilder, searchInferenceId);
5051
} else {
5152
// Multiple inference IDs, construct a boolean query
52-
return buildInferenceQueryWithMultipleInferenceIds(queryBuilder, inferenceIdsIndices);
53+
finalQueryBuilder = buildInferenceQueryWithMultipleInferenceIds(queryBuilder, inferenceIdsIndices);
5354
}
55+
finalQueryBuilder.queryName(queryBuilder.queryName());
56+
finalQueryBuilder.boost(queryBuilder.boost());
57+
return finalQueryBuilder;
5458
}
5559

5660
private QueryBuilder buildInferenceQueryWithMultipleInferenceIds(
@@ -95,6 +99,8 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
9599
)
96100
);
97101
}
102+
boolQueryBuilder.boost(queryBuilder.boost());
103+
boolQueryBuilder.queryName(queryBuilder.queryName());
98104
return boolQueryBuilder;
99105
}
100106

@@ -114,8 +120,7 @@ private QueryBuilder buildNestedQueryFromSparseVectorQuery(QueryBuilder queryBui
114120
sparseVectorQueryBuilder.shouldPruneTokens(),
115121
sparseVectorQueryBuilder.getTokenPruningConfig()
116122
),
117-
ScoreMode.Max,
118-
queryBuilder.queryName()
123+
ScoreMode.Max
119124
);
120125
}
121126

x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticSparseVectorQueryRewriteInterceptorTests.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ public void testBoostAndQueryNameOnSparseVectorQueryRewrite() throws IOException
128128
rewritten instanceof InterceptedQueryBuilderWrapper
129129
);
130130
InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten;
131+
assertEquals(BOOST, intercepted.boost(), 1.0f);
132+
assertEquals(QUERY_NAME, intercepted.queryName());
131133
assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder);
132134
NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder;
133135
assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path());
@@ -137,8 +139,7 @@ public void testBoostAndQueryNameOnSparseVectorQueryRewrite() throws IOException
137139
assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), sparseVectorQueryBuilder.getFieldName());
138140
assertEquals(INFERENCE_ID, sparseVectorQueryBuilder.getInferenceId());
139141
assertEquals(QUERY, sparseVectorQueryBuilder.getQuery());
140-
assertEquals(BOOST, sparseVectorQueryBuilder.boost(), 0.0f);
141-
assertEquals(QUERY_NAME, sparseVectorQueryBuilder.queryName());
142+
assertEquals(BOOST, sparseVectorQueryBuilder.boost(), 1.0f);
142143
}
143144

144145
private QueryRewriteContext createQueryRewriteContext(Map<String, InferenceFieldMetadata> inferenceFields) {

0 commit comments

Comments
 (0)