diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticKnnVectorQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticKnnVectorQueryRewriteInterceptor.java index b1f5c240371f8..6c3ec2add4bc8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticKnnVectorQueryRewriteInterceptor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticKnnVectorQueryRewriteInterceptor.java @@ -33,10 +33,10 @@ public class SemanticKnnVectorQueryRewriteInterceptor extends SemanticQueryRewri public SemanticKnnVectorQueryRewriteInterceptor() {} @Override - protected String getFieldName(QueryBuilder queryBuilder) { + protected Map getFieldNamesWithBoosts(QueryBuilder queryBuilder) { assert (queryBuilder instanceof KnnVectorQueryBuilder); KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) queryBuilder; - return knnVectorQueryBuilder.getFieldName(); + return Map.of(knnVectorQueryBuilder.getFieldName(), 1.0f); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java index a6599afc66c3f..42bcdfaf88947 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java @@ -12,6 +12,8 @@ import org.elasticsearch.index.query.MatchQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; +import java.util.Map; + public class SemanticMatchQueryRewriteInterceptor extends SemanticQueryRewriteInterceptor { public static final NodeFeature SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature( @@ -21,10 +23,10 @@ public class SemanticMatchQueryRewriteInterceptor extends SemanticQueryRewriteIn public SemanticMatchQueryRewriteInterceptor() {} @Override - protected String getFieldName(QueryBuilder queryBuilder) { + protected Map getFieldNamesWithBoosts(QueryBuilder queryBuilder) { assert (queryBuilder instanceof MatchQueryBuilder); MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder; - return matchQueryBuilder.fieldName(); + return Map.of(matchQueryBuilder.fieldName(), 1.0f); } @Override @@ -36,7 +38,7 @@ protected String getQuery(QueryBuilder queryBuilder) { @Override protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) { - SemanticQueryBuilder semanticQueryBuilder = new SemanticQueryBuilder(indexInformation.fieldName(), getQuery(queryBuilder), false); + SemanticQueryBuilder semanticQueryBuilder = new SemanticQueryBuilder(getField(queryBuilder), getQuery(queryBuilder), false); semanticQueryBuilder.boost(queryBuilder.boost()); semanticQueryBuilder.queryName(queryBuilder.queryName()); return semanticQueryBuilder; @@ -71,6 +73,12 @@ public String getQueryName() { return MatchQueryBuilder.NAME; } + private String getField(QueryBuilder queryBuilder) { + assert (queryBuilder instanceof MatchQueryBuilder); + MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder; + return matchQueryBuilder.fieldName(); + } + private MatchQueryBuilder copyMatchQueryBuilder(MatchQueryBuilder queryBuilder) { MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder(queryBuilder.fieldName(), queryBuilder.value()); matchQueryBuilder.operator(queryBuilder.operator()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryRewriteInterceptor.java index bb76ef0be24e9..a651828f6b446 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryRewriteInterceptor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryRewriteInterceptor.java @@ -11,6 +11,7 @@ import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.index.mapper.IndexFieldMapper; +import org.elasticsearch.index.query.AbstractQueryBuilder; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryRewriteContext; @@ -20,8 +21,10 @@ import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.stream.Collectors; /** @@ -33,7 +36,6 @@ public SemanticQueryRewriteInterceptor() {} @Override public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) { - String fieldName = getFieldName(queryBuilder); ResolvedIndices resolvedIndices = context.getResolvedIndices(); if (resolvedIndices == null) { @@ -41,7 +43,7 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde return queryBuilder; } - InferenceIndexInformationForField indexInformation = resolveIndicesForField(fieldName, resolvedIndices); + InferenceIndexInformationForField indexInformation = resolveIndicesForFields(queryBuilder, resolvedIndices); if (indexInformation.getInferenceIndices().isEmpty()) { // No inference fields were identified, so return the original query. return queryBuilder; @@ -58,10 +60,12 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde } /** - * @param queryBuilder {@link QueryBuilder} - * @return The singular field name requested by the provided query builder. + * Extracts field names and their associated boost values from the query builder. + * + * @param queryBuilder the query builder to extract field information from + * @return a map where keys are field names and values are their boost multipliers */ - protected abstract String getFieldName(QueryBuilder queryBuilder); + protected abstract Map getFieldNamesWithBoosts(QueryBuilder queryBuilder); /** * @param queryBuilder {@link QueryBuilder} @@ -90,21 +94,57 @@ protected abstract QueryBuilder buildCombinedInferenceAndNonInferenceQuery( InferenceIndexInformationForField indexInformation ); - private InferenceIndexInformationForField resolveIndicesForField(String fieldName, ResolvedIndices resolvedIndices) { + private static void addToFieldBoostsMap(Map fieldBoosts, String field, Float boost) { + fieldBoosts.compute(field, (k, v) -> v == null ? boost : v * boost); + } + + protected InferenceIndexInformationForField resolveIndicesForFields(QueryBuilder queryBuilder, ResolvedIndices resolvedIndices) { + Map fieldsWithBoosts = getFieldNamesWithBoosts(queryBuilder); Collection indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values(); - Map inferenceIndicesMetadata = new HashMap<>(); - List nonInferenceIndices = new ArrayList<>(); + + Map> inferenceFieldsPerIndex = new HashMap<>(); + Map> nonInferenceFieldsPerIndex = new HashMap<>(); + Map allFieldBoosts = new HashMap<>(); + for (IndexMetadata indexMetadata : indexMetadataCollection) { String indexName = indexMetadata.getIndex().getName(); - InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get(fieldName); - if (inferenceFieldMetadata != null) { - inferenceIndicesMetadata.put(indexName, inferenceFieldMetadata); - } else { - nonInferenceIndices.add(indexName); + Map indexInferenceFields = new HashMap<>(); + Map indexInferenceMetadata = indexMetadata.getInferenceFields(); + + // Collect resolved inference fields for this index + Set resolvedInferenceFields = new HashSet<>(); + + // Handle explicit inference fields only + for (Map.Entry entry : fieldsWithBoosts.entrySet()) { + String field = entry.getKey(); + Float boost = entry.getValue(); + + if (indexInferenceMetadata.containsKey(field)) { + indexInferenceFields.put(field, indexInferenceMetadata.get(field)); + resolvedInferenceFields.add(field); + addToFieldBoostsMap(allFieldBoosts, field, boost); + } + } + + // Non-inference fields + Set indexNonInferenceFields = new HashSet<>(fieldsWithBoosts.keySet()); + indexNonInferenceFields.removeAll(resolvedInferenceFields); + + // Store boosts for non-inference field patterns + for (String nonInferenceField : indexNonInferenceFields) { + addToFieldBoostsMap(allFieldBoosts, nonInferenceField, fieldsWithBoosts.get(nonInferenceField)); + } + + if (indexInferenceFields.isEmpty() == false) { + inferenceFieldsPerIndex.put(indexName, indexInferenceFields); + } + + if (indexNonInferenceFields.isEmpty() == false) { + nonInferenceFieldsPerIndex.put(indexName, indexNonInferenceFields); } } - return new InferenceIndexInformationForField(fieldName, inferenceIndicesMetadata, nonInferenceIndices); + return new InferenceIndexInformationForField(inferenceFieldsPerIndex, nonInferenceFieldsPerIndex, allFieldBoosts); } protected QueryBuilder createSubQueryForIndices(Collection indices, QueryBuilder queryBuilder) { @@ -122,27 +162,68 @@ protected QueryBuilder createSemanticSubQuery(Collection indices, String } /** - * Represents the indices and associated inference information for a field. + * Represents the indices and associated inference information for fields. */ public record InferenceIndexInformationForField( - String fieldName, - Map inferenceIndicesMetadata, - List nonInferenceIndices + // Map: IndexName -> (FieldName -> InferenceFieldMetadata) + Map> inferenceFieldsPerIndex, + // Map: IndexName -> Set - non-inference fields per index (boosts stored in fieldBoosts) + Map> nonInferenceFieldsPerIndex, + // Map: FieldName -> Boost - stores boosts for all fields (both inference and non-inference) + Map fieldBoosts ) { + public Set getAllInferenceFields() { + return inferenceFieldsPerIndex.values().stream().flatMap(fields -> fields.keySet().stream()).collect(Collectors.toSet()); + } + + public boolean hasInferenceFields() { + return inferenceFieldsPerIndex.isEmpty() == false; + } + + public boolean hasNonInferenceFields() { + return nonInferenceFieldsPerIndex.isEmpty() == false; + } + public Collection getInferenceIndices() { - return inferenceIndicesMetadata.keySet(); + return inferenceFieldsPerIndex.keySet(); + } + + public List nonInferenceIndices() { + return new ArrayList<>(nonInferenceFieldsPerIndex.keySet()); } public Map> getInferenceIdsIndices() { - return inferenceIndicesMetadata.entrySet() - .stream() - .collect( - Collectors.groupingBy( - entry -> entry.getValue().getSearchInferenceId(), - Collectors.mapping(Map.Entry::getKey, Collectors.toList()) - ) - ); + Map> result = new HashMap<>(); + for (Map.Entry> indexEntry : inferenceFieldsPerIndex.entrySet()) { + String indexName = indexEntry.getKey(); + for (InferenceFieldMetadata metadata : indexEntry.getValue().values()) { + String inferenceId = metadata.getSearchInferenceId(); + result.computeIfAbsent(inferenceId, k -> new ArrayList<>()).add(indexName); + } + } + return result; + } + + /** + * Returns the set of indices where the given field is a semantic field (has inference metadata). + */ + public Set getInferenceIndicesForField(String fieldName) { + Set indices = new HashSet<>(); + for (Map.Entry> entry : inferenceFieldsPerIndex.entrySet()) { + if (entry.getValue().containsKey(fieldName)) { + indices.add(entry.getKey()); + } + } + return indices; + } + + /** + * @param fieldName the field name + * @return the resolved boost for the field + */ + public float getFieldBoost(String fieldName) { + return fieldBoosts.getOrDefault(fieldName, AbstractQueryBuilder.DEFAULT_BOOST); } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java index c85a21f10301d..b735c223d75fe 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java @@ -27,10 +27,10 @@ public class SemanticSparseVectorQueryRewriteInterceptor extends SemanticQueryRe public SemanticSparseVectorQueryRewriteInterceptor() {} @Override - protected String getFieldName(QueryBuilder queryBuilder) { + protected Map getFieldNamesWithBoosts(QueryBuilder queryBuilder) { assert (queryBuilder instanceof SparseVectorQueryBuilder); SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) queryBuilder; - return sparseVectorQueryBuilder.getFieldName(); + return Map.of(sparseVectorQueryBuilder.getFieldName(), 1.0f); } @Override