Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ public class SemanticKnnVectorQueryRewriteInterceptor extends SemanticQueryRewri
public SemanticKnnVectorQueryRewriteInterceptor() {}

@Override
protected String getFieldName(QueryBuilder queryBuilder) {
protected Map<String, Float> getFieldNamesWithBoosts(QueryBuilder queryBuilder) {
assert (queryBuilder instanceof KnnVectorQueryBuilder);
KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) queryBuilder;
return knnVectorQueryBuilder.getFieldName();
return Map.of(knnVectorQueryBuilder.getFieldName(), 1.0f);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import org.elasticsearch.index.query.MatchQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;

import java.util.Map;

public class SemanticMatchQueryRewriteInterceptor extends SemanticQueryRewriteInterceptor {

public static final NodeFeature SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature(
Expand All @@ -21,10 +23,10 @@ public class SemanticMatchQueryRewriteInterceptor extends SemanticQueryRewriteIn
public SemanticMatchQueryRewriteInterceptor() {}

@Override
protected String getFieldName(QueryBuilder queryBuilder) {
protected Map<String, Float> getFieldNamesWithBoosts(QueryBuilder queryBuilder) {
assert (queryBuilder instanceof MatchQueryBuilder);
MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder;
return matchQueryBuilder.fieldName();
return Map.of(matchQueryBuilder.fieldName(), 1.0f);
}

@Override
Expand All @@ -36,7 +38,7 @@ protected String getQuery(QueryBuilder queryBuilder) {

@Override
protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) {
SemanticQueryBuilder semanticQueryBuilder = new SemanticQueryBuilder(indexInformation.fieldName(), getQuery(queryBuilder), false);
SemanticQueryBuilder semanticQueryBuilder = new SemanticQueryBuilder(getField(queryBuilder), getQuery(queryBuilder), false);
semanticQueryBuilder.boost(queryBuilder.boost());
semanticQueryBuilder.queryName(queryBuilder.queryName());
return semanticQueryBuilder;
Expand Down Expand Up @@ -71,6 +73,12 @@ public String getQueryName() {
return MatchQueryBuilder.NAME;
}

private String getField(QueryBuilder queryBuilder) {
assert (queryBuilder instanceof MatchQueryBuilder);
MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder;
return matchQueryBuilder.fieldName();
}

private MatchQueryBuilder copyMatchQueryBuilder(MatchQueryBuilder queryBuilder) {
MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder(queryBuilder.fieldName(), queryBuilder.value());
matchQueryBuilder.operator(queryBuilder.operator());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.index.mapper.IndexFieldMapper;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
Expand All @@ -20,8 +21,10 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

/**
Expand All @@ -33,15 +36,14 @@ public SemanticQueryRewriteInterceptor() {}

@Override
public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes to this method are not effectively preparing for the changes that will actually necessary to support multi_match here. I suggest taking some time to think about how you will change this method to properly support multi_match, making an implementation plan, and then determining which parts of that plan can and should be done as part of this initial refactor PR.

String fieldName = getFieldName(queryBuilder);
ResolvedIndices resolvedIndices = context.getResolvedIndices();

if (resolvedIndices == null) {
// No resolved indices, so return the original query.
return queryBuilder;
}

InferenceIndexInformationForField indexInformation = resolveIndicesForField(fieldName, resolvedIndices);
InferenceIndexInformationForField indexInformation = resolveIndicesForFields(queryBuilder, resolvedIndices);
if (indexInformation.getInferenceIndices().isEmpty()) {
// No inference fields were identified, so return the original query.
return queryBuilder;
Expand All @@ -58,10 +60,12 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde
}

/**
* @param queryBuilder {@link QueryBuilder}
* @return The singular field name requested by the provided query builder.
* Extracts field names and their associated boost values from the query builder.
*
* @param queryBuilder the query builder to extract field information from
* @return a map where keys are field names and values are their boost multipliers
*/
protected abstract String getFieldName(QueryBuilder queryBuilder);
protected abstract Map<String, Float> getFieldNamesWithBoosts(QueryBuilder queryBuilder);

/**
* @param queryBuilder {@link QueryBuilder}
Expand Down Expand Up @@ -90,21 +94,57 @@ protected abstract QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
InferenceIndexInformationForField indexInformation
);

private InferenceIndexInformationForField resolveIndicesForField(String fieldName, ResolvedIndices resolvedIndices) {
private static void addToFieldBoostsMap(Map<String, Float> fieldBoosts, String field, Float boost) {
fieldBoosts.compute(field, (k, v) -> v == null ? boost : v * boost);
}

protected InferenceIndexInformationForField resolveIndicesForFields(QueryBuilder queryBuilder, ResolvedIndices resolvedIndices) {
Map<String, Float> fieldsWithBoosts = getFieldNamesWithBoosts(queryBuilder);
Collection<IndexMetadata> indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values();
Map<String, InferenceFieldMetadata> inferenceIndicesMetadata = new HashMap<>();
List<String> nonInferenceIndices = new ArrayList<>();

Map<String, Map<String, InferenceFieldMetadata>> inferenceFieldsPerIndex = new HashMap<>();
Map<String, Set<String>> nonInferenceFieldsPerIndex = new HashMap<>();
Map<String, Float> allFieldBoosts = new HashMap<>();

for (IndexMetadata indexMetadata : indexMetadataCollection) {
String indexName = indexMetadata.getIndex().getName();
InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get(fieldName);
if (inferenceFieldMetadata != null) {
inferenceIndicesMetadata.put(indexName, inferenceFieldMetadata);
} else {
nonInferenceIndices.add(indexName);
Map<String, InferenceFieldMetadata> indexInferenceFields = new HashMap<>();
Map<String, InferenceFieldMetadata> indexInferenceMetadata = indexMetadata.getInferenceFields();

// Collect resolved inference fields for this index
Set<String> resolvedInferenceFields = new HashSet<>();

// Handle explicit inference fields only
for (Map.Entry<String, Float> entry : fieldsWithBoosts.entrySet()) {
String field = entry.getKey();
Float boost = entry.getValue();

if (indexInferenceMetadata.containsKey(field)) {
indexInferenceFields.put(field, indexInferenceMetadata.get(field));
resolvedInferenceFields.add(field);
addToFieldBoostsMap(allFieldBoosts, field, boost);
}
}

// Non-inference fields
Set<String> indexNonInferenceFields = new HashSet<>(fieldsWithBoosts.keySet());
indexNonInferenceFields.removeAll(resolvedInferenceFields);

// Store boosts for non-inference field patterns
for (String nonInferenceField : indexNonInferenceFields) {
addToFieldBoostsMap(allFieldBoosts, nonInferenceField, fieldsWithBoosts.get(nonInferenceField));
}

if (indexInferenceFields.isEmpty() == false) {
inferenceFieldsPerIndex.put(indexName, indexInferenceFields);
}

if (indexNonInferenceFields.isEmpty() == false) {
nonInferenceFieldsPerIndex.put(indexName, indexNonInferenceFields);
}
}

return new InferenceIndexInformationForField(fieldName, inferenceIndicesMetadata, nonInferenceIndices);
return new InferenceIndexInformationForField(inferenceFieldsPerIndex, nonInferenceFieldsPerIndex, allFieldBoosts);
}

protected QueryBuilder createSubQueryForIndices(Collection<String> indices, QueryBuilder queryBuilder) {
Expand All @@ -122,27 +162,68 @@ protected QueryBuilder createSemanticSubQuery(Collection<String> indices, String
}

/**
* Represents the indices and associated inference information for a field.
* Represents the indices and associated inference information for fields.
*/
public record InferenceIndexInformationForField(
String fieldName,
Map<String, InferenceFieldMetadata> inferenceIndicesMetadata,
List<String> nonInferenceIndices
// Map: IndexName -> (FieldName -> InferenceFieldMetadata)
Map<String, Map<String, InferenceFieldMetadata>> inferenceFieldsPerIndex,
// Map: IndexName -> Set<FieldName> - non-inference fields per index (boosts stored in fieldBoosts)
Map<String, Set<String>> nonInferenceFieldsPerIndex,
// Map: FieldName -> Boost - stores boosts for all fields (both inference and non-inference)
Map<String, Float> fieldBoosts
) {

public Set<String> getAllInferenceFields() {
return inferenceFieldsPerIndex.values().stream().flatMap(fields -> fields.keySet().stream()).collect(Collectors.toSet());
}

public boolean hasInferenceFields() {
return inferenceFieldsPerIndex.isEmpty() == false;
}

public boolean hasNonInferenceFields() {
return nonInferenceFieldsPerIndex.isEmpty() == false;
}

public Collection<String> getInferenceIndices() {
return inferenceIndicesMetadata.keySet();
return inferenceFieldsPerIndex.keySet();
}

public List<String> nonInferenceIndices() {
return new ArrayList<>(nonInferenceFieldsPerIndex.keySet());
}

public Map<String, List<String>> getInferenceIdsIndices() {
return inferenceIndicesMetadata.entrySet()
.stream()
.collect(
Collectors.groupingBy(
entry -> entry.getValue().getSearchInferenceId(),
Collectors.mapping(Map.Entry::getKey, Collectors.toList())
)
);
Map<String, List<String>> result = new HashMap<>();
for (Map.Entry<String, Map<String, InferenceFieldMetadata>> indexEntry : inferenceFieldsPerIndex.entrySet()) {
String indexName = indexEntry.getKey();
for (InferenceFieldMetadata metadata : indexEntry.getValue().values()) {
String inferenceId = metadata.getSearchInferenceId();
result.computeIfAbsent(inferenceId, k -> new ArrayList<>()).add(indexName);
}
}
return result;
}

/**
* Returns the set of indices where the given field is a semantic field (has inference metadata).
*/
public Set<String> getInferenceIndicesForField(String fieldName) {
Set<String> indices = new HashSet<>();
for (Map.Entry<String, Map<String, InferenceFieldMetadata>> entry : inferenceFieldsPerIndex.entrySet()) {
if (entry.getValue().containsKey(fieldName)) {
indices.add(entry.getKey());
}
}
return indices;
}

/**
* @param fieldName the field name
* @return the resolved boost for the field
*/
public float getFieldBoost(String fieldName) {
return fieldBoosts.getOrDefault(fieldName, AbstractQueryBuilder.DEFAULT_BOOST);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ public class SemanticSparseVectorQueryRewriteInterceptor extends SemanticQueryRe
public SemanticSparseVectorQueryRewriteInterceptor() {}

@Override
protected String getFieldName(QueryBuilder queryBuilder) {
protected Map<String, Float> getFieldNamesWithBoosts(QueryBuilder queryBuilder) {
assert (queryBuilder instanceof SparseVectorQueryBuilder);
SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) queryBuilder;
return sparseVectorQueryBuilder.getFieldName();
return Map.of(sparseVectorQueryBuilder.getFieldName(), 1.0f);
}

@Override
Expand Down