Skip to content

Commit 4675586

Browse files
committed
Add some tests
1 parent cf2c8b3 commit 4675586

File tree

5 files changed

+111
-7
lines changed

5 files changed

+111
-7
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import static org.elasticsearch.xpack.inference.queries.SemanticKnnVectorQueryRewriteInterceptor.SEMANTIC_KNN_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED;
2424
import static org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor.SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED;
2525
import static org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor.SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED;
26+
import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_SNIPPETS;
2627

2728
/**
2829
* Provides inference features.
@@ -68,7 +69,8 @@ public Set<NodeFeature> getTestFeatures() {
6869
SEMANTIC_TEXT_EXCLUDE_SUB_FIELDS_FROM_FIELD_CAPS,
6970
SEMANTIC_TEXT_INDEX_OPTIONS,
7071
COHERE_V2_API,
71-
SEMANTIC_TEXT_INDEX_OPTIONS_WITH_DEFAULTS
72+
SEMANTIC_TEXT_INDEX_OPTIONS_WITH_DEFAULTS,
73+
TEXT_SIMILARITY_RERANKER_SNIPPETS
7274
);
7375
}
7476
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,15 +209,15 @@ protected InferenceAction.Request generateRequest(List<String> docFeatures) {
209209
);
210210
}
211211

212-
private float[] extractScoresFromRankedDocs(List<RankedDocsResults.RankedDoc> rankedDocs) {
212+
float[] extractScoresFromRankedDocs(List<RankedDocsResults.RankedDoc> rankedDocs) {
213213
float[] scores = new float[rankedDocs.size()];
214214
for (RankedDocsResults.RankedDoc rankedDoc : rankedDocs) {
215215
scores[rankedDoc.index()] = rankedDoc.relevanceScore();
216216
}
217217
return scores;
218218
}
219219

220-
private float[] extractScoresFromRankedSnippets(List<RankedDocsResults.RankedDoc> rankedDocs, RankFeatureDoc[] featureDocs) {
220+
float[] extractScoresFromRankedSnippets(List<RankedDocsResults.RankedDoc> rankedDocs, RankFeatureDoc[] featureDocs) {
221221
float[] scores = new float[featureDocs.length];
222222
boolean[] hasScore = new boolean[featureDocs.length];
223223

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
4242
"text_similarity_reranker_alias_handling_fix"
4343
);
4444
public static final NodeFeature TEXT_SIMILARITY_RERANKER_MINSCORE_FIX = new NodeFeature("text_similarity_reranker_minscore_fix");
45+
public static final NodeFeature TEXT_SIMILARITY_RERANKER_SNIPPETS = new NodeFeature("text_similarity_reranker_snippets");
4546

4647
public static final ParseField RETRIEVER_FIELD = new ParseField("retriever");
4748
public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id");
@@ -145,6 +146,9 @@ public TextSimilarityRankRetrieverBuilder(
145146
if (retrieverSource.size() != 1) {
146147
throw new IllegalArgumentException("[" + getName() + "] retriever should have exactly one inner retriever");
147148
}
149+
if (snippets != null && snippets.numSnippets() != null && snippets.numSnippets() < 1) {
150+
throw new IllegalArgumentException("num_snippets must be greater than 0, was: " + snippets.numSnippets());
151+
}
148152
this.inferenceId = inferenceId;
149153
this.inferenceText = inferenceText;
150154
this.field = field;

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContextTests.java

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
import org.elasticsearch.client.internal.Client;
1111
import org.elasticsearch.inference.TaskType;
1212
import org.elasticsearch.search.rank.feature.RankFeatureDoc;
13+
import org.elasticsearch.search.rank.feature.RerankSnippetInput;
1314
import org.elasticsearch.test.ESTestCase;
1415
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
16+
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
1517

1618
import java.util.List;
1719

@@ -38,6 +40,18 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContextTests extends E
3840
null
3941
);
4042

43+
TextSimilarityRankFeaturePhaseRankCoordinatorContext withSnippets = new TextSimilarityRankFeaturePhaseRankCoordinatorContext(
44+
10,
45+
0,
46+
100,
47+
mockClient,
48+
"my-inference-id",
49+
"some query",
50+
0.0f,
51+
false,
52+
new RerankSnippetInput(2)
53+
);
54+
4155
public void testComputeScores() {
4256
RankFeatureDoc featureDoc1 = new RankFeatureDoc(0, 1.0f, 0);
4357
featureDoc1.featureData(List.of("text 1"));
@@ -64,4 +78,57 @@ public void testComputeScoresForEmpty() {
6478
);
6579
}
6680

81+
public void testExtractScoresFromRankedDocs() {
82+
List<RankedDocsResults.RankedDoc> rankedDocs = List.of(
83+
new RankedDocsResults.RankedDoc(0, 1.0f, "text 1"),
84+
new RankedDocsResults.RankedDoc(1, 3.0f, "text 2"),
85+
new RankedDocsResults.RankedDoc(2, 2.0f, "text 3")
86+
);
87+
float[] scores = subject.extractScoresFromRankedDocs(rankedDocs);
88+
assertArrayEquals(new float[] { 1.0f, 3.0f, 2.0f }, scores, 0.0f);
89+
}
90+
91+
public void testExtractScoresFromSingleSnippets() {
92+
93+
List<RankedDocsResults.RankedDoc> rankedDocs = List.of(
94+
new RankedDocsResults.RankedDoc(0, 1.0f, "text 1"),
95+
new RankedDocsResults.RankedDoc(1, 2.5f, "text 2"),
96+
new RankedDocsResults.RankedDoc(2, 1.5f, "text 3")
97+
);
98+
RankFeatureDoc[] featureDocs = new RankFeatureDoc[] {
99+
createRankFeatureDoc(0, 1.0f, 0, List.of("text 1")),
100+
createRankFeatureDoc(1, 3.0f, 1, List.of("text 2")),
101+
createRankFeatureDoc(2, 2.0f, 0, List.of("text 3")) };
102+
103+
float[] scores = withSnippets.extractScoresFromRankedSnippets(rankedDocs, featureDocs);
104+
// Returned cores are from the snippet, not the whole text
105+
assertArrayEquals(new float[] { 1.0f, 2.5f, 1.5f }, scores, 0.0f);
106+
}
107+
108+
public void testExtractScoresFromMultipleSnippets() {
109+
110+
List<RankedDocsResults.RankedDoc> rankedDocs = List.of(
111+
new RankedDocsResults.RankedDoc(0, 1.0f, "this is text 1"),
112+
new RankedDocsResults.RankedDoc(1, 2.5f, "some more text"),
113+
new RankedDocsResults.RankedDoc(2, 1.5f, "yet more text"),
114+
new RankedDocsResults.RankedDoc(3, 3.0f, "this is text 2"),
115+
new RankedDocsResults.RankedDoc(4, 2.0f, "this is text 3"),
116+
new RankedDocsResults.RankedDoc(5, 1.5f, "oh look, more text")
117+
);
118+
RankFeatureDoc[] featureDocs = new RankFeatureDoc[] {
119+
createRankFeatureDoc(0, 1.0f, 0, List.of("this is text 1", "some more text")),
120+
createRankFeatureDoc(1, 3.0f, 1, List.of("yet more text", "this is text 2")),
121+
createRankFeatureDoc(2, 2.0f, 0, List.of("this is text 3", "oh look, more text")) };
122+
123+
float[] scores = withSnippets.extractScoresFromRankedSnippets(rankedDocs, featureDocs);
124+
// Returned scores are from the best-ranking snippet, not the whole text
125+
assertArrayEquals(new float[] { 2.5f, 3.0f, 2.0f }, scores, 0.0f);
126+
}
127+
128+
private RankFeatureDoc createRankFeatureDoc(int doc, float score, int shardIndex, List<String> featureData) {
129+
RankFeatureDoc featureDoc = new RankFeatureDoc(doc, score, shardIndex);
130+
featureDoc.featureData(featureData);
131+
return featureDoc;
132+
}
133+
67134
}

x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -298,8 +298,8 @@ setup:
298298
- match: { hits.hits.0._id: "doc_2" }
299299
- match: { hits.hits.1._id: "doc_1" }
300300

301-
- match: {hits.hits.0._explanation.description: "/text_similarity_reranker.match.using.inference.endpoint:.\\[my-rerank-model\\].on.document.field:.\\[inference_text_field\\].*/" }
302-
- match: {hits.hits.0._explanation.details.0.details.0.description: "/subtopic.*astronomy.*/" }
301+
- match: { hits.hits.0._explanation.description: "/text_similarity_reranker.match.using.inference.endpoint:.\\[my-rerank-model\\].on.document.field:.\\[inference_text_field\\].*/" }
302+
- match: { hits.hits.0._explanation.details.0.details.0.description: "/subtopic.*astronomy.*/" }
303303

304304
---
305305
"text similarity reranker properly handles aliases":
@@ -448,7 +448,7 @@ setup:
448448
retriever:
449449
standard:
450450
query:
451-
match_all: {}
451+
match_all: { }
452452
rank_window_size: 10
453453
inference_id: my-rerank-model
454454
inference_text: "How often does the moon hide the sun?"
@@ -477,7 +477,7 @@ setup:
477477
retriever:
478478
standard:
479479
query:
480-
match_all: {}
480+
match_all: { }
481481
rank_window_size: 10
482482
inference_id: my-rerank-model
483483
inference_text: "How often does the moon hide the sun?"
@@ -487,3 +487,34 @@ setup:
487487

488488
- match: { hits.total.value: 0 }
489489
- length: { hits.hits: 0 }
490+
491+
492+
---
493+
"Text similarity reranker specifying number of snippets must be > 0":
494+
495+
- requires:
496+
cluster_features: "text_similarity_reranker_snippets"
497+
reason: snippets introduced in 9.2.0
498+
499+
- do:
500+
catch: /num_snippets must be greater than 0/
501+
search:
502+
index: test-index
503+
body:
504+
track_total_hits: true
505+
fields: [ "text", "topic" ]
506+
retriever:
507+
text_similarity_reranker:
508+
retriever:
509+
standard:
510+
query:
511+
match_all: { }
512+
rank_window_size: 10
513+
inference_id: my-rerank-model
514+
inference_text: "How often does the moon hide the sun?"
515+
field: inference_text_field
516+
snippets:
517+
num_snippets: 0
518+
size: 10
519+
520+
- match: { status: 400 }

0 commit comments

Comments
 (0)