Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/129200.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 129200
summary: Simplified Linear Retriever
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be updated? Probably something like Add simplified syntax and hybrid support to linear retriever.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this description is still succinct and accurate, good as is

area: Search
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
Expand Down Expand Up @@ -65,7 +66,7 @@ protected CompoundRetrieverBuilder(List<RetrieverSource> innerRetrievers, int ra

@SuppressWarnings("unchecked")
public T addChild(RetrieverBuilder retrieverBuilder) {
innerRetrievers.add(new RetrieverSource(retrieverBuilder, null));
innerRetrievers.add(convertToRetrieverSource(retrieverBuilder));
return (T) this;
}

Expand Down Expand Up @@ -99,6 +100,11 @@ public final RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOExceptio
throw new IllegalStateException("PIT is required");
}

RetrieverBuilder rewritten = doRewrite(ctx);
if (rewritten != this) {
return rewritten;
}

// Rewrite prefilters
// We eagerly rewrite prefilters, because some of the innerRetrievers
// could be compound too, so we want to propagate all the necessary filter information to them
Expand Down Expand Up @@ -290,6 +296,14 @@ public int rankWindowSize() {
return rankWindowSize;
}

public List<RetrieverSource> innerRetrievers() {
return Collections.unmodifiableList(innerRetrievers);
}

public static RetrieverSource convertToRetrieverSource(RetrieverBuilder retrieverBuilder) {
return new RetrieverSource(retrieverBuilder, null);
}

protected final SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) {
var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit)
.trackTotalHits(false)
Expand All @@ -316,6 +330,16 @@ protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBu
return sourceBuilder;
}

/**
* Perform any custom rewrite logic necessary
*
* @param ctx The query rewrite context
* @return RetrieverBuilder the rewritten retriever
*/
protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
return this;
}

private RankDoc[] getRankDocs(SearchResponse searchResponse) {
int size = searchResponse.getHits().getHits().length;
RankDoc[] docs = new RankDoc[size];
Expand Down
2 changes: 1 addition & 1 deletion x-pack/plugin/inference/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -403,5 +403,5 @@ tasks.named("thirdPartyAudit").configure {
}

tasks.named('yamlRestTest') {
usesDefaultDistribution("to be triaged")
usesDefaultDistribution("Uses the inference API")
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.queries;

import org.apache.lucene.index.Term;
import org.apache.lucene.search.DisjunctionMaxQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.mapper.MapperServiceTestCase;
import org.elasticsearch.index.mapper.ParsedDocument;
import org.elasticsearch.index.query.MultiMatchQueryBuilder;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.test.ClusterServiceUtils;
import org.elasticsearch.test.client.NoOpClient;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.junit.AfterClass;
import org.junit.BeforeClass;

import java.util.Collection;
import java.util.List;
import java.util.function.Supplier;

public class SemanticMultiMatchQueryBuilderTests extends MapperServiceTestCase {
private static TestThreadPool threadPool;
private static ModelRegistry modelRegistry;

private static class InferencePluginWithModelRegistry extends InferencePlugin {
InferencePluginWithModelRegistry(Settings settings) {
super(settings);
}

@Override
protected Supplier<ModelRegistry> getModelRegistry() {
return () -> modelRegistry;
}
}

@BeforeClass
public static void startModelRegistry() {
threadPool = new TestThreadPool(SemanticMultiMatchQueryBuilderTests.class.getName());
var clusterService = ClusterServiceUtils.createClusterService(threadPool);
modelRegistry = new ModelRegistry(clusterService, new NoOpClient(threadPool));
modelRegistry.clusterChanged(new ClusterChangedEvent("init", clusterService.state(), clusterService.state()) {
@Override
public boolean localNodeMaster() {
return false;
}
});
}

@AfterClass
public static void stopModelRegistry() {
IOUtils.closeWhileHandlingException(threadPool);
}

@Override
protected Collection<? extends Plugin> getPlugins() {
return List.of(new InferencePluginWithModelRegistry(Settings.EMPTY));
}

public void testResolveSemanticTextFieldFromWildcard() throws Exception {
MapperService mapperService = createMapperService("""
{
"_doc" : {
"properties": {
"text_field": { "type": "text" },
"keyword_field": { "type": "keyword" },
"inference_field": { "type": "semantic_text", "inference_id": "test_service" }
}
}
}
""");

ParsedDocument doc = mapperService.documentMapper().parse(source("""
{
"text_field" : "foo",
"keyword_field" : "foo",
"inference_field" : "foo",
"_inference_fields": {
"inference_field": {
"inference": {
"inference_id": "test_service",
"model_settings": {
"task_type": "sparse_embedding"
},
"chunks": {
"inference_field": [
{
"start_offset": 0,
"end_offset": 3,
"embeddings": {
"foo": 1.0
}
}
]
}
}
}
}
}
"""));

withLuceneIndex(mapperService, iw -> iw.addDocument(doc.rootDoc()), ir -> {
SearchExecutionContext context = createSearchExecutionContext(mapperService, newSearcher(ir));
Query query = new MultiMatchQueryBuilder("foo", "*_field").toQuery(context);
Query expected = new DisjunctionMaxQuery(
List.of(new TermQuery(new Term("text_field", "foo")), new TermQuery(new Term("keyword_field", "foo"))),
0f
);
assertEquals(expected, query);
});
}
}
4 changes: 4 additions & 0 deletions x-pack/plugin/rank-rrf/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,7 @@ dependencies {

clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin')
}

tasks.named('yamlRestTest') {
usesDefaultDistribution("Uses the inference API")
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
public class RankRRFFeatures implements FeatureSpecification {

public static final NodeFeature LINEAR_RETRIEVER_SUPPORTED = new NodeFeature("linear_retriever_supported");
public static final NodeFeature SIMPLIFIED_RETRIEVER_FORMAT = new NodeFeature("simplified_retriever_format");

@Override
public Set<NodeFeature> getFeatures() {
Expand All @@ -26,6 +27,6 @@ public Set<NodeFeature> getFeatures() {

@Override
public Set<NodeFeature> getTestFeatures() {
return Set.of(INNER_RETRIEVERS_FILTER_SUPPORT, LINEAR_RETRIEVER_MINMAX_SINGLE_DOC_FIX);
return Set.of(INNER_RETRIEVERS_FILTER_SUPPORT, LINEAR_RETRIEVER_MINMAX_SINGLE_DOC_FIX, SIMPLIFIED_RETRIEVER_FORMAT);
}
}
Loading
Loading