diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index f5283510bd1c9..d8d14d10b247f 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -216,6 +216,7 @@ static TransportVersion def(int id) { public static final TransportVersion INITIAL_ELASTICSEARCH_8_19_1 = def(8_841_0_65); public static final TransportVersion INITIAL_ELASTICSEARCH_8_19_2 = def(8_841_0_66); public static final TransportVersion INITIAL_ELASTICSEARCH_8_19_3 = def(8_841_0_67); + public static final TransportVersion MULTI_MATCH_SEMANTIC_TEXT_SUPPORT_8_19 = def(8_841_0_68); public static final TransportVersion V_9_0_0 = def(9_000_0_09); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11); @@ -365,6 +366,7 @@ static TransportVersion def(int id) { public static final TransportVersion SIMULATE_INGEST_MAPPING_MERGE_TYPE = def(9_138_0_00); public static final TransportVersion ESQL_LOOKUP_JOIN_ON_MANY_FIELDS = def(9_139_0_00); public static final TransportVersion SIMULATE_INGEST_EFFECTIVE_MAPPING = def(9_140_0_00); + public static final TransportVersion MULTI_MATCH_SEMANTIC_TEXT_SUPPORT = def(9_141_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/index/query/MultiMatchQueryBuilder.java b/server/src/main/java/org/elasticsearch/index/query/MultiMatchQueryBuilder.java index cfd2fdcda853c..f7c721af82887 100644 --- a/server/src/main/java/org/elasticsearch/index/query/MultiMatchQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/MultiMatchQueryBuilder.java @@ -52,6 +52,7 @@ public final class MultiMatchQueryBuilder extends AbstractQueryBuilder getTestFeatures() { SemanticTextFieldMapper.SEMANTIC_TEXT_SKIP_INFERENCE_FIELDS, SEMANTIC_TEXT_HIGHLIGHTER, SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED, + SEMANTIC_MULTI_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED, SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED, SemanticInferenceMetadataFieldsMapper.EXPLICIT_NULL_FIXES, SEMANTIC_KNN_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index c3ae4f0d9d6d6..63f061357b944 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -95,6 +95,7 @@ import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.queries.SemanticKnnVectorQueryRewriteInterceptor; import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor; +import org.elasticsearch.xpack.inference.queries.SemanticMultiMatchQueryRewriteInterceptor; import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; import org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor; import org.elasticsearch.xpack.inference.rank.random.RandomRankBuilder; @@ -581,6 +582,7 @@ public List getQueryRewriteInterceptors() { return List.of( new SemanticKnnVectorQueryRewriteInterceptor(), new SemanticMatchQueryRewriteInterceptor(), + new SemanticMultiMatchQueryRewriteInterceptor(), new SemanticSparseVectorQueryRewriteInterceptor() ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticKnnVectorQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticKnnVectorQueryRewriteInterceptor.java index b1f5c240371f8..12f1d9fcc1a7e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticKnnVectorQueryRewriteInterceptor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticKnnVectorQueryRewriteInterceptor.java @@ -47,6 +47,11 @@ protected String getQuery(QueryBuilder queryBuilder) { return queryVectorBuilder != null ? queryVectorBuilder.getModelText() : null; } + @Override + protected boolean shouldResolveInferenceFieldWildcards(QueryBuilder queryBuilder) { + return false; + } + @Override protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) { assert (queryBuilder instanceof KnnVectorQueryBuilder); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java index a6599afc66c3f..0534da8f98eff 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java @@ -34,9 +34,14 @@ protected String getQuery(QueryBuilder queryBuilder) { return (String) matchQueryBuilder.value(); } + @Override + protected boolean shouldResolveInferenceFieldWildcards(QueryBuilder queryBuilder) { + return false; + } + @Override protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) { - SemanticQueryBuilder semanticQueryBuilder = new SemanticQueryBuilder(indexInformation.fieldName(), getQuery(queryBuilder), false); + SemanticQueryBuilder semanticQueryBuilder = new SemanticQueryBuilder(getFieldName(queryBuilder), getQuery(queryBuilder), false); semanticQueryBuilder.boost(queryBuilder.boost()); semanticQueryBuilder.queryName(queryBuilder.queryName()); return semanticQueryBuilder; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMultiMatchQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMultiMatchQueryRewriteInterceptor.java new file mode 100644 index 0000000000000..6e06f9af82ae3 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMultiMatchQueryRewriteInterceptor.java @@ -0,0 +1,380 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.queries; + +import org.elasticsearch.action.ResolvedIndices; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.common.regex.Regex; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.features.NodeFeature; +import org.elasticsearch.index.query.DisMaxQueryBuilder; +import org.elasticsearch.index.query.MultiMatchQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.index.search.QueryParserHelper; + +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.elasticsearch.index.IndexSettings.DEFAULT_FIELD_SETTING; + +public class SemanticMultiMatchQueryRewriteInterceptor extends SemanticQueryRewriteInterceptor { + + public static final NodeFeature SEMANTIC_MULTI_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature( + "search.semantic_multi_match_query_rewrite_interception_supported" + ); + + public SemanticMultiMatchQueryRewriteInterceptor() {} + + @Override + protected String getFieldName(QueryBuilder queryBuilder) { + assert (queryBuilder instanceof MultiMatchQueryBuilder); + MultiMatchQueryBuilder multiMatchQuery = (MultiMatchQueryBuilder) queryBuilder; + Map fields = multiMatchQuery.fields(); + if (fields.size() > 1) { + throw new IllegalArgumentException("getFieldName() called on MultiMatchQuery with multiple fields"); + } + return fields.keySet().iterator().next(); + } + + @Override + protected Map getFieldsWithWeights(QueryBuilder queryBuilder) { + assert (queryBuilder instanceof MultiMatchQueryBuilder); + MultiMatchQueryBuilder multiMatchQuery = (MultiMatchQueryBuilder) queryBuilder; + return multiMatchQuery.fields(); + } + + @Override + protected String getQuery(QueryBuilder queryBuilder) { + assert (queryBuilder instanceof MultiMatchQueryBuilder); + MultiMatchQueryBuilder multiMatchQuery = (MultiMatchQueryBuilder) queryBuilder; + return (String) multiMatchQuery.value(); + } + + @Override + protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) { + assert (queryBuilder instanceof MultiMatchQueryBuilder); + MultiMatchQueryBuilder originalQuery = (MultiMatchQueryBuilder) queryBuilder; + String queryValue = getQuery(queryBuilder); + + validateQueryTypeSupported(originalQuery.type()); + Set inferenceFields = indexInformation.getAllInferenceFields(); + + if (inferenceFields.size() == 1) { + String fieldName = inferenceFields.iterator().next(); + SemanticQueryBuilder semanticQuery = new SemanticQueryBuilder(fieldName, queryValue, false); + + // Apply top-level query boost with per field and name + semanticQuery.boost(indexInformation.getFieldBoost(fieldName) * originalQuery.boost()); + semanticQuery.queryName(originalQuery.queryName()); + return semanticQuery; + } else { + return buildMultiFieldSemanticQuery(originalQuery, inferenceFields, queryValue, indexInformation); + } + } + + @Override + protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery( + QueryBuilder queryBuilder, + InferenceIndexInformationForField indexInformation + ) { + assert (queryBuilder instanceof MultiMatchQueryBuilder); + MultiMatchQueryBuilder originalQuery = (MultiMatchQueryBuilder) queryBuilder; + String queryValue = getQuery(queryBuilder); + + validateQueryTypeSupported(originalQuery.type()); + + return switch (originalQuery.type()) { + case BEST_FIELDS, MOST_FIELDS -> buildCombinedQuery(originalQuery, indexInformation, queryValue); + default -> throw new IllegalArgumentException("Unsupported query type [" + originalQuery.type() + "] for semantic_text fields"); + }; + } + + @Override + public String getQueryName() { + return MultiMatchQueryBuilder.NAME; + } + + @Override + public boolean shouldResolveInferenceFieldWildcards(QueryBuilder queryBuilder) { + assert (queryBuilder instanceof MultiMatchQueryBuilder); + MultiMatchQueryBuilder multiMatchQuery = (MultiMatchQueryBuilder) queryBuilder; + return multiMatchQuery.resolveInferenceFieldWildcards(); + } + + @Override + protected boolean shouldUseDefaultFields() { + return true; + } + + @Override + protected InferenceIndexInformationForField resolveIndicesForFields( + QueryBuilder queryBuilder, + ResolvedIndices resolvedIndices, + boolean resolveInferenceFieldWildcards + ) { + assert (queryBuilder instanceof MultiMatchQueryBuilder); + MultiMatchQueryBuilder multiMatchQuery = (MultiMatchQueryBuilder) queryBuilder; + + // If wildcard resolution is disabled, use the simple parent class implementation + if (resolveInferenceFieldWildcards == false || multiMatchQuery.resolveInferenceFieldWildcards() == false) { + return super.resolveIndicesForFields(queryBuilder, resolvedIndices, resolveInferenceFieldWildcards); + } + + return resolveIndicesForFieldsWithWildcards(queryBuilder, resolvedIndices); + } + + private InferenceIndexInformationForField resolveIndicesForFieldsWithWildcards( + QueryBuilder queryBuilder, + ResolvedIndices resolvedIndices + ) { + Map fieldsWithWeights = getFieldsWithWeights(queryBuilder); + Collection indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values(); + + // Global wildcard resolution for inference fields + Map globalInferenceFieldBoosts = new HashMap<>(); + + // Get all unique inference fields across all indices + Set allInferenceFields = indexMetadataCollection.stream() + .flatMap(idx -> idx.getInferenceFields().keySet().stream()) + .collect(Collectors.toSet()); + + // Calculate boost for each inference field based on matching patterns + for (String inferenceField : allInferenceFields) { + for (Map.Entry entry : fieldsWithWeights.entrySet()) { + String pattern = entry.getKey(); + Float boost = entry.getValue(); + + if (Regex.isMatchAllPattern(pattern) + || (Regex.isSimpleMatchPattern(pattern) && Regex.simpleMatch(pattern, inferenceField)) + || pattern.equals(inferenceField)) { + addToFieldBoostsMap(globalInferenceFieldBoosts, inferenceField, boost); + } + } + } + + // Per-index processing using pre-calculated global boosts + Map> inferenceFieldsPerIndex = new HashMap<>(); + Map> nonInferenceFieldsPerIndex = new HashMap<>(); + Map allFieldBoosts = new HashMap<>(globalInferenceFieldBoosts); + + for (IndexMetadata indexMetadata : indexMetadataCollection) { + String indexName = indexMetadata.getIndex().getName(); + Map indexInferenceFields = new HashMap<>(); + Map indexInferenceMetadata = indexMetadata.getInferenceFields(); + + // Handle default fields per index when no fields are specified + Map fieldsToProcess = fieldsWithWeights; + if (fieldsToProcess.isEmpty() && shouldUseDefaultFields()) { + Settings settings = indexMetadata.getSettings(); + List defaultFields = settings.getAsList(DEFAULT_FIELD_SETTING.getKey(), DEFAULT_FIELD_SETTING.getDefault(settings)); + fieldsToProcess = QueryParserHelper.parseFieldsAndWeights(defaultFields); + } + + // Collect resolved inference fields for this index + Set resolvedInferenceFields = new HashSet<>(); + + // Add wildcard-resolved inference fields that exist in this index + for (String inferenceField : globalInferenceFieldBoosts.keySet()) { + if (indexInferenceMetadata.containsKey(inferenceField)) { + indexInferenceFields.put(inferenceField, indexInferenceMetadata.get(inferenceField)); + resolvedInferenceFields.add(inferenceField); + } + } + + // Always handle explicit inference fields (both wildcard and non-wildcard cases) + for (Map.Entry entry : fieldsToProcess.entrySet()) { + String field = entry.getKey(); + Float boost = entry.getValue(); + + if (indexInferenceMetadata.containsKey(field)) { + indexInferenceFields.put(field, indexInferenceMetadata.get(field)); + resolvedInferenceFields.add(field); + addToFieldBoostsMap(allFieldBoosts, field, boost); + } + } + + // Non-inference fields: all patterns minus resolved inference fields + Set indexNonInferenceFields = new HashSet<>(fieldsToProcess.keySet()); + indexNonInferenceFields.removeAll(resolvedInferenceFields); + + // Store boosts for non-inference field patterns + for (String nonInferenceField : indexNonInferenceFields) { + addToFieldBoostsMap(allFieldBoosts, nonInferenceField, fieldsToProcess.get(nonInferenceField)); + } + + if (indexInferenceFields.isEmpty() == false) { + inferenceFieldsPerIndex.put(indexName, indexInferenceFields); + } + + if (indexNonInferenceFields.isEmpty() == false) { + nonInferenceFieldsPerIndex.put(indexName, indexNonInferenceFields); + } + } + + return new InferenceIndexInformationForField(inferenceFieldsPerIndex, nonInferenceFieldsPerIndex, allFieldBoosts); + } + + private static void addToFieldBoostsMap(Map fieldBoosts, String field, Float boost) { + fieldBoosts.compute(field, (k, v) -> v == null ? boost : v * boost); + } + + private QueryBuilder buildMultiFieldSemanticQuery( + MultiMatchQueryBuilder originalQuery, + Set inferenceFields, + String queryValue, + InferenceIndexInformationForField indexInformation + ) { + return switch (originalQuery.type()) { + case BEST_FIELDS, MOST_FIELDS -> buildSemanticQuery(originalQuery, indexInformation, inferenceFields, queryValue); + default -> throw new IllegalArgumentException("Unsupported query type [" + originalQuery.type() + "] for semantic_text fields"); + }; + } + + /** + * Validates that the multi_match query type is supported for semantic_text fields. + * Throws IllegalArgumentException for unsupported types. + */ + private void validateQueryTypeSupported(MultiMatchQueryBuilder.Type queryType) { + switch (queryType) { + case CROSS_FIELDS: + throw new IllegalArgumentException( + "multi_match query with type [cross_fields] is not supported for semantic_text fields. " + + "Use [best_fields] or [most_fields] instead." + ); + case PHRASE: + throw new IllegalArgumentException( + "multi_match query with type [phrase] is not supported for semantic_text fields. " + "Use [best_fields] instead." + ); + case PHRASE_PREFIX: + throw new IllegalArgumentException( + "multi_match query with type [phrase_prefix] is not supported for semantic_text fields. " + "Use [best_fields] instead." + ); + case BOOL_PREFIX: + throw new IllegalArgumentException( + "multi_match query with type [bool_prefix] is not supported for semantic_text fields. " + + "Use [best_fields] or [most_fields] instead." + ); + } + } + + /** + * Creates a semantic query with field boost applied, supporting wildcard-resolved boosts. + */ + private SemanticQueryBuilder createSemanticQuery(String fieldName, String queryValue, InferenceIndexInformationForField inferenceInfo) { + SemanticQueryBuilder semanticQuery = new SemanticQueryBuilder(fieldName, queryValue, false); + semanticQuery.boost(inferenceInfo.getFieldBoost(fieldName)); + return semanticQuery; + } + + /** + * Builds a semantic query for multiple fields using Dismax. + */ + private QueryBuilder buildSemanticQuery( + MultiMatchQueryBuilder originalQuery, + InferenceIndexInformationForField indexInformation, + Set inferenceFields, + String queryValue + ) { + DisMaxQueryBuilder disMaxQuery = QueryBuilders.disMaxQuery(); + for (String fieldName : inferenceFields) { + disMaxQuery.add(createSemanticQuery(fieldName, queryValue, indexInformation)); + } + + // Apply tie_breaker - use explicit value or fall back to type's default + Float tieBreaker = originalQuery.tieBreaker(); + disMaxQuery.tieBreaker(Objects.requireNonNullElseGet(tieBreaker, () -> originalQuery.type().tieBreaker())); + disMaxQuery.boost(originalQuery.boost()); + disMaxQuery.queryName(originalQuery.queryName()); + return disMaxQuery; + } + + /** + * Builds a combined query for both inference and non-inference fields. + */ + private QueryBuilder buildCombinedQuery( + MultiMatchQueryBuilder originalQuery, + InferenceIndexInformationForField inferenceInfo, + String queryValue + ) { + DisMaxQueryBuilder disMaxQuery = QueryBuilders.disMaxQuery(); + + // Add semantic queries + for (String fieldName : inferenceInfo.getAllInferenceFields()) { + Set semanticIndices = inferenceInfo.getInferenceIndicesForField(fieldName); + if (semanticIndices.isEmpty() == false) { + disMaxQuery.add( + createSemanticSubQuery(semanticIndices, fieldName, queryValue).boost(inferenceInfo.getFieldBoost(fieldName)) + ); + } + } + + // Add non-inference queries + addNonInferenceQueries(disMaxQuery::add, originalQuery, inferenceInfo); + + // Apply tie_breaker - use explicit value or fall back to type's default + Float tieBreaker = originalQuery.tieBreaker(); + disMaxQuery.tieBreaker(Objects.requireNonNullElseGet(tieBreaker, () -> originalQuery.type().tieBreaker())); + disMaxQuery.boost(originalQuery.boost()); + disMaxQuery.queryName(originalQuery.queryName()); + return disMaxQuery; + } + + private void addNonInferenceQueries( + java.util.function.Consumer addQuery, + MultiMatchQueryBuilder originalQuery, + InferenceIndexInformationForField inferenceInfo + ) { + for (Map.Entry> entry : inferenceInfo.nonInferenceFieldsPerIndex().entrySet()) { + String indexName = entry.getKey(); + Set indexFieldNames = entry.getValue(); + + Map indexFields = new HashMap<>(); + for (String fieldName : indexFieldNames) { + indexFields.put(fieldName, inferenceInfo.getFieldBoost(fieldName)); + } + + MultiMatchQueryBuilder indexQuery = new MultiMatchQueryBuilder(originalQuery.value()); + indexQuery.fields(indexFields); + copyQueryProperties(originalQuery, indexQuery); + + addQuery.accept(createSubQueryForIndices(List.of(indexName), indexQuery)); + } + } + + /** + * Copies all properties from original query to target query except fields. + */ + private void copyQueryProperties(MultiMatchQueryBuilder original, MultiMatchQueryBuilder target) { + target.type(original.type()); + target.operator(original.operator()); + target.slop(original.slop()); + target.analyzer(original.analyzer()); + target.minimumShouldMatch(original.minimumShouldMatch()); + target.fuzzyRewrite(original.fuzzyRewrite()); + target.prefixLength(original.prefixLength()); + target.maxExpansions(original.maxExpansions()); + target.fuzzyTranspositions(original.fuzzyTranspositions()); + target.lenient(original.lenient()); + target.zeroTermsQuery(original.zeroTermsQuery()); + target.autoGenerateSynonymsPhraseQuery(original.autoGenerateSynonymsPhraseQuery()); + target.tieBreaker(original.tieBreaker()); + target.resolveInferenceFieldWildcards(original.resolveInferenceFieldWildcards()); + + if (original.fuzziness() != null) { + target.fuzziness(original.fuzziness()); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryRewriteInterceptor.java index bb76ef0be24e9..b335f85e9897d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryRewriteInterceptor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryRewriteInterceptor.java @@ -10,20 +10,27 @@ import org.elasticsearch.action.ResolvedIndices; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.mapper.IndexFieldMapper; +import org.elasticsearch.index.query.AbstractQueryBuilder; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.TermsQueryBuilder; +import org.elasticsearch.index.search.QueryParserHelper; import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.stream.Collectors; +import static org.elasticsearch.index.IndexSettings.DEFAULT_FIELD_SETTING; + /** * Intercepts and adapts a query to be rewritten to work seamlessly on a semantic_text field. */ @@ -33,7 +40,6 @@ public SemanticQueryRewriteInterceptor() {} @Override public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) { - String fieldName = getFieldName(queryBuilder); ResolvedIndices resolvedIndices = context.getResolvedIndices(); if (resolvedIndices == null) { @@ -41,18 +47,23 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde return queryBuilder; } - InferenceIndexInformationForField indexInformation = resolveIndicesForField(fieldName, resolvedIndices); - if (indexInformation.getInferenceIndices().isEmpty()) { + boolean resolveInferenceFieldWildcards = shouldResolveInferenceFieldWildcards(queryBuilder); + InferenceIndexInformationForField indexInformation = resolveIndicesForFields( + queryBuilder, + resolvedIndices, + resolveInferenceFieldWildcards + ); + if (indexInformation.hasInferenceFields() == false) { // No inference fields were identified, so return the original query. return queryBuilder; - } else if (indexInformation.nonInferenceIndices().isEmpty() == false) { - // Combined case where the field name requested by this query contains both + } else if (indexInformation.hasNonInferenceFields()) { + // Combined case where the field name(s) requested by this query contain both // semantic_text and non-inference fields, so we have to combine queries per index // containing each field type. return buildCombinedInferenceAndNonInferenceQuery(queryBuilder, indexInformation); } else { // The only fields we've identified are inference fields (e.g. semantic_text), - // so rewrite the entire query to work on a semantic_text field. + // so rewrite the entire query to work on semantic_text field(s). return buildInferenceQuery(queryBuilder, indexInformation); } } @@ -63,12 +74,42 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde */ protected abstract String getFieldName(QueryBuilder queryBuilder); + /** + * @param queryBuilder {@link QueryBuilder} + * @return The field names with their weights requested by the provided query builder. + */ + protected Map getFieldsWithWeights(QueryBuilder queryBuilder) { + // Default implementation for single-field queries + String fieldName = getFieldName(queryBuilder); + return Map.of(fieldName, 1.0f); + } + /** * @param queryBuilder {@link QueryBuilder} * @return The text/query string requested by the provided query builder. */ protected abstract String getQuery(QueryBuilder queryBuilder); + /** + * Determines if inference field wildcards should be resolved. + * This is typically used to expand wildcard queries to all inference fields. + * + * @param queryBuilder {@link QueryBuilder} + * @return true if inference field wildcards should be resolved, false otherwise. + */ + protected abstract boolean shouldResolveInferenceFieldWildcards(QueryBuilder queryBuilder); + + /** + * Determines if this query type should use default fields when no fields are specified. + * This is typically only needed for multi_match queries. + * Default implementation returns false for most query types. + * + * @return true if default fields should be used when no fields are specified, false otherwise. + */ + protected boolean shouldUseDefaultFields() { + return false; + } + /** * Builds the inference query * @@ -90,21 +131,70 @@ protected abstract QueryBuilder buildCombinedInferenceAndNonInferenceQuery( InferenceIndexInformationForField indexInformation ); - private InferenceIndexInformationForField resolveIndicesForField(String fieldName, ResolvedIndices resolvedIndices) { + private static void addToFieldBoostsMap(Map fieldBoosts, String field, Float boost) { + fieldBoosts.compute(field, (k, v) -> v == null ? boost : v * boost); + } + + protected InferenceIndexInformationForField resolveIndicesForFields( + QueryBuilder queryBuilder, + ResolvedIndices resolvedIndices, + boolean resolveInferenceFieldWildcards + ) { + Map fieldsWithWeights = getFieldsWithWeights(queryBuilder); Collection indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values(); - Map inferenceIndicesMetadata = new HashMap<>(); - List nonInferenceIndices = new ArrayList<>(); + + // Simple implementation: only handle explicit inference fields (no wildcards) + Map> inferenceFieldsPerIndex = new HashMap<>(); + Map> nonInferenceFieldsPerIndex = new HashMap<>(); + Map allFieldBoosts = new HashMap<>(); + for (IndexMetadata indexMetadata : indexMetadataCollection) { String indexName = indexMetadata.getIndex().getName(); - InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get(fieldName); - if (inferenceFieldMetadata != null) { - inferenceIndicesMetadata.put(indexName, inferenceFieldMetadata); - } else { - nonInferenceIndices.add(indexName); + Map indexInferenceFields = new HashMap<>(); + Map indexInferenceMetadata = indexMetadata.getInferenceFields(); + + // Handle default fields per index when no fields are specified (only for multi_match queries) + Map fieldsToProcess = fieldsWithWeights; + if (fieldsToProcess.isEmpty() && shouldUseDefaultFields()) { + Settings settings = indexMetadata.getSettings(); + List defaultFields = settings.getAsList(DEFAULT_FIELD_SETTING.getKey(), DEFAULT_FIELD_SETTING.getDefault(settings)); + fieldsToProcess = QueryParserHelper.parseFieldsAndWeights(defaultFields); + } + + // Collect resolved inference fields for this index + Set resolvedInferenceFields = new HashSet<>(); + + // Handle explicit inference fields only + for (Map.Entry entry : fieldsToProcess.entrySet()) { + String field = entry.getKey(); + Float boost = entry.getValue(); + + if (indexInferenceMetadata.containsKey(field)) { + indexInferenceFields.put(field, indexInferenceMetadata.get(field)); + resolvedInferenceFields.add(field); + addToFieldBoostsMap(allFieldBoosts, field, boost); + } + } + + // Non-inference fields + Set indexNonInferenceFields = new HashSet<>(fieldsToProcess.keySet()); + indexNonInferenceFields.removeAll(resolvedInferenceFields); + + // Store boosts for non-inference field patterns + for (String nonInferenceField : indexNonInferenceFields) { + addToFieldBoostsMap(allFieldBoosts, nonInferenceField, fieldsToProcess.get(nonInferenceField)); + } + + if (indexInferenceFields.isEmpty() == false) { + inferenceFieldsPerIndex.put(indexName, indexInferenceFields); + } + + if (indexNonInferenceFields.isEmpty() == false) { + nonInferenceFieldsPerIndex.put(indexName, indexNonInferenceFields); } } - return new InferenceIndexInformationForField(fieldName, inferenceIndicesMetadata, nonInferenceIndices); + return new InferenceIndexInformationForField(inferenceFieldsPerIndex, nonInferenceFieldsPerIndex, allFieldBoosts); } protected QueryBuilder createSubQueryForIndices(Collection indices, QueryBuilder queryBuilder) { @@ -122,27 +212,69 @@ protected QueryBuilder createSemanticSubQuery(Collection indices, String } /** - * Represents the indices and associated inference information for a field. + * Represents the indices and associated inference information for fields. */ public record InferenceIndexInformationForField( - String fieldName, - Map inferenceIndicesMetadata, - List nonInferenceIndices + // Map: IndexName -> (FieldName -> InferenceFieldMetadata) + Map> inferenceFieldsPerIndex, + // Map: IndexName -> Set - non-inference fields per index (boosts stored in fieldBoosts) + Map> nonInferenceFieldsPerIndex, + // Map: FieldName -> Boost - stores boosts for all fields (both inference and non-inference) + Map fieldBoosts ) { + public Set getAllInferenceFields() { + return inferenceFieldsPerIndex.values().stream().flatMap(fields -> fields.keySet().stream()).collect(Collectors.toSet()); + } + + public boolean hasInferenceFields() { + return inferenceFieldsPerIndex.isEmpty() == false; + } + + public boolean hasNonInferenceFields() { + return nonInferenceFieldsPerIndex.isEmpty() == false; + } + public Collection getInferenceIndices() { - return inferenceIndicesMetadata.keySet(); + return inferenceFieldsPerIndex.keySet(); + } + + public List nonInferenceIndices() { + return new ArrayList<>(nonInferenceFieldsPerIndex.keySet()); } public Map> getInferenceIdsIndices() { - return inferenceIndicesMetadata.entrySet() - .stream() - .collect( - Collectors.groupingBy( - entry -> entry.getValue().getSearchInferenceId(), - Collectors.mapping(Map.Entry::getKey, Collectors.toList()) - ) - ); + Map> result = new HashMap<>(); + for (Map.Entry> indexEntry : inferenceFieldsPerIndex.entrySet()) { + String indexName = indexEntry.getKey(); + for (InferenceFieldMetadata metadata : indexEntry.getValue().values()) { + String inferenceId = metadata.getSearchInferenceId(); + result.computeIfAbsent(inferenceId, k -> new ArrayList<>()).add(indexName); + } + } + return result; } + + /** + * Returns the set of indices where the given field is a semantic field (has inference metadata). + */ + public Set getInferenceIndicesForField(String fieldName) { + Set indices = new HashSet<>(); + for (Map.Entry> entry : inferenceFieldsPerIndex.entrySet()) { + if (entry.getValue().containsKey(fieldName)) { + indices.add(entry.getKey()); + } + } + return indices; + } + + /** + * @param fieldName the field name + * @return the resolved boost for the field + */ + public float getFieldBoost(String fieldName) { + return fieldBoosts.getOrDefault(fieldName, AbstractQueryBuilder.DEFAULT_BOOST); + } + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java index c85a21f10301d..e03d5590fdd40 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java @@ -40,6 +40,11 @@ protected String getQuery(QueryBuilder queryBuilder) { return sparseVectorQueryBuilder.getQuery(); } + @Override + protected boolean shouldResolveInferenceFieldWildcards(QueryBuilder queryBuilder) { + return false; + } + @Override protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) { Map> inferenceIdsIndices = indexInformation.getInferenceIdsIndices(); diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/100_semantic_text_multi_match.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/100_semantic_text_multi_match.yml new file mode 100644 index 0000000000000..e80c326bf6f4d --- /dev/null +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/100_semantic_text_multi_match.yml @@ -0,0 +1,345 @@ +setup: + - requires: + cluster_features: "search.semantic_multi_match_query_rewrite_interception_supported" + reason: semantic_text multi_match support introduced in 9.2.0 + + - do: + inference.put: + task_type: sparse_embedding + inference_id: sparse-inference-id + body: > + { + "service": "test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + inference.put: + task_type: text_embedding + inference_id: dense-inference-id + body: > + { + "service": "text_embedding_test_service", + "service_settings": { + "model": "my_model", + "dimensions": 10, + "api_key": "abc64", + "similarity": "COSINE" + }, + "task_settings": { + } + } + + - do: + indices.create: + index: test-semantic-index + body: + mappings: + properties: + title: + type: semantic_text + inference_id: sparse-inference-id + content: + type: semantic_text + inference_id: sparse-inference-id + summary: + type: semantic_text + inference_id: dense-inference-id + + - do: + indices.create: + index: test-mixed-index + body: + mappings: + properties: + title: + type: text + semantic_content: + type: semantic_text + inference_id: sparse-inference-id + tags: + type: keyword + + - do: + indices.create: + index: test-text-only-index + body: + mappings: + properties: + title: + type: text + content: + type: text + +--- +"Multi-match query on semantic_text fields": + - do: + index: + index: test-semantic-index + id: doc_1 + body: + title: "Machine learning algorithms" + content: "Deep neural networks for computer vision" + summary: "AI and machine learning fundamentals" + refresh: true + + - do: + search: + index: test-semantic-index + body: + query: + multi_match: + query: "machine learning neural networks" + fields: ["title", "content"] + type: "best_fields" + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + +--- +"Multi-match query with field boosts on semantic_text": + - do: + index: + index: test-semantic-index + id: doc_1 + body: + title: "Advanced algorithms" + content: "Machine learning and artificial intelligence" + summary: "Comprehensive AI guide" + refresh: true + + - do: + search: + index: test-semantic-index + body: + query: + multi_match: + query: "machine learning" + fields: ["title^2", "content^1.5", "summary"] + type: "best_fields" + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + +--- +"Multi-match query on mixed semantic_text and text fields": + - do: + index: + index: test-mixed-index + id: doc_1 + body: + title: "Quantum computing breakthrough" + semantic_content: "Revolutionary quantum algorithms for cryptography" + tags: ["quantum", "computing"] + refresh: true + + - do: + index: + index: test-mixed-index + id: doc_2 + body: + title: "AI research advances" + semantic_content: "Neural network architectures and deep learning" + tags: ["ai", "research"] + refresh: true + + - do: + search: + index: test-mixed-index + body: + query: + multi_match: + query: "quantum algorithms" + fields: ["title", "semantic_content"] + type: "best_fields" + + - match: { hits.total.value: 2 } + +--- +"Multi-match query with wildcard fields on semantic_text": + - do: + index: + index: test-semantic-index + id: doc_1 + body: + title: "Natural language processing" + content: "Text analysis and language models" + summary: "NLP fundamentals and applications" + refresh: true + + - do: + search: + index: test-semantic-index + body: + query: + multi_match: + query: "language processing" + fields: ["*"] + type: "most_fields" + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + +--- +"Multi-match most_fields query on semantic_text": + - do: + index: + index: test-semantic-index + id: doc_1 + body: + title: "Computer vision systems" + content: "Image recognition and computer vision algorithms" + summary: "Visual processing and pattern recognition" + refresh: true + + - do: + search: + index: test-semantic-index + body: + query: + multi_match: + query: "computer vision" + fields: ["title", "content", "summary"] + type: "most_fields" + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + +--- +"Multi-match query with overall boost on mixed fields": + - do: + index: + index: test-mixed-index + id: doc_1 + body: + title: "Blockchain technology" + semantic_content: "Distributed ledger systems and cryptocurrency" + tags: ["blockchain", "crypto"] + refresh: true + + - do: + search: + index: test-mixed-index + body: + query: + multi_match: + query: "blockchain systems" + fields: ["title^2", "semantic_content"] + type: "best_fields" + boost: 2.0 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + +--- +"Multi-match query across multiple indices with different field types": + - do: + index: + index: test-semantic-index + id: doc_1 + body: + title: "Robotics automation" + content: "Autonomous systems and robotic controls" + refresh: true + + - do: + index: + index: test-text-only-index + id: doc_2 + body: + title: "Robotics research" + content: "Industrial automation and robotics" + refresh: true + + - do: + search: + index: "test-semantic-index,test-text-only-index" + body: + query: + multi_match: + query: "robotics automation" + fields: ["title", "content"] + type: "best_fields" + + - match: { hits.total.value: 2 } + +--- +"Multi-match query with pattern fields and boosts": + - do: + index: + index: test-semantic-index + id: doc_1 + body: + title: "Artificial intelligence" + content: "AI algorithms and machine learning" + summary: "Comprehensive AI overview" + refresh: true + + - do: + search: + index: test-semantic-index + body: + query: + multi_match: + query: "artificial intelligence" + fields: ["title^3", "*content*^1.5", "summ*^2"] + type: "best_fields" + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + +--- +"Multi-match query error on unsupported type with multiple fields": + - do: + catch: /multi_match query with type \[cross_fields\] is not supported for semantic_text fields/ + search: + index: test-semantic-index + body: + query: + multi_match: + query: "test query" + fields: ["title", "content"] + type: "cross_fields" + +--- +"Multi-match query error on unsupported type with single field": + - do: + catch: /multi_match query with type \[phrase\] is not supported for semantic_text fields/ + search: + index: test-semantic-index + body: + query: + multi_match: + query: "test query" + fields: ["title"] + type: "phrase" + +--- +"Multi-match single field optimization": + - do: + index: + index: test-semantic-index + id: doc_1 + body: + title: "Single field test" + content: "This should be optimized to single semantic query" + refresh: true + + - do: + search: + index: test-semantic-index + body: + query: + multi_match: + query: "single field optimization" + fields: ["title^2"] + type: "best_fields" + boost: 1.5 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/MultiFieldsInnerRetrieverUtils.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/MultiFieldsInnerRetrieverUtils.java index 8aa5dbf366a7a..61f4d13a3f6bd 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/MultiFieldsInnerRetrieverUtils.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/MultiFieldsInnerRetrieverUtils.java @@ -207,7 +207,7 @@ private static List generateInnerRetrieversForIndex( if (nonInferenceFields.isEmpty() == false) { MultiMatchQueryBuilder nonInferenceFieldQueryBuilder = new MultiMatchQueryBuilder(query).type( MultiMatchQueryBuilder.Type.MOST_FIELDS - ).fields(nonInferenceFields); + ).fields(nonInferenceFields).resolveInferenceFieldWildcards(false); innerRetrievers.add(new StandardRetrieverBuilder(nonInferenceFieldQueryBuilder)); } if (inferenceFields.isEmpty() == false) { diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderTests.java index c211440d10bae..8c2a18812e001 100644 --- a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderTests.java +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderTests.java @@ -242,6 +242,7 @@ private static void assertMultiFieldsParamsRewrite( new StandardRetrieverBuilder( new MultiMatchQueryBuilder(expectedQuery).type(MultiMatchQueryBuilder.Type.MOST_FIELDS) .fields(expectedNonInferenceFields) + .resolveInferenceFieldWildcards(false) ), 1.0f, expectedNormalizer diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderTests.java index 7885ac9df2aa8..8e8a152ea8f3c 100644 --- a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderTests.java +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderTests.java @@ -322,6 +322,7 @@ private static void assertMultiFieldsParamsRewrite( new StandardRetrieverBuilder( new MultiMatchQueryBuilder(expectedQuery).type(MultiMatchQueryBuilder.Type.MOST_FIELDS) .fields(expectedNonInferenceFields) + .resolveInferenceFieldWildcards(false) ) ), Set.of(expectedInferenceFields.entrySet().stream().map(e -> {