Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.inference.queries.SemanticKnnVectorQueryRewriteInterceptor;
import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor;
import org.elasticsearch.xpack.inference.queries.SemanticMultiMatchQueryRewriteInterceptor;
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
import org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor;
import org.elasticsearch.xpack.inference.rank.random.RandomRankBuilder;
Expand Down Expand Up @@ -569,7 +570,8 @@ public List<QueryRewriteInterceptor> getQueryRewriteInterceptors() {
return List.of(
new SemanticKnnVectorQueryRewriteInterceptor(),
new SemanticMatchQueryRewriteInterceptor(),
new SemanticSparseVectorQueryRewriteInterceptor()
new SemanticSparseVectorQueryRewriteInterceptor(),
new SemanticMultiMatchQueryRewriteInterceptor()
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ public class SemanticKnnVectorQueryRewriteInterceptor extends SemanticQueryRewri
public SemanticKnnVectorQueryRewriteInterceptor() {}

@Override
protected String getFieldName(QueryBuilder queryBuilder) {
protected Map<String, Float> getFieldNamesWithBoosts(QueryBuilder queryBuilder) {
assert (queryBuilder instanceof KnnVectorQueryBuilder);
KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) queryBuilder;
return knnVectorQueryBuilder.getFieldName();
return Map.of(knnVectorQueryBuilder.getFieldName(), 1.0f);
}

@Override
Expand All @@ -48,7 +48,11 @@ protected String getQuery(QueryBuilder queryBuilder) {
}

@Override
protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) {
protected QueryBuilder buildInferenceQuery(
QueryBuilder queryBuilder,
InferenceIndexInformationForField indexInformation,
Float fieldBoost
) {
assert (queryBuilder instanceof KnnVectorQueryBuilder);
KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) queryBuilder;
Map<String, List<String>> inferenceIdsIndices = indexInformation.getInferenceIdsIndices();
Expand All @@ -63,7 +67,7 @@ protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceI
// Multiple inference IDs, construct a boolean query
finalQueryBuilder = buildInferenceQueryWithMultipleInferenceIds(knnVectorQueryBuilder, inferenceIdsIndices);
}
finalQueryBuilder.boost(queryBuilder.boost());
finalQueryBuilder.boost(queryBuilder.boost() * fieldBoost);
finalQueryBuilder.queryName(queryBuilder.queryName());
return finalQueryBuilder;
}
Expand All @@ -87,7 +91,8 @@ private QueryBuilder buildInferenceQueryWithMultipleInferenceIds(
@Override
protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
QueryBuilder queryBuilder,
InferenceIndexInformationForField indexInformation
InferenceIndexInformationForField indexInformation,
Float fieldBoost
) {
assert (queryBuilder instanceof KnnVectorQueryBuilder);
KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) queryBuilder;
Expand All @@ -106,7 +111,7 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
)
);
}
boolQueryBuilder.boost(queryBuilder.boost());
boolQueryBuilder.boost(queryBuilder.boost() * fieldBoost);
boolQueryBuilder.queryName(queryBuilder.queryName());
return boolQueryBuilder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import org.elasticsearch.index.query.MatchQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;

import java.util.Map;

public class SemanticMatchQueryRewriteInterceptor extends SemanticQueryRewriteInterceptor {

public static final NodeFeature SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature(
Expand All @@ -21,10 +23,10 @@ public class SemanticMatchQueryRewriteInterceptor extends SemanticQueryRewriteIn
public SemanticMatchQueryRewriteInterceptor() {}

@Override
protected String getFieldName(QueryBuilder queryBuilder) {
protected Map<String, Float> getFieldNamesWithBoosts(QueryBuilder queryBuilder) {
assert (queryBuilder instanceof MatchQueryBuilder);
MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder;
return matchQueryBuilder.fieldName();
return Map.of(matchQueryBuilder.fieldName(), 1.0f);
}

@Override
Expand All @@ -35,17 +37,22 @@ protected String getQuery(QueryBuilder queryBuilder) {
}

@Override
protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) {
protected QueryBuilder buildInferenceQuery(
QueryBuilder queryBuilder,
InferenceIndexInformationForField indexInformation,
Float fieldBoost
) {
SemanticQueryBuilder semanticQueryBuilder = new SemanticQueryBuilder(indexInformation.fieldName(), getQuery(queryBuilder), false);
semanticQueryBuilder.boost(queryBuilder.boost());
semanticQueryBuilder.boost(queryBuilder.boost() * fieldBoost);
semanticQueryBuilder.queryName(queryBuilder.queryName());
return semanticQueryBuilder;
}

@Override
protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
QueryBuilder queryBuilder,
InferenceIndexInformationForField indexInformation
InferenceIndexInformationForField indexInformation,
Float fieldBoost
) {
assert (queryBuilder instanceof MatchQueryBuilder);
MatchQueryBuilder originalMatchQueryBuilder = (MatchQueryBuilder) queryBuilder;
Expand All @@ -61,7 +68,7 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
)
);
boolQueryBuilder.should(createSubQueryForIndices(indexInformation.nonInferenceIndices(), matchQueryBuilder));
boolQueryBuilder.boost(queryBuilder.boost());
boolQueryBuilder.boost(queryBuilder.boost() * fieldBoost);
boolQueryBuilder.queryName(queryBuilder.queryName());
return boolQueryBuilder;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.queries;

import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.MultiMatchQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;

import java.util.Map;

public class SemanticMultiMatchQueryRewriteInterceptor extends SemanticQueryRewriteInterceptor {
@Override
protected Map<String, Float> getFieldNamesWithBoosts(QueryBuilder queryBuilder) {
assert (queryBuilder instanceof MultiMatchQueryBuilder);
MultiMatchQueryBuilder multiMatchQueryBuilder = (MultiMatchQueryBuilder) queryBuilder;
return multiMatchQueryBuilder.fields();
}

@Override
protected String getQuery(QueryBuilder queryBuilder) {
assert (queryBuilder instanceof MultiMatchQueryBuilder);
MultiMatchQueryBuilder multiMatchQueryBuilder = (MultiMatchQueryBuilder) queryBuilder;
return (String) multiMatchQueryBuilder.value();
}

@Override
protected QueryBuilder buildInferenceQuery(
QueryBuilder queryBuilder,
InferenceIndexInformationForField indexInformation,
Float fieldBoost
) {
SemanticQueryBuilder semanticQueryBuilder = new SemanticQueryBuilder(indexInformation.fieldName(), getQuery(queryBuilder), false);
semanticQueryBuilder.boost(queryBuilder.boost() * fieldBoost);
semanticQueryBuilder.queryName(queryBuilder.queryName());
return semanticQueryBuilder;
}

@Override
protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
QueryBuilder queryBuilder,
InferenceIndexInformationForField indexInformation,
Float fieldBoost
) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();

// Add the semantic part for inference indices
boolQueryBuilder.should(
createSemanticSubQuery(indexInformation.getInferenceIndices(), indexInformation.fieldName(), getQuery(queryBuilder))
);

// Add the non-semantic part for non-inference indices
boolQueryBuilder.should(
createSubQueryForIndices(
indexInformation.nonInferenceIndices(),
QueryBuilders.matchQuery(indexInformation.fieldName(), getQuery(queryBuilder))
)
);

// Apply the field boost
boolQueryBuilder.boost(queryBuilder.boost() * fieldBoost);
boolQueryBuilder.queryName(queryBuilder.queryName());

return boolQueryBuilder;
}

@Override
public String getQueryName() {
return MultiMatchQueryBuilder.NAME;
}

public static void copyMultiMatchConfiguration(MultiMatchQueryBuilder source, MultiMatchQueryBuilder target) {
target.type(source.type());
target.operator(source.operator());
if (source.analyzer() != null) {
target.analyzer(source.analyzer());
}
target.slop(source.slop());
if (source.fuzziness() != null) {
target.fuzziness(source.fuzziness());
}
target.prefixLength(source.prefixLength());
target.maxExpansions(source.maxExpansions());
if (source.minimumShouldMatch() != null) {
target.minimumShouldMatch(source.minimumShouldMatch());
}
if (source.fuzzyRewrite() != null) {
target.fuzzyRewrite(source.fuzzyRewrite());
}
if (source.tieBreaker() != null) {
target.tieBreaker(source.tieBreaker());
}
target.lenient(source.lenient());
target.zeroTermsQuery(source.zeroTermsQuery());
target.autoGenerateSynonymsPhraseQuery(source.autoGenerateSynonymsPhraseQuery());
target.fuzzyTranspositions(source.fuzzyTranspositions());
}

}
Loading