Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/129200.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 129200
summary: Simplified Linear Retriever
area: Search
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
Expand All @@ -53,7 +54,11 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde

public static final ParseField RANK_WINDOW_SIZE_FIELD = new ParseField("rank_window_size");

public record RetrieverSource(RetrieverBuilder retriever, SearchSourceBuilder source) {}
public record RetrieverSource(RetrieverBuilder retriever, SearchSourceBuilder source) {
public static RetrieverSource from(RetrieverBuilder retriever) {
return new RetrieverSource(retriever, null);
}
}

protected final int rankWindowSize;
protected final List<RetrieverSource> innerRetrievers;
Expand All @@ -65,7 +70,7 @@ protected CompoundRetrieverBuilder(List<RetrieverSource> innerRetrievers, int ra

@SuppressWarnings("unchecked")
public T addChild(RetrieverBuilder retrieverBuilder) {
innerRetrievers.add(new RetrieverSource(retrieverBuilder, null));
innerRetrievers.add(RetrieverSource.from(retrieverBuilder));
return (T) this;
}

Expand Down Expand Up @@ -99,6 +104,11 @@ public final RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOExceptio
throw new IllegalStateException("PIT is required");
}

RetrieverBuilder rewritten = doRewrite(ctx);
if (rewritten != this) {
return rewritten;
}

// Rewrite prefilters
// We eagerly rewrite prefilters, because some of the innerRetrievers
// could be compound too, so we want to propagate all the necessary filter information to them
Expand All @@ -121,7 +131,7 @@ public final RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOExceptio
}
RetrieverBuilder newRetriever = entry.retriever.rewrite(ctx);
if (newRetriever != entry.retriever) {
newRetrievers.add(new RetrieverSource(newRetriever, null));
newRetrievers.add(RetrieverSource.from(newRetriever));
hasChanged |= true;
} else {
var sourceBuilder = entry.source != null
Expand Down Expand Up @@ -291,6 +301,10 @@ public int rankWindowSize() {
return rankWindowSize;
}

public List<RetrieverSource> innerRetrievers() {
return Collections.unmodifiableList(innerRetrievers);
}

protected final SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) {
var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit)
.trackTotalHits(false)
Expand All @@ -317,6 +331,16 @@ protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBu
return sourceBuilder;
}

/**
* Perform any custom rewrite logic necessary
*
* @param ctx The query rewrite context
* @return RetrieverBuilder the rewritten retriever
*/
protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
return this;
}

private RankDoc[] getRankDocs(SearchResponse searchResponse) {
int size = searchResponse.getHits().getHits().length;
RankDoc[] docs = new RankDoc[size];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public static RescorerRetrieverBuilder fromXContent(XContentParser parser, Retri
private final List<RescorerBuilder<?>> rescorers;

public RescorerRetrieverBuilder(RetrieverBuilder retriever, List<RescorerBuilder<?>> rescorers) {
super(List.of(new RetrieverSource(retriever, null)), extractMinWindowSize(rescorers));
super(List.of(RetrieverSource.from(retriever)), extractMinWindowSize(rescorers));
if (rescorers.isEmpty()) {
throw new IllegalArgumentException("Missing rescore definition");
}
Expand Down
2 changes: 1 addition & 1 deletion x-pack/plugin/inference/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -404,5 +404,5 @@ tasks.named("thirdPartyAudit").configure {
}

tasks.named('yamlRestTest') {
usesDefaultDistribution()
usesDefaultDistribution("Uses the inference API")
}
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ public TextSimilarityRankRetrieverBuilder(
int rankWindowSize,
boolean failuresAllowed
) {
super(List.of(new RetrieverSource(retrieverBuilder, null)), rankWindowSize);
super(List.of(RetrieverSource.from(retrieverBuilder)), rankWindowSize);
this.inferenceId = inferenceId;
this.inferenceText = inferenceText;
this.field = field;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.queries;

import org.apache.lucene.index.Term;
import org.apache.lucene.search.DisjunctionMaxQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.mapper.MapperServiceTestCase;
import org.elasticsearch.index.mapper.ParsedDocument;
import org.elasticsearch.index.query.MultiMatchQueryBuilder;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.test.ClusterServiceUtils;
import org.elasticsearch.test.client.NoOpClient;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.junit.AfterClass;
import org.junit.BeforeClass;

import java.util.Collection;
import java.util.List;
import java.util.function.Supplier;

public class SemanticMultiMatchQueryBuilderTests extends MapperServiceTestCase {
private static TestThreadPool threadPool;
private static ModelRegistry modelRegistry;

private static class InferencePluginWithModelRegistry extends InferencePlugin {
InferencePluginWithModelRegistry(Settings settings) {
super(settings);
}

@Override
protected Supplier<ModelRegistry> getModelRegistry() {
return () -> modelRegistry;
}
}

@BeforeClass
public static void startModelRegistry() {
threadPool = new TestThreadPool(SemanticMultiMatchQueryBuilderTests.class.getName());
var clusterService = ClusterServiceUtils.createClusterService(threadPool);
modelRegistry = new ModelRegistry(clusterService, new NoOpClient(threadPool));
modelRegistry.clusterChanged(new ClusterChangedEvent("init", clusterService.state(), clusterService.state()) {
@Override
public boolean localNodeMaster() {
return false;
}
});
}

@AfterClass
public static void stopModelRegistry() {
IOUtils.closeWhileHandlingException(threadPool);
}

@Override
protected Collection<? extends Plugin> getPlugins() {
return List.of(new InferencePluginWithModelRegistry(Settings.EMPTY));
}

public void testResolveSemanticTextFieldFromWildcard() throws Exception {
MapperService mapperService = createMapperService("""
{
"_doc" : {
"properties": {
"text_field": { "type": "text" },
"keyword_field": { "type": "keyword" },
"inference_field": { "type": "semantic_text", "inference_id": "test_service" }
}
}
}
""");

ParsedDocument doc = mapperService.documentMapper().parse(source("""
{
"text_field" : "foo",
"keyword_field" : "foo",
"inference_field" : "foo",
"_inference_fields": {
"inference_field": {
"inference": {
"inference_id": "test_service",
"model_settings": {
"task_type": "sparse_embedding"
},
"chunks": {
"inference_field": [
{
"start_offset": 0,
"end_offset": 3,
"embeddings": {
"foo": 1.0
}
}
]
}
}
}
}
}
"""));

withLuceneIndex(mapperService, iw -> iw.addDocument(doc.rootDoc()), ir -> {
SearchExecutionContext context = createSearchExecutionContext(mapperService, newSearcher(ir));
Query query = new MultiMatchQueryBuilder("foo", "*_field").toQuery(context);
Query expected = new DisjunctionMaxQuery(
List.of(new TermQuery(new Term("text_field", "foo")), new TermQuery(new Term("keyword_field", "foo"))),
0f
);
assertEquals(expected, query);
});
}
}
4 changes: 4 additions & 0 deletions x-pack/plugin/rank-rrf/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,7 @@ dependencies {

clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin')
}

tasks.named('yamlRestTest') {
usesDefaultDistribution("Uses the inference API")
}
Loading