Skip to content

Commit 7f649e0

Browse files
update fieldNames to accomodate per field boosting from now
1 parent f952838 commit 7f649e0

File tree

4 files changed

+17
-9
lines changed

4 files changed

+17
-9
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticKnnVectorQueryRewriteInterceptor.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ public class SemanticKnnVectorQueryRewriteInterceptor extends SemanticQueryRewri
3333
public SemanticKnnVectorQueryRewriteInterceptor() {}
3434

3535
@Override
36-
protected String getFieldName(QueryBuilder queryBuilder) {
36+
protected Map<String, Float> getFieldNamesWithWeights(QueryBuilder queryBuilder) {
3737
assert (queryBuilder instanceof KnnVectorQueryBuilder);
3838
KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) queryBuilder;
39-
return knnVectorQueryBuilder.getFieldName();
39+
return Map.of(knnVectorQueryBuilder.getFieldName(), 1.0f);
4040
}
4141

4242
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import org.elasticsearch.index.query.MatchQueryBuilder;
1313
import org.elasticsearch.index.query.QueryBuilder;
1414

15+
import java.util.Map;
16+
1517
public class SemanticMatchQueryRewriteInterceptor extends SemanticQueryRewriteInterceptor {
1618

1719
public static final NodeFeature SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature(
@@ -21,10 +23,10 @@ public class SemanticMatchQueryRewriteInterceptor extends SemanticQueryRewriteIn
2123
public SemanticMatchQueryRewriteInterceptor() {}
2224

2325
@Override
24-
protected String getFieldName(QueryBuilder queryBuilder) {
26+
protected Map<String, Float> getFieldNamesWithWeights(QueryBuilder queryBuilder) {
2527
assert (queryBuilder instanceof MatchQueryBuilder);
2628
MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder;
27-
return matchQueryBuilder.fieldName();
29+
return Map.of(matchQueryBuilder.fieldName(), 1.0f);
2830
}
2931

3032
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryRewriteInterceptor.java

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,12 @@ public SemanticQueryRewriteInterceptor() {}
3333

3434
@Override
3535
public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) {
36-
String fieldName = getFieldName(queryBuilder);
36+
Map<String, Float> fieldNamesWithWeights = getFieldNamesWithWeights(queryBuilder);
37+
if (fieldNamesWithWeights.size() > 1) {
38+
// Multi-field query, so return the original query, and an exception will be thrown eventually
39+
return queryBuilder;
40+
}
41+
String fieldName = fieldNamesWithWeights.keySet().iterator().next();
3742
ResolvedIndices resolvedIndices = context.getResolvedIndices();
3843

3944
if (resolvedIndices == null) {
@@ -59,9 +64,10 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde
5964

6065
/**
6166
* @param queryBuilder {@link QueryBuilder}
62-
* @return The singular field name requested by the provided query builder.
67+
* @return Map of field names with their weights for multi-field queries.
68+
* For single-field queries, return a map with one entry.
6369
*/
64-
protected abstract String getFieldName(QueryBuilder queryBuilder);
70+
protected abstract Map<String, Float> getFieldNamesWithWeights(QueryBuilder queryBuilder);
6571

6672
/**
6773
* @param queryBuilder {@link QueryBuilder}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@ public class SemanticSparseVectorQueryRewriteInterceptor extends SemanticQueryRe
2727
public SemanticSparseVectorQueryRewriteInterceptor() {}
2828

2929
@Override
30-
protected String getFieldName(QueryBuilder queryBuilder) {
30+
protected Map<String, Float> getFieldNamesWithWeights(QueryBuilder queryBuilder) {
3131
assert (queryBuilder instanceof SparseVectorQueryBuilder);
3232
SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) queryBuilder;
33-
return sparseVectorQueryBuilder.getFieldName();
33+
return Map.of(sparseVectorQueryBuilder.getFieldName(), 1.0f);
3434
}
3535

3636
@Override

0 commit comments

Comments
 (0)