Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.inference.queries.SemanticKnnVectorQueryRewriteInterceptor;
import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor;
import org.elasticsearch.xpack.inference.queries.SemanticMultiMatchQueryRewriteInterceptor;
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
import org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor;
import org.elasticsearch.xpack.inference.rank.random.RandomRankBuilder;
Expand Down Expand Up @@ -569,7 +570,8 @@ public List<QueryRewriteInterceptor> getQueryRewriteInterceptors() {
return List.of(
new SemanticKnnVectorQueryRewriteInterceptor(),
new SemanticMatchQueryRewriteInterceptor(),
new SemanticSparseVectorQueryRewriteInterceptor()
new SemanticSparseVectorQueryRewriteInterceptor(),
new SemanticMultiMatchQueryRewriteInterceptor()
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ public class SemanticKnnVectorQueryRewriteInterceptor extends SemanticQueryRewri
public SemanticKnnVectorQueryRewriteInterceptor() {}

@Override
protected String getFieldName(QueryBuilder queryBuilder) {
protected Map<String, Float> getFieldNamesWithBoosts(QueryBuilder queryBuilder) {
assert (queryBuilder instanceof KnnVectorQueryBuilder);
KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) queryBuilder;
return knnVectorQueryBuilder.getFieldName();
return Map.of(knnVectorQueryBuilder.getFieldName(), 1.0f);
}

@Override
Expand All @@ -48,7 +48,11 @@ protected String getQuery(QueryBuilder queryBuilder) {
}

@Override
protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) {
protected QueryBuilder buildInferenceQuery(
QueryBuilder queryBuilder,
InferenceIndexInformationForField indexInformation,
Float fieldBoost
) {
assert (queryBuilder instanceof KnnVectorQueryBuilder);
KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) queryBuilder;
Map<String, List<String>> inferenceIdsIndices = indexInformation.getInferenceIdsIndices();
Expand All @@ -63,7 +67,7 @@ protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceI
// Multiple inference IDs, construct a boolean query
finalQueryBuilder = buildInferenceQueryWithMultipleInferenceIds(knnVectorQueryBuilder, inferenceIdsIndices);
}
finalQueryBuilder.boost(queryBuilder.boost());
finalQueryBuilder.boost(queryBuilder.boost() * fieldBoost);
finalQueryBuilder.queryName(queryBuilder.queryName());
return finalQueryBuilder;
}
Expand All @@ -87,7 +91,8 @@ private QueryBuilder buildInferenceQueryWithMultipleInferenceIds(
@Override
protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
QueryBuilder queryBuilder,
InferenceIndexInformationForField indexInformation
InferenceIndexInformationForField indexInformation,
Float fieldBoost
) {
assert (queryBuilder instanceof KnnVectorQueryBuilder);
KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) queryBuilder;
Expand All @@ -106,7 +111,7 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
)
);
}
boolQueryBuilder.boost(queryBuilder.boost());
boolQueryBuilder.boost(queryBuilder.boost() * fieldBoost);
boolQueryBuilder.queryName(queryBuilder.queryName());
return boolQueryBuilder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import org.elasticsearch.index.query.MatchQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;

import java.util.Map;

public class SemanticMatchQueryRewriteInterceptor extends SemanticQueryRewriteInterceptor {

public static final NodeFeature SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature(
Expand All @@ -21,10 +23,10 @@ public class SemanticMatchQueryRewriteInterceptor extends SemanticQueryRewriteIn
public SemanticMatchQueryRewriteInterceptor() {}

@Override
protected String getFieldName(QueryBuilder queryBuilder) {
protected Map<String, Float> getFieldNamesWithBoosts(QueryBuilder queryBuilder) {
assert (queryBuilder instanceof MatchQueryBuilder);
MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder;
return matchQueryBuilder.fieldName();
return Map.of(matchQueryBuilder.fieldName(), 1.0f);
}

@Override
Expand All @@ -35,17 +37,22 @@ protected String getQuery(QueryBuilder queryBuilder) {
}

@Override
protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) {
protected QueryBuilder buildInferenceQuery(
QueryBuilder queryBuilder,
InferenceIndexInformationForField indexInformation,
Float fieldBoost
) {
SemanticQueryBuilder semanticQueryBuilder = new SemanticQueryBuilder(indexInformation.fieldName(), getQuery(queryBuilder), false);
semanticQueryBuilder.boost(queryBuilder.boost());
semanticQueryBuilder.boost(queryBuilder.boost() * fieldBoost);
semanticQueryBuilder.queryName(queryBuilder.queryName());
return semanticQueryBuilder;
}

@Override
protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
QueryBuilder queryBuilder,
InferenceIndexInformationForField indexInformation
InferenceIndexInformationForField indexInformation,
Float fieldBoost
) {
assert (queryBuilder instanceof MatchQueryBuilder);
MatchQueryBuilder originalMatchQueryBuilder = (MatchQueryBuilder) queryBuilder;
Expand All @@ -61,7 +68,7 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
)
);
boolQueryBuilder.should(createSubQueryForIndices(indexInformation.nonInferenceIndices(), matchQueryBuilder));
boolQueryBuilder.boost(queryBuilder.boost());
boolQueryBuilder.boost(queryBuilder.boost() * fieldBoost);
boolQueryBuilder.queryName(queryBuilder.queryName());
return boolQueryBuilder;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.queries;

import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.MultiMatchQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;

import java.util.Map;

public class SemanticMultiMatchQueryRewriteInterceptor extends SemanticQueryRewriteInterceptor {
@Override
protected Map<String, Float> getFieldNamesWithBoosts(QueryBuilder queryBuilder) {
assert (queryBuilder instanceof MultiMatchQueryBuilder);
MultiMatchQueryBuilder multiMatchQueryBuilder = (MultiMatchQueryBuilder) queryBuilder;
return multiMatchQueryBuilder.fields();
}

@Override
protected String getQuery(QueryBuilder queryBuilder) {
assert (queryBuilder instanceof MultiMatchQueryBuilder);
MultiMatchQueryBuilder multiMatchQueryBuilder = (MultiMatchQueryBuilder) queryBuilder;
return (String) multiMatchQueryBuilder.value();
}

@Override
protected QueryBuilder buildInferenceQuery(
QueryBuilder queryBuilder,
InferenceIndexInformationForField indexInformation,
Float fieldWBoost
) {
SemanticQueryBuilder semanticQueryBuilder = new SemanticQueryBuilder(indexInformation.fieldName(), getQuery(queryBuilder), false);
semanticQueryBuilder.boost(queryBuilder.boost() * fieldWBoost);
semanticQueryBuilder.queryName(queryBuilder.queryName());
return semanticQueryBuilder;
}

@Override
protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
QueryBuilder queryBuilder,
InferenceIndexInformationForField indexInformation,
Float fieldBoost
) {
assert (queryBuilder instanceof MultiMatchQueryBuilder);
MultiMatchQueryBuilder multiMatchQueryBuilder = (MultiMatchQueryBuilder) queryBuilder;

BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.should(
createSemanticSubQuery(
indexInformation.getInferenceIndices(),
indexInformation.fieldName(),
(String) multiMatchQueryBuilder.value()
)
);

boolQueryBuilder.should(
createMatchSubQuery(
indexInformation.nonInferenceIndices(),
indexInformation.fieldName(),
(String) multiMatchQueryBuilder.value()
)
);

boolQueryBuilder.boost(queryBuilder.boost() * fieldBoost);
boolQueryBuilder.queryName(queryBuilder.queryName());
return boolQueryBuilder;
}

@Override
public String getQueryName() {
return MultiMatchQueryBuilder.NAME;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.index.mapper.IndexFieldMapper;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.MatchQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.TermsQueryBuilder;
Expand All @@ -33,14 +34,21 @@ public SemanticQueryRewriteInterceptor() {}

@Override
public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) {
String fieldName = getFieldName(queryBuilder);
Map<String, Float> fieldsWithBoosts = getFieldNamesWithBoosts(queryBuilder);
ResolvedIndices resolvedIndices = context.getResolvedIndices();

if (resolvedIndices == null) {
// No resolved indices, so return the original query.
return queryBuilder;
}

if (fieldsWithBoosts.size() > 1) {
// Multi-field query, so return the original query.
return handleMultiFieldQuery(queryBuilder, fieldsWithBoosts, resolvedIndices);
}

String fieldName = fieldsWithBoosts.keySet().iterator().next();
Float weight = fieldsWithBoosts.get(fieldName);
InferenceIndexInformationForField indexInformation = resolveIndicesForField(fieldName, resolvedIndices);
if (indexInformation.getInferenceIndices().isEmpty()) {
// No inference fields were identified, so return the original query.
Expand All @@ -49,19 +57,66 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde
// Combined case where the field name requested by this query contains both
// semantic_text and non-inference fields, so we have to combine queries per index
// containing each field type.
return buildCombinedInferenceAndNonInferenceQuery(queryBuilder, indexInformation);
return buildCombinedInferenceAndNonInferenceQuery(queryBuilder, indexInformation, weight);
} else {
// The only fields we've identified are inference fields (e.g. semantic_text),
// so rewrite the entire query to work on a semantic_text field.
return buildInferenceQuery(queryBuilder, indexInformation);
return buildInferenceQuery(queryBuilder, indexInformation, weight);
}
}

/**
* @param queryBuilder {@link QueryBuilder}
* @return The singular field name requested by the provided query builder.
* Handle multi-field queries (new logic)
*/
protected abstract String getFieldName(QueryBuilder queryBuilder);
private QueryBuilder handleMultiFieldQuery(
QueryBuilder queryBuilder,
Map<String, Float> fieldNamesWithWeights,
ResolvedIndices resolvedIndices
) {
BoolQueryBuilder finalQueryBuilder = new BoolQueryBuilder();
boolean hasAnySemanticFields = false;

for (Map.Entry<String, Float> fieldEntry : fieldNamesWithWeights.entrySet()) {
String fieldName = fieldEntry.getKey();
Float fieldWeight = fieldEntry.getValue();
InferenceIndexInformationForField indexInformation = resolveIndicesForField(fieldName, resolvedIndices);

if (indexInformation.getInferenceIndices().isEmpty()) {
// Pure non-semantic field - create individual match query
QueryBuilder nonSemanticQuery = createMatchSubQuery(
indexInformation.nonInferenceIndices(),
fieldName,
getQuery(queryBuilder)
);
finalQueryBuilder.should(nonSemanticQuery);
} else if (indexInformation.nonInferenceIndices().isEmpty() == false) {
// Mixed semantic/non-semantic field - use combined approach
QueryBuilder combinedQuery = buildCombinedInferenceAndNonInferenceQuery(queryBuilder, indexInformation, fieldWeight);
finalQueryBuilder.should(combinedQuery);
hasAnySemanticFields = true;
} else {
// Pure semantic field - create semantic query
QueryBuilder semanticQuery = buildInferenceQuery(queryBuilder, indexInformation, fieldWeight);
finalQueryBuilder.should(semanticQuery);
hasAnySemanticFields = true;
}
}

// If no semantic fields were found, return original query
if (hasAnySemanticFields == false) {
return queryBuilder;
}

return finalQueryBuilder;
}

/**
* Extracts field names and their associated boost values from the query builder.
*
* @param queryBuilder the query builder to extract field information from
* @return a map where keys are field names and values are their boost multipliers
*/
protected abstract Map<String, Float> getFieldNamesWithBoosts(QueryBuilder queryBuilder);

/**
* @param queryBuilder {@link QueryBuilder}
Expand All @@ -74,20 +129,27 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde
*
* @param queryBuilder {@link QueryBuilder}
* @param indexInformation {@link InferenceIndexInformationForField}
* @param fieldBoost per field boost value
* @return {@link QueryBuilder}
*/
protected abstract QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation);
protected abstract QueryBuilder buildInferenceQuery(
QueryBuilder queryBuilder,
InferenceIndexInformationForField indexInformation,
Float fieldBoost
);

/**
* Builds a combined inference and non-inference query,
* which separates the different queries into appropriate indices based on field type.
* @param queryBuilder {@link QueryBuilder}
* @param indexInformation {@link InferenceIndexInformationForField}
* @param fieldBoost per field boost value
* @return {@link QueryBuilder}
*/
protected abstract QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
QueryBuilder queryBuilder,
InferenceIndexInformationForField indexInformation
InferenceIndexInformationForField indexInformation,
Float fieldBoost
);

private InferenceIndexInformationForField resolveIndicesForField(String fieldName, ResolvedIndices resolvedIndices) {
Expand All @@ -107,6 +169,14 @@ private InferenceIndexInformationForField resolveIndicesForField(String fieldNam
return new InferenceIndexInformationForField(fieldName, inferenceIndicesMetadata, nonInferenceIndices);
}

protected QueryBuilder createMatchSubQuery(Collection<String> indices, String fieldName, String value) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
MatchQueryBuilder matchQuery = new MatchQueryBuilder(fieldName, value);
boolQueryBuilder.must(matchQuery);
boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices));
return boolQueryBuilder;
}

protected QueryBuilder createSubQueryForIndices(Collection<String> indices, QueryBuilder queryBuilder) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.must(queryBuilder);
Expand Down
Loading