Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ static TransportVersion def(int id) {
public static final TransportVersion RRF_QUERY_REWRITE = def(8_758_00_0);
public static final TransportVersion SEARCH_FAILURE_STATS = def(8_759_00_0);
public static final TransportVersion INGEST_GEO_DATABASE_PROVIDERS = def(8_760_00_0);
public static final TransportVersion TEXT_SIMILARITY_RERANKER_QUERY_REWRITE = def(8_761_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchService;
import org.elasticsearch.search.SearchSortValues;
import org.elasticsearch.search.aggregations.AggregationReduceContext;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.InternalAggregations;
Expand All @@ -51,6 +52,7 @@
import org.elasticsearch.search.query.QuerySearchResult;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext;
import org.elasticsearch.search.sort.ShardDocSortField;
import org.elasticsearch.search.suggest.Suggest;
import org.elasticsearch.search.suggest.Suggest.Suggestion;
import org.elasticsearch.search.suggest.completion.CompletionSuggestion;
Expand Down Expand Up @@ -464,6 +466,13 @@ private static SearchHits getHits(
assert shardDoc instanceof RankDoc;
searchHit.setRank(((RankDoc) shardDoc).rank);
searchHit.score(shardDoc.score);
long shardAndDoc = ShardDocSortField.encodeShardAndDoc(shardDoc.shardIndex, shardDoc.doc);
searchHit.sortValues(
new SearchSortValues(
new Object[] { shardDoc.score, shardAndDoc },
new DocValueFormat[] { DocValueFormat.RAW, DocValueFormat.RAW }
)
);
} else if (sortedTopDocs.isSortedByField) {
FieldDoc fieldDoc = (FieldDoc) shardDoc;
searchHit.sortValues(fieldDoc.fields, reducedQueryPhase.sortValueFormats);
Expand Down
11 changes: 9 additions & 2 deletions server/src/main/java/org/elasticsearch/search/rank/RankDoc.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@

import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.ScoreDoc;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
import org.elasticsearch.xcontent.ToXContentFragment;
import org.elasticsearch.xcontent.XContentBuilder;

Expand All @@ -24,7 +26,7 @@
* {@code RankDoc} is the base class for all ranked results.
* Subclasses should extend this with additional information required for their global ranking method.
*/
public class RankDoc extends ScoreDoc implements NamedWriteable, ToXContentFragment, Comparable<RankDoc> {
public class RankDoc extends ScoreDoc implements VersionedNamedWriteable, ToXContentFragment, Comparable<RankDoc> {

public static final String NAME = "rank_doc";

Expand All @@ -40,6 +42,11 @@ public String getWriteableName() {
return NAME;
}

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.RANK_DOCS_RETRIEVER;
}

@Override
public final int compareTo(RankDoc other) {
if (score != other.score) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ public void onFailure(Exception e) {

@Override
public final QueryBuilder topDocsQuery() {
throw new IllegalStateException(getName() + " cannot be nested");
throw new IllegalStateException("Should not be called, missing a rewrite?");
}

@Override
Expand Down Expand Up @@ -208,7 +208,7 @@ public int doHashCode() {
return Objects.hash(innerRetrievers);
}

private SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) {
protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) {
var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit)
.trackTotalHits(false)
.storedFields(new StoredFieldsContext(false))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public void setTopValue(Long value) {

@Override
public Long value(int slot) {
return (((long) shardRequestIndex) << 32) | (delegate.value(slot) & 0xFFFFFFFFL);
return encodeShardAndDoc(shardRequestIndex, delegate.value(slot));
}

@Override
Expand All @@ -87,4 +87,8 @@ public static int decodeDoc(long value) {
public static int decodeShardRequestIndex(long value) {
return (int) (value >> 32);
}

public static long encodeShardAndDoc(int shardIndex, int doc) {
return (((long) shardIndex) << 32) | (doc & 0xFFFFFFFFL);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ public Set<NodeFeature> getFeatures() {
TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_RETRIEVER_SUPPORTED,
RandomRankRetrieverBuilder.RANDOM_RERANKER_RETRIEVER_SUPPORTED,
SemanticTextFieldMapper.SEMANTIC_TEXT_SEARCH_INFERENCE_ID,
SemanticQueryBuilder.SEMANTIC_TEXT_INNER_HITS
SemanticQueryBuilder.SEMANTIC_TEXT_INNER_HITS,
TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_COMPOSITION_SUPPORTED
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestHandler;
import org.elasticsearch.search.rank.RankBuilder;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.threadpool.ExecutorBuilder;
import org.elasticsearch.threadpool.ScalingExecutorBuilder;
import org.elasticsearch.xcontent.ParseField;
Expand Down Expand Up @@ -66,6 +67,7 @@
import org.elasticsearch.xpack.inference.rank.random.RandomRankBuilder;
import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder;
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankBuilder;
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankDoc;
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceEndpointAction;
Expand Down Expand Up @@ -250,6 +252,7 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
var entries = new ArrayList<>(InferenceNamedWriteablesProvider.getNamedWriteables());
entries.add(new NamedWriteableRegistry.Entry(RankBuilder.class, TextSimilarityRankBuilder.NAME, TextSimilarityRankBuilder::new));
entries.add(new NamedWriteableRegistry.Entry(RankBuilder.class, RandomRankBuilder.NAME, RandomRankBuilder::new));
entries.add(new NamedWriteableRegistry.Entry(RankDoc.class, TextSimilarityRankDoc.NAME, TextSimilarityRankDoc::new));
return entries;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.rank.textsimilarity;

import org.apache.lucene.search.Explanation;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Objects;

public class TextSimilarityRankDoc extends RankDoc {

public static final String NAME = "text_similarity_rank_doc";

public final String inferenceId;
public final String field;

public TextSimilarityRankDoc(int doc, float score, int shardIndex, String inferenceId, String field) {
super(doc, score, shardIndex);
this.inferenceId = inferenceId;
this.field = field;
}

public TextSimilarityRankDoc(StreamInput in) throws IOException {
super(in);
inferenceId = in.readString();
field = in.readString();
}

@Override
public Explanation explain(Explanation[] sources, String[] queryNames) {
final String queryAlias = queryNames[0] == null ? "" : "[" + queryNames[0] + "]";
return Explanation.match(
score,
"text_similarity_reranker match using inference endpoint: ["
+ inferenceId
+ "] on document field: ["
+ field
+ "] matching on source query "
+ queryAlias,
sources
);
}

@Override
public void doWriteTo(StreamOutput out) throws IOException {
out.writeString(inferenceId);
out.writeString(field);
}

@Override
public boolean doEquals(RankDoc rd) {
TextSimilarityRankDoc tsrd = (TextSimilarityRankDoc) rd;
return Objects.equals(inferenceId, tsrd.inferenceId) && Objects.equals(field, tsrd.field);
}

@Override
public int doHashCode() {
return Objects.hash(inferenceId, field);
}

@Override
public String toString() {
return "TextSimilarityRankDoc{"
+ "doc="
+ doc
+ ", shardIndex="
+ shardIndex
+ ", score="
+ score
+ ", inferenceId="
+ inferenceId
+ ", field="
+ field
+ '}';
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
protected void doToXContent(XContentBuilder builder, Params params) throws IOException {
builder.field("inferenceId", inferenceId);
builder.field("field", field);
}

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.TEXT_SIMILARITY_RERANKER_QUERY_REWRITE;
}
}
Loading