diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 374770ad25eb1..56fb57a1446fc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -95,6 +95,7 @@ import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.queries.SemanticKnnVectorQueryRewriteInterceptor; import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor; +import org.elasticsearch.xpack.inference.queries.SemanticMultiMatchQueryRewriteInterceptor; import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; import org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor; import org.elasticsearch.xpack.inference.rank.random.RandomRankBuilder; @@ -569,7 +570,8 @@ public List getQueryRewriteInterceptors() { return List.of( new SemanticKnnVectorQueryRewriteInterceptor(), new SemanticMatchQueryRewriteInterceptor(), - new SemanticSparseVectorQueryRewriteInterceptor() + new SemanticSparseVectorQueryRewriteInterceptor(), + new SemanticMultiMatchQueryRewriteInterceptor() ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticKnnVectorQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticKnnVectorQueryRewriteInterceptor.java index b1f5c240371f8..05821effa2f3a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticKnnVectorQueryRewriteInterceptor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticKnnVectorQueryRewriteInterceptor.java @@ -33,10 +33,10 @@ public class SemanticKnnVectorQueryRewriteInterceptor extends SemanticQueryRewri public SemanticKnnVectorQueryRewriteInterceptor() {} @Override - protected String getFieldName(QueryBuilder queryBuilder) { + protected Map getFieldNamesWithBoosts(QueryBuilder queryBuilder) { assert (queryBuilder instanceof KnnVectorQueryBuilder); KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) queryBuilder; - return knnVectorQueryBuilder.getFieldName(); + return Map.of(knnVectorQueryBuilder.getFieldName(), 1.0f); } @Override @@ -48,7 +48,11 @@ protected String getQuery(QueryBuilder queryBuilder) { } @Override - protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) { + protected QueryBuilder buildInferenceQuery( + QueryBuilder queryBuilder, + InferenceIndexInformationForField indexInformation, + Float fieldBoost + ) { assert (queryBuilder instanceof KnnVectorQueryBuilder); KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) queryBuilder; Map> inferenceIdsIndices = indexInformation.getInferenceIdsIndices(); @@ -63,7 +67,7 @@ protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceI // Multiple inference IDs, construct a boolean query finalQueryBuilder = buildInferenceQueryWithMultipleInferenceIds(knnVectorQueryBuilder, inferenceIdsIndices); } - finalQueryBuilder.boost(queryBuilder.boost()); + finalQueryBuilder.boost(queryBuilder.boost() * fieldBoost); finalQueryBuilder.queryName(queryBuilder.queryName()); return finalQueryBuilder; } @@ -87,7 +91,8 @@ private QueryBuilder buildInferenceQueryWithMultipleInferenceIds( @Override protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery( QueryBuilder queryBuilder, - InferenceIndexInformationForField indexInformation + InferenceIndexInformationForField indexInformation, + Float fieldBoost ) { assert (queryBuilder instanceof KnnVectorQueryBuilder); KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) queryBuilder; @@ -106,7 +111,7 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery( ) ); } - boolQueryBuilder.boost(queryBuilder.boost()); + boolQueryBuilder.boost(queryBuilder.boost() * fieldBoost); boolQueryBuilder.queryName(queryBuilder.queryName()); return boolQueryBuilder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java index a6599afc66c3f..915234f4a5a66 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java @@ -12,6 +12,8 @@ import org.elasticsearch.index.query.MatchQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; +import java.util.Map; + public class SemanticMatchQueryRewriteInterceptor extends SemanticQueryRewriteInterceptor { public static final NodeFeature SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature( @@ -21,10 +23,10 @@ public class SemanticMatchQueryRewriteInterceptor extends SemanticQueryRewriteIn public SemanticMatchQueryRewriteInterceptor() {} @Override - protected String getFieldName(QueryBuilder queryBuilder) { + protected Map getFieldNamesWithBoosts(QueryBuilder queryBuilder) { assert (queryBuilder instanceof MatchQueryBuilder); MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder; - return matchQueryBuilder.fieldName(); + return Map.of(matchQueryBuilder.fieldName(), 1.0f); } @Override @@ -35,9 +37,13 @@ protected String getQuery(QueryBuilder queryBuilder) { } @Override - protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) { + protected QueryBuilder buildInferenceQuery( + QueryBuilder queryBuilder, + InferenceIndexInformationForField indexInformation, + Float fieldBoost + ) { SemanticQueryBuilder semanticQueryBuilder = new SemanticQueryBuilder(indexInformation.fieldName(), getQuery(queryBuilder), false); - semanticQueryBuilder.boost(queryBuilder.boost()); + semanticQueryBuilder.boost(queryBuilder.boost() * fieldBoost); semanticQueryBuilder.queryName(queryBuilder.queryName()); return semanticQueryBuilder; } @@ -45,7 +51,8 @@ protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceI @Override protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery( QueryBuilder queryBuilder, - InferenceIndexInformationForField indexInformation + InferenceIndexInformationForField indexInformation, + Float fieldBoost ) { assert (queryBuilder instanceof MatchQueryBuilder); MatchQueryBuilder originalMatchQueryBuilder = (MatchQueryBuilder) queryBuilder; @@ -61,7 +68,7 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery( ) ); boolQueryBuilder.should(createSubQueryForIndices(indexInformation.nonInferenceIndices(), matchQueryBuilder)); - boolQueryBuilder.boost(queryBuilder.boost()); + boolQueryBuilder.boost(queryBuilder.boost() * fieldBoost); boolQueryBuilder.queryName(queryBuilder.queryName()); return boolQueryBuilder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMultiMatchQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMultiMatchQueryRewriteInterceptor.java new file mode 100644 index 0000000000000..39d4486232a9c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMultiMatchQueryRewriteInterceptor.java @@ -0,0 +1,78 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.queries; + +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.MultiMatchQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; + +import java.util.Map; + +public class SemanticMultiMatchQueryRewriteInterceptor extends SemanticQueryRewriteInterceptor { + @Override + protected Map getFieldNamesWithBoosts(QueryBuilder queryBuilder) { + assert (queryBuilder instanceof MultiMatchQueryBuilder); + MultiMatchQueryBuilder multiMatchQueryBuilder = (MultiMatchQueryBuilder) queryBuilder; + return multiMatchQueryBuilder.fields(); + } + + @Override + protected String getQuery(QueryBuilder queryBuilder) { + assert (queryBuilder instanceof MultiMatchQueryBuilder); + MultiMatchQueryBuilder multiMatchQueryBuilder = (MultiMatchQueryBuilder) queryBuilder; + return (String) multiMatchQueryBuilder.value(); + } + + @Override + protected QueryBuilder buildInferenceQuery( + QueryBuilder queryBuilder, + InferenceIndexInformationForField indexInformation, + Float fieldWBoost + ) { + SemanticQueryBuilder semanticQueryBuilder = new SemanticQueryBuilder(indexInformation.fieldName(), getQuery(queryBuilder), false); + semanticQueryBuilder.boost(queryBuilder.boost() * fieldWBoost); + semanticQueryBuilder.queryName(queryBuilder.queryName()); + return semanticQueryBuilder; + } + + @Override + protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery( + QueryBuilder queryBuilder, + InferenceIndexInformationForField indexInformation, + Float fieldBoost + ) { + assert (queryBuilder instanceof MultiMatchQueryBuilder); + MultiMatchQueryBuilder multiMatchQueryBuilder = (MultiMatchQueryBuilder) queryBuilder; + + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + boolQueryBuilder.should( + createSemanticSubQuery( + indexInformation.getInferenceIndices(), + indexInformation.fieldName(), + (String) multiMatchQueryBuilder.value() + ) + ); + + boolQueryBuilder.should( + createMatchSubQuery( + indexInformation.nonInferenceIndices(), + indexInformation.fieldName(), + (String) multiMatchQueryBuilder.value() + ) + ); + + boolQueryBuilder.boost(queryBuilder.boost() * fieldBoost); + boolQueryBuilder.queryName(queryBuilder.queryName()); + return boolQueryBuilder; + } + + @Override + public String getQueryName() { + return MultiMatchQueryBuilder.NAME; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryRewriteInterceptor.java index bb76ef0be24e9..85d5c0feee4b9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryRewriteInterceptor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryRewriteInterceptor.java @@ -12,6 +12,7 @@ import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.index.mapper.IndexFieldMapper; import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.MatchQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.TermsQueryBuilder; @@ -33,7 +34,7 @@ public SemanticQueryRewriteInterceptor() {} @Override public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) { - String fieldName = getFieldName(queryBuilder); + Map fieldsWithBoosts = getFieldNamesWithBoosts(queryBuilder); ResolvedIndices resolvedIndices = context.getResolvedIndices(); if (resolvedIndices == null) { @@ -41,6 +42,13 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde return queryBuilder; } + if (fieldsWithBoosts.size() > 1) { + // Multi-field query, so return the original query. + return handleMultiFieldQuery(queryBuilder, fieldsWithBoosts, resolvedIndices); + } + + String fieldName = fieldsWithBoosts.keySet().iterator().next(); + Float weight = fieldsWithBoosts.get(fieldName); InferenceIndexInformationForField indexInformation = resolveIndicesForField(fieldName, resolvedIndices); if (indexInformation.getInferenceIndices().isEmpty()) { // No inference fields were identified, so return the original query. @@ -49,19 +57,66 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde // Combined case where the field name requested by this query contains both // semantic_text and non-inference fields, so we have to combine queries per index // containing each field type. - return buildCombinedInferenceAndNonInferenceQuery(queryBuilder, indexInformation); + return buildCombinedInferenceAndNonInferenceQuery(queryBuilder, indexInformation, weight); } else { // The only fields we've identified are inference fields (e.g. semantic_text), // so rewrite the entire query to work on a semantic_text field. - return buildInferenceQuery(queryBuilder, indexInformation); + return buildInferenceQuery(queryBuilder, indexInformation, weight); } } /** - * @param queryBuilder {@link QueryBuilder} - * @return The singular field name requested by the provided query builder. + * Handle multi-field queries (new logic) */ - protected abstract String getFieldName(QueryBuilder queryBuilder); + private QueryBuilder handleMultiFieldQuery( + QueryBuilder queryBuilder, + Map fieldNamesWithWeights, + ResolvedIndices resolvedIndices + ) { + BoolQueryBuilder finalQueryBuilder = new BoolQueryBuilder(); + boolean hasAnySemanticFields = false; + + for (Map.Entry fieldEntry : fieldNamesWithWeights.entrySet()) { + String fieldName = fieldEntry.getKey(); + Float fieldWeight = fieldEntry.getValue(); + InferenceIndexInformationForField indexInformation = resolveIndicesForField(fieldName, resolvedIndices); + + if (indexInformation.getInferenceIndices().isEmpty()) { + // Pure non-semantic field - create individual match query + QueryBuilder nonSemanticQuery = createMatchSubQuery( + indexInformation.nonInferenceIndices(), + fieldName, + getQuery(queryBuilder) + ); + finalQueryBuilder.should(nonSemanticQuery); + } else if (indexInformation.nonInferenceIndices().isEmpty() == false) { + // Mixed semantic/non-semantic field - use combined approach + QueryBuilder combinedQuery = buildCombinedInferenceAndNonInferenceQuery(queryBuilder, indexInformation, fieldWeight); + finalQueryBuilder.should(combinedQuery); + hasAnySemanticFields = true; + } else { + // Pure semantic field - create semantic query + QueryBuilder semanticQuery = buildInferenceQuery(queryBuilder, indexInformation, fieldWeight); + finalQueryBuilder.should(semanticQuery); + hasAnySemanticFields = true; + } + } + + // If no semantic fields were found, return original query + if (hasAnySemanticFields == false) { + return queryBuilder; + } + + return finalQueryBuilder; + } + + /** + * Extracts field names and their associated boost values from the query builder. + * + * @param queryBuilder the query builder to extract field information from + * @return a map where keys are field names and values are their boost multipliers + */ + protected abstract Map getFieldNamesWithBoosts(QueryBuilder queryBuilder); /** * @param queryBuilder {@link QueryBuilder} @@ -74,20 +129,27 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde * * @param queryBuilder {@link QueryBuilder} * @param indexInformation {@link InferenceIndexInformationForField} + * @param fieldBoost per field boost value * @return {@link QueryBuilder} */ - protected abstract QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation); + protected abstract QueryBuilder buildInferenceQuery( + QueryBuilder queryBuilder, + InferenceIndexInformationForField indexInformation, + Float fieldBoost + ); /** * Builds a combined inference and non-inference query, * which separates the different queries into appropriate indices based on field type. * @param queryBuilder {@link QueryBuilder} * @param indexInformation {@link InferenceIndexInformationForField} + * @param fieldBoost per field boost value * @return {@link QueryBuilder} */ protected abstract QueryBuilder buildCombinedInferenceAndNonInferenceQuery( QueryBuilder queryBuilder, - InferenceIndexInformationForField indexInformation + InferenceIndexInformationForField indexInformation, + Float fieldBoost ); private InferenceIndexInformationForField resolveIndicesForField(String fieldName, ResolvedIndices resolvedIndices) { @@ -107,6 +169,14 @@ private InferenceIndexInformationForField resolveIndicesForField(String fieldNam return new InferenceIndexInformationForField(fieldName, inferenceIndicesMetadata, nonInferenceIndices); } + protected QueryBuilder createMatchSubQuery(Collection indices, String fieldName, String value) { + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + MatchQueryBuilder matchQuery = new MatchQueryBuilder(fieldName, value); + boolQueryBuilder.must(matchQuery); + boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices)); + return boolQueryBuilder; + } + protected QueryBuilder createSubQueryForIndices(Collection indices, QueryBuilder queryBuilder) { BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); boolQueryBuilder.must(queryBuilder); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java index c85a21f10301d..8230356190a7f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java @@ -27,10 +27,10 @@ public class SemanticSparseVectorQueryRewriteInterceptor extends SemanticQueryRe public SemanticSparseVectorQueryRewriteInterceptor() {} @Override - protected String getFieldName(QueryBuilder queryBuilder) { + protected Map getFieldNamesWithBoosts(QueryBuilder queryBuilder) { assert (queryBuilder instanceof SparseVectorQueryBuilder); SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) queryBuilder; - return sparseVectorQueryBuilder.getFieldName(); + return Map.of(sparseVectorQueryBuilder.getFieldName(), 1.0f); } @Override @@ -41,7 +41,11 @@ protected String getQuery(QueryBuilder queryBuilder) { } @Override - protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) { + protected QueryBuilder buildInferenceQuery( + QueryBuilder queryBuilder, + InferenceIndexInformationForField indexInformation, + Float fieldBoost + ) { Map> inferenceIdsIndices = indexInformation.getInferenceIdsIndices(); QueryBuilder finalQueryBuilder; if (inferenceIdsIndices.size() == 1) { @@ -53,7 +57,7 @@ protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceI finalQueryBuilder = buildInferenceQueryWithMultipleInferenceIds(queryBuilder, inferenceIdsIndices); } finalQueryBuilder.queryName(queryBuilder.queryName()); - finalQueryBuilder.boost(queryBuilder.boost()); + finalQueryBuilder.boost(queryBuilder.boost() * fieldBoost); return finalQueryBuilder; } @@ -73,10 +77,10 @@ private QueryBuilder buildInferenceQueryWithMultipleInferenceIds( return boolQueryBuilder; } - @Override protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery( QueryBuilder queryBuilder, - InferenceIndexInformationForField indexInformation + InferenceIndexInformationForField indexInformation, + Float fieldBoost ) { assert (queryBuilder instanceof SparseVectorQueryBuilder); SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) queryBuilder; @@ -106,7 +110,7 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery( ) ); } - boolQueryBuilder.boost(queryBuilder.boost()); + boolQueryBuilder.boost(queryBuilder.boost() * fieldBoost); boolQueryBuilder.queryName(queryBuilder.queryName()); return boolQueryBuilder; }