diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 374770ad25eb1..56fb57a1446fc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -95,6 +95,7 @@ import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.queries.SemanticKnnVectorQueryRewriteInterceptor; import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor; +import org.elasticsearch.xpack.inference.queries.SemanticMultiMatchQueryRewriteInterceptor; import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; import org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor; import org.elasticsearch.xpack.inference.rank.random.RandomRankBuilder; @@ -569,7 +570,8 @@ public List getQueryRewriteInterceptors() { return List.of( new SemanticKnnVectorQueryRewriteInterceptor(), new SemanticMatchQueryRewriteInterceptor(), - new SemanticSparseVectorQueryRewriteInterceptor() + new SemanticSparseVectorQueryRewriteInterceptor(), + new SemanticMultiMatchQueryRewriteInterceptor() ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticKnnVectorQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticKnnVectorQueryRewriteInterceptor.java index b1f5c240371f8..05821effa2f3a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticKnnVectorQueryRewriteInterceptor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticKnnVectorQueryRewriteInterceptor.java @@ -33,10 +33,10 @@ public class SemanticKnnVectorQueryRewriteInterceptor extends SemanticQueryRewri public SemanticKnnVectorQueryRewriteInterceptor() {} @Override - protected String getFieldName(QueryBuilder queryBuilder) { + protected Map getFieldNamesWithBoosts(QueryBuilder queryBuilder) { assert (queryBuilder instanceof KnnVectorQueryBuilder); KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) queryBuilder; - return knnVectorQueryBuilder.getFieldName(); + return Map.of(knnVectorQueryBuilder.getFieldName(), 1.0f); } @Override @@ -48,7 +48,11 @@ protected String getQuery(QueryBuilder queryBuilder) { } @Override - protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) { + protected QueryBuilder buildInferenceQuery( + QueryBuilder queryBuilder, + InferenceIndexInformationForField indexInformation, + Float fieldBoost + ) { assert (queryBuilder instanceof KnnVectorQueryBuilder); KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) queryBuilder; Map> inferenceIdsIndices = indexInformation.getInferenceIdsIndices(); @@ -63,7 +67,7 @@ protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceI // Multiple inference IDs, construct a boolean query finalQueryBuilder = buildInferenceQueryWithMultipleInferenceIds(knnVectorQueryBuilder, inferenceIdsIndices); } - finalQueryBuilder.boost(queryBuilder.boost()); + finalQueryBuilder.boost(queryBuilder.boost() * fieldBoost); finalQueryBuilder.queryName(queryBuilder.queryName()); return finalQueryBuilder; } @@ -87,7 +91,8 @@ private QueryBuilder buildInferenceQueryWithMultipleInferenceIds( @Override protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery( QueryBuilder queryBuilder, - InferenceIndexInformationForField indexInformation + InferenceIndexInformationForField indexInformation, + Float fieldBoost ) { assert (queryBuilder instanceof KnnVectorQueryBuilder); KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) queryBuilder; @@ -106,7 +111,7 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery( ) ); } - boolQueryBuilder.boost(queryBuilder.boost()); + boolQueryBuilder.boost(queryBuilder.boost() * fieldBoost); boolQueryBuilder.queryName(queryBuilder.queryName()); return boolQueryBuilder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java index a6599afc66c3f..915234f4a5a66 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java @@ -12,6 +12,8 @@ import org.elasticsearch.index.query.MatchQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; +import java.util.Map; + public class SemanticMatchQueryRewriteInterceptor extends SemanticQueryRewriteInterceptor { public static final NodeFeature SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature( @@ -21,10 +23,10 @@ public class SemanticMatchQueryRewriteInterceptor extends SemanticQueryRewriteIn public SemanticMatchQueryRewriteInterceptor() {} @Override - protected String getFieldName(QueryBuilder queryBuilder) { + protected Map getFieldNamesWithBoosts(QueryBuilder queryBuilder) { assert (queryBuilder instanceof MatchQueryBuilder); MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder; - return matchQueryBuilder.fieldName(); + return Map.of(matchQueryBuilder.fieldName(), 1.0f); } @Override @@ -35,9 +37,13 @@ protected String getQuery(QueryBuilder queryBuilder) { } @Override - protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) { + protected QueryBuilder buildInferenceQuery( + QueryBuilder queryBuilder, + InferenceIndexInformationForField indexInformation, + Float fieldBoost + ) { SemanticQueryBuilder semanticQueryBuilder = new SemanticQueryBuilder(indexInformation.fieldName(), getQuery(queryBuilder), false); - semanticQueryBuilder.boost(queryBuilder.boost()); + semanticQueryBuilder.boost(queryBuilder.boost() * fieldBoost); semanticQueryBuilder.queryName(queryBuilder.queryName()); return semanticQueryBuilder; } @@ -45,7 +51,8 @@ protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceI @Override protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery( QueryBuilder queryBuilder, - InferenceIndexInformationForField indexInformation + InferenceIndexInformationForField indexInformation, + Float fieldBoost ) { assert (queryBuilder instanceof MatchQueryBuilder); MatchQueryBuilder originalMatchQueryBuilder = (MatchQueryBuilder) queryBuilder; @@ -61,7 +68,7 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery( ) ); boolQueryBuilder.should(createSubQueryForIndices(indexInformation.nonInferenceIndices(), matchQueryBuilder)); - boolQueryBuilder.boost(queryBuilder.boost()); + boolQueryBuilder.boost(queryBuilder.boost() * fieldBoost); boolQueryBuilder.queryName(queryBuilder.queryName()); return boolQueryBuilder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMultiMatchQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMultiMatchQueryRewriteInterceptor.java new file mode 100644 index 0000000000000..095e61b25dca8 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMultiMatchQueryRewriteInterceptor.java @@ -0,0 +1,104 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.queries; + +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.MultiMatchQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; + +import java.util.Map; + +public class SemanticMultiMatchQueryRewriteInterceptor extends SemanticQueryRewriteInterceptor { + @Override + protected Map getFieldNamesWithBoosts(QueryBuilder queryBuilder) { + assert (queryBuilder instanceof MultiMatchQueryBuilder); + MultiMatchQueryBuilder multiMatchQueryBuilder = (MultiMatchQueryBuilder) queryBuilder; + return multiMatchQueryBuilder.fields(); + } + + @Override + protected String getQuery(QueryBuilder queryBuilder) { + assert (queryBuilder instanceof MultiMatchQueryBuilder); + MultiMatchQueryBuilder multiMatchQueryBuilder = (MultiMatchQueryBuilder) queryBuilder; + return (String) multiMatchQueryBuilder.value(); + } + + @Override + protected QueryBuilder buildInferenceQuery( + QueryBuilder queryBuilder, + InferenceIndexInformationForField indexInformation, + Float fieldBoost + ) { + SemanticQueryBuilder semanticQueryBuilder = new SemanticQueryBuilder(indexInformation.fieldName(), getQuery(queryBuilder), false); + semanticQueryBuilder.boost(queryBuilder.boost() * fieldBoost); + semanticQueryBuilder.queryName(queryBuilder.queryName()); + return semanticQueryBuilder; + } + + @Override + protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery( + QueryBuilder queryBuilder, + InferenceIndexInformationForField indexInformation, + Float fieldBoost + ) { + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + + // Add the semantic part for inference indices + boolQueryBuilder.should( + createSemanticSubQuery(indexInformation.getInferenceIndices(), indexInformation.fieldName(), getQuery(queryBuilder)) + ); + + // Add the non-semantic part for non-inference indices + boolQueryBuilder.should( + createSubQueryForIndices( + indexInformation.nonInferenceIndices(), + QueryBuilders.matchQuery(indexInformation.fieldName(), getQuery(queryBuilder)) + ) + ); + + // Apply the field boost + boolQueryBuilder.boost(queryBuilder.boost() * fieldBoost); + boolQueryBuilder.queryName(queryBuilder.queryName()); + + return boolQueryBuilder; + } + + @Override + public String getQueryName() { + return MultiMatchQueryBuilder.NAME; + } + + public static void copyMultiMatchConfiguration(MultiMatchQueryBuilder source, MultiMatchQueryBuilder target) { + target.type(source.type()); + target.operator(source.operator()); + if (source.analyzer() != null) { + target.analyzer(source.analyzer()); + } + target.slop(source.slop()); + if (source.fuzziness() != null) { + target.fuzziness(source.fuzziness()); + } + target.prefixLength(source.prefixLength()); + target.maxExpansions(source.maxExpansions()); + if (source.minimumShouldMatch() != null) { + target.minimumShouldMatch(source.minimumShouldMatch()); + } + if (source.fuzzyRewrite() != null) { + target.fuzzyRewrite(source.fuzzyRewrite()); + } + if (source.tieBreaker() != null) { + target.tieBreaker(source.tieBreaker()); + } + target.lenient(source.lenient()); + target.zeroTermsQuery(source.zeroTermsQuery()); + target.autoGenerateSynonymsPhraseQuery(source.autoGenerateSynonymsPhraseQuery()); + target.fuzzyTranspositions(source.fuzzyTranspositions()); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryRewriteInterceptor.java index bb76ef0be24e9..d0c2a5b8d1632 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryRewriteInterceptor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryRewriteInterceptor.java @@ -12,7 +12,9 @@ import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.index.mapper.IndexFieldMapper; import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.MultiMatchQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.TermsQueryBuilder; import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor; @@ -33,7 +35,7 @@ public SemanticQueryRewriteInterceptor() {} @Override public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) { - String fieldName = getFieldName(queryBuilder); + Map fieldsWithBoosts = getFieldNamesWithBoosts(queryBuilder); ResolvedIndices resolvedIndices = context.getResolvedIndices(); if (resolvedIndices == null) { @@ -41,6 +43,13 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde return queryBuilder; } + if (queryBuilder instanceof MultiMatchQueryBuilder) { + return handleMultiFieldQuery(queryBuilder, fieldsWithBoosts, resolvedIndices); + } + + String fieldName = fieldsWithBoosts.keySet().iterator().next(); + Float fieldBoost = fieldsWithBoosts.get(fieldName); + InferenceIndexInformationForField indexInformation = resolveIndicesForField(fieldName, resolvedIndices); if (indexInformation.getInferenceIndices().isEmpty()) { // No inference fields were identified, so return the original query. @@ -49,19 +58,103 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde // Combined case where the field name requested by this query contains both // semantic_text and non-inference fields, so we have to combine queries per index // containing each field type. - return buildCombinedInferenceAndNonInferenceQuery(queryBuilder, indexInformation); + return buildCombinedInferenceAndNonInferenceQuery(queryBuilder, indexInformation, fieldBoost); } else { // The only fields we've identified are inference fields (e.g. semantic_text), // so rewrite the entire query to work on a semantic_text field. - return buildInferenceQuery(queryBuilder, indexInformation); + return buildInferenceQuery(queryBuilder, indexInformation, fieldBoost); } } /** - * @param queryBuilder {@link QueryBuilder} - * @return The singular field name requested by the provided query builder. + * Handle multi-field queries by analyzing each field for semantic_text fields + * and creating appropriate queries */ - protected abstract String getFieldName(QueryBuilder queryBuilder); + private QueryBuilder handleMultiFieldQuery( + QueryBuilder queryBuilder, + Map fieldsWithBoosts, + ResolvedIndices resolvedIndices + ) { + assert (queryBuilder instanceof MultiMatchQueryBuilder); + MultiMatchQueryBuilder multiMatchQueryBuilder = (MultiMatchQueryBuilder) queryBuilder; + String queryText = getQuery(queryBuilder); + + boolean hasSemanticField = false; + Map fieldInfoMap = new HashMap<>(); + + // Analyze each field to determine if it's semantic or not + for (Map.Entry fieldEntry : fieldsWithBoosts.entrySet()) { + String fieldName = fieldEntry.getKey(); + InferenceIndexInformationForField indexInfo = resolveIndicesForField(fieldName, resolvedIndices); + fieldInfoMap.put(fieldName, indexInfo); + + if (indexInfo.getInferenceIndices().isEmpty() == false) { + hasSemanticField = true; + } + } + + // If no semantic fields were found, return the original query + if (hasSemanticField == false) { + return queryBuilder; + } + + // Create a combined query + BoolQueryBuilder combinedQuery = new BoolQueryBuilder(); + + // Apply the MultiMatch type and tie-breaker + MultiMatchQueryBuilder.Type type = multiMatchQueryBuilder.type(); + Float tieBreaker = multiMatchQueryBuilder.tieBreaker(); + boolean shouldUseTieBreaker = (tieBreaker != null); + + Map nonSemanticFields = new HashMap<>(); + for (Map.Entry fieldEntry : fieldsWithBoosts.entrySet()) { + String fieldName = fieldEntry.getKey(); + Float fieldBoost = fieldEntry.getValue(); + InferenceIndexInformationForField indexInfo = fieldInfoMap.get(fieldName); + + if (indexInfo.getInferenceIndices().isEmpty()) { + nonSemanticFields.put(fieldName, fieldBoost); + } else if (indexInfo.nonInferenceIndices().isEmpty()) { + QueryBuilder semanticQuery = buildInferenceQuery(queryBuilder, indexInfo, fieldBoost); + if (shouldUseTieBreaker) { + semanticQuery.boost(semanticQuery.boost() * (type == MultiMatchQueryBuilder.Type.MOST_FIELDS ? 1.0f : tieBreaker)); + } + combinedQuery.should(semanticQuery); + } else { + QueryBuilder mixedQuery = buildCombinedInferenceAndNonInferenceQuery(queryBuilder, indexInfo, fieldBoost); + if (shouldUseTieBreaker) { + mixedQuery.boost(mixedQuery.boost() * (type == MultiMatchQueryBuilder.Type.MOST_FIELDS ? 1.0f : tieBreaker)); + } + combinedQuery.should(mixedQuery); + } + } + + if (nonSemanticFields.isEmpty() == false) { + MultiMatchQueryBuilder nonSemanticQuery = QueryBuilders.multiMatchQuery(queryText); + nonSemanticQuery.fields(nonSemanticFields); + + SemanticMultiMatchQueryRewriteInterceptor.copyMultiMatchConfiguration(multiMatchQueryBuilder, nonSemanticQuery); + + if (shouldUseTieBreaker) { + nonSemanticQuery.boost(nonSemanticQuery.boost() * (type == MultiMatchQueryBuilder.Type.MOST_FIELDS ? 1.0f : tieBreaker)); + } + + combinedQuery.should(nonSemanticQuery); + } + + combinedQuery.boost(queryBuilder.boost()); + combinedQuery.queryName(queryBuilder.queryName()); + + return combinedQuery; + } + + /** + * Extracts field names and their associated boost values from the query builder. + * + * @param queryBuilder the query builder to extract field information from + * @return a map where keys are field names and values are their boost multipliers + */ + protected abstract Map getFieldNamesWithBoosts(QueryBuilder queryBuilder); /** * @param queryBuilder {@link QueryBuilder} @@ -74,20 +167,27 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde * * @param queryBuilder {@link QueryBuilder} * @param indexInformation {@link InferenceIndexInformationForField} + * @param fieldBoost per field boost value * @return {@link QueryBuilder} */ - protected abstract QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation); + protected abstract QueryBuilder buildInferenceQuery( + QueryBuilder queryBuilder, + InferenceIndexInformationForField indexInformation, + Float fieldBoost + ); /** * Builds a combined inference and non-inference query, * which separates the different queries into appropriate indices based on field type. * @param queryBuilder {@link QueryBuilder} * @param indexInformation {@link InferenceIndexInformationForField} + * @param fieldBoost per field boost value * @return {@link QueryBuilder} */ protected abstract QueryBuilder buildCombinedInferenceAndNonInferenceQuery( QueryBuilder queryBuilder, - InferenceIndexInformationForField indexInformation + InferenceIndexInformationForField indexInformation, + Float fieldBoost ); private InferenceIndexInformationForField resolveIndicesForField(String fieldName, ResolvedIndices resolvedIndices) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java index c85a21f10301d..228a07769c520 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java @@ -27,10 +27,10 @@ public class SemanticSparseVectorQueryRewriteInterceptor extends SemanticQueryRe public SemanticSparseVectorQueryRewriteInterceptor() {} @Override - protected String getFieldName(QueryBuilder queryBuilder) { + protected Map getFieldNamesWithBoosts(QueryBuilder queryBuilder) { assert (queryBuilder instanceof SparseVectorQueryBuilder); SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) queryBuilder; - return sparseVectorQueryBuilder.getFieldName(); + return Map.of(sparseVectorQueryBuilder.getFieldName(), 1.0f); } @Override @@ -41,7 +41,11 @@ protected String getQuery(QueryBuilder queryBuilder) { } @Override - protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) { + protected QueryBuilder buildInferenceQuery( + QueryBuilder queryBuilder, + InferenceIndexInformationForField indexInformation, + Float fieldBoost + ) { Map> inferenceIdsIndices = indexInformation.getInferenceIdsIndices(); QueryBuilder finalQueryBuilder; if (inferenceIdsIndices.size() == 1) { @@ -53,7 +57,7 @@ protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceI finalQueryBuilder = buildInferenceQueryWithMultipleInferenceIds(queryBuilder, inferenceIdsIndices); } finalQueryBuilder.queryName(queryBuilder.queryName()); - finalQueryBuilder.boost(queryBuilder.boost()); + finalQueryBuilder.boost(queryBuilder.boost() * fieldBoost); return finalQueryBuilder; } @@ -76,7 +80,8 @@ private QueryBuilder buildInferenceQueryWithMultipleInferenceIds( @Override protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery( QueryBuilder queryBuilder, - InferenceIndexInformationForField indexInformation + InferenceIndexInformationForField indexInformation, + Float fieldBoost ) { assert (queryBuilder instanceof SparseVectorQueryBuilder); SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) queryBuilder; @@ -106,7 +111,7 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery( ) ); } - boolQueryBuilder.boost(queryBuilder.boost()); + boolQueryBuilder.boost(queryBuilder.boost() * fieldBoost); boolQueryBuilder.queryName(queryBuilder.queryName()); return boolQueryBuilder; }