Skip to content

Commit e745c92

Browse files
Backporting text_similarity_reranker retriever rework to be evaluated during rewrite phase to 8.x (elastic#114282)
* backporting text_similarity_reranker rework to 8.x * fixing comp --------- Co-authored-by: Elastic Machine <[email protected]>
1 parent 029c683 commit e745c92

File tree

19 files changed

+760
-196
lines changed

19 files changed

+760
-196
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ static TransportVersion def(int id) {
236236
public static final TransportVersion INGEST_GEO_DATABASE_PROVIDERS = def(8_760_00_0);
237237
public static final TransportVersion DATE_TIME_DOC_VALUES_LOCALES = def(8_761_00_0);
238238
public static final TransportVersion FAST_REFRESH_RCO = def(8_762_00_0);
239+
public static final TransportVersion TEXT_SIMILARITY_RERANKER_QUERY_REWRITE = def(8_763_00_0);
239240

240241
/*
241242
* STOP! READ THIS FIRST! No, really,

server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import org.elasticsearch.search.SearchHits;
3737
import org.elasticsearch.search.SearchPhaseResult;
3838
import org.elasticsearch.search.SearchService;
39+
import org.elasticsearch.search.SearchSortValues;
3940
import org.elasticsearch.search.aggregations.AggregationReduceContext;
4041
import org.elasticsearch.search.aggregations.AggregatorFactories;
4142
import org.elasticsearch.search.aggregations.InternalAggregations;
@@ -51,6 +52,7 @@
5152
import org.elasticsearch.search.query.QuerySearchResult;
5253
import org.elasticsearch.search.rank.RankDoc;
5354
import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext;
55+
import org.elasticsearch.search.sort.ShardDocSortField;
5456
import org.elasticsearch.search.suggest.Suggest;
5557
import org.elasticsearch.search.suggest.Suggest.Suggestion;
5658
import org.elasticsearch.search.suggest.completion.CompletionSuggestion;
@@ -464,6 +466,13 @@ private static SearchHits getHits(
464466
assert shardDoc instanceof RankDoc;
465467
searchHit.setRank(((RankDoc) shardDoc).rank);
466468
searchHit.score(shardDoc.score);
469+
long shardAndDoc = ShardDocSortField.encodeShardAndDoc(shardDoc.shardIndex, shardDoc.doc);
470+
searchHit.sortValues(
471+
new SearchSortValues(
472+
new Object[] { shardDoc.score, shardAndDoc },
473+
new DocValueFormat[] { DocValueFormat.RAW, DocValueFormat.RAW }
474+
)
475+
);
467476
} else if (sortedTopDocs.isSortedByField) {
468477
FieldDoc fieldDoc = (FieldDoc) shardDoc;
469478
searchHit.sortValues(fieldDoc.fields, reducedQueryPhase.sortValueFormats);

server/src/main/java/org/elasticsearch/search/rank/RankDoc.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111

1212
import org.apache.lucene.search.Explanation;
1313
import org.apache.lucene.search.ScoreDoc;
14-
import org.elasticsearch.common.io.stream.NamedWriteable;
14+
import org.elasticsearch.TransportVersion;
15+
import org.elasticsearch.TransportVersions;
1516
import org.elasticsearch.common.io.stream.StreamInput;
1617
import org.elasticsearch.common.io.stream.StreamOutput;
18+
import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
1719
import org.elasticsearch.xcontent.ToXContentFragment;
1820
import org.elasticsearch.xcontent.XContentBuilder;
1921

@@ -24,7 +26,7 @@
2426
* {@code RankDoc} is the base class for all ranked results.
2527
* Subclasses should extend this with additional information required for their global ranking method.
2628
*/
27-
public class RankDoc extends ScoreDoc implements NamedWriteable, ToXContentFragment, Comparable<RankDoc> {
29+
public class RankDoc extends ScoreDoc implements VersionedNamedWriteable, ToXContentFragment, Comparable<RankDoc> {
2830

2931
public static final String NAME = "rank_doc";
3032

@@ -40,6 +42,11 @@ public String getWriteableName() {
4042
return NAME;
4143
}
4244

45+
@Override
46+
public TransportVersion getMinimalSupportedVersion() {
47+
return TransportVersions.RANK_DOCS_RETRIEVER;
48+
}
49+
4350
@Override
4451
public final int compareTo(RankDoc other) {
4552
if (score != other.score) {

server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ public void onFailure(Exception e) {
160160

161161
@Override
162162
public final QueryBuilder topDocsQuery() {
163-
throw new IllegalStateException(getName() + " cannot be nested");
163+
throw new IllegalStateException("Should not be called, missing a rewrite?");
164164
}
165165

166166
@Override
@@ -208,7 +208,7 @@ public int doHashCode() {
208208
return Objects.hash(innerRetrievers);
209209
}
210210

211-
private SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) {
211+
protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) {
212212
var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit)
213213
.trackTotalHits(false)
214214
.storedFields(new StoredFieldsContext(false))

server/src/main/java/org/elasticsearch/search/sort/ShardDocSortField.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ public void setTopValue(Long value) {
6464

6565
@Override
6666
public Long value(int slot) {
67-
return (((long) shardRequestIndex) << 32) | (delegate.value(slot) & 0xFFFFFFFFL);
67+
return encodeShardAndDoc(shardRequestIndex, delegate.value(slot));
6868
}
6969

7070
@Override
@@ -87,4 +87,8 @@ public static int decodeDoc(long value) {
8787
public static int decodeShardRequestIndex(long value) {
8888
return (int) (value >> 32);
8989
}
90+
91+
public static long encodeShardAndDoc(int shardIndex, int doc) {
92+
return (((long) shardIndex) << 32) | (doc & 0xFFFFFFFFL);
93+
}
9094
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.rank;
11+
12+
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
13+
import org.elasticsearch.common.settings.Settings;
14+
import org.elasticsearch.index.query.QueryBuilder;
15+
import org.elasticsearch.search.SearchModule;
16+
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
17+
import org.elasticsearch.test.AbstractWireSerializingTestCase;
18+
19+
import java.io.IOException;
20+
import java.util.Collections;
21+
import java.util.HashSet;
22+
import java.util.List;
23+
import java.util.Set;
24+
25+
import static org.hamcrest.Matchers.equalTo;
26+
27+
public abstract class AbstractRankDocWireSerializingTestCase<T extends RankDoc> extends AbstractWireSerializingTestCase<T> {
28+
29+
protected abstract T createTestRankDoc();
30+
31+
@Override
32+
protected NamedWriteableRegistry writableRegistry() {
33+
SearchModule searchModule = new SearchModule(Settings.EMPTY, Collections.emptyList());
34+
List<NamedWriteableRegistry.Entry> entries = searchModule.getNamedWriteables();
35+
entries.addAll(getAdditionalNamedWriteables());
36+
return new NamedWriteableRegistry(entries);
37+
}
38+
39+
protected abstract List<NamedWriteableRegistry.Entry> getAdditionalNamedWriteables();
40+
41+
@Override
42+
protected T createTestInstance() {
43+
return createTestRankDoc();
44+
}
45+
46+
@SuppressWarnings({ "unchecked", "rawtypes" })
47+
public void testRankDocSerialization() throws IOException {
48+
int totalDocs = randomIntBetween(10, 100);
49+
Set<T> docs = new HashSet<>();
50+
for (int i = 0; i < totalDocs; i++) {
51+
docs.add(createTestRankDoc());
52+
}
53+
RankDocsQueryBuilder rankDocsQueryBuilder = new RankDocsQueryBuilder(docs.toArray((T[]) new RankDoc[0]), null, randomBoolean());
54+
RankDocsQueryBuilder copy = (RankDocsQueryBuilder) copyNamedWriteable(rankDocsQueryBuilder, writableRegistry(), QueryBuilder.class);
55+
assertThat(rankDocsQueryBuilder, equalTo(copy));
56+
}
57+
}

server/src/test/java/org/elasticsearch/search/rank/RankDocTests.java

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,27 +9,29 @@
99

1010
package org.elasticsearch.search.rank;
1111

12+
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1213
import org.elasticsearch.common.io.stream.Writeable;
13-
import org.elasticsearch.test.AbstractWireSerializingTestCase;
1414

1515
import java.io.IOException;
16+
import java.util.Collections;
17+
import java.util.List;
1618

17-
public class RankDocTests extends AbstractWireSerializingTestCase<RankDoc> {
19+
public class RankDocTests extends AbstractRankDocWireSerializingTestCase<RankDoc> {
1820

19-
static RankDoc createTestRankDoc() {
21+
protected RankDoc createTestRankDoc() {
2022
RankDoc rankDoc = new RankDoc(randomNonNegativeInt(), randomFloat(), randomIntBetween(0, 1));
2123
rankDoc.rank = randomNonNegativeInt();
2224
return rankDoc;
2325
}
2426

2527
@Override
26-
protected Writeable.Reader<RankDoc> instanceReader() {
27-
return RankDoc::new;
28+
protected List<NamedWriteableRegistry.Entry> getAdditionalNamedWriteables() {
29+
return Collections.emptyList();
2830
}
2931

3032
@Override
31-
protected RankDoc createTestInstance() {
32-
return createTestRankDoc();
33+
protected Writeable.Reader<RankDoc> instanceReader() {
34+
return RankDoc::new;
3335
}
3436

3537
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ public Set<NodeFeature> getFeatures() {
2525
return Set.of(
2626
TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_RETRIEVER_SUPPORTED,
2727
RandomRankRetrieverBuilder.RANDOM_RERANKER_RETRIEVER_SUPPORTED,
28-
SemanticTextFieldMapper.SEMANTIC_TEXT_SEARCH_INFERENCE_ID
28+
SemanticTextFieldMapper.SEMANTIC_TEXT_SEARCH_INFERENCE_ID,
29+
TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_COMPOSITION_SUPPORTED
2930
);
3031
}
3132

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import org.elasticsearch.rest.RestController;
3737
import org.elasticsearch.rest.RestHandler;
3838
import org.elasticsearch.search.rank.RankBuilder;
39+
import org.elasticsearch.search.rank.RankDoc;
3940
import org.elasticsearch.threadpool.ExecutorBuilder;
4041
import org.elasticsearch.threadpool.ScalingExecutorBuilder;
4142
import org.elasticsearch.xcontent.ParseField;
@@ -66,6 +67,7 @@
6667
import org.elasticsearch.xpack.inference.rank.random.RandomRankBuilder;
6768
import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder;
6869
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankBuilder;
70+
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankDoc;
6971
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder;
7072
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
7173
import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceEndpointAction;
@@ -253,6 +255,7 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
253255
var entries = new ArrayList<>(InferenceNamedWriteablesProvider.getNamedWriteables());
254256
entries.add(new NamedWriteableRegistry.Entry(RankBuilder.class, TextSimilarityRankBuilder.NAME, TextSimilarityRankBuilder::new));
255257
entries.add(new NamedWriteableRegistry.Entry(RankBuilder.class, RandomRankBuilder.NAME, RandomRankBuilder::new));
258+
entries.add(new NamedWriteableRegistry.Entry(RankDoc.class, TextSimilarityRankDoc.NAME, TextSimilarityRankDoc::new));
256259
return entries;
257260
}
258261

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.rank.textsimilarity;
9+
10+
import org.apache.lucene.search.Explanation;
11+
import org.elasticsearch.TransportVersion;
12+
import org.elasticsearch.TransportVersions;
13+
import org.elasticsearch.common.io.stream.StreamInput;
14+
import org.elasticsearch.common.io.stream.StreamOutput;
15+
import org.elasticsearch.search.rank.RankDoc;
16+
import org.elasticsearch.xcontent.XContentBuilder;
17+
18+
import java.io.IOException;
19+
import java.util.Objects;
20+
21+
public class TextSimilarityRankDoc extends RankDoc {
22+
23+
public static final String NAME = "text_similarity_rank_doc";
24+
25+
public final String inferenceId;
26+
public final String field;
27+
28+
public TextSimilarityRankDoc(int doc, float score, int shardIndex, String inferenceId, String field) {
29+
super(doc, score, shardIndex);
30+
this.inferenceId = inferenceId;
31+
this.field = field;
32+
}
33+
34+
public TextSimilarityRankDoc(StreamInput in) throws IOException {
35+
super(in);
36+
inferenceId = in.readString();
37+
field = in.readString();
38+
}
39+
40+
@Override
41+
public Explanation explain(Explanation[] sources, String[] queryNames) {
42+
final String queryAlias = queryNames[0] == null ? "" : "[" + queryNames[0] + "]";
43+
return Explanation.match(
44+
score,
45+
"text_similarity_reranker match using inference endpoint: ["
46+
+ inferenceId
47+
+ "] on document field: ["
48+
+ field
49+
+ "] matching on source query "
50+
+ queryAlias,
51+
sources
52+
);
53+
}
54+
55+
@Override
56+
public void doWriteTo(StreamOutput out) throws IOException {
57+
out.writeString(inferenceId);
58+
out.writeString(field);
59+
}
60+
61+
@Override
62+
public boolean doEquals(RankDoc rd) {
63+
TextSimilarityRankDoc tsrd = (TextSimilarityRankDoc) rd;
64+
return Objects.equals(inferenceId, tsrd.inferenceId) && Objects.equals(field, tsrd.field);
65+
}
66+
67+
@Override
68+
public int doHashCode() {
69+
return Objects.hash(inferenceId, field);
70+
}
71+
72+
@Override
73+
public String toString() {
74+
return "TextSimilarityRankDoc{"
75+
+ "doc="
76+
+ doc
77+
+ ", shardIndex="
78+
+ shardIndex
79+
+ ", score="
80+
+ score
81+
+ ", inferenceId="
82+
+ inferenceId
83+
+ ", field="
84+
+ field
85+
+ '}';
86+
}
87+
88+
@Override
89+
public String getWriteableName() {
90+
return NAME;
91+
}
92+
93+
@Override
94+
protected void doToXContent(XContentBuilder builder, Params params) throws IOException {
95+
builder.field("inferenceId", inferenceId);
96+
builder.field("field", field);
97+
}
98+
99+
@Override
100+
public TransportVersion getMinimalSupportedVersion() {
101+
return TransportVersions.TEXT_SIMILARITY_RERANKER_QUERY_REWRITE;
102+
}
103+
}

0 commit comments

Comments
 (0)