diff --git a/docs/changelog/129200.yaml b/docs/changelog/129200.yaml new file mode 100644 index 0000000000000..c657283682c4e --- /dev/null +++ b/docs/changelog/129200.yaml @@ -0,0 +1,5 @@ +pr: 129200 +summary: Simplified Linear Retriever +area: Search +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java index 3f9353251b920..8042be444292d 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java @@ -36,6 +36,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Locale; import java.util.Objects; @@ -53,7 +54,11 @@ public abstract class CompoundRetrieverBuilder innerRetrievers; @@ -65,7 +70,7 @@ protected CompoundRetrieverBuilder(List innerRetrievers, int ra @SuppressWarnings("unchecked") public T addChild(RetrieverBuilder retrieverBuilder) { - innerRetrievers.add(new RetrieverSource(retrieverBuilder, null)); + innerRetrievers.add(RetrieverSource.from(retrieverBuilder)); return (T) this; } @@ -99,6 +104,11 @@ public final RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOExceptio throw new IllegalStateException("PIT is required"); } + RetrieverBuilder rewritten = doRewrite(ctx); + if (rewritten != this) { + return rewritten; + } + // Rewrite prefilters // We eagerly rewrite prefilters, because some of the innerRetrievers // could be compound too, so we want to propagate all the necessary filter information to them @@ -121,7 +131,7 @@ public final RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOExceptio } RetrieverBuilder newRetriever = entry.retriever.rewrite(ctx); if (newRetriever != entry.retriever) { - newRetrievers.add(new RetrieverSource(newRetriever, null)); + newRetrievers.add(RetrieverSource.from(newRetriever)); hasChanged |= true; } else { var sourceBuilder = entry.source != null @@ -291,6 +301,10 @@ public int rankWindowSize() { return rankWindowSize; } + public List innerRetrievers() { + return Collections.unmodifiableList(innerRetrievers); + } + protected final SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) { var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit) .trackTotalHits(false) @@ -317,6 +331,16 @@ protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBu return sourceBuilder; } + /** + * Perform any custom rewrite logic necessary + * + * @param ctx The query rewrite context + * @return RetrieverBuilder the rewritten retriever + */ + protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) { + return this; + } + private RankDoc[] getRankDocs(SearchResponse searchResponse) { int size = searchResponse.getHits().getHits().length; RankDoc[] docs = new RankDoc[size]; diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RescorerRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/RescorerRetrieverBuilder.java index 67a29383f3388..c2bda5587e1bb 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RescorerRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RescorerRetrieverBuilder.java @@ -78,7 +78,7 @@ public static RescorerRetrieverBuilder fromXContent(XContentParser parser, Retri private final List> rescorers; public RescorerRetrieverBuilder(RetrieverBuilder retriever, List> rescorers) { - super(List.of(new RetrieverSource(retriever, null)), extractMinWindowSize(rescorers)); + super(List.of(RetrieverSource.from(retriever)), extractMinWindowSize(rescorers)); if (rescorers.isEmpty()) { throw new IllegalArgumentException("Missing rescore definition"); } diff --git a/x-pack/plugin/inference/build.gradle b/x-pack/plugin/inference/build.gradle index 58aa9b29f8565..b58e1e941b168 100644 --- a/x-pack/plugin/inference/build.gradle +++ b/x-pack/plugin/inference/build.gradle @@ -403,5 +403,5 @@ tasks.named("thirdPartyAudit").configure { } tasks.named('yamlRestTest') { - usesDefaultDistribution("to be triaged") + usesDefaultDistribution("Uses the inference API") } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java index 0a6ff009f367e..2942381a1d181 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java @@ -106,7 +106,7 @@ public TextSimilarityRankRetrieverBuilder( int rankWindowSize, boolean failuresAllowed ) { - super(List.of(new RetrieverSource(retrieverBuilder, null)), rankWindowSize); + super(List.of(RetrieverSource.from(retrieverBuilder)), rankWindowSize); this.inferenceId = inferenceId; this.inferenceText = inferenceText; this.field = field; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticMultiMatchQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticMultiMatchQueryBuilderTests.java new file mode 100644 index 0000000000000..b54ca946e6179 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticMultiMatchQueryBuilderTests.java @@ -0,0 +1,125 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.queries; + +import org.apache.lucene.index.Term; +import org.apache.lucene.search.DisjunctionMaxQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TermQuery; +import org.elasticsearch.cluster.ClusterChangedEvent; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.IOUtils; +import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.index.mapper.MapperServiceTestCase; +import org.elasticsearch.index.mapper.ParsedDocument; +import org.elasticsearch.index.query.MultiMatchQueryBuilder; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.ClusterServiceUtils; +import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.xpack.inference.InferencePlugin; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.junit.AfterClass; +import org.junit.BeforeClass; + +import java.util.Collection; +import java.util.List; +import java.util.function.Supplier; + +public class SemanticMultiMatchQueryBuilderTests extends MapperServiceTestCase { + private static TestThreadPool threadPool; + private static ModelRegistry modelRegistry; + + private static class InferencePluginWithModelRegistry extends InferencePlugin { + InferencePluginWithModelRegistry(Settings settings) { + super(settings); + } + + @Override + protected Supplier getModelRegistry() { + return () -> modelRegistry; + } + } + + @BeforeClass + public static void startModelRegistry() { + threadPool = new TestThreadPool(SemanticMultiMatchQueryBuilderTests.class.getName()); + var clusterService = ClusterServiceUtils.createClusterService(threadPool); + modelRegistry = new ModelRegistry(clusterService, new NoOpClient(threadPool)); + modelRegistry.clusterChanged(new ClusterChangedEvent("init", clusterService.state(), clusterService.state()) { + @Override + public boolean localNodeMaster() { + return false; + } + }); + } + + @AfterClass + public static void stopModelRegistry() { + IOUtils.closeWhileHandlingException(threadPool); + } + + @Override + protected Collection getPlugins() { + return List.of(new InferencePluginWithModelRegistry(Settings.EMPTY)); + } + + public void testResolveSemanticTextFieldFromWildcard() throws Exception { + MapperService mapperService = createMapperService(""" + { + "_doc" : { + "properties": { + "text_field": { "type": "text" }, + "keyword_field": { "type": "keyword" }, + "inference_field": { "type": "semantic_text", "inference_id": "test_service" } + } + } + } + """); + + ParsedDocument doc = mapperService.documentMapper().parse(source(""" + { + "text_field" : "foo", + "keyword_field" : "foo", + "inference_field" : "foo", + "_inference_fields": { + "inference_field": { + "inference": { + "inference_id": "test_service", + "model_settings": { + "task_type": "sparse_embedding" + }, + "chunks": { + "inference_field": [ + { + "start_offset": 0, + "end_offset": 3, + "embeddings": { + "foo": 1.0 + } + } + ] + } + } + } + } + } + """)); + + withLuceneIndex(mapperService, iw -> iw.addDocument(doc.rootDoc()), ir -> { + SearchExecutionContext context = createSearchExecutionContext(mapperService, newSearcher(ir)); + Query query = new MultiMatchQueryBuilder("foo", "*_field").toQuery(context); + Query expected = new DisjunctionMaxQuery( + List.of(new TermQuery(new Term("text_field", "foo")), new TermQuery(new Term("keyword_field", "foo"))), + 0f + ); + assertEquals(expected, query); + }); + } +} diff --git a/x-pack/plugin/rank-rrf/build.gradle b/x-pack/plugin/rank-rrf/build.gradle index fa598c6ef677a..bf8cbba1390a2 100644 --- a/x-pack/plugin/rank-rrf/build.gradle +++ b/x-pack/plugin/rank-rrf/build.gradle @@ -30,3 +30,7 @@ dependencies { clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin') } + +tasks.named('yamlRestTest') { + usesDefaultDistribution("Uses the inference API") +} diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/MultiFieldsInnerRetrieverUtils.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/MultiFieldsInnerRetrieverUtils.java new file mode 100644 index 0000000000000..0715b4fa67544 --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/MultiFieldsInnerRetrieverUtils.java @@ -0,0 +1,221 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.rank; + +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.common.regex.Regex; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.query.MatchQueryBuilder; +import org.elasticsearch.index.query.MultiMatchQueryBuilder; +import org.elasticsearch.index.search.QueryParserHelper; +import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; +import org.elasticsearch.search.retriever.RetrieverBuilder; +import org.elasticsearch.search.retriever.StandardRetrieverBuilder; +import org.elasticsearch.xpack.rank.linear.LinearRetrieverBuilder; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Function; + +import static org.elasticsearch.action.ValidateActions.addValidationError; +import static org.elasticsearch.index.IndexSettings.DEFAULT_FIELD_SETTING; + +/** + * A utility class for managing and validating the multi-fields query format for the {@link LinearRetrieverBuilder} retriever. + */ +public class MultiFieldsInnerRetrieverUtils { + private MultiFieldsInnerRetrieverUtils() {} + + public record WeightedRetrieverSource(CompoundRetrieverBuilder.RetrieverSource retrieverSource, float weight) {} + + /** + * Validates the parameters related to the multi-fields query format. + * + * @param innerRetrievers The custom inner retrievers already defined + * @param fields The fields to query + * @param query The query text + * @param retrieverName The containing retriever name + * @param retrieversParamName The parameter name for custom inner retrievers + * @param fieldsParamName The parameter name for the fields to query + * @param queryParamName The parameter name for the query text + * @param validationException The validation exception + * @return The validation exception with any multi-fields query format validation errors appended + */ + public static ActionRequestValidationException validateParams( + List innerRetrievers, + List fields, + @Nullable String query, + String retrieverName, + String retrieversParamName, + String fieldsParamName, + String queryParamName, + ActionRequestValidationException validationException + ) { + if (fields.isEmpty() == false || query != null) { + // Using the multi-fields query format + if (query == null) { + // Return early here because the following validation checks assume a query param value is provided + return addValidationError( + String.format( + Locale.ROOT, + "[%s] [%s] must be provided when [%s] is specified", + retrieverName, + queryParamName, + fieldsParamName + ), + validationException + ); + } + + if (query.isEmpty()) { + validationException = addValidationError( + String.format(Locale.ROOT, "[%s] [%s] cannot be empty", retrieverName, queryParamName), + validationException + ); + } + + if (innerRetrievers.isEmpty() == false) { + validationException = addValidationError( + String.format(Locale.ROOT, "[%s] cannot combine [%s] and [%s]", retrieverName, retrieversParamName, queryParamName), + validationException + ); + } + } else if (innerRetrievers.isEmpty()) { + validationException = addValidationError( + String.format(Locale.ROOT, "[%s] must provide [%s] or [%s]", retrieverName, retrieversParamName, queryParamName), + validationException + ); + } + + return validationException; + } + + /** + * Generate the inner retriever tree for the given fields, weights, and query. The tree follows this structure: + * + *
+     * multi_match query on all lexical fields
+     * normalizer retriever
+     *   match query on semantic_text field A
+     *   match query on semantic_text field B
+     *   ...
+     *   match query on semantic_text field Z
+     * 
+ * + *

+ * Where the normalizer retriever is constructed by the {@code innerNormalizerGenerator} function. + *

+ *

+ * This tree structure is repeated for each index in {@code indicesMetadata}. That is to say, that for each index in + * {@code indicesMetadata}, (up to) a pair of retrievers will be added to the returned {@code RetrieverBuilder} list. + *

+ * + * @param fieldsAndWeights The fields to query and their respective weights, in "field^weight" format + * @param query The query text + * @param indicesMetadata The metadata for the indices to search + * @param innerNormalizerGenerator The inner normalizer retriever generator function + * @param weightValidator The field weight validator + * @return The inner retriever tree as a {@code RetrieverBuilder} list + */ + public static List generateInnerRetrievers( + List fieldsAndWeights, + String query, + Collection indicesMetadata, + Function, CompoundRetrieverBuilder> innerNormalizerGenerator, + @Nullable Consumer weightValidator + ) { + Map parsedFieldsAndWeights = QueryParserHelper.parseFieldsAndWeights(fieldsAndWeights); + if (weightValidator != null) { + parsedFieldsAndWeights.values().forEach(weightValidator); + } + + // We expect up to 2 inner retrievers to be generated for each index queried + List innerRetrievers = new ArrayList<>(indicesMetadata.size() * 2); + for (IndexMetadata indexMetadata : indicesMetadata) { + innerRetrievers.addAll( + generateInnerRetrieversForIndex(parsedFieldsAndWeights, query, indexMetadata, innerNormalizerGenerator, weightValidator) + ); + } + return innerRetrievers; + } + + private static List generateInnerRetrieversForIndex( + Map parsedFieldsAndWeights, + String query, + IndexMetadata indexMetadata, + Function, CompoundRetrieverBuilder> innerNormalizerGenerator, + @Nullable Consumer weightValidator + ) { + Map fieldsAndWeightsToQuery = parsedFieldsAndWeights; + if (fieldsAndWeightsToQuery.isEmpty()) { + Settings settings = indexMetadata.getSettings(); + List defaultFields = settings.getAsList(DEFAULT_FIELD_SETTING.getKey(), DEFAULT_FIELD_SETTING.getDefault(settings)); + fieldsAndWeightsToQuery = QueryParserHelper.parseFieldsAndWeights(defaultFields); + if (weightValidator != null) { + fieldsAndWeightsToQuery.values().forEach(weightValidator); + } + } + + Map inferenceFields = new HashMap<>(); + Map indexInferenceFields = indexMetadata.getInferenceFields(); + for (Map.Entry entry : fieldsAndWeightsToQuery.entrySet()) { + String field = entry.getKey(); + Float weight = entry.getValue(); + + if (Regex.isMatchAllPattern(field)) { + indexInferenceFields.keySet().forEach(f -> addToInferenceFieldsMap(inferenceFields, f, weight)); + } else if (Regex.isSimpleMatchPattern(field)) { + indexInferenceFields.keySet() + .stream() + .filter(f -> Regex.simpleMatch(field, f)) + .forEach(f -> addToInferenceFieldsMap(inferenceFields, f, weight)); + } else { + // No wildcards in field name + if (indexInferenceFields.containsKey(field)) { + addToInferenceFieldsMap(inferenceFields, field, weight); + } + } + } + + Map nonInferenceFields = new HashMap<>(fieldsAndWeightsToQuery); + nonInferenceFields.keySet().removeAll(inferenceFields.keySet()); // Remove all inference fields from non-inference fields map + + // TODO: Set index pre-filters on returned retrievers when we want to implement multi-index support + List innerRetrievers = new ArrayList<>(2); + if (nonInferenceFields.isEmpty() == false) { + MultiMatchQueryBuilder nonInferenceFieldQueryBuilder = new MultiMatchQueryBuilder(query).type( + MultiMatchQueryBuilder.Type.MOST_FIELDS + ).fields(nonInferenceFields); + innerRetrievers.add(new StandardRetrieverBuilder(nonInferenceFieldQueryBuilder)); + } + if (inferenceFields.isEmpty() == false) { + List inferenceFieldRetrievers = new ArrayList<>(inferenceFields.size()); + inferenceFields.forEach((f, w) -> { + RetrieverBuilder retrieverBuilder = new StandardRetrieverBuilder(new MatchQueryBuilder(f, query)); + inferenceFieldRetrievers.add( + new WeightedRetrieverSource(CompoundRetrieverBuilder.RetrieverSource.from(retrieverBuilder), w) + ); + }); + + innerRetrievers.add(innerNormalizerGenerator.apply(inferenceFieldRetrievers)); + } + return innerRetrievers; + } + + private static void addToInferenceFieldsMap(Map inferenceFields, String field, Float weight) { + inferenceFields.compute(field, (k, v) -> v == null ? weight : v * weight); + } +} diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/RankRRFFeatures.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/RankRRFFeatures.java index 00395ebd18239..0f11df321300b 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/RankRRFFeatures.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/RankRRFFeatures.java @@ -9,6 +9,7 @@ import org.elasticsearch.features.FeatureSpecification; import org.elasticsearch.features.NodeFeature; +import org.elasticsearch.xpack.rank.linear.LinearRetrieverBuilder; import java.util.Set; @@ -32,7 +33,8 @@ public Set getTestFeatures() { INNER_RETRIEVERS_FILTER_SUPPORT, LINEAR_RETRIEVER_MINMAX_SINGLE_DOC_FIX, LINEAR_RETRIEVER_L2_NORM, - LINEAR_RETRIEVER_MINSCORE_FIX + LINEAR_RETRIEVER_MINSCORE_FIX, + LinearRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT ); } } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java index 7631446ca71d0..f0c36f9819af8 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java @@ -8,10 +8,14 @@ package org.elasticsearch.xpack.rank.linear; import org.apache.lucene.search.ScoreDoc; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ResolvedIndices; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.util.Maps; import org.elasticsearch.features.NodeFeature; +import org.elasticsearch.index.query.MatchNoneQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.rank.RankBuilder; @@ -19,20 +23,24 @@ import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverParserContext; +import org.elasticsearch.search.retriever.StandardRetrieverBuilder; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.XPackPlugin; +import org.elasticsearch.xpack.rank.MultiFieldsInnerRetrieverUtils; import org.elasticsearch.xpack.rank.rrf.RRFRankPlugin; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Locale; import java.util.Map; +import java.util.Objects; -import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.action.ValidateActions.addValidationError; import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.xpack.rank.RankRRFFeatures.LINEAR_RETRIEVER_SUPPORTED; import static org.elasticsearch.xpack.rank.linear.LinearRetrieverComponent.DEFAULT_WEIGHT; @@ -46,51 +54,69 @@ * */ public final class LinearRetrieverBuilder extends CompoundRetrieverBuilder { + public static final NodeFeature MULTI_FIELDS_QUERY_FORMAT_SUPPORT = new NodeFeature( + "linear_retriever.multi_fields_query_format_support" + ); public static final NodeFeature LINEAR_RETRIEVER_MINSCORE_FIX = new NodeFeature("linear_retriever_minscore_fix"); public static final String NAME = "linear"; public static final ParseField RETRIEVERS_FIELD = new ParseField("retrievers"); + public static final ParseField FIELDS_FIELD = new ParseField("fields"); + public static final ParseField QUERY_FIELD = new ParseField("query"); + public static final ParseField NORMALIZER_FIELD = new ParseField("normalizer"); public static final float DEFAULT_SCORE = 0f; private final float[] weights; private final ScoreNormalizer[] normalizers; + private final List fields; + private final String query; + private final ScoreNormalizer normalizer; @SuppressWarnings("unchecked") static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( NAME, false, args -> { - List retrieverComponents = (List) args[0]; - int rankWindowSize = args[1] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[1]; - List innerRetrievers = new ArrayList<>(); + List retrieverComponents = args[0] == null ? List.of() : (List) args[0]; + List fields = (List) args[1]; + String query = (String) args[2]; + ScoreNormalizer normalizer = args[3] == null ? null : ScoreNormalizer.valueOf((String) args[3]); + int rankWindowSize = args[4] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[4]; + + int index = 0; float[] weights = new float[retrieverComponents.size()]; ScoreNormalizer[] normalizers = new ScoreNormalizer[retrieverComponents.size()]; - int index = 0; + List innerRetrievers = new ArrayList<>(); for (LinearRetrieverComponent component : retrieverComponents) { - innerRetrievers.add(new RetrieverSource(component.retriever, null)); + innerRetrievers.add(RetrieverSource.from(component.retriever)); weights[index] = component.weight; normalizers[index] = component.normalizer; index++; } - return new LinearRetrieverBuilder(innerRetrievers, rankWindowSize, weights, normalizers); + return new LinearRetrieverBuilder(innerRetrievers, fields, query, normalizer, rankWindowSize, weights, normalizers); } ); static { - PARSER.declareObjectArray(constructorArg(), LinearRetrieverComponent::fromXContent, RETRIEVERS_FIELD); + PARSER.declareObjectArray(optionalConstructorArg(), LinearRetrieverComponent::fromXContent, RETRIEVERS_FIELD); + PARSER.declareStringArray(optionalConstructorArg(), FIELDS_FIELD); + PARSER.declareString(optionalConstructorArg(), QUERY_FIELD); + PARSER.declareString(optionalConstructorArg(), NORMALIZER_FIELD); PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD); RetrieverBuilder.declareBaseParserFields(PARSER); } - private static float[] getDefaultWeight(int size) { + private static float[] getDefaultWeight(List innerRetrievers) { + int size = innerRetrievers != null ? innerRetrievers.size() : 0; float[] weights = new float[size]; Arrays.fill(weights, DEFAULT_WEIGHT); return weights; } - private static ScoreNormalizer[] getDefaultNormalizers(int size) { + private static ScoreNormalizer[] getDefaultNormalizers(List innerRetrievers) { + int size = innerRetrievers != null ? innerRetrievers.size() : 0; ScoreNormalizer[] normalizers = new ScoreNormalizer[size]; Arrays.fill(normalizers, IdentityScoreNormalizer.INSTANCE); return normalizers; @@ -107,28 +133,48 @@ public static LinearRetrieverBuilder fromXContent(XContentParser parser, Retriev } LinearRetrieverBuilder(List innerRetrievers, int rankWindowSize) { - this(innerRetrievers, rankWindowSize, getDefaultWeight(innerRetrievers.size()), getDefaultNormalizers(innerRetrievers.size())); + this(innerRetrievers, null, null, null, rankWindowSize, getDefaultWeight(innerRetrievers), getDefaultNormalizers(innerRetrievers)); + } + + public LinearRetrieverBuilder( + List innerRetrievers, + int rankWindowSize, + float[] weights, + ScoreNormalizer[] normalizers + ) { + this(innerRetrievers, null, null, null, rankWindowSize, weights, normalizers); } public LinearRetrieverBuilder( List innerRetrievers, + List fields, + String query, + ScoreNormalizer normalizer, int rankWindowSize, float[] weights, ScoreNormalizer[] normalizers ) { - super(innerRetrievers, rankWindowSize); - if (weights.length != innerRetrievers.size()) { + // Use a mutable list for innerRetrievers so that we can use addChild + super(innerRetrievers == null ? new ArrayList<>() : new ArrayList<>(innerRetrievers), rankWindowSize); + if (weights.length != this.innerRetrievers.size()) { throw new IllegalArgumentException("The number of weights must match the number of inner retrievers"); } - if (normalizers.length != innerRetrievers.size()) { + if (normalizers.length != this.innerRetrievers.size()) { throw new IllegalArgumentException("The number of normalizers must match the number of inner retrievers"); } + + this.fields = fields == null ? List.of() : List.copyOf(fields); + this.query = query; + this.normalizer = normalizer; this.weights = weights; this.normalizers = normalizers; } public LinearRetrieverBuilder( List innerRetrievers, + List fields, + String query, + ScoreNormalizer normalizer, int rankWindowSize, float[] weights, ScoreNormalizer[] normalizers, @@ -136,7 +182,7 @@ public LinearRetrieverBuilder( String retrieverName, List preFilterQueryBuilders ) { - this(innerRetrievers, rankWindowSize, weights, normalizers); + this(innerRetrievers, fields, query, normalizer, rankWindowSize, weights, normalizers); this.minScore = minScore; if (minScore != null && minScore < 0) { throw new IllegalArgumentException("[min_score] must be greater than or equal to 0, was: [" + minScore + "]"); @@ -145,10 +191,59 @@ public LinearRetrieverBuilder( this.preFilterQueryBuilders = preFilterQueryBuilders; } + @Override + public ActionRequestValidationException validate( + SearchSourceBuilder source, + ActionRequestValidationException validationException, + boolean isScroll, + boolean allowPartialSearchResults + ) { + validationException = super.validate(source, validationException, isScroll, allowPartialSearchResults); + validationException = MultiFieldsInnerRetrieverUtils.validateParams( + innerRetrievers, + fields, + query, + getName(), + RETRIEVERS_FIELD.getPreferredName(), + FIELDS_FIELD.getPreferredName(), + QUERY_FIELD.getPreferredName(), + validationException + ); + + if (query != null && normalizer == null) { + validationException = addValidationError( + String.format( + Locale.ROOT, + "[%s] [%s] must be provided when [%s] is specified", + getName(), + NORMALIZER_FIELD.getPreferredName(), + QUERY_FIELD.getPreferredName() + ), + validationException + ); + } else if (innerRetrievers.isEmpty() == false && normalizer != null) { + validationException = addValidationError( + String.format( + Locale.ROOT, + "[%s] [%s] cannot be provided when [%s] is specified", + getName(), + NORMALIZER_FIELD.getPreferredName(), + RETRIEVERS_FIELD.getPreferredName() + ), + validationException + ); + } + + return validationException; + } + @Override protected LinearRetrieverBuilder clone(List newChildRetrievers, List newPreFilterQueryBuilders) { return new LinearRetrieverBuilder( newChildRetrievers, + fields, + query, + normalizer, rankWindowSize, weights, normalizers, @@ -213,11 +308,83 @@ protected RankDoc[] combineInnerRetrieverResults(List rankResults, b return topResults; } + @Override + protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) { + RetrieverBuilder rewritten = this; + + ResolvedIndices resolvedIndices = ctx.getResolvedIndices(); + if (resolvedIndices != null && query != null) { + // Using the multi-fields query format + var localIndicesMetadata = resolvedIndices.getConcreteLocalIndicesMetadata(); + if (localIndicesMetadata.size() > 1) { + throw new IllegalArgumentException( + "[" + NAME + "] cannot specify [" + QUERY_FIELD.getPreferredName() + "] when querying multiple indices" + ); + } else if (resolvedIndices.getRemoteClusterIndices().isEmpty() == false) { + throw new IllegalArgumentException( + "[" + NAME + "] cannot specify [" + QUERY_FIELD.getPreferredName() + "] when querying remote indices" + ); + } + + List fieldsInnerRetrievers = MultiFieldsInnerRetrieverUtils.generateInnerRetrievers( + fields, + query, + localIndicesMetadata.values(), + r -> { + List retrievers = new ArrayList<>(r.size()); + float[] weights = new float[r.size()]; + ScoreNormalizer[] normalizers = new ScoreNormalizer[r.size()]; + + int index = 0; + for (var weightedRetriever : r) { + retrievers.add(weightedRetriever.retrieverSource()); + weights[index] = weightedRetriever.weight(); + normalizers[index] = normalizer; + index++; + } + + return new LinearRetrieverBuilder(retrievers, rankWindowSize, weights, normalizers); + }, + w -> { + if (w < 0) { + throw new IllegalArgumentException("[" + NAME + "] per-field weights must be non-negative"); + } + } + ).stream().map(RetrieverSource::from).toList(); + + if (fieldsInnerRetrievers.isEmpty() == false) { + float[] weights = new float[fieldsInnerRetrievers.size()]; + Arrays.fill(weights, DEFAULT_WEIGHT); + + ScoreNormalizer[] normalizers = new ScoreNormalizer[fieldsInnerRetrievers.size()]; + Arrays.fill(normalizers, normalizer); + + // TODO: This is a incomplete solution as it does not address other incomplete copy issues + // (such as dropping the retriever name and min score) + rewritten = new LinearRetrieverBuilder(fieldsInnerRetrievers, null, null, normalizer, rankWindowSize, weights, normalizers); + rewritten.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders); + } else { + // Inner retriever list can be empty when using an index wildcard pattern that doesn't match any indices + rewritten = new StandardRetrieverBuilder(new MatchNoneQueryBuilder()); + } + } + + return rewritten; + } + @Override public String getName() { return NAME; } + float[] getWeights() { + return weights; + } + + ScoreNormalizer[] getNormalizers() { + return normalizers; + } + public void doToXContent(XContentBuilder builder, Params params) throws IOException { int index = 0; if (innerRetrievers.isEmpty() == false) { @@ -232,6 +399,37 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept } builder.endArray(); } + + if (fields.isEmpty() == false) { + builder.startArray(FIELDS_FIELD.getPreferredName()); + for (String field : fields) { + builder.value(field); + } + builder.endArray(); + } + if (query != null) { + builder.field(QUERY_FIELD.getPreferredName(), query); + } + if (normalizer != null) { + builder.field(NORMALIZER_FIELD.getPreferredName(), normalizer.getName()); + } + builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize); } + + @Override + public boolean doEquals(Object o) { + LinearRetrieverBuilder that = (LinearRetrieverBuilder) o; + return super.doEquals(o) + && Arrays.equals(weights, that.weights) + && Arrays.equals(normalizers, that.normalizers) + && Objects.equals(fields, that.fields) + && Objects.equals(query, that.query) + && Objects.equals(normalizer, that.normalizer); + } + + @Override + public int doHashCode() { + return Objects.hash(super.doHashCode(), Arrays.hashCode(weights), Arrays.hashCode(normalizers), fields, query, normalizer); + } } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java index 91bc19a3e0903..b6ffbf8f3301e 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java @@ -202,7 +202,7 @@ public RetrieverBuilder toRetriever(SearchSourceBuilder source, Predicate { List childRetrievers = (List) args[0]; - List innerRetrievers = childRetrievers.stream().map(r -> new RetrieverSource(r, null)).toList(); + List innerRetrievers = childRetrievers.stream().map(RetrieverSource::from).toList(); int rankWindowSize = args[1] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[1]; int rankConstant = args[2] == null ? DEFAULT_RANK_CONSTANT : (int) args[2]; return new RRFRetrieverBuilder(innerRetrievers, rankWindowSize, rankConstant); diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java index 5cc66c6f50d3c..74e18bf12fffc 100644 --- a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.rank.linear; +import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; @@ -40,9 +41,23 @@ public static void afterClass() throws Exception { xContentRegistryEntries = null; } + /** + * Creates a random {@link LinearRetrieverBuilder}. The created instance is not guaranteed to pass {@link SearchRequest} validation. + * This is purely for x-content testing. + */ @Override protected LinearRetrieverBuilder createTestInstance() { int rankWindowSize = randomInt(100); + + List fields = null; + String query = null; + ScoreNormalizer normalizer = null; + if (randomBoolean()) { + fields = randomList(1, 10, () -> randomAlphaOfLengthBetween(1, 10)); + query = randomAlphaOfLengthBetween(1, 10); + normalizer = randomScoreNormalizer(); + } + int num = randomIntBetween(1, 3); List innerRetrievers = new ArrayList<>(); float[] weights = new float[num]; @@ -54,7 +69,8 @@ protected LinearRetrieverBuilder createTestInstance() { weights[i] = randomFloat(); normalizers[i] = randomScoreNormalizer(); } - return new LinearRetrieverBuilder(innerRetrievers, rankWindowSize, weights, normalizers); + + return new LinearRetrieverBuilder(innerRetrievers, fields, query, normalizer, rankWindowSize, weights, normalizers); } @Override diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderTests.java new file mode 100644 index 0000000000000..c211440d10bae --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderTests.java @@ -0,0 +1,329 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.rank.linear; + +import org.elasticsearch.action.MockResolvedIndices; +import org.elasticsearch.action.OriginalIndices; +import org.elasticsearch.action.ResolvedIndices; +import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.Index; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.query.MatchQueryBuilder; +import org.elasticsearch.index.query.MultiMatchQueryBuilder; +import org.elasticsearch.index.query.QueryRewriteContext; +import org.elasticsearch.search.builder.PointInTimeBuilder; +import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; +import org.elasticsearch.search.retriever.RetrieverBuilder; +import org.elasticsearch.search.retriever.StandardRetrieverBuilder; +import org.elasticsearch.test.ESTestCase; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.elasticsearch.search.rank.RankBuilder.DEFAULT_RANK_WINDOW_SIZE; + +public class LinearRetrieverBuilderTests extends ESTestCase { + public void testMultiFieldsParamsRewrite() { + final String indexName = "test-index"; + final List testInferenceFields = List.of("semantic_field_1", "semantic_field_2"); + final ResolvedIndices resolvedIndices = createMockResolvedIndices(indexName, testInferenceFields, null); + final QueryRewriteContext queryRewriteContext = new QueryRewriteContext( + parserConfig(), + null, + null, + resolvedIndices, + new PointInTimeBuilder(new BytesArray("pitid")), + null + ); + + // No wildcards, no per-field boosting + LinearRetrieverBuilder retriever = new LinearRetrieverBuilder( + null, + List.of("field_1", "field_2", "semantic_field_1", "semantic_field_2"), + "foo", + MinMaxScoreNormalizer.INSTANCE, + DEFAULT_RANK_WINDOW_SIZE, + new float[0], + new ScoreNormalizer[0] + ); + assertMultiFieldsParamsRewrite( + retriever, + queryRewriteContext, + Map.of("field_1", 1.0f, "field_2", 1.0f), + Map.of("semantic_field_1", 1.0f, "semantic_field_2", 1.0f), + "foo", + MinMaxScoreNormalizer.INSTANCE + ); + + // Non-default rank window size + retriever = new LinearRetrieverBuilder( + null, + List.of("field_1", "field_2", "semantic_field_1", "semantic_field_2"), + "foo2", + MinMaxScoreNormalizer.INSTANCE, + DEFAULT_RANK_WINDOW_SIZE * 2, + new float[0], + new ScoreNormalizer[0] + ); + assertMultiFieldsParamsRewrite( + retriever, + queryRewriteContext, + Map.of("field_1", 1.0f, "field_2", 1.0f), + Map.of("semantic_field_1", 1.0f, "semantic_field_2", 1.0f), + "foo2", + MinMaxScoreNormalizer.INSTANCE + ); + + // No wildcards, per-field boosting + retriever = new LinearRetrieverBuilder( + null, + List.of("field_1", "field_2^1.5", "semantic_field_1", "semantic_field_2^2"), + "bar", + MinMaxScoreNormalizer.INSTANCE, + DEFAULT_RANK_WINDOW_SIZE, + new float[0], + new ScoreNormalizer[0] + ); + assertMultiFieldsParamsRewrite( + retriever, + queryRewriteContext, + Map.of("field_1", 1.0f, "field_2", 1.5f), + Map.of("semantic_field_1", 1.0f, "semantic_field_2", 2.0f), + "bar", + MinMaxScoreNormalizer.INSTANCE + ); + + // Glob matching on inference and non-inference fields with per-field boosting + retriever = new LinearRetrieverBuilder( + null, + List.of("field_*^1.5", "*_field_1^2.5"), + "baz", + MinMaxScoreNormalizer.INSTANCE, + DEFAULT_RANK_WINDOW_SIZE, + new float[0], + new ScoreNormalizer[0] + ); + assertMultiFieldsParamsRewrite( + retriever, + queryRewriteContext, + Map.of("field_*", 1.5f, "*_field_1", 2.5f), + Map.of("semantic_field_1", 2.5f), + "baz", + MinMaxScoreNormalizer.INSTANCE + ); + + // Multiple boosts defined on the same field + retriever = new LinearRetrieverBuilder( + null, + List.of("field_*^1.5", "field_1^3.0", "*_field_1^2.5", "semantic_*^1.5"), + "baz2", + MinMaxScoreNormalizer.INSTANCE, + DEFAULT_RANK_WINDOW_SIZE, + new float[0], + new ScoreNormalizer[0] + ); + assertMultiFieldsParamsRewrite( + retriever, + queryRewriteContext, + Map.of("field_*", 1.5f, "field_1", 3.0f, "*_field_1", 2.5f, "semantic_*", 1.5f), + Map.of("semantic_field_1", 3.75f, "semantic_field_2", 1.5f), + "baz2", + MinMaxScoreNormalizer.INSTANCE + ); + + // All-fields wildcard + retriever = new LinearRetrieverBuilder( + null, + List.of("*"), + "qux", + MinMaxScoreNormalizer.INSTANCE, + DEFAULT_RANK_WINDOW_SIZE, + new float[0], + new ScoreNormalizer[0] + ); + assertMultiFieldsParamsRewrite( + retriever, + queryRewriteContext, + Map.of("*", 1.0f), + Map.of("semantic_field_1", 1.0f, "semantic_field_2", 1.0f), + "qux", + MinMaxScoreNormalizer.INSTANCE + ); + } + + public void testSearchRemoteIndex() { + final ResolvedIndices resolvedIndices = createMockResolvedIndices( + "local-index", + List.of(), + Map.of("remote-cluster", "remote-index") + ); + final QueryRewriteContext queryRewriteContext = new QueryRewriteContext( + parserConfig(), + null, + null, + resolvedIndices, + new PointInTimeBuilder(new BytesArray("pitid")), + null + ); + + LinearRetrieverBuilder retriever = new LinearRetrieverBuilder( + null, + null, + "foo", + MinMaxScoreNormalizer.INSTANCE, + DEFAULT_RANK_WINDOW_SIZE, + new float[0], + new ScoreNormalizer[0] + ); + + IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> retriever.doRewrite(queryRewriteContext)); + assertEquals("[linear] cannot specify [query] when querying remote indices", iae.getMessage()); + } + + private static ResolvedIndices createMockResolvedIndices( + String localIndexName, + List inferenceFields, + Map remoteIndexNames + ) { + Index index = new Index(localIndexName, randomAlphaOfLength(10)); + IndexMetadata.Builder indexMetadataBuilder = IndexMetadata.builder(index.getName()) + .settings( + Settings.builder() + .put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()) + .put(IndexMetadata.SETTING_INDEX_UUID, index.getUUID()) + ) + .numberOfShards(1) + .numberOfReplicas(0); + + for (String inferenceField : inferenceFields) { + indexMetadataBuilder.putInferenceField( + new InferenceFieldMetadata(inferenceField, randomAlphaOfLengthBetween(3, 5), new String[] { inferenceField }, null) + ); + } + + Map remoteIndices = new HashMap<>(); + if (remoteIndexNames != null) { + for (Map.Entry entry : remoteIndexNames.entrySet()) { + remoteIndices.put(entry.getKey(), new OriginalIndices(new String[] { entry.getValue() }, IndicesOptions.DEFAULT)); + } + } + + return new MockResolvedIndices( + remoteIndices, + new OriginalIndices(new String[] { localIndexName }, IndicesOptions.DEFAULT), + Map.of(index, indexMetadataBuilder.build()) + ); + } + + private static void assertMultiFieldsParamsRewrite( + LinearRetrieverBuilder retriever, + QueryRewriteContext ctx, + Map expectedNonInferenceFields, + Map expectedInferenceFields, + String expectedQuery, + ScoreNormalizer expectedNormalizer + ) { + Set expectedInnerRetrievers = Set.of( + new InnerRetriever( + new StandardRetrieverBuilder( + new MultiMatchQueryBuilder(expectedQuery).type(MultiMatchQueryBuilder.Type.MOST_FIELDS) + .fields(expectedNonInferenceFields) + ), + 1.0f, + expectedNormalizer + ), + new InnerRetriever( + expectedInferenceFields.entrySet() + .stream() + .map( + e -> new InnerRetriever( + new StandardRetrieverBuilder(new MatchQueryBuilder(e.getKey(), expectedQuery)), + e.getValue(), + expectedNormalizer + ) + ) + .collect(Collectors.toSet()), + 1.0f, + expectedNormalizer + ) + ); + + RetrieverBuilder rewritten = retriever.doRewrite(ctx); + assertNotSame(retriever, rewritten); + assertTrue(rewritten instanceof LinearRetrieverBuilder); + + LinearRetrieverBuilder rewrittenLinear = (LinearRetrieverBuilder) rewritten; + assertEquals(retriever.rankWindowSize(), rewrittenLinear.rankWindowSize()); + assertEquals(expectedInnerRetrievers, getInnerRetrieversAsSet(rewrittenLinear)); + } + + private static Set getInnerRetrieversAsSet(LinearRetrieverBuilder retriever) { + float[] weights = retriever.getWeights(); + ScoreNormalizer[] normalizers = retriever.getNormalizers(); + + int i = 0; + Set innerRetrieversSet = new HashSet<>(); + for (CompoundRetrieverBuilder.RetrieverSource innerRetriever : retriever.innerRetrievers()) { + float weight = weights[i]; + ScoreNormalizer normalizer = normalizers[i]; + + if (innerRetriever.retriever() instanceof LinearRetrieverBuilder innerLinearRetriever) { + assertEquals(retriever.rankWindowSize(), innerLinearRetriever.rankWindowSize()); + innerRetrieversSet.add(new InnerRetriever(getInnerRetrieversAsSet(innerLinearRetriever), weight, normalizer)); + } else { + innerRetrieversSet.add(new InnerRetriever(innerRetriever.retriever(), weight, normalizer)); + } + + i++; + } + + return innerRetrieversSet; + } + + private static class InnerRetriever { + private final Object retriever; + private final float weight; + private final ScoreNormalizer normalizer; + + InnerRetriever(RetrieverBuilder retriever, float weight, ScoreNormalizer normalizer) { + this.retriever = retriever; + this.weight = weight; + this.normalizer = normalizer; + } + + InnerRetriever(Set innerRetrievers, float weight, ScoreNormalizer normalizer) { + this.retriever = innerRetrievers; + this.weight = weight; + this.normalizer = normalizer; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + InnerRetriever that = (InnerRetriever) o; + return Float.compare(weight, that.weight) == 0 + && Objects.equals(retriever, that.retriever) + && Objects.equals(normalizer, that.normalizer); + } + + @Override + public int hashCode() { + return Objects.hash(retriever, weight, normalizer); + } + } +} diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/java/org/elasticsearch/xpack/rank/rrf/LinearRankClientYamlTestSuiteIT.java b/x-pack/plugin/rank-rrf/src/yamlRestTest/java/org/elasticsearch/xpack/rank/rrf/LinearRankClientYamlTestSuiteIT.java index 8af4ae307a51a..00f756ff6ee3f 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/java/org/elasticsearch/xpack/rank/rrf/LinearRankClientYamlTestSuiteIT.java +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/java/org/elasticsearch/xpack/rank/rrf/LinearRankClientYamlTestSuiteIT.java @@ -11,6 +11,7 @@ import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; import org.elasticsearch.test.cluster.ElasticsearchCluster; +import org.elasticsearch.test.cluster.local.distribution.DistributionType; import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate; import org.elasticsearch.test.rest.yaml.ESClientYamlSuiteTestCase; import org.junit.ClassRule; @@ -25,8 +26,12 @@ public class LinearRankClientYamlTestSuiteIT extends ESClientYamlSuiteTestCase { .module("rank-rrf") .module("lang-painless") .module("x-pack-inference") + .systemProperty("tests.seed", System.getProperty("tests.seed")) + .setting("xpack.security.enabled", "false") + .setting("xpack.security.http.ssl.enabled", "false") .setting("xpack.license.self_generated.type", "trial") .plugin("inference-service-test") + .distribution(DistributionType.DEFAULT) .build(); public LinearRankClientYamlTestSuiteIT(@Name("yaml") ClientYamlTestCandidate testCandidate) { diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/20_linear_retriever_simplified.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/20_linear_retriever_simplified.yml new file mode 100644 index 0000000000000..01cfa218c918d --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/20_linear_retriever_simplified.yml @@ -0,0 +1,474 @@ +setup: + - requires: + cluster_features: [ "linear_retriever.multi_fields_query_format_support" ] + reason: "Linear retriever multi-fields query format support" + test_runner_features: [ "close_to", "headers", "contains" ] + + - do: + inference.put: + task_type: sparse_embedding + inference_id: sparse-inference-id + body: > + { + "service": "test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + inference.put: + task_type: text_embedding + inference_id: dense-inference-id + body: > + { + "service": "text_embedding_test_service", + "service_settings": { + "model": "my_model", + "dimensions": 128, + "similarity": "cosine", + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + indices.create: + index: test-index + body: + mappings: + properties: + keyword: + type: keyword + dense_inference: + type: semantic_text + inference_id: dense-inference-id + sparse_inference: + type: semantic_text + inference_id: sparse-inference-id + text_1: + type: text + text_2: + type: text + timestamp: + type: date + dense_vector: + type: dense_vector + dims: 1 + index: true + similarity: l2_norm + index_options: + type: flat + sparse_vector: + type: sparse_vector + + - do: + bulk: + index: test-index + refresh: true + body: | + {"index": {"_id": "1"}} + { + "keyword": "keyword match 1", + "dense_inference": "you know", + "sparse_inference": "for testing", + "text_1": "foo match 1", + "text_2": "x match 2", + "timestamp": "2000-03-30", + "dense_vector": [1], + "sparse_vector": { + "foo": 1.0 + } + } + {"index": {"_id": "2"}} + { + "keyword": "keyword match 2", + "dense_inference": "ElasticSearch is an open source", + "sparse_inference": "distributed, RESTful, search engine", + "text_1": "bar match 3", + "text_2": "y match 4", + "timestamp": "2010-02-08", + "dense_vector": [2], + "sparse_vector": { + "bar": 1.0 + } + } + {"index": {"_id": "3"}} + { + "keyword": "keyword match 3", + "dense_inference": "which is built on top of Lucene internally", + "sparse_inference": "and enjoys all the features it provides", + "text_1": "baz match 5", + "text_2": "z match 6", + "timestamp": "2024-08-08", + "dense_vector": [3], + "sparse_vector": { + "baz": 1.0 + } + } + +--- +"Query all fields using the simplified format": + - do: + headers: + Content-Type: application/json + search: + index: test-index + body: + retriever: + linear: + query: "match" + normalizer: "minmax" + + - match: { hits.total.value: 3 } + - length: { hits.hits: 3 } + - match: { hits.hits.0._id: "3" } + - lte: { hits.hits.0._score: 2.0 } + - match: { hits.hits.1._id: "2" } + - lte: { hits.hits.1._score: 2.0 } + - match: { hits.hits.2._id: "1" } + - lte: { hits.hits.2._score: 2.0 } + +--- +"Lexical match per-field boosting using the simplified format": + - do: + headers: + Content-Type: application/json + search: + index: test-index + body: + retriever: + linear: + fields: [ "text_1", "text_2" ] + query: "foo 1 z" + normalizer: "minmax" + + # Lexical-only match, so max score is 1 + - match: { hits.total.value: 2 } + - length: { hits.hits: 2 } + - match: { hits.hits.0._id: "1" } + - close_to: { hits.hits.0._score: { value: 1.0, error: 0.0001 } } + - match: { hits.hits.1._id: "3" } + - lt: { hits.hits.1._score: 1.0 } + + - do: + headers: + Content-Type: application/json + search: + index: test-index + body: + retriever: + linear: + fields: ["text_1", "text_2^3"] + query: "foo 1 z" + normalizer: "minmax" + + # Lexical-only match, so max score is 1 + - match: { hits.total.value: 2 } + - length: { hits.hits: 2 } + - match: { hits.hits.0._id: "3" } + - close_to: { hits.hits.0._score: { value: 1.0, error: 0.0001 } } + - match: { hits.hits.1._id: "1" } + - lt: { hits.hits.1._score: 1.0 } + +--- +"Semantic match per-field boosting using the simplified format": + # The mock inference services generate synthetic vectors that don't accurately represent similarity to non-identical + # input, so it's hard to create a test that produces intuitive results. Instead, we rely on the fact that the inference + # services generate consistent vectors (i.e. same input -> same output) to demonstrate that per-field boosting on + # a semantic_text field can change the result order. + - do: + headers: + Content-Type: application/json + search: + index: test-index + body: + retriever: + linear: + fields: [ "dense_inference", "sparse_inference" ] + query: "distributed, RESTful, search engine" + normalizer: "minmax" + + # Semantic-only match, so max score is 1 + - match: { hits.total.value: 3 } + - length: { hits.hits: 3 } + - match: { hits.hits.0._id: "2" } + - close_to: { hits.hits.0._score: { value: 1.0, error: 0.0001 } } + - match: { hits.hits.1._id: "3" } + - lt: { hits.hits.1._score: 1.0 } + - match: { hits.hits.2._id: "1" } + - lt: { hits.hits.2._score: 1.0 } + + - do: + headers: + Content-Type: application/json + search: + index: test-index + body: + retriever: + linear: + fields: [ "dense_inference^3", "sparse_inference" ] + query: "distributed, RESTful, search engine" + normalizer: "minmax" + + # Semantic-only match, so max score is 1 + - match: { hits.total.value: 3 } + - length: { hits.hits: 3 } + - match: { hits.hits.0._id: "3" } + - close_to: { hits.hits.0._score: { value: 1.0, error: 0.0001 } } + - match: { hits.hits.1._id: "2" } + - lt: { hits.hits.1._score: 1.0 } + - match: { hits.hits.2._id: "1" } + - lt: { hits.hits.2._score: 1.0 } + +--- +"Can query keyword fields": + - do: + headers: + Content-Type: application/json + search: + index: test-index + body: + retriever: + linear: + fields: [ "keyword" ] + query: "keyword match 1" + normalizer: "minmax" + + # Lexical-only match, so max score is 1 + - match: { hits.total.value: 1 } + - length: { hits.hits: 1 } + - match: { hits.hits.0._id: "1" } + - close_to: { hits.hits.0._score: { value: 1.0, error: 0.0001 } } + +--- +"Can query date fields": + - do: + headers: + Content-Type: application/json + search: + index: test-index + body: + retriever: + linear: + fields: [ "timestamp" ] + query: "2010-02-08" + normalizer: "minmax" + + # Lexical-only match, so max score is 1 + - match: { hits.total.value: 1 } + - length: { hits.hits: 1 } + - match: { hits.hits.0._id: "2" } + - close_to: { hits.hits.0._score: { value: 1.0, error: 0.0001 } } + +--- +"Can query sparse vector fields": + - do: + search: + index: test-index + body: + retriever: + linear: + fields: [ "sparse_vector" ] + query: "foo" + normalizer: "minmax" + + # Lexical-only match, so max score is 1 + - match: { hits.total.value: 1 } + - length: { hits.hits: 1 } + - match: { hits.hits.0._id: "1" } + - close_to: { hits.hits.0._score: { value: 1.0, error: 0.0001 } } + +--- +"Cannot query dense vector fields": + - do: + catch: bad_request + search: + index: test-index + body: + retriever: + linear: + fields: [ "dense_vector" ] + query: "foo" + normalizer: "minmax" + + - contains: { error.root_cause.0.reason: "[linear] search failed - retrievers '[standard]' returned errors" } + - contains: { error.root_cause.0.suppressed.0.failed_shards.0.reason.reason: "Field [dense_vector] of type [dense_vector] does not support match queries" } + +--- +"Filters are propagated": + - do: + headers: + Content-Type: application/json + search: + index: test-index + body: + retriever: + linear: + query: "match" + normalizer: "minmax" + filter: + - term: + keyword: "keyword match 1" + + - match: { hits.total.value: 1 } + - length: { hits.hits: 1 } + - match: { hits.hits.0._id: "1" } + - close_to: { hits.hits.0._score: { value: 2.0, error: 0.0001 } } + +--- +"Wildcard index patterns that do not resolve to any index are handled gracefully": + - do: + search: + index: wildcard-* + body: + retriever: + linear: + query: "match" + normalizer: "minmax" + + - match: { hits.total.value: 0 } + - length: { hits.hits: 0 } + +--- +"Multi-index searches are not allowed": + - do: + indices.create: + index: test-index-2 + + - do: + catch: bad_request + search: + index: [ test-index, test-index-2 ] + body: + retriever: + linear: + query: "match" + normalizer: "minmax" + + - match: { error.root_cause.0.reason: "[linear] cannot specify [query] when querying multiple indices" } + + - do: + indices.put_alias: + index: test-index + name: test-alias + - do: + indices.put_alias: + index: test-index-2 + name: test-alias + + - do: + catch: bad_request + search: + index: test-alias + body: + retriever: + linear: + query: "match" + normalizer: "minmax" + + - match: { error.root_cause.0.reason: "[linear] cannot specify [query] when querying multiple indices" } + +--- +"Wildcard field patterns that do not resolve to any field are handled gracefully": + - do: + search: + index: test-index + body: + retriever: + linear: + fields: [ "wildcard-*" ] + query: "match" + normalizer: "minmax" + + - match: { hits.total.value: 0 } + - length: { hits.hits: 0 } + +--- +"Cannot mix simplified query format with custom sub-retrievers": + - do: + catch: bad_request + search: + index: test-index + body: + retriever: + linear: + query: "foo" + normalizer: "minmax" + retrievers: + - retriever: + standard: + query: + match: + keyword: "bar" + + - contains: { error.root_cause.0.reason: "[linear] cannot combine [retrievers] and [query]" } + +--- +"Cannot set top-level normalizer when using custom sub-retrievers": + - do: + catch: bad_request + search: + index: test-index + body: + retriever: + linear: + normalizer: "minmax" + retrievers: + - retriever: + standard: + query: + match: + keyword: "bar" + + - contains: { error.root_cause.0.reason: "[linear] [normalizer] cannot be provided when [retrievers] is specified" } + +--- +"Missing required params": + - do: + catch: bad_request + search: + index: test-index + body: + retriever: + linear: + query: "foo" + + - contains: { error.root_cause.0.reason: "[linear] [normalizer] must be provided when [query] is specified" } + + - do: + catch: bad_request + search: + index: test-index + body: + retriever: + linear: + fields: ["text_1", "text_2"] + + - contains: { error.root_cause.0.reason: "[linear] [query] must be provided when [fields] is specified" } + + - do: + catch: bad_request + search: + index: test-index + body: + retriever: + linear: + fields: [ "text_1", "text_2" ] + query: "" + + - contains: { error.root_cause.0.reason: "[linear] [query] cannot be empty" } + + - do: + catch: bad_request + search: + index: test-index + body: + retriever: + linear: {} + + - contains: { error.root_cause.0.reason: "[linear] must provide [retrievers] or [query]" }