diff --git a/docs/changelog/129369.yaml b/docs/changelog/129369.yaml new file mode 100644 index 0000000000000..c53dc5ae0d25f --- /dev/null +++ b/docs/changelog/129369.yaml @@ -0,0 +1,6 @@ +pr: 129369 +summary: Support semantic reranking using contextual snippets instead of entire field + text +area: Relevance +type: enhancement +issues: [] diff --git a/qa/ccs-common-rest/src/yamlRestTest/java/org/elasticsearch/test/rest/yaml/CcsCommonYamlTestSuiteIT.java b/qa/ccs-common-rest/src/yamlRestTest/java/org/elasticsearch/test/rest/yaml/CcsCommonYamlTestSuiteIT.java index 37480ba3c14e0..0b7893fc26b1d 100644 --- a/qa/ccs-common-rest/src/yamlRestTest/java/org/elasticsearch/test/rest/yaml/CcsCommonYamlTestSuiteIT.java +++ b/qa/ccs-common-rest/src/yamlRestTest/java/org/elasticsearch/test/rest/yaml/CcsCommonYamlTestSuiteIT.java @@ -92,7 +92,8 @@ public class CcsCommonYamlTestSuiteIT extends ESClientYamlSuiteTestCase { .feature(FeatureFlag.TIME_SERIES_MODE) .feature(FeatureFlag.SUB_OBJECTS_AUTO_ENABLED) .feature(FeatureFlag.IVF_FORMAT) - .feature(FeatureFlag.SYNTHETIC_VECTORS); + .feature(FeatureFlag.SYNTHETIC_VECTORS) + .feature(FeatureFlag.RERANK_SNIPPETS); private static ElasticsearchCluster remoteCluster = ElasticsearchCluster.local() .name(REMOTE_CLUSTER_NAME) diff --git a/qa/ccs-common-rest/src/yamlRestTest/java/org/elasticsearch/test/rest/yaml/RcsCcsCommonYamlTestSuiteIT.java b/qa/ccs-common-rest/src/yamlRestTest/java/org/elasticsearch/test/rest/yaml/RcsCcsCommonYamlTestSuiteIT.java index a79aeba690f57..577bb4be8629b 100644 --- a/qa/ccs-common-rest/src/yamlRestTest/java/org/elasticsearch/test/rest/yaml/RcsCcsCommonYamlTestSuiteIT.java +++ b/qa/ccs-common-rest/src/yamlRestTest/java/org/elasticsearch/test/rest/yaml/RcsCcsCommonYamlTestSuiteIT.java @@ -94,6 +94,7 @@ public class RcsCcsCommonYamlTestSuiteIT extends ESClientYamlSuiteTestCase { .feature(FeatureFlag.SUB_OBJECTS_AUTO_ENABLED) .feature(FeatureFlag.IVF_FORMAT) .feature(FeatureFlag.SYNTHETIC_VECTORS) + .feature(FeatureFlag.RERANK_SNIPPETS) .user("test_admin", "x-pack-test-password"); private static ElasticsearchCluster fulfillingCluster = ElasticsearchCluster.local() diff --git a/qa/smoke-test-multinode/src/yamlRestTest/java/org/elasticsearch/smoketest/SmokeTestMultiNodeClientYamlTestSuiteIT.java b/qa/smoke-test-multinode/src/yamlRestTest/java/org/elasticsearch/smoketest/SmokeTestMultiNodeClientYamlTestSuiteIT.java index b50df4183e2ab..2be870dbf4ea5 100644 --- a/qa/smoke-test-multinode/src/yamlRestTest/java/org/elasticsearch/smoketest/SmokeTestMultiNodeClientYamlTestSuiteIT.java +++ b/qa/smoke-test-multinode/src/yamlRestTest/java/org/elasticsearch/smoketest/SmokeTestMultiNodeClientYamlTestSuiteIT.java @@ -40,6 +40,7 @@ public class SmokeTestMultiNodeClientYamlTestSuiteIT extends ESClientYamlSuiteTe .feature(FeatureFlag.USE_LUCENE101_POSTINGS_FORMAT) .feature(FeatureFlag.IVF_FORMAT) .feature(FeatureFlag.SYNTHETIC_VECTORS) + .feature(FeatureFlag.RERANK_SNIPPETS) .build(); public SmokeTestMultiNodeClientYamlTestSuiteIT(@Name("yaml") ClientYamlTestCandidate testCandidate) { diff --git a/rest-api-spec/src/yamlRestTest/java/org/elasticsearch/test/rest/ClientYamlTestSuiteIT.java b/rest-api-spec/src/yamlRestTest/java/org/elasticsearch/test/rest/ClientYamlTestSuiteIT.java index de2a0859dcf7b..739b6fd755aa8 100644 --- a/rest-api-spec/src/yamlRestTest/java/org/elasticsearch/test/rest/ClientYamlTestSuiteIT.java +++ b/rest-api-spec/src/yamlRestTest/java/org/elasticsearch/test/rest/ClientYamlTestSuiteIT.java @@ -40,6 +40,7 @@ public class ClientYamlTestSuiteIT extends ESClientYamlSuiteTestCase { .feature(FeatureFlag.USE_LUCENE101_POSTINGS_FORMAT) .feature(FeatureFlag.IVF_FORMAT) .feature(FeatureFlag.SYNTHETIC_VECTORS) + .feature(FeatureFlag.RERANK_SNIPPETS) .build(); public ClientYamlTestSuiteIT(@Name("yaml") ClientYamlTestCandidate testCandidate) { diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/rank/FieldBasedRerankerIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/rank/FieldBasedRerankerIT.java index cefabe277eb31..28177da0e341a 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/rank/FieldBasedRerankerIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/rank/FieldBasedRerankerIT.java @@ -193,7 +193,7 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length]; for (int i = 0; i < hits.getHits().length; i++) { rankFeatureDocs[i] = new RankFeatureDoc(hits.getHits()[i].docId(), hits.getHits()[i].getScore(), shardId); - rankFeatureDocs[i].featureData(hits.getHits()[i].field(field).getValue().toString()); + rankFeatureDocs[i].featureData(List.of(hits.getHits()[i].field(field).getValue().toString())); } return new RankFeatureShardResult(rankFeatureDocs); } catch (Exception ex) { @@ -210,7 +210,7 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { float[] scores = new float[featureDocs.length]; for (int i = 0; i < featureDocs.length; i++) { - scores[i] = Float.parseFloat(featureDocs[i].featureData); + scores[i] = Float.parseFloat(featureDocs[i].featureData.get(0)); } scoreListener.onResponse(scores); } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/rank/MockedRequestActionBasedRerankerIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/rank/MockedRequestActionBasedRerankerIT.java index fa9e58d179a26..465b853ae436d 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/rank/MockedRequestActionBasedRerankerIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/rank/MockedRequestActionBasedRerankerIT.java @@ -275,7 +275,7 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener featureData = Arrays.stream(featureDocs).map(x -> x.featureData).toList(); + List featureData = Arrays.stream(featureDocs).map(x -> x.featureData).flatMap(List::stream).toList(); TestRerankingActionRequest request = generateRequest(featureData); try { ActionType action = actionType(); diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 362738bc9d60c..0fea2a42b7f3e 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -352,6 +352,7 @@ static TransportVersion def(int id) { public static final TransportVersion ESQL_SAMPLE_OPERATOR_STATUS = def(9_127_0_00); public static final TransportVersion ESQL_TOPN_TIMINGS = def(9_128_0_00); public static final TransportVersion NODE_WEIGHTS_ADDED_TO_NODE_BALANCE_STATS = def(9_129_0_00); + public static final TransportVersion RERANK_SNIPPETS = def(9_130_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java b/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java index a3701f20583db..f56d40713b22f 100644 --- a/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java @@ -1187,9 +1187,11 @@ public SearchSourceBuilder rewrite(QueryRewriteContext context) throws IOExcepti sliceBuilder, sorts, rescoreBuilders, - highlightBuilder + highlightBuilder, + rankBuilder ) )); + if (retrieverBuilder != null) { var newRetriever = retrieverBuilder.rewrite(context); if (newRetriever != retrieverBuilder) { @@ -1205,6 +1207,11 @@ public SearchSourceBuilder rewrite(QueryRewriteContext context) throws IOExcepti } } + RankBuilder rankBuilder = null; + if (this.rankBuilder != null) { + rankBuilder = this.rankBuilder.rewrite(context); + } + List subSearchSourceBuilders = Rewriteable.rewrite(this.subSearchSourceBuilders, context); QueryBuilder postQueryBuilder = null; if (this.postQueryBuilder != null) { @@ -1229,7 +1236,8 @@ public SearchSourceBuilder rewrite(QueryRewriteContext context) throws IOExcepti || aggregations != this.aggregations || rescoreBuilders != this.rescoreBuilders || sorts != this.sorts - || this.highlightBuilder != highlightBuilder; + || this.highlightBuilder != highlightBuilder + || this.rankBuilder != rankBuilder; if (rewritten) { return shallowCopy( subSearchSourceBuilders, @@ -1239,7 +1247,8 @@ public SearchSourceBuilder rewrite(QueryRewriteContext context) throws IOExcepti this.sliceBuilder, sorts, rescoreBuilders, - highlightBuilder + highlightBuilder, + rankBuilder ); } return this; @@ -1257,7 +1266,8 @@ public SearchSourceBuilder shallowCopy() { sliceBuilder, sorts, rescoreBuilders, - highlightBuilder + highlightBuilder, + rankBuilder ); } @@ -1274,7 +1284,8 @@ private SearchSourceBuilder shallowCopy( SliceBuilder slice, List> sorts, List rescoreBuilders, - HighlightBuilder highlightBuilder + HighlightBuilder highlightBuilder, + RankBuilder rankBuilder ) { SearchSourceBuilder rewrittenBuilder = new SearchSourceBuilder(); rewrittenBuilder.aggregations = aggregations; diff --git a/server/src/main/java/org/elasticsearch/search/fetch/subphase/highlight/HighlightBuilder.java b/server/src/main/java/org/elasticsearch/search/fetch/subphase/highlight/HighlightBuilder.java index 1997601e73f6d..2a48e8338e377 100644 --- a/server/src/main/java/org/elasticsearch/search/fetch/subphase/highlight/HighlightBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/fetch/subphase/highlight/HighlightBuilder.java @@ -399,7 +399,7 @@ public Field(String name) { this.name = name; } - private Field(Field template, QueryBuilder builder) { + Field(Field template, QueryBuilder builder) { super(template, builder); name = template.name; fragmentOffset = template.fragmentOffset; diff --git a/server/src/main/java/org/elasticsearch/search/fetch/subphase/highlight/SearchHighlightContext.java b/server/src/main/java/org/elasticsearch/search/fetch/subphase/highlight/SearchHighlightContext.java index 631a75a355abf..a85ae92c24bcf 100644 --- a/server/src/main/java/org/elasticsearch/search/fetch/subphase/highlight/SearchHighlightContext.java +++ b/server/src/main/java/org/elasticsearch/search/fetch/subphase/highlight/SearchHighlightContext.java @@ -40,7 +40,7 @@ public static class Field { private final String field; private final FieldOptions fieldOptions; - Field(String field, FieldOptions fieldOptions) { + public Field(String field, FieldOptions fieldOptions) { assert field != null; assert fieldOptions != null; this.field = field; diff --git a/server/src/main/java/org/elasticsearch/search/rank/RankBuilder.java b/server/src/main/java/org/elasticsearch/search/rank/RankBuilder.java index af53273f8bc93..437be05cd1e2f 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/RankBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/rank/RankBuilder.java @@ -19,6 +19,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.UpdateForV10; import org.elasticsearch.features.NodeFeature; +import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.search.SearchService; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext; @@ -60,6 +61,10 @@ public final void writeTo(StreamOutput out) throws IOException { doWriteTo(out); } + public RankBuilder rewrite(QueryRewriteContext queryRewriteContext) throws IOException { + return this; + } + protected abstract void doWriteTo(StreamOutput out) throws IOException; @Override diff --git a/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankShardContext.java b/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankShardContext.java index 368661c66de5d..487278afba6e0 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankShardContext.java +++ b/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankShardContext.java @@ -11,6 +11,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.rank.RankShardResult; /** @@ -37,4 +38,13 @@ public String getField() { */ @Nullable public abstract RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId); + + /** + * Prepares a SearchContext with any additional information needed before executing + * commands on shards. + * @param context SearchContext + */ + public void prepareForFetch(SearchContext context) { + // Default no-op + } } diff --git a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java index cd8d9392aced8..afbb32fd829f7 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java +++ b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java @@ -10,12 +10,14 @@ package org.elasticsearch.search.rank.feature; import org.apache.lucene.search.Explanation; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; +import java.util.List; import java.util.Objects; /** @@ -26,7 +28,7 @@ public class RankFeatureDoc extends RankDoc { public static final String NAME = "rank_feature_doc"; // TODO: update to support more than 1 fields; and not restrict to string data - public String featureData; + public List featureData; public RankFeatureDoc(int doc, float score, int shardIndex) { super(doc, score, shardIndex); @@ -34,7 +36,12 @@ public RankFeatureDoc(int doc, float score, int shardIndex) { public RankFeatureDoc(StreamInput in) throws IOException { super(in); - featureData = in.readOptionalString(); + if (in.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) { + featureData = in.readOptionalStringCollectionAsList(); + } else { + String featureDataString = in.readOptionalString(); + featureData = featureDataString == null ? null : List.of(featureDataString); + } } @Override @@ -42,13 +49,17 @@ public Explanation explain(Explanation[] sources, String[] queryNames) { throw new UnsupportedOperationException("explain is not supported for {" + getClass() + "}"); } - public void featureData(String featureData) { + public void featureData(List featureData) { this.featureData = featureData; } @Override protected void doWriteTo(StreamOutput out) throws IOException { - out.writeOptionalString(featureData); + if (out.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) { + out.writeOptionalStringCollection(featureData); + } else { + out.writeOptionalString(featureData.get(0)); + } } @Override @@ -59,7 +70,7 @@ protected boolean doEquals(RankDoc rd) { @Override protected int doHashCode() { - return Objects.hashCode(featureData); + return Objects.hash(featureData); } @Override @@ -69,6 +80,6 @@ public String getWriteableName() { @Override protected void doToXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("featureData", featureData); + builder.array("featureData", featureData); } } diff --git a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java index 4374c06da365d..c7a2d0d945420 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java +++ b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java @@ -48,10 +48,10 @@ public static void prepareForFetch(SearchContext searchContext, RankFeatureShard RankFeaturePhaseRankShardContext rankFeaturePhaseRankShardContext = shardContext(searchContext); if (rankFeaturePhaseRankShardContext != null) { - assert rankFeaturePhaseRankShardContext.getField() != null : "field must not be null"; - searchContext.fetchFieldsContext( - new FetchFieldsContext(Collections.singletonList(new FieldAndFormat(rankFeaturePhaseRankShardContext.getField(), null))) - ); + String field = rankFeaturePhaseRankShardContext.getField(); + assert field != null : "field must not be null"; + searchContext.fetchFieldsContext(new FetchFieldsContext(Collections.singletonList(new FieldAndFormat(field, null)))); + rankFeaturePhaseRankShardContext.prepareForFetch(searchContext); searchContext.storedFieldsContext(StoredFieldsContext.fromList(Collections.singletonList(StoredFieldsContext._NONE_))); searchContext.addFetchResult(); Arrays.sort(request.getDocIds()); diff --git a/server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java b/server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java index 96867fb1d190b..7574315753894 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java +++ b/server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java @@ -20,6 +20,7 @@ import org.elasticsearch.search.rank.feature.RankFeatureShardResult; import java.util.Arrays; +import java.util.List; /** * The {@code ReRankingRankFeaturePhaseRankShardContext} is handles the {@code SearchHits} generated from the {@code RankFeatureShardPhase} @@ -37,15 +38,7 @@ public RerankingRankFeaturePhaseRankShardContext(String field) { @Override public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) { try { - RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length]; - for (int i = 0; i < hits.getHits().length; i++) { - rankFeatureDocs[i] = new RankFeatureDoc(hits.getHits()[i].docId(), hits.getHits()[i].getScore(), shardId); - DocumentField docField = hits.getHits()[i].field(field); - if (docField != null) { - rankFeatureDocs[i].featureData(docField.getValue().toString()); - } - } - return new RankFeatureShardResult(rankFeatureDocs); + return doBuildRankFeatureShardResult(hits, shardId); } catch (Exception ex) { logger.warn( "Error while fetching feature data for {field: [" @@ -58,4 +51,16 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) return null; } } + + protected RankShardResult doBuildRankFeatureShardResult(SearchHits hits, int shardId) { + RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length]; + for (int i = 0; i < hits.getHits().length; i++) { + rankFeatureDocs[i] = new RankFeatureDoc(hits.getHits()[i].docId(), hits.getHits()[i].getScore(), shardId); + DocumentField docField = hits.getHits()[i].field(field); + if (docField != null) { + rankFeatureDocs[i].featureData(List.of(docField.getValue().toString())); + } + } + return new RankFeatureShardResult(rankFeatureDocs); + } } diff --git a/server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java index ae483a1dbe13a..c22069b6abf73 100644 --- a/server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java @@ -73,7 +73,7 @@ public class RankFeaturePhaseTests extends ESTestCase { defaultRankFeaturePhaseRankCoordinatorContext(DEFAULT_SIZE, DEFAULT_FROM, DEFAULT_RANK_WINDOW_SIZE) ); - private record ExpectedRankFeatureDoc(int doc, int rank, float score, String featureData) {} + private record ExpectedRankFeatureDoc(int doc, int rank, float score, List featureData) {} public void testRankFeaturePhaseWith1Shard() { // request params used within SearchSourceBuilder and *RankContext classes @@ -145,8 +145,8 @@ public void sendExecuteRankFeature( SearchPhaseResult shard1Result = rankPhaseResults.getAtomicArray().get(0); List expectedShardResults = List.of( - new ExpectedRankFeatureDoc(1, 1, 110.0F, "ranked_1"), - new ExpectedRankFeatureDoc(2, 2, 109.0F, "ranked_2") + new ExpectedRankFeatureDoc(1, 1, 110.0F, List.of("ranked_1")), + new ExpectedRankFeatureDoc(2, 2, 109.0F, List.of("ranked_2")) ); List expectedFinalResults = new ArrayList<>(expectedShardResults); assertShardResults(shard1Result, expectedShardResults); @@ -263,19 +263,19 @@ public void sendExecuteRankFeature( assertEquals(2, rankPhaseResults.getSuccessfulResults().count()); SearchPhaseResult shard1Result = rankPhaseResults.getAtomicArray().get(0); - List expectedShard1Results = List.of(new ExpectedRankFeatureDoc(1, 1, 110.0F, "ranked_1")); + List expectedShard1Results = List.of(new ExpectedRankFeatureDoc(1, 1, 110.0F, List.of("ranked_1"))); assertShardResults(shard1Result, expectedShard1Results); SearchPhaseResult shard2Result = rankPhaseResults.getAtomicArray().get(1); - List expectedShard2Results = List.of(new ExpectedRankFeatureDoc(2, 1, 109.0F, "ranked_2")); + List expectedShard2Results = List.of(new ExpectedRankFeatureDoc(2, 1, 109.0F, List.of("ranked_2"))); assertShardResults(shard2Result, expectedShard2Results); SearchPhaseResult shard3Result = rankPhaseResults.getAtomicArray().get(2); assertNull(shard3Result); List expectedFinalResults = List.of( - new ExpectedRankFeatureDoc(1, 1, 110.0F, "ranked_1"), - new ExpectedRankFeatureDoc(2, 2, 109.0F, "ranked_2") + new ExpectedRankFeatureDoc(1, 1, 110.0F, List.of("ranked_1")), + new ExpectedRankFeatureDoc(2, 2, 109.0F, List.of("ranked_2")) ); assertFinalResults(finalResults[0], expectedFinalResults); } finally { @@ -379,7 +379,7 @@ public void sendExecuteRankFeature( assertNull(shard1Result); SearchPhaseResult shard2Result = rankPhaseResults.getAtomicArray().get(1); - List expectedShard2Results = List.of(new ExpectedRankFeatureDoc(2, 1, 109.0F, "ranked_2")); + List expectedShard2Results = List.of(new ExpectedRankFeatureDoc(2, 1, 109.0F, List.of("ranked_2"))); List expectedFinalResults = new ArrayList<>(expectedShard2Results); assertShardResults(shard2Result, expectedShard2Results); assertFinalResults(finalResults[0], expectedFinalResults); @@ -609,22 +609,21 @@ public void sendExecuteRankFeature( assertEquals(2, rankPhaseResults.getSuccessfulResults().count()); SearchPhaseResult shard1Result = rankPhaseResults.getAtomicArray().get(0); - List expectedShard1Results = List.of(new ExpectedRankFeatureDoc(1, 1, 110.0F, "ranked_1")); + List expectedShard1Results = List.of(new ExpectedRankFeatureDoc(1, 1, 110.0F, List.of("ranked_1"))); assertShardResults(shard1Result, expectedShard1Results); SearchPhaseResult shard2Result = rankPhaseResults.getAtomicArray().get(1); List expectedShard2Results = List.of( - new ExpectedRankFeatureDoc(11, 1, 200.0F, "ranked_11"), - new ExpectedRankFeatureDoc(2, 2, 109.0F, "ranked_2"), - new ExpectedRankFeatureDoc(200, 3, 101.0F, "ranked_200") - + new ExpectedRankFeatureDoc(11, 1, 200.0F, List.of("ranked_11")), + new ExpectedRankFeatureDoc(2, 2, 109.0F, List.of("ranked_2")), + new ExpectedRankFeatureDoc(200, 3, 101.0F, List.of("ranked_200")) ); assertShardResults(shard2Result, expectedShard2Results); SearchPhaseResult shard3Result = rankPhaseResults.getAtomicArray().get(2); assertNull(shard3Result); - List expectedFinalResults = List.of(new ExpectedRankFeatureDoc(1, 2, 110.0F, "ranked_1")); + List expectedFinalResults = List.of(new ExpectedRankFeatureDoc(1, 2, 110.0F, List.of("ranked_1"))); assertFinalResults(finalResults[0], expectedFinalResults); } finally { rankFeaturePhase.rankPhaseResults.close(); @@ -748,19 +747,21 @@ public void sendExecuteRankFeature( assertEquals(2, rankPhaseResults.getSuccessfulResults().count()); SearchPhaseResult shard1Result = rankPhaseResults.getAtomicArray().get(0); - List expectedShardResults = List.of(new ExpectedRankFeatureDoc(1, 1, 110.0F, "ranked_1")); + List expectedShardResults = List.of(new ExpectedRankFeatureDoc(1, 1, 110.0F, List.of("ranked_1"))); assertShardResults(shard1Result, expectedShardResults); SearchPhaseResult shard2Result = rankPhaseResults.getAtomicArray().get(1); - List expectedShard2Results = List.of(new ExpectedRankFeatureDoc(11, 1, 200.0F, "ranked_11")); + List expectedShard2Results = List.of( + new ExpectedRankFeatureDoc(11, 1, 200.0F, List.of("ranked_11")) + ); assertShardResults(shard2Result, expectedShard2Results); SearchPhaseResult shard3Result = rankPhaseResults.getAtomicArray().get(2); assertNull(shard3Result); List expectedFinalResults = List.of( - new ExpectedRankFeatureDoc(11, 1, 200.0F, "ranked_11"), - new ExpectedRankFeatureDoc(1, 2, 110.0F, "ranked_1") + new ExpectedRankFeatureDoc(11, 1, 200.0F, List.of("ranked_11")), + new ExpectedRankFeatureDoc(1, 2, 110.0F, List.of("ranked_1")) ); assertFinalResults(finalResults[0], expectedFinalResults); } finally { @@ -813,7 +814,7 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) SearchHit hit = hits.getHits()[i]; rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId); rankFeatureDocs[i].score += 100f; - rankFeatureDocs[i].featureData("ranked_" + hit.docId()); + rankFeatureDocs[i].featureData(List.of("ranked_" + hit.docId())); rankFeatureDocs[i].rank = i + 1; } return new RankFeatureShardResult(rankFeatureDocs); diff --git a/server/src/test/java/org/elasticsearch/search/SearchServiceSingleNodeTests.java b/server/src/test/java/org/elasticsearch/search/SearchServiceSingleNodeTests.java index e9dd4d8674dea..45922280b74d3 100644 --- a/server/src/test/java/org/elasticsearch/search/SearchServiceSingleNodeTests.java +++ b/server/src/test/java/org/elasticsearch/search/SearchServiceSingleNodeTests.java @@ -523,7 +523,7 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) for (int i = 0; i < hits.getHits().length; i++) { SearchHit hit = hits.getHits()[i]; rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId); - rankFeatureDocs[i].featureData(hit.getFields().get(rankFeatureFieldName).getValue()); + rankFeatureDocs[i].featureData(parseFeatureData(hit, rankFeatureFieldName)); rankFeatureDocs[i].score = (numDocs - i) + randomFloat(); rankFeatureDocs[i].rank = i + 1; } @@ -580,7 +580,7 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) assertEquals(sortedRankWindowDocs.size(), rankFeatureShardResult.rankFeatureDocs.length); for (int i = 0; i < sortedRankWindowDocs.size(); i++) { assertEquals((long) sortedRankWindowDocs.get(i), rankFeatureShardResult.rankFeatureDocs[i].doc); - assertEquals(rankFeatureShardResult.rankFeatureDocs[i].featureData, "aardvark_" + sortedRankWindowDocs.get(i)); + assertEquals(rankFeatureShardResult.rankFeatureDocs[i].featureData, List.of("aardvark_" + sortedRankWindowDocs.get(i))); } List globalTopKResults = randomNonEmptySubsetOf( @@ -760,7 +760,7 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) for (int i = 0; i < hits.getHits().length; i++) { SearchHit hit = hits.getHits()[i]; rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId); - rankFeatureDocs[i].featureData(hit.getFields().get(rankFeatureFieldName).getValue()); + rankFeatureDocs[i].featureData(parseFeatureData(hit, rankFeatureFieldName)); rankFeatureDocs[i].score = randomFloat(); rankFeatureDocs[i].rank = i + 1; } @@ -887,7 +887,7 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) for (int i = 0; i < hits.getHits().length; i++) { SearchHit hit = hits.getHits()[i]; rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId); - rankFeatureDocs[i].featureData(hit.getFields().get(rankFeatureFieldName).getValue()); + rankFeatureDocs[i].featureData(parseFeatureData(hit, rankFeatureFieldName)); rankFeatureDocs[i].score = randomFloat(); rankFeatureDocs[i].rank = i + 1; } @@ -1151,7 +1151,7 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) for (int i = 0; i < hits.getHits().length; i++) { SearchHit hit = hits.getHits()[i]; rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId); - rankFeatureDocs[i].featureData(hit.getFields().get(rankFeatureFieldName).getValue()); + rankFeatureDocs[i].featureData(parseFeatureData(hit, rankFeatureFieldName)); rankFeatureDocs[i].score = randomFloat(); rankFeatureDocs[i].rank = i + 1; } @@ -2904,6 +2904,13 @@ private static ReaderContext createReaderContext(IndexService indexService, Inde ); } + private List parseFeatureData(SearchHit hit, String fieldName) { + Object fieldValue = hit.getFields().get(fieldName).getValue(); + @SuppressWarnings("unchecked") + List fieldValues = fieldValue instanceof List ? (List) fieldValue : List.of(String.valueOf(fieldValue)); + return fieldValues; + } + private static class TestRewriteCounterQueryBuilder extends AbstractQueryBuilder { final int asyncRewriteCount; diff --git a/server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java b/server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java index a0ccdacd62eec..1ffe05baf30a1 100644 --- a/server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java @@ -160,7 +160,12 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) for (int i = 0; i < hits.getHits().length; i++) { SearchHit hit = hits.getHits()[i]; rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId); - rankFeatureDocs[i].featureData(hit.getFields().get(field).getValue()); + Object fieldValue = hit.getFields().get(field).getValue(); + @SuppressWarnings("unchecked") + List featureData = fieldValue instanceof List + ? (List) fieldValue + : List.of(String.valueOf(fieldValue)); + rankFeatureDocs[i].featureData(featureData); rankFeatureDocs[i].rank = i + 1; } return new RankFeatureShardResult(rankFeatureDocs); @@ -279,7 +284,14 @@ public void testPrepareForFetchWhileTaskIsCancelled() { public void testProcessFetch() { final String fieldName = "some_field"; int numDocs = randomIntBetween(15, 30); - Map expectedFieldData = Map.of(4, "doc_4_aardvark", 9, "doc_9_aardvark", numDocs - 1, "last_doc_aardvark"); + Map> expectedFieldData = Map.of( + 4, + List.of("doc_4_aardvark"), + 9, + List.of("doc_9_aardvark"), + numDocs - 1, + List.of("last_doc_aardvark") + ); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchSourceBuilder.rankBuilder(getRankBuilder(fieldName)); diff --git a/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java b/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java index d0d6a4825e1d1..888c4afbf3326 100644 --- a/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java +++ b/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java @@ -23,7 +23,8 @@ public enum FeatureFlag { IVF_FORMAT("es.ivf_format_feature_flag_enabled=true", Version.fromString("9.1.0"), null), LOGS_STREAM("es.logs_stream_feature_flag_enabled=true", Version.fromString("9.1.0"), null), PATTERNED_TEXT("es.patterned_text_feature_flag_enabled=true", Version.fromString("9.1.0"), null), - SYNTHETIC_VECTORS("es.mapping_synthetic_vectors=true", Version.fromString("9.2.0"), null); + SYNTHETIC_VECTORS("es.mapping_synthetic_vectors=true", Version.fromString("9.2.0"), null), + RERANK_SNIPPETS("es.text_similarity_reranker_snippets=true", Version.fromString("9.2.0"), null); public final String systemProperty; public final Version from; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java index 00f40e903d1ff..c35ac4c413773 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java @@ -13,6 +13,7 @@ import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder; +import java.util.HashSet; import java.util.Set; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.SEMANTIC_TEXT_EXCLUDE_SUB_FIELDS_FROM_FIELD_CAPS; @@ -23,6 +24,8 @@ import static org.elasticsearch.xpack.inference.queries.SemanticKnnVectorQueryRewriteInterceptor.SEMANTIC_KNN_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED; import static org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor.SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED; import static org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor.SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED; +import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.RERANK_SNIPPETS; +import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_SNIPPETS; /** * Provides inference features. @@ -46,35 +49,41 @@ public class InferenceFeatures implements FeatureSpecification { @Override public Set getTestFeatures() { - return Set.of( - SemanticTextFieldMapper.SEMANTIC_TEXT_IN_OBJECT_FIELD_FIX, - SemanticTextFieldMapper.SEMANTIC_TEXT_SINGLE_FIELD_UPDATE_FIX, - SemanticTextFieldMapper.SEMANTIC_TEXT_DELETE_FIX, - SemanticTextFieldMapper.SEMANTIC_TEXT_ZERO_SIZE_FIX, - SemanticTextFieldMapper.SEMANTIC_TEXT_ALWAYS_EMIT_INFERENCE_ID_FIX, - SemanticTextFieldMapper.SEMANTIC_TEXT_SKIP_INFERENCE_FIELDS, - SEMANTIC_TEXT_HIGHLIGHTER, - SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED, - SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED, - SemanticInferenceMetadataFieldsMapper.EXPLICIT_NULL_FIXES, - SEMANTIC_KNN_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED, - TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_ALIAS_HANDLING_FIX, - TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_MINSCORE_FIX, - SemanticInferenceMetadataFieldsMapper.INFERENCE_METADATA_FIELDS_ENABLED_BY_DEFAULT, - SEMANTIC_TEXT_HIGHLIGHTER_DEFAULT, - SEMANTIC_KNN_FILTER_FIX, - TEST_RERANKING_SERVICE_PARSE_TEXT_AS_SCORE, - SemanticTextFieldMapper.SEMANTIC_TEXT_BIT_VECTOR_SUPPORT, - SemanticTextFieldMapper.SEMANTIC_TEXT_HANDLE_EMPTY_INPUT, - TEST_RULE_RETRIEVER_WITH_INDICES_THAT_DONT_RETURN_RANK_DOCS, - SEMANTIC_TEXT_SUPPORT_CHUNKING_CONFIG, - SEMANTIC_TEXT_MATCH_ALL_HIGHLIGHTER, - SEMANTIC_TEXT_EXCLUDE_SUB_FIELDS_FROM_FIELD_CAPS, - SEMANTIC_TEXT_INDEX_OPTIONS, - COHERE_V2_API, - SEMANTIC_TEXT_INDEX_OPTIONS_WITH_DEFAULTS, - SEMANTIC_QUERY_REWRITE_INTERCEPTORS_PROPAGATE_BOOST_AND_QUERY_NAME_FIX, - SEMANTIC_TEXT_HIGHLIGHTING_FLAT + var testFeatures = new HashSet<>( + Set.of( + SemanticTextFieldMapper.SEMANTIC_TEXT_IN_OBJECT_FIELD_FIX, + SemanticTextFieldMapper.SEMANTIC_TEXT_SINGLE_FIELD_UPDATE_FIX, + SemanticTextFieldMapper.SEMANTIC_TEXT_DELETE_FIX, + SemanticTextFieldMapper.SEMANTIC_TEXT_ZERO_SIZE_FIX, + SemanticTextFieldMapper.SEMANTIC_TEXT_ALWAYS_EMIT_INFERENCE_ID_FIX, + SemanticTextFieldMapper.SEMANTIC_TEXT_SKIP_INFERENCE_FIELDS, + SEMANTIC_TEXT_HIGHLIGHTER, + SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED, + SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED, + SemanticInferenceMetadataFieldsMapper.EXPLICIT_NULL_FIXES, + SEMANTIC_KNN_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED, + TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_ALIAS_HANDLING_FIX, + TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_MINSCORE_FIX, + SemanticInferenceMetadataFieldsMapper.INFERENCE_METADATA_FIELDS_ENABLED_BY_DEFAULT, + SEMANTIC_TEXT_HIGHLIGHTER_DEFAULT, + SEMANTIC_KNN_FILTER_FIX, + TEST_RERANKING_SERVICE_PARSE_TEXT_AS_SCORE, + SemanticTextFieldMapper.SEMANTIC_TEXT_BIT_VECTOR_SUPPORT, + SemanticTextFieldMapper.SEMANTIC_TEXT_HANDLE_EMPTY_INPUT, + TEST_RULE_RETRIEVER_WITH_INDICES_THAT_DONT_RETURN_RANK_DOCS, + SEMANTIC_TEXT_SUPPORT_CHUNKING_CONFIG, + SEMANTIC_TEXT_MATCH_ALL_HIGHLIGHTER, + SEMANTIC_TEXT_EXCLUDE_SUB_FIELDS_FROM_FIELD_CAPS, + SEMANTIC_TEXT_INDEX_OPTIONS, + COHERE_V2_API, + SEMANTIC_TEXT_INDEX_OPTIONS_WITH_DEFAULTS, + SEMANTIC_QUERY_REWRITE_INTERCEPTORS_PROPAGATE_BOOST_AND_QUERY_NAME_FIX, + SEMANTIC_TEXT_HIGHLIGHTING_FLAT + ) ); + if (RERANK_SNIPPETS.isEnabled()) { + testFeatures.add(TEXT_SIMILARITY_RERANKER_SNIPPETS); + } + return testFeatures; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/SnippetConfig.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/SnippetConfig.java new file mode 100644 index 0000000000000..f25ee40ca7ab1 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/SnippetConfig.java @@ -0,0 +1,88 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.rank.textsimilarity; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.index.query.QueryBuilder; + +import java.io.IOException; +import java.util.Objects; + +public class SnippetConfig implements Writeable { + + public final Integer numSnippets; + private final String inferenceText; + private final Integer tokenSizeLimit; + public final QueryBuilder snippetQueryBuilder; + + public static final int DEFAULT_NUM_SNIPPETS = 1; + + public SnippetConfig(StreamInput in) throws IOException { + this.numSnippets = in.readOptionalVInt(); + this.inferenceText = in.readString(); + this.tokenSizeLimit = in.readOptionalVInt(); + this.snippetQueryBuilder = in.readOptionalNamedWriteable(QueryBuilder.class); + } + + public SnippetConfig(Integer numSnippets) { + this(numSnippets, null, null); + } + + public SnippetConfig(Integer numSnippets, String inferenceText, Integer tokenSizeLimit) { + this(numSnippets, inferenceText, tokenSizeLimit, null); + } + + public SnippetConfig(Integer numSnippets, String inferenceText, Integer tokenSizeLimit, QueryBuilder snippetQueryBuilder) { + this.numSnippets = numSnippets; + this.inferenceText = inferenceText; + this.tokenSizeLimit = tokenSizeLimit; + this.snippetQueryBuilder = snippetQueryBuilder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalVInt(numSnippets); + out.writeString(inferenceText); + out.writeOptionalVInt(tokenSizeLimit); + out.writeOptionalNamedWriteable(snippetQueryBuilder); + } + + public Integer numSnippets() { + return numSnippets; + } + + public String inferenceText() { + return inferenceText; + } + + public Integer tokenSizeLimit() { + return tokenSizeLimit; + } + + public QueryBuilder snippetQueryBuilder() { + return snippetQueryBuilder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SnippetConfig that = (SnippetConfig) o; + return Objects.equals(numSnippets, that.numSnippets) + && Objects.equals(inferenceText, that.inferenceText) + && Objects.equals(tokenSizeLimit, that.tokenSizeLimit) + && Objects.equals(snippetQueryBuilder, that.snippetQueryBuilder); + } + + @Override + public int hashCode() { + return Objects.hash(numSnippets, inferenceText, tokenSizeLimit, snippetQueryBuilder); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java index 71f3465f8dfc8..6e213c5906b23 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java @@ -12,8 +12,12 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.index.query.MatchQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.license.License; import org.elasticsearch.license.LicensedFeature; import org.elasticsearch.search.rank.RankBuilder; @@ -23,7 +27,6 @@ import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext; import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext; import org.elasticsearch.search.rank.feature.RankFeatureDoc; -import org.elasticsearch.search.rank.rerank.RerankingRankFeaturePhaseRankShardContext; import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; @@ -35,6 +38,7 @@ import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.INFERENCE_ID_FIELD; import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.INFERENCE_TEXT_FIELD; import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.MIN_SCORE_FIELD; +import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.SNIPPETS_FIELD; /** * A {@code RankBuilder} that enables ranking with text similarity model inference. Supports parameters for configuring the inference call. @@ -43,6 +47,11 @@ public class TextSimilarityRankBuilder extends RankBuilder { public static final String NAME = "text_similarity_reranker"; + /** + * The default token size limit of the Elastic reranker is 512. + */ + private static final int DEFAULT_TOKEN_SIZE_LIMIT = 512; + public static final LicensedFeature.Momentary TEXT_SIMILARITY_RERANKER_FEATURE = LicensedFeature.momentary( null, "text-similarity-reranker", @@ -54,6 +63,7 @@ public class TextSimilarityRankBuilder extends RankBuilder { private final String field; private final Float minScore; private final boolean failuresAllowed; + private final SnippetConfig snippetConfig; public TextSimilarityRankBuilder( String field, @@ -61,7 +71,8 @@ public TextSimilarityRankBuilder( String inferenceText, int rankWindowSize, Float minScore, - boolean failuresAllowed + boolean failuresAllowed, + SnippetConfig snippetConfig ) { super(rankWindowSize); this.inferenceId = inferenceId; @@ -69,6 +80,7 @@ public TextSimilarityRankBuilder( this.field = field; this.minScore = minScore; this.failuresAllowed = failuresAllowed; + this.snippetConfig = snippetConfig; } public TextSimilarityRankBuilder(StreamInput in) throws IOException { @@ -84,6 +96,11 @@ public TextSimilarityRankBuilder(StreamInput in) throws IOException { } else { this.failuresAllowed = false; } + if (in.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) { + this.snippetConfig = in.readOptionalWriteable(SnippetConfig::new); + } else { + this.snippetConfig = null; + } } @Override @@ -107,6 +124,9 @@ public void doWriteTo(StreamOutput out) throws IOException { || out.getTransportVersion().onOrAfter(TransportVersions.RERANKER_FAILURES_ALLOWED)) { out.writeBoolean(failuresAllowed); } + if (out.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) { + out.writeOptionalWriteable(snippetConfig); + } } @Override @@ -122,6 +142,53 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio if (failuresAllowed) { builder.field(FAILURES_ALLOWED_FIELD.getPreferredName(), true); } + if (snippetConfig != null) { + builder.field(SNIPPETS_FIELD.getPreferredName(), snippetConfig); + } + } + + @Override + public RankBuilder rewrite(QueryRewriteContext queryRewriteContext) throws IOException { + TextSimilarityRankBuilder rewritten = this; + if (snippetConfig != null) { + QueryBuilder snippetQueryBuilder = snippetConfig.snippetQueryBuilder(); + if (snippetQueryBuilder == null) { + rewritten = new TextSimilarityRankBuilder( + field, + inferenceId, + inferenceText, + rankWindowSize(), + minScore, + failuresAllowed, + new SnippetConfig( + snippetConfig.numSnippets(), + snippetConfig.inferenceText(), + snippetConfig.tokenSizeLimit(), + new MatchQueryBuilder(field, inferenceText) + ) + ); + } else { + QueryBuilder rewrittenSnippetQueryBuilder = snippetQueryBuilder.rewrite(queryRewriteContext); + if (snippetQueryBuilder != rewrittenSnippetQueryBuilder) { + rewritten = new TextSimilarityRankBuilder( + field, + inferenceId, + inferenceText, + rankWindowSize(), + minScore, + failuresAllowed, + new SnippetConfig( + snippetConfig.numSnippets(), + snippetConfig.inferenceText(), + snippetConfig.tokenSizeLimit(), + rewrittenSnippetQueryBuilder + ) + ); + } + } + } + + return rewritten; } @Override @@ -168,7 +235,7 @@ public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int si @Override public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { - return new RerankingRankFeaturePhaseRankShardContext(field); + return new TextSimilarityRerankingRankFeaturePhaseRankShardContext(field, snippetConfig); } @Override @@ -181,10 +248,19 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo inferenceId, inferenceText, minScore, - failuresAllowed + failuresAllowed, + snippetConfig != null ? new SnippetConfig(snippetConfig.numSnippets, inferenceText, tokenSizeLimit(inferenceId)) : null ); } + /** + * @return The token size limit to apply to this rerank context. + * TODO: This should be pulled from the inference endpoint when available, not hardcoded. + */ + public static Integer tokenSizeLimit(String inferenceId) { + return DEFAULT_TOKEN_SIZE_LIMIT; + } + public String field() { return field; } @@ -212,11 +288,17 @@ protected boolean doEquals(RankBuilder other) { && Objects.equals(inferenceText, that.inferenceText) && Objects.equals(field, that.field) && Objects.equals(minScore, that.minScore) - && failuresAllowed == that.failuresAllowed; + && failuresAllowed == that.failuresAllowed + && Objects.equals(snippetConfig, that.snippetConfig); } @Override protected int doHashCode() { - return Objects.hash(inferenceId, inferenceText, field, minScore, failuresAllowed); + return Objects.hash(inferenceId, inferenceText, field, minScore, failuresAllowed, snippetConfig); + } + + @Override + public String toString() { + return Strings.toString(this); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java index 27221dc1f5caf..0a47db4d2a519 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java @@ -9,6 +9,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.TaskType; @@ -39,6 +40,7 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContext extends RankFe protected final String inferenceId; protected final String inferenceText; protected final Float minScore; + protected final SnippetConfig snippetConfig; public TextSimilarityRankFeaturePhaseRankCoordinatorContext( int size, @@ -48,39 +50,56 @@ public TextSimilarityRankFeaturePhaseRankCoordinatorContext( String inferenceId, String inferenceText, Float minScore, - boolean failuresAllowed + boolean failuresAllowed, + @Nullable SnippetConfig snippetConfig ) { super(size, from, rankWindowSize, failuresAllowed); this.client = client; this.inferenceId = inferenceId; this.inferenceText = inferenceText; this.minScore = minScore; + this.snippetConfig = snippetConfig; } @Override protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { + // Wrap the provided rankListener to an ActionListener that would handle the response from the inference service // and then pass the results final ActionListener inferenceListener = scoreListener.delegateFailureAndWrap((l, r) -> { InferenceServiceResults results = r.getResults(); assert results instanceof RankedDocsResults; - // Ensure we get exactly as many scores as the number of docs we passed, otherwise we may return incorrect results + // If we have an empty list of ranked docs, simply return the original scores List rankedDocs = ((RankedDocsResults) results).getRankedDocs(); - - if (rankedDocs.size() != featureDocs.length) { - l.onFailure( - new IllegalStateException( - "Reranker input document count and returned score count mismatch: [" - + featureDocs.length - + "] vs [" - + rankedDocs.size() - + "]" - ) - ); + if (rankedDocs.isEmpty()) { + float[] originalScores = new float[featureDocs.length]; + for (int i = 0; i < featureDocs.length; i++) { + originalScores[i] = featureDocs[i].score; + } + l.onResponse(originalScores); } else { - float[] scores = extractScoresFromRankedDocs(rankedDocs); - l.onResponse(scores); + final float[] scores; + if (this.snippetConfig != null) { + scores = extractScoresFromRankedSnippets(rankedDocs, featureDocs); + } else { + scores = extractScoresFromRankedDocs(rankedDocs); + } + + // Ensure we get exactly as many final scores as the number of docs we passed, otherwise we may return incorrect results + if (scores.length != featureDocs.length) { + l.onFailure( + new IllegalStateException( + "Reranker input document count and returned score count mismatch: [" + + featureDocs.length + + "] vs [" + + scores.length + + "]" + ) + ); + } else { + l.onResponse(scores); + } } }); @@ -118,8 +137,11 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener featureData = Arrays.stream(featureDocs).map(x -> x.featureData).toList(); - InferenceAction.Request inferenceRequest = generateRequest(featureData); + List inferenceInputs = Arrays.stream(featureDocs) + .filter(featureDoc -> featureDoc.featureData != null) + .flatMap(featureDoc -> featureDoc.featureData.stream()) + .toList(); + InferenceAction.Request inferenceRequest = generateRequest(inferenceInputs); try { executeAsyncWithOrigin(client, INFERENCE_ORIGIN, InferenceAction.INSTANCE, inferenceRequest, inferenceListener); } finally { @@ -170,7 +192,7 @@ protected InferenceAction.Request generateRequest(List docFeatures) { ); } - private float[] extractScoresFromRankedDocs(List rankedDocs) { + float[] extractScoresFromRankedDocs(List rankedDocs) { float[] scores = new float[rankedDocs.size()]; for (RankedDocsResults.RankedDoc rankedDoc : rankedDocs) { scores[rankedDoc.index()] = rankedDoc.relevanceScore(); @@ -178,6 +200,33 @@ private float[] extractScoresFromRankedDocs(List ra return scores; } + float[] extractScoresFromRankedSnippets(List rankedDocs, RankFeatureDoc[] featureDocs) { + float[] scores = new float[featureDocs.length]; + boolean[] hasScore = new boolean[featureDocs.length]; + + // We need to correlate the index/doc values of each RankedDoc in correlation with its associated RankFeatureDoc. + int[] rankedDocToFeatureDoc = Arrays.stream(featureDocs) + .flatMapToInt( + doc -> java.util.stream.IntStream.generate(() -> Arrays.asList(featureDocs).indexOf(doc)).limit(doc.featureData.size()) + ) + .limit(rankedDocs.size()) + .toArray(); + + for (RankedDocsResults.RankedDoc rankedDoc : rankedDocs) { + int docId = rankedDocToFeatureDoc[rankedDoc.index()]; + float score = rankedDoc.relevanceScore(); + scores[docId] = hasScore[docId] == false ? score : Math.max(scores[docId], score); + hasScore[docId] = true; + } + + float[] result = new float[featureDocs.length]; + for (int i = 0; i < featureDocs.length; i++) { + result[i] = hasScore[i] ? scores[i] : 0f; + } + + return result; + } + private static float normalizeScore(float score) { // As some models might produce negative scores, we want to ensure that all scores will be positive // so we will make use of the following normalization formula: diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java index 2942381a1d181..18bbbd8a2c134 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.rank.textsimilarity; import org.apache.lucene.search.ScoreDoc; +import org.elasticsearch.common.util.FeatureFlag; import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.license.LicenseUtils; @@ -41,12 +42,16 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder "text_similarity_reranker_alias_handling_fix" ); public static final NodeFeature TEXT_SIMILARITY_RERANKER_MINSCORE_FIX = new NodeFeature("text_similarity_reranker_minscore_fix"); + public static final NodeFeature TEXT_SIMILARITY_RERANKER_SNIPPETS = new NodeFeature("text_similarity_reranker_snippets"); + public static final FeatureFlag RERANK_SNIPPETS = new FeatureFlag("text_similarity_reranker_snippets"); public static final ParseField RETRIEVER_FIELD = new ParseField("retriever"); public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); public static final ParseField INFERENCE_TEXT_FIELD = new ParseField("inference_text"); public static final ParseField FIELD_FIELD = new ParseField("field"); public static final ParseField FAILURES_ALLOWED_FIELD = new ParseField("allow_rerank_failures"); + public static final ParseField SNIPPETS_FIELD = new ParseField("snippets"); + public static final ParseField NUM_SNIPPETS_FIELD = new ParseField("num_snippets"); public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(TextSimilarityRankBuilder.NAME, args -> { @@ -56,6 +61,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder String field = (String) args[3]; int rankWindowSize = args[4] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[4]; boolean failuresAllowed = args[5] != null && (Boolean) args[5]; + SnippetConfig snippets = (SnippetConfig) args[6]; return new TextSimilarityRankRetrieverBuilder( retrieverBuilder, @@ -63,10 +69,20 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder inferenceText, field, rankWindowSize, - failuresAllowed + failuresAllowed, + snippets ); }); + private static final ConstructingObjectParser SNIPPETS_PARSER = new ConstructingObjectParser<>( + SNIPPETS_FIELD.getPreferredName(), + true, + args -> { + Integer numSnippets = (Integer) args[0]; + return new SnippetConfig(numSnippets); + } + ); + static { PARSER.declareNamedObject(constructorArg(), (p, c, n) -> { RetrieverBuilder innerRetriever = p.namedObject(RetrieverBuilder.class, n, c); @@ -78,6 +94,10 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder PARSER.declareString(constructorArg(), FIELD_FIELD); PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD); PARSER.declareBoolean(optionalConstructorArg(), FAILURES_ALLOWED_FIELD); + PARSER.declareObject(optionalConstructorArg(), SNIPPETS_PARSER, SNIPPETS_FIELD); + if (RERANK_SNIPPETS.isEnabled()) { + SNIPPETS_PARSER.declareInt(optionalConstructorArg(), NUM_SNIPPETS_FIELD); + } RetrieverBuilder.declareBaseParserFields(PARSER); } @@ -97,6 +117,7 @@ public static TextSimilarityRankRetrieverBuilder fromXContent( private final String inferenceText; private final String field; private final boolean failuresAllowed; + private final SnippetConfig snippets; public TextSimilarityRankRetrieverBuilder( RetrieverBuilder retrieverBuilder, @@ -104,13 +125,15 @@ public TextSimilarityRankRetrieverBuilder( String inferenceText, String field, int rankWindowSize, - boolean failuresAllowed + boolean failuresAllowed, + SnippetConfig snippets ) { super(List.of(RetrieverSource.from(retrieverBuilder)), rankWindowSize); this.inferenceId = inferenceId; this.inferenceText = inferenceText; this.field = field; this.failuresAllowed = failuresAllowed; + this.snippets = snippets; } public TextSimilarityRankRetrieverBuilder( @@ -122,12 +145,16 @@ public TextSimilarityRankRetrieverBuilder( Float minScore, boolean failuresAllowed, String retrieverName, - List preFilterQueryBuilders + List preFilterQueryBuilders, + SnippetConfig snippets ) { super(retrieverSource, rankWindowSize); if (retrieverSource.size() != 1) { throw new IllegalArgumentException("[" + getName() + "] retriever should have exactly one inner retriever"); } + if (snippets != null && snippets.numSnippets() != null && snippets.numSnippets() < 1) { + throw new IllegalArgumentException("num_snippets must be greater than 0, was: " + snippets.numSnippets()); + } this.inferenceId = inferenceId; this.inferenceText = inferenceText; this.field = field; @@ -135,6 +162,7 @@ public TextSimilarityRankRetrieverBuilder( this.failuresAllowed = failuresAllowed; this.retrieverName = retrieverName; this.preFilterQueryBuilders = preFilterQueryBuilders; + this.snippets = snippets; } @Override @@ -151,7 +179,8 @@ protected TextSimilarityRankRetrieverBuilder clone( minScore, failuresAllowed, retrieverName, - newPreFilterQueryBuilders + newPreFilterQueryBuilders, + snippets ); } @@ -179,7 +208,17 @@ protected RankDoc[] combineInnerRetrieverResults(List rankResults, b @Override protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBuilder) { sourceBuilder.rankBuilder( - new TextSimilarityRankBuilder(field, inferenceId, inferenceText, rankWindowSize, minScore, failuresAllowed) + new TextSimilarityRankBuilder( + field, + inferenceId, + inferenceText, + rankWindowSize, + minScore, + failuresAllowed, + snippets != null + ? new SnippetConfig(snippets.numSnippets, inferenceText, TextSimilarityRankBuilder.tokenSizeLimit(inferenceId)) + : null + ) ); return sourceBuilder; } @@ -207,6 +246,13 @@ protected void doToXContent(XContentBuilder builder, Params params) throws IOExc if (failuresAllowed) { builder.field(FAILURES_ALLOWED_FIELD.getPreferredName(), failuresAllowed); } + if (snippets != null) { + builder.startObject(SNIPPETS_FIELD.getPreferredName()); + if (snippets.numSnippets() != null) { + builder.field(NUM_SNIPPETS_FIELD.getPreferredName(), snippets.numSnippets()); + } + builder.endObject(); + } } @Override @@ -218,11 +264,12 @@ public boolean doEquals(Object other) { && Objects.equals(field, that.field) && rankWindowSize == that.rankWindowSize && Objects.equals(minScore, that.minScore) - && failuresAllowed == that.failuresAllowed; + && failuresAllowed == that.failuresAllowed + && Objects.equals(snippets, that.snippets); } @Override public int doHashCode() { - return Objects.hash(inferenceId, inferenceText, field, rankWindowSize, minScore, failuresAllowed); + return Objects.hash(inferenceId, inferenceText, field, rankWindowSize, minScore, failuresAllowed, snippets); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRerankingRankFeaturePhaseRankShardContext.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRerankingRankFeaturePhaseRankShardContext.java new file mode 100644 index 0000000000000..66fb4a366a757 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRerankingRankFeaturePhaseRankShardContext.java @@ -0,0 +1,97 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.rank.textsimilarity; + +import org.elasticsearch.common.document.DocumentField; +import org.elasticsearch.common.logging.HeaderWarning; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder; +import org.elasticsearch.search.fetch.subphase.highlight.HighlightField; +import org.elasticsearch.search.fetch.subphase.highlight.SearchHighlightContext; +import org.elasticsearch.search.internal.SearchContext; +import org.elasticsearch.search.rank.RankShardResult; +import org.elasticsearch.search.rank.feature.RankFeatureDoc; +import org.elasticsearch.search.rank.feature.RankFeatureShardResult; +import org.elasticsearch.search.rank.rerank.RerankingRankFeaturePhaseRankShardContext; +import org.elasticsearch.xcontent.Text; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.rank.textsimilarity.SnippetConfig.DEFAULT_NUM_SNIPPETS; + +public class TextSimilarityRerankingRankFeaturePhaseRankShardContext extends RerankingRankFeaturePhaseRankShardContext { + + private final SnippetConfig snippetRankInput; + + // Rough approximation of token size vs. characters in highlight fragments. + // TODO: highlighter should be able to set fragment size by token not length + private static final int TOKEN_SIZE_LIMIT_MULTIPLIER = 5; + + public TextSimilarityRerankingRankFeaturePhaseRankShardContext(String field, @Nullable SnippetConfig snippetRankInput) { + super(field); + this.snippetRankInput = snippetRankInput; + } + + @Override + public RankShardResult doBuildRankFeatureShardResult(SearchHits hits, int shardId) { + RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length]; + for (int i = 0; i < hits.getHits().length; i++) { + rankFeatureDocs[i] = new RankFeatureDoc(hits.getHits()[i].docId(), hits.getHits()[i].getScore(), shardId); + SearchHit hit = hits.getHits()[i]; + DocumentField docField = hit.field(field); + if (snippetRankInput == null && docField != null) { + rankFeatureDocs[i].featureData(List.of(docField.getValue().toString())); + } else { + Map highlightFields = hit.getHighlightFields(); + if (highlightFields != null && highlightFields.containsKey(field) && highlightFields.get(field).fragments().length > 0) { + List snippets = Arrays.stream(highlightFields.get(field).fragments()).map(Text::string).toList(); + rankFeatureDocs[i].featureData(snippets); + } else if (docField != null) { + // If we did not get highlighting results, backfill with the doc field value + // but pass in a warning because we are not reranking on snippets only + rankFeatureDocs[i].featureData(List.of(docField.getValue().toString())); + HeaderWarning.addWarning( + "Reranking on snippets requested, but no snippets were found for field [" + field + "]. Using field value instead." + ); + } + } + } + return new RankFeatureShardResult(rankFeatureDocs); + } + + @Override + public void prepareForFetch(SearchContext context) { + if (snippetRankInput != null) { + try { + HighlightBuilder highlightBuilder = new HighlightBuilder(); + highlightBuilder.highlightQuery(snippetRankInput.snippetQueryBuilder()); + // Stripping pre/post tags as they're not useful for snippet creation + highlightBuilder.field(field).preTags("").postTags(""); + // Return highest scoring fragments + highlightBuilder.order(HighlightBuilder.Order.SCORE); + int numSnippets = snippetRankInput.numSnippets() != null ? snippetRankInput.numSnippets() : DEFAULT_NUM_SNIPPETS; + highlightBuilder.numOfFragments(numSnippets); + // Rely on the model to determine the fragment size + int tokenSizeLimit = snippetRankInput.tokenSizeLimit(); + int fragmentSize = tokenSizeLimit * TOKEN_SIZE_LIMIT_MULTIPLIER; + highlightBuilder.fragmentSize(fragmentSize); + highlightBuilder.noMatchSize(fragmentSize); + SearchHighlightContext searchHighlightContext = highlightBuilder.build(context.getSearchExecutionContext()); + context.highlight(searchHighlightContext); + } catch (IOException e) { + throw new RuntimeException("Failed to generate snippet request", e); + } + } + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContextTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContextTests.java index 717fb9437ad52..27aa8b6fb5b5a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContextTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContextTests.java @@ -12,6 +12,9 @@ import org.elasticsearch.search.rank.feature.RankFeatureDoc; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; + +import java.util.List; import static org.elasticsearch.action.support.ActionTestUtils.assertNoFailureListener; import static org.mockito.ArgumentMatchers.any; @@ -32,16 +35,29 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContextTests extends E "my-inference-id", "some query", 0.0f, - false + false, + null + ); + + TextSimilarityRankFeaturePhaseRankCoordinatorContext withSnippets = new TextSimilarityRankFeaturePhaseRankCoordinatorContext( + 10, + 0, + 100, + mockClient, + "my-inference-id", + "some query", + 0.0f, + false, + new SnippetConfig(2, "some query", 10) ); public void testComputeScores() { RankFeatureDoc featureDoc1 = new RankFeatureDoc(0, 1.0f, 0); - featureDoc1.featureData("text 1"); + featureDoc1.featureData(List.of("text 1")); RankFeatureDoc featureDoc2 = new RankFeatureDoc(1, 3.0f, 1); - featureDoc2.featureData("text 2"); + featureDoc2.featureData(List.of("text 2")); RankFeatureDoc featureDoc3 = new RankFeatureDoc(2, 2.0f, 0); - featureDoc3.featureData("text 3"); + featureDoc3.featureData(List.of("text 3")); RankFeatureDoc[] featureDocs = new RankFeatureDoc[] { featureDoc1, featureDoc2, featureDoc3 }; subject.computeScores(featureDocs, assertNoFailureListener(f -> assertArrayEquals(new float[] { 1.0f, 3.0f, 2.0f }, f, 0.0f))); @@ -61,4 +77,57 @@ public void testComputeScoresForEmpty() { ); } + public void testExtractScoresFromRankedDocs() { + List rankedDocs = List.of( + new RankedDocsResults.RankedDoc(0, 1.0f, "text 1"), + new RankedDocsResults.RankedDoc(1, 3.0f, "text 2"), + new RankedDocsResults.RankedDoc(2, 2.0f, "text 3") + ); + float[] scores = subject.extractScoresFromRankedDocs(rankedDocs); + assertArrayEquals(new float[] { 1.0f, 3.0f, 2.0f }, scores, 0.0f); + } + + public void testExtractScoresFromSingleSnippets() { + + List rankedDocs = List.of( + new RankedDocsResults.RankedDoc(0, 1.0f, "text 1"), + new RankedDocsResults.RankedDoc(1, 2.5f, "text 2"), + new RankedDocsResults.RankedDoc(2, 1.5f, "text 3") + ); + RankFeatureDoc[] featureDocs = new RankFeatureDoc[] { + createRankFeatureDoc(0, 1.0f, 0, List.of("text 1")), + createRankFeatureDoc(1, 3.0f, 1, List.of("text 2")), + createRankFeatureDoc(2, 2.0f, 0, List.of("text 3")) }; + + float[] scores = withSnippets.extractScoresFromRankedSnippets(rankedDocs, featureDocs); + // Returned cores are from the snippet, not the whole text + assertArrayEquals(new float[] { 1.0f, 2.5f, 1.5f }, scores, 0.0f); + } + + public void testExtractScoresFromMultipleSnippets() { + + List rankedDocs = List.of( + new RankedDocsResults.RankedDoc(0, 1.0f, "this is text 1"), + new RankedDocsResults.RankedDoc(1, 2.5f, "some more text"), + new RankedDocsResults.RankedDoc(2, 1.5f, "yet more text"), + new RankedDocsResults.RankedDoc(3, 3.0f, "this is text 2"), + new RankedDocsResults.RankedDoc(4, 2.0f, "this is text 3"), + new RankedDocsResults.RankedDoc(5, 1.5f, "oh look, more text") + ); + RankFeatureDoc[] featureDocs = new RankFeatureDoc[] { + createRankFeatureDoc(0, 1.0f, 0, List.of("this is text 1", "some more text")), + createRankFeatureDoc(1, 3.0f, 1, List.of("yet more text", "this is text 2")), + createRankFeatureDoc(2, 2.0f, 0, List.of("this is text 3", "oh look, more text")) }; + + float[] scores = withSnippets.extractScoresFromRankedSnippets(rankedDocs, featureDocs); + // Returned scores are from the best-ranking snippet, not the whole text + assertArrayEquals(new float[] { 2.5f, 3.0f, 2.0f }, scores, 0.0f); + } + + private RankFeatureDoc createRankFeatureDoc(int doc, float score, int shardIndex, List featureData) { + RankFeatureDoc featureDoc = new RankFeatureDoc(doc, score, shardIndex); + featureDoc.featureData(featureData); + return featureDoc; + } + } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java index 4acebd9c956b1..8a6e13c3d8623 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java @@ -32,7 +32,7 @@ public class TextSimilarityRankMultiNodeTests extends AbstractRerankerIT { @Override protected RankBuilder getRankBuilder(int rankWindowSize, String rankFeatureField) { - return new TextSimilarityRankBuilder(rankFeatureField, inferenceId, inferenceText, rankWindowSize, minScore, false); + return new TextSimilarityRankBuilder(rankFeatureField, inferenceId, inferenceText, rankWindowSize, minScore, false, null); } @Override @@ -53,7 +53,8 @@ protected RankBuilder getThrowingRankBuilder( inferenceText, minScore, failuresAllowed, - type.name() + type.name(), + null ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java index 9977da9044d44..42063cf499fdc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java @@ -58,7 +58,8 @@ public static TextSimilarityRankRetrieverBuilder createRandomTextSimilarityRankR randomAlphaOfLength(20), randomAlphaOfLength(50), randomIntBetween(100, 10000), - randomBoolean() + randomBoolean(), + null ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java index 07474169fbe97..7f6bc6117561b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java @@ -139,7 +139,8 @@ public void testTelemetryForRRFRetriever() throws IOException { "some_inference_text", "some_field", 10, - false + false, + null ) ) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java index ae6e5fb5a53a4..b39dc74f6e72c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java @@ -57,7 +57,7 @@ public TopNConfigurationAcceptingTextSimilarityRankBuilder( Float minScore, int topN ) { - super(field, inferenceId + "-task-settings-top-" + topN, inferenceText, rankWindowSize, minScore, false); + super(field, inferenceId + "-task-settings-top-" + topN, inferenceText, rankWindowSize, minScore, false, null); } } @@ -76,7 +76,7 @@ public InferenceResultCountAcceptingTextSimilarityRankBuilder( Float minScore, int inferenceResultCount ) { - super(field, inferenceId, inferenceText, rankWindowSize, minScore, false); + super(field, inferenceId, inferenceText, rankWindowSize, minScore, false, null); this.inferenceResultCount = inferenceResultCount; } @@ -90,7 +90,8 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo inferenceId, inferenceText, minScore, - failuresAllowed() + failuresAllowed(), + null ) { @Override protected InferenceAction.Request generateRequest(List docFeatures) { @@ -136,7 +137,7 @@ public void testRerank() { ElasticsearchAssertions.assertNoFailuresAndResponse( // Execute search with text similarity reranking client.prepareSearch() - .setRankBuilder(new TextSimilarityRankBuilder("text", "my-rerank-model", "my query", 100, 0.0f, false)) + .setRankBuilder(new TextSimilarityRankBuilder("text", "my-rerank-model", "my query", 100, 0.0f, false, null)) .setQuery(QueryBuilders.matchAllQuery()), response -> { // Verify order, rank and score of results @@ -159,7 +160,7 @@ public void testRerankWithMinScore() { ElasticsearchAssertions.assertNoFailuresAndResponse( // Execute search with text similarity reranking client.prepareSearch() - .setRankBuilder(new TextSimilarityRankBuilder("text", "my-rerank-model", "my query", 100, 1.5f, false)) + .setRankBuilder(new TextSimilarityRankBuilder("text", "my-rerank-model", "my query", 100, 1.5f, false, null)) .setQuery(QueryBuilders.matchAllQuery()), response -> { // Verify order, rank and score of results @@ -183,7 +184,8 @@ public void testRerankInferenceFailure() { "my query", 0.7f, false, - AbstractRerankerIT.ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_COORDINATOR_CONTEXT.name() + AbstractRerankerIT.ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_COORDINATOR_CONTEXT.name(), + null ) ) .setQuery(QueryBuilders.matchAllQuery()), @@ -204,7 +206,8 @@ public void testRerankInferenceAllowedFailure() { "my query", null, true, - AbstractRerankerIT.ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_COORDINATOR_CONTEXT.name() + AbstractRerankerIT.ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_COORDINATOR_CONTEXT.name(), + null ) ) .setQuery( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java index f8563aebe0764..c88ce1b65ee3d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java @@ -176,9 +176,10 @@ public ThrowingMockRequestActionBasedRankBuilder( String inferenceText, Float minScore, boolean failuresAllowed, - String throwingType + String throwingType, + SnippetConfig snippetConfig ) { - super(field, inferenceId, inferenceText, rankWindowSize, minScore, failuresAllowed); + super(field, inferenceId, inferenceText, rankWindowSize, minScore, failuresAllowed, snippetConfig); this.throwingRankBuilderType = AbstractRerankerIT.ThrowingRankBuilderType.valueOf(throwingType); } @@ -218,7 +219,8 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo inferenceId, inferenceText, minScore, - failuresAllowed() + failuresAllowed(), + null ) { @Override protected InferenceAction.Request generateRequest(List docFeatures) { diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml index 98e392ed1ccee..3dd85ef9e8658 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml @@ -20,6 +20,21 @@ setup: } } + - do: + inference.put: + task_type: sparse_embedding + inference_id: sparse-inference-id + body: > + { + "service": "test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + } + } + - do: indices.create: index: test-index @@ -28,12 +43,20 @@ setup: properties: text: type: text + copy_to: semantic_text_field topic: type: keyword subtopic: type: keyword inference_text_field: type: text + semantic_text_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 - do: index: @@ -298,8 +321,10 @@ setup: - match: { hits.hits.0._id: "doc_2" } - match: { hits.hits.1._id: "doc_1" } - - match: {hits.hits.0._explanation.description: "/text_similarity_reranker.match.using.inference.endpoint:.\\[my-rerank-model\\].on.document.field:.\\[inference_text_field\\].*/" } - - match: {hits.hits.0._explanation.details.0.details.0.description: "/subtopic.*astronomy.*/" } + - match: { hits.hits.0._explanation.description: "sum of:" } + - match: { hits.hits.0._explanation.details.0.description: "/text_similarity_reranker.match.using.inference.endpoint:.\\[my-rerank-model\\].on.document.field:.\\[inference_text_field\\].*/" } + - match: { hits.hits.0._explanation.details.0.details.0.details.0.description: "/subtopic.*astronomy.*/" } + - match: { hits.hits.0._explanation.details.1.description: "/match.on.required.clause,.product.of:*/" } --- "text similarity reranker properly handles aliases": @@ -448,7 +473,7 @@ setup: retriever: standard: query: - match_all: {} + match_all: { } rank_window_size: 10 inference_id: my-rerank-model inference_text: "How often does the moon hide the sun?" @@ -477,7 +502,7 @@ setup: retriever: standard: query: - match_all: {} + match_all: { } rank_window_size: 10 inference_id: my-rerank-model inference_text: "How often does the moon hide the sun?" @@ -487,3 +512,208 @@ setup: - match: { hits.total.value: 0 } - length: { hits.hits: 0 } + + +--- +"Text similarity reranker specifying number of snippets must be > 0": + + - requires: + cluster_features: "text_similarity_reranker_snippets" + reason: snippets introduced in 9.2.0 + + - do: + catch: /num_snippets must be greater than 0/ + search: + index: test-index + body: + track_total_hits: true + fields: [ "text", "topic" ] + retriever: + text_similarity_reranker: + retriever: + standard: + query: + match_all: { } + rank_window_size: 10 + inference_id: my-rerank-model + inference_text: "How often does the moon hide the sun?" + field: inference_text_field + snippets: + num_snippets: 0 + size: 10 + + - match: { status: 400 } + +--- +"Reranking based on snippets": + + - requires: + cluster_features: "text_similarity_reranker_snippets" + reason: snippets introduced in 9.2.0 + + - do: + search: + index: test-index + body: + track_total_hits: true + fields: [ "text", "topic" ] + retriever: + text_similarity_reranker: + retriever: + standard: + query: + match: + topic: + query: "science" + rank_window_size: 10 + inference_id: my-rerank-model + inference_text: "How often does the moon hide the sun?" + field: text + snippets: + num_snippets: 2 + size: 10 + + - match: { hits.total.value: 2 } + - length: { hits.hits: 2 } + + - match: { hits.hits.0._id: "doc_1" } + - match: { hits.hits.1._id: "doc_2" } + +--- +"Reranking based on snippets using defaults": + + - requires: + cluster_features: "text_similarity_reranker_snippets" + reason: snippets introduced in 9.2.0 + + - do: + search: + index: test-index + body: + track_total_hits: true + fields: [ "text", "topic" ] + retriever: + text_similarity_reranker: + retriever: + standard: + query: + term: + topic: "science" + rank_window_size: 10 + inference_id: my-rerank-model + inference_text: "How often does the moon hide the sun?" + field: text + snippets: { } + size: 10 + + - match: { hits.total.value: 2 } + - length: { hits.hits: 2 } + + - match: { hits.hits.0._id: "doc_1" } + - match: { hits.hits.1._id: "doc_2" } + +--- +"Reranking based on snippets on a semantic_text field": + + - requires: + cluster_features: "text_similarity_reranker_snippets" + reason: snippets introduced in 9.2.0 + + - do: + search: + index: test-index + body: + track_total_hits: true + fields: [ "text", "semantic_text_field", "topic" ] + retriever: + text_similarity_reranker: + retriever: + standard: + query: + match: + topic: + query: "science" + rank_window_size: 10 + inference_id: my-rerank-model + inference_text: "how often does the moon hide the sun?" + field: semantic_text_field + snippets: + num_snippets: 2 + size: 10 + + - match: { hits.total.value: 2 } + - length: { hits.hits: 2 } + + - match: { hits.hits.0._id: "doc_1" } + - match: { hits.hits.1._id: "doc_2" } + +--- +"Reranking based on snippets on a semantic_text field using defaults": + + - requires: + cluster_features: "text_similarity_reranker_snippets" + reason: snippets introduced in 9.2.0 + + - do: + search: + index: test-index + body: + track_total_hits: true + fields: [ "text", "semantic_text_field", "topic" ] + retriever: + text_similarity_reranker: + retriever: + standard: + query: + match: + topic: + query: "science" + rank_window_size: 10 + inference_id: my-rerank-model + inference_text: "how often does the moon hide the sun?" + field: semantic_text_field + snippets: { } + size: 10 + + - match: { hits.total.value: 2 } + - length: { hits.hits: 2 } + + - match: { hits.hits.0._id: "doc_1" } + - match: { hits.hits.1._id: "doc_2" } + +--- +"Reranking based on snippets when highlighter doesn't return results": + + - requires: + test_runner_features: allowed_warnings + cluster_features: "text_similarity_reranker_snippets" + reason: snippets introduced in 9.2.0 + + - do: + allowed_warnings: + - "Reranking on snippets requested, but no snippets were found for field [inference_text_field]. Using field value instead." + search: + index: test-index + body: + track_total_hits: true + fields: [ "text", "topic" ] + retriever: + text_similarity_reranker: + retriever: + standard: + query: + term: + topic: "science" + rank_window_size: 10 + inference_id: my-rerank-model + inference_text: "How often does the moon hide the sun?" + field: inference_text_field + snippets: + num_snippets: 2 + size: 10 + + - match: { hits.total.value: 2 } + - length: { hits.hits: 2 } + + - match: { hits.hits.0._id: "doc_2" } + - match: { hits.hits.1._id: "doc_1" } diff --git a/x-pack/qa/core-rest-tests-with-security/src/yamlRestTest/java/org/elasticsearch/xpack/security/CoreWithSecurityClientYamlTestSuiteIT.java b/x-pack/qa/core-rest-tests-with-security/src/yamlRestTest/java/org/elasticsearch/xpack/security/CoreWithSecurityClientYamlTestSuiteIT.java index 0291e55187278..88c754b257f5e 100644 --- a/x-pack/qa/core-rest-tests-with-security/src/yamlRestTest/java/org/elasticsearch/xpack/security/CoreWithSecurityClientYamlTestSuiteIT.java +++ b/x-pack/qa/core-rest-tests-with-security/src/yamlRestTest/java/org/elasticsearch/xpack/security/CoreWithSecurityClientYamlTestSuiteIT.java @@ -54,6 +54,7 @@ public class CoreWithSecurityClientYamlTestSuiteIT extends ESClientYamlSuiteTest .feature(FeatureFlag.USE_LUCENE101_POSTINGS_FORMAT) .feature(FeatureFlag.IVF_FORMAT) .feature(FeatureFlag.SYNTHETIC_VECTORS) + .feature(FeatureFlag.RERANK_SNIPPETS) .build(); public CoreWithSecurityClientYamlTestSuiteIT(@Name("yaml") ClientYamlTestCandidate testCandidate) {