Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
import org.elasticsearch.xpack.inference.mapper.SemanticInferenceMetadataFieldsMapper;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.inference.queries.SemanticKnnVectorQueryRewriteInterceptor;
import org.elasticsearch.xpack.inference.queries.SemanticMatchAllQueryRewriteInterceptor;
import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor;
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
import org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor;
Expand Down Expand Up @@ -547,6 +548,7 @@ public List<QueryRewriteInterceptor> getQueryRewriteInterceptors() {
return List.of(
new SemanticKnnVectorQueryRewriteInterceptor(),
new SemanticMatchQueryRewriteInterceptor(),
new SemanticMatchAllQueryRewriteInterceptor(),
new SemanticSparseVectorQueryRewriteInterceptor()
);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.queries;

import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;

public class SemanticMatchAllQueryRewriteInterceptor extends SemanticQueryRewriteInterceptor {

public static final NodeFeature SEMANTIC_MATCH_ALL_QUERY_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature(
"search.semantic_match_all_query_rewrite_interception_supported"
);

public SemanticMatchAllQueryRewriteInterceptor() {}

@Override
protected String getFieldName(QueryBuilder queryBuilder) {
assert (queryBuilder instanceof MatchAllQueryBuilder);
return null; // MatchAllQueryBuilder does not have a field name, it matches all documents
}

@Override
protected String getQuery(QueryBuilder queryBuilder) {
return "*"; // MatchAllQueryBuilder does not have a specific query, it matches all documents
}

@Override
protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) {
return new SemanticQueryBuilder(indexInformation.fieldName(), getQuery(queryBuilder), false);
}

@Override
protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
QueryBuilder queryBuilder,
InferenceIndexInformationForField indexInformation
) {
assert (queryBuilder instanceof MatchAllQueryBuilder);
MatchAllQueryBuilder matchAllQueryBuilder = (MatchAllQueryBuilder) queryBuilder;
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();

boolQueryBuilder.should(
createSemanticSubQuery(indexInformation.getInferenceIndices(), indexInformation.fieldName(), getQuery(queryBuilder))
);
boolQueryBuilder.should(createSubQueryForIndices(indexInformation.nonInferenceIndices(), matchAllQueryBuilder));
return boolQueryBuilder;
}

@Override
public String getQueryName() {
return MatchAllQueryBuilder.NAME;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.index.mapper.IndexFieldMapper;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.TermsQueryBuilder;
Expand All @@ -22,6 +23,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

/**
Expand All @@ -41,6 +43,10 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde
return queryBuilder;
}

if (fieldName == null && getQueryName().equals(MatchAllQueryBuilder.NAME)) {
return handleMatchAllQuery(queryBuilder, resolvedIndices);
}

InferenceIndexInformationForField indexInformation = resolveIndicesForField(fieldName, resolvedIndices);
if (indexInformation.getInferenceIndices().isEmpty()) {
// No inference fields were identified, so return the original query.
Expand All @@ -57,6 +63,30 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde
}
}

private QueryBuilder handleMatchAllQuery(QueryBuilder queryBuilder, ResolvedIndices resolvedIndices) {
if (getInferenceFieldsFromResolveIndices(resolvedIndices).isEmpty()) {
// No inference fields were identified, so return the original query.
return queryBuilder;
}

List<String> fieldList = new ArrayList<>(getFieldsFromResolveIndices(resolvedIndices));
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
for (String field : fieldList) {
InferenceIndexInformationForField indexInformation = resolveIndicesForField(field, resolvedIndices);
if (indexInformation.nonInferenceIndices().isEmpty() == false) {
// Combined case where the field name requested by this query contains both
// semantic_text and non-inference fields, so we have to combine queries per index
// containing each field type.
boolQueryBuilder.should(buildCombinedInferenceAndNonInferenceQuery(queryBuilder, indexInformation));
} else {
// The only fields we've identified are inference fields (e.g. semantic_text),
// so rewrite this semantic_text field into a semantic query
boolQueryBuilder.should(buildInferenceQuery(queryBuilder, indexInformation));
}
}
return boolQueryBuilder.should().isEmpty() ? queryBuilder : boolQueryBuilder;
}

/**
* @param queryBuilder {@link QueryBuilder}
* @return The singular field name requested by the provided query builder.
Expand Down Expand Up @@ -90,6 +120,36 @@ protected abstract QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
InferenceIndexInformationForField indexInformation
);

private List<String> getFieldsFromResolveIndices(ResolvedIndices resolvedIndices) {
Collection<IndexMetadata> indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values();
List<String> fields = new ArrayList<>();
for (IndexMetadata indexMetadata : indexMetadataCollection) {
if (indexMetadata.mapping() == null || indexMetadata.mapping().sourceAsMap() == null) {
// No mapping, so no fields.
continue;
}
Collection<Object> mappingSource = indexMetadata.mapping().getSourceAsMap().values();
@SuppressWarnings("unchecked")
Set<String> fieldNames = mappingSource.stream()
.filter(obj -> obj instanceof Map<?, ?>)
.map(obj -> (Map<String, Object>) obj)
.flatMap(map -> map.keySet().stream())
.collect(Collectors.toSet());
fields.addAll(fieldNames);
}
return fields;
}

private List<InferenceFieldMetadata> getInferenceFieldsFromResolveIndices(ResolvedIndices resolvedIndices) {
Collection<IndexMetadata> indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values();
List<InferenceFieldMetadata> inferenceIndicesMetadata = new ArrayList<>();
for (IndexMetadata indexMetadata : indexMetadataCollection) {
inferenceIndicesMetadata.addAll(indexMetadata.getInferenceFields().values());
}

return inferenceIndicesMetadata;
}

private InferenceIndexInformationForField resolveIndicesForField(String fieldName, ResolvedIndices resolvedIndices) {
Collection<IndexMetadata> indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values();
Map<String, InferenceFieldMetadata> inferenceIndicesMetadata = new HashMap<>();
Expand Down