Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ static TransportVersion def(int id) {
public static final TransportVersion SEARCH_FAILURE_STATS = def(8_759_00_0);
public static final TransportVersion INGEST_GEO_DATABASE_PROVIDERS = def(8_760_00_0);
public static final TransportVersion DATE_TIME_DOC_VALUES_LOCALES = def(8_761_00_0);
public static final TransportVersion TEXT_SIMILARITY_RERANKER_QUERY_REWRITE = def(8_762_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchService;
import org.elasticsearch.search.SearchSortValues;
import org.elasticsearch.search.aggregations.AggregationReduceContext;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.InternalAggregations;
Expand All @@ -51,6 +52,7 @@
import org.elasticsearch.search.query.QuerySearchResult;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext;
import org.elasticsearch.search.sort.ShardDocSortField;
import org.elasticsearch.search.suggest.Suggest;
import org.elasticsearch.search.suggest.Suggest.Suggestion;
import org.elasticsearch.search.suggest.completion.CompletionSuggestion;
Expand Down Expand Up @@ -464,6 +466,13 @@ private static SearchHits getHits(
assert shardDoc instanceof RankDoc;
searchHit.setRank(((RankDoc) shardDoc).rank);
searchHit.score(shardDoc.score);
long shardAndDoc = ShardDocSortField.encodeShardAndDoc(shardDoc.shardIndex, shardDoc.doc);
searchHit.sortValues(
new SearchSortValues(
new Object[] { shardDoc.score, shardAndDoc },
new DocValueFormat[] { DocValueFormat.RAW, DocValueFormat.RAW }
)
);
} else if (sortedTopDocs.isSortedByField) {
FieldDoc fieldDoc = (FieldDoc) shardDoc;
searchHit.sortValues(fieldDoc.fields, reducedQueryPhase.sortValueFormats);
Expand Down
11 changes: 9 additions & 2 deletions server/src/main/java/org/elasticsearch/search/rank/RankDoc.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@

import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.ScoreDoc;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
import org.elasticsearch.xcontent.ToXContentFragment;
import org.elasticsearch.xcontent.XContentBuilder;

Expand All @@ -24,7 +26,7 @@
* {@code RankDoc} is the base class for all ranked results.
* Subclasses should extend this with additional information required for their global ranking method.
*/
public class RankDoc extends ScoreDoc implements NamedWriteable, ToXContentFragment, Comparable<RankDoc> {
public class RankDoc extends ScoreDoc implements VersionedNamedWriteable, ToXContentFragment, Comparable<RankDoc> {

public static final String NAME = "rank_doc";

Expand All @@ -40,6 +42,11 @@ public String getWriteableName() {
return NAME;
}

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.RANK_DOCS_RETRIEVER;
}

@Override
public final int compareTo(RankDoc other) {
if (score != other.score) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ public void onFailure(Exception e) {

@Override
public final QueryBuilder topDocsQuery() {
throw new IllegalStateException(getName() + " cannot be nested");
throw new IllegalStateException("Should not be called, missing a rewrite?");
}

@Override
Expand Down Expand Up @@ -208,7 +208,7 @@ public int doHashCode() {
return Objects.hash(innerRetrievers);
}

private SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) {
protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) {
var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit)
.trackTotalHits(false)
.storedFields(new StoredFieldsContext(false))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public void setTopValue(Long value) {

@Override
public Long value(int slot) {
return (((long) shardRequestIndex) << 32) | (delegate.value(slot) & 0xFFFFFFFFL);
return encodeShardAndDoc(shardRequestIndex, delegate.value(slot));
}

@Override
Expand All @@ -87,4 +87,8 @@ public static int decodeDoc(long value) {
public static int decodeShardRequestIndex(long value) {
return (int) (value >> 32);
}

public static long encodeShardAndDoc(int shardIndex, int doc) {
return (((long) shardIndex) << 32) | (doc & 0xFFFFFFFFL);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.search.rank;

import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
import org.elasticsearch.test.AbstractWireSerializingTestCase;

import java.io.IOException;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import static org.hamcrest.Matchers.equalTo;

public abstract class AbstractRankDocWireSerializingTestCase<T extends RankDoc> extends AbstractWireSerializingTestCase<T> {

protected abstract T createTestRankDoc();

@Override
protected NamedWriteableRegistry writableRegistry() {
SearchModule searchModule = new SearchModule(Settings.EMPTY, Collections.emptyList());
List<NamedWriteableRegistry.Entry> entries = searchModule.getNamedWriteables();
entries.addAll(getAdditionalNamedWriteables());
return new NamedWriteableRegistry(entries);
}

protected abstract List<NamedWriteableRegistry.Entry> getAdditionalNamedWriteables();

@Override
protected T createTestInstance() {
return createTestRankDoc();
}

@SuppressWarnings({ "unchecked", "rawtypes" })
public void testRankDocSerialization() throws IOException {
int totalDocs = randomIntBetween(10, 100);
Set<T> docs = new HashSet<>();
for (int i = 0; i < totalDocs; i++) {
docs.add(createTestRankDoc());
}
RankDocsQueryBuilder rankDocsQueryBuilder = new RankDocsQueryBuilder(docs.toArray((T[]) new RankDoc[0]), null, randomBoolean());
RankDocsQueryBuilder copy = (RankDocsQueryBuilder) copyNamedWriteable(rankDocsQueryBuilder, writableRegistry(), QueryBuilder.class);
assertThat(rankDocsQueryBuilder, equalTo(copy));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,29 @@

package org.elasticsearch.search.rank;

import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;

import java.io.IOException;
import java.util.Collections;
import java.util.List;

public class RankDocTests extends AbstractWireSerializingTestCase<RankDoc> {
public class RankDocTests extends AbstractRankDocWireSerializingTestCase<RankDoc> {

static RankDoc createTestRankDoc() {
protected RankDoc createTestRankDoc() {
RankDoc rankDoc = new RankDoc(randomNonNegativeInt(), randomFloat(), randomIntBetween(0, 1));
rankDoc.rank = randomNonNegativeInt();
return rankDoc;
}

@Override
protected Writeable.Reader<RankDoc> instanceReader() {
return RankDoc::new;
protected List<NamedWriteableRegistry.Entry> getAdditionalNamedWriteables() {
return Collections.emptyList();
}

@Override
protected RankDoc createTestInstance() {
return createTestRankDoc();
protected Writeable.Reader<RankDoc> instanceReader() {
return RankDoc::new;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ public Set<NodeFeature> getFeatures() {
TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_RETRIEVER_SUPPORTED,
RandomRankRetrieverBuilder.RANDOM_RERANKER_RETRIEVER_SUPPORTED,
SemanticTextFieldMapper.SEMANTIC_TEXT_SEARCH_INFERENCE_ID,
SemanticQueryBuilder.SEMANTIC_TEXT_INNER_HITS
SemanticQueryBuilder.SEMANTIC_TEXT_INNER_HITS,
TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_COMPOSITION_SUPPORTED
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestHandler;
import org.elasticsearch.search.rank.RankBuilder;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.threadpool.ExecutorBuilder;
import org.elasticsearch.threadpool.ScalingExecutorBuilder;
import org.elasticsearch.xcontent.ParseField;
Expand Down Expand Up @@ -66,6 +67,7 @@
import org.elasticsearch.xpack.inference.rank.random.RandomRankBuilder;
import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder;
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankBuilder;
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankDoc;
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceEndpointAction;
Expand Down Expand Up @@ -253,6 +255,7 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
var entries = new ArrayList<>(InferenceNamedWriteablesProvider.getNamedWriteables());
entries.add(new NamedWriteableRegistry.Entry(RankBuilder.class, TextSimilarityRankBuilder.NAME, TextSimilarityRankBuilder::new));
entries.add(new NamedWriteableRegistry.Entry(RankBuilder.class, RandomRankBuilder.NAME, RandomRankBuilder::new));
entries.add(new NamedWriteableRegistry.Entry(RankDoc.class, TextSimilarityRankDoc.NAME, TextSimilarityRankDoc::new));
return entries;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.rank.textsimilarity;

import org.apache.lucene.search.Explanation;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Objects;

public class TextSimilarityRankDoc extends RankDoc {

public static final String NAME = "text_similarity_rank_doc";

public final String inferenceId;
public final String field;

public TextSimilarityRankDoc(int doc, float score, int shardIndex, String inferenceId, String field) {
super(doc, score, shardIndex);
this.inferenceId = inferenceId;
this.field = field;
}

public TextSimilarityRankDoc(StreamInput in) throws IOException {
super(in);
inferenceId = in.readString();
field = in.readString();
}

@Override
public Explanation explain(Explanation[] sources, String[] queryNames) {
final String queryAlias = queryNames[0] == null ? "" : "[" + queryNames[0] + "]";
return Explanation.match(
score,
"text_similarity_reranker match using inference endpoint: ["
+ inferenceId
+ "] on document field: ["
+ field
+ "] matching on source query "
+ queryAlias,
sources
);
}

@Override
public void doWriteTo(StreamOutput out) throws IOException {
out.writeString(inferenceId);
out.writeString(field);
}

@Override
public boolean doEquals(RankDoc rd) {
TextSimilarityRankDoc tsrd = (TextSimilarityRankDoc) rd;
return Objects.equals(inferenceId, tsrd.inferenceId) && Objects.equals(field, tsrd.field);
}

@Override
public int doHashCode() {
return Objects.hash(inferenceId, field);
}

@Override
public String toString() {
return "TextSimilarityRankDoc{"
+ "doc="
+ doc
+ ", shardIndex="
+ shardIndex
+ ", score="
+ score
+ ", inferenceId="
+ inferenceId
+ ", field="
+ field
+ '}';
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
protected void doToXContent(XContentBuilder builder, Params params) throws IOException {
builder.field("inferenceId", inferenceId);
builder.field("field", field);
}

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.TEXT_SIMILARITY_RERANKER_QUERY_REWRITE;
}
}
Loading