diff --git a/docs/changelog/133576.yaml b/docs/changelog/133576.yaml new file mode 100644 index 0000000000000..31b87f9fbebda --- /dev/null +++ b/docs/changelog/133576.yaml @@ -0,0 +1,5 @@ +pr: 133576 +summary: Text similarity reranker chunks and scores snippets +area: Relevance +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/inference/ChunkingSettings.java b/server/src/main/java/org/elasticsearch/inference/ChunkingSettings.java index 5b18b2f2466d0..08b8e7e90720d 100644 --- a/server/src/main/java/org/elasticsearch/inference/ChunkingSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/ChunkingSettings.java @@ -24,4 +24,6 @@ public interface ChunkingSettings extends ToXContentObject, VersionedNamedWritea * @return The max chunk size specified, or null if not specified */ Integer maxChunkSize(); + + default void validate() {} } diff --git a/x-pack/plugin/core/src/main/java/module-info.java b/x-pack/plugin/core/src/main/java/module-info.java index 33be0a0c43506..ae9a8910901ad 100644 --- a/x-pack/plugin/core/src/main/java/module-info.java +++ b/x-pack/plugin/core/src/main/java/module-info.java @@ -234,6 +234,7 @@ exports org.elasticsearch.xpack.core.watcher.watch; exports org.elasticsearch.xpack.core.watcher; exports org.elasticsearch.xpack.core.security.authc.apikey; + exports org.elasticsearch.xpack.core.common.chunks; provides org.elasticsearch.action.admin.cluster.node.info.ComponentVersionNumber with diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/common/chunks/MemoryIndexChunkScorer.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/common/chunks/MemoryIndexChunkScorer.java new file mode 100644 index 0000000000000..5b6a895e1e090 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/common/chunks/MemoryIndexChunkScorer.java @@ -0,0 +1,98 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.common.chunks; + +import org.apache.lucene.analysis.standard.StandardAnalyzer; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.TextField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.ByteBuffersDirectory; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.QueryBuilder; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +/** + * Utility class for scoring pre-determined chunks using an in-memory Lucene index. + */ +public class MemoryIndexChunkScorer { + + private static final String CONTENT_FIELD = "content"; + + private final StandardAnalyzer analyzer; + + public MemoryIndexChunkScorer() { + // TODO: Allow analyzer to be customizable and/or read from the field mapping + this.analyzer = new StandardAnalyzer(); + } + + /** + * Creates an in-memory index of chunks, or chunks, returns ordered, scored list. + * + * @param chunks the list of text chunks to score + * @param inferenceText the query text to compare against + * @param maxResults maximum number of results to return + * @return list of scored chunks ordered by relevance + * @throws IOException on failure scoring chunks + */ + public List scoreChunks(List chunks, String inferenceText, int maxResults) throws IOException { + if (chunks == null || chunks.isEmpty() || inferenceText == null || inferenceText.trim().isEmpty()) { + return new ArrayList<>(); + } + + try (Directory directory = new ByteBuffersDirectory()) { + IndexWriterConfig config = new IndexWriterConfig(analyzer); + try (IndexWriter writer = new IndexWriter(directory, config)) { + for (String chunk : chunks) { + Document doc = new Document(); + doc.add(new TextField(CONTENT_FIELD, chunk, Field.Store.YES)); + writer.addDocument(doc); + } + writer.commit(); + } + + try (DirectoryReader reader = DirectoryReader.open(directory)) { + IndexSearcher searcher = new IndexSearcher(reader); + + org.apache.lucene.util.QueryBuilder qb = new QueryBuilder(analyzer); + Query query = qb.createBooleanQuery(CONTENT_FIELD, inferenceText, BooleanClause.Occur.SHOULD); + int numResults = Math.min(maxResults, chunks.size()); + TopDocs topDocs = searcher.search(query, numResults); + + List scoredChunks = new ArrayList<>(); + for (ScoreDoc scoreDoc : topDocs.scoreDocs) { + Document doc = reader.storedFields().document(scoreDoc.doc); + String content = doc.get(CONTENT_FIELD); + scoredChunks.add(new ScoredChunk(content, scoreDoc.score)); + } + + // It's possible that no chunks were scorable (for example, a semantic match that does not have a lexical match). + // In this case, we'll return the first N chunks with a score of 0. + // TODO: consider parameterizing this + return scoredChunks.isEmpty() == false + ? scoredChunks + : chunks.subList(0, Math.min(maxResults, chunks.size())).stream().map(c -> new ScoredChunk(c, 0.0f)).toList(); + } + } + } + + /** + * Represents a chunk with its relevance score. + */ + public record ScoredChunk(String content, float score) {} +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/common/chunks/MemoryIndexChunkScorerTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/common/chunks/MemoryIndexChunkScorerTests.java new file mode 100644 index 0000000000000..30b82eabf9742 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/common/chunks/MemoryIndexChunkScorerTests.java @@ -0,0 +1,95 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.common.chunks; + +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; + +public class MemoryIndexChunkScorerTests extends ESTestCase { + + private static final List CHUNKS = Arrays.asList( + "Cats like to sleep all day and play with mice", + "Dogs are loyal companions and great pets", + "The weather today is very sunny and warm", + "Dogs love to play with toys and go for walks", + "Elasticsearch is a great search engine" + ); + + public void testScoreChunks() throws IOException { + MemoryIndexChunkScorer scorer = new MemoryIndexChunkScorer(); + + String inferenceText = "dogs play walk"; + int maxResults = 3; + + List scoredChunks = scorer.scoreChunks(CHUNKS, inferenceText, maxResults); + + assertEquals(maxResults, scoredChunks.size()); + + // The chunks about dogs should score highest, followed by the chunk about cats + MemoryIndexChunkScorer.ScoredChunk chunk = scoredChunks.getFirst(); + assertTrue(chunk.content().equalsIgnoreCase("Dogs love to play with toys and go for walks")); + assertThat(chunk.score(), greaterThan(0f)); + + chunk = scoredChunks.get(1); + assertTrue(chunk.content().equalsIgnoreCase("Dogs are loyal companions and great pets")); + assertThat(chunk.score(), greaterThan(0f)); + + chunk = scoredChunks.get(2); + assertTrue(chunk.content().equalsIgnoreCase("Cats like to sleep all day and play with mice")); + assertThat(chunk.score(), greaterThan(0f)); + + // Scores should be in descending order + for (int i = 1; i < scoredChunks.size(); i++) { + assertTrue(scoredChunks.get(i - 1).score() >= scoredChunks.get(i).score()); + } + } + + public void testEmptyChunks() throws IOException { + + int maxResults = 3; + + MemoryIndexChunkScorer scorer = new MemoryIndexChunkScorer(); + + // Zero results + List scoredChunks = scorer.scoreChunks(CHUNKS, "puggles", maxResults); + assertEquals(maxResults, scoredChunks.size()); + + // There were no results so we return the first N chunks in order + MemoryIndexChunkScorer.ScoredChunk chunk = scoredChunks.getFirst(); + assertTrue(chunk.content().equalsIgnoreCase("Cats like to sleep all day and play with mice")); + assertThat(chunk.score(), equalTo(0f)); + + chunk = scoredChunks.get(1); + assertTrue(chunk.content().equalsIgnoreCase("Dogs are loyal companions and great pets")); + assertThat(chunk.score(), equalTo(0f)); + + chunk = scoredChunks.get(2); + assertTrue(chunk.content().equalsIgnoreCase("The weather today is very sunny and warm")); + assertThat(chunk.score(), equalTo(0f)); + + // Null and Empty chunk input + scoredChunks = scorer.scoreChunks(List.of(), "puggles", maxResults); + assertTrue(scoredChunks.isEmpty()); + + scoredChunks = scorer.scoreChunks(CHUNKS, "", maxResults); + assertTrue(scoredChunks.isEmpty()); + + scoredChunks = scorer.scoreChunks(null, "puggles", maxResults); + assertTrue(scoredChunks.isEmpty()); + + scoredChunks = scorer.scoreChunks(CHUNKS, null, maxResults); + assertTrue(scoredChunks.isEmpty()); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkingSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkingSettings.java index e53ac64c4bc69..be6bdb6b16b1c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkingSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkingSettings.java @@ -52,6 +52,25 @@ public RecursiveChunkingSettings(StreamInput in) throws IOException { separators = in.readCollectionAsList(StreamInput::readString); } + @Override + public void validate() { + ValidationException validationException = new ValidationException(); + + if (maxChunkSize < MAX_CHUNK_SIZE_LOWER_LIMIT) { + validationException.addValidationError( + ChunkingSettingsOptions.MAX_CHUNK_SIZE + "[" + maxChunkSize + "] must be above " + MAX_CHUNK_SIZE_LOWER_LIMIT + ); + + if (separators != null && separators.isEmpty()) { + validationException.addValidationError("Recursive chunking settings can not have an empty list of separators"); + } + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + } + } + public static RecursiveChunkingSettings fromMap(Map map) { ValidationException validationException = new ValidationException(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java index b0ea93a8ead1c..25b5d248f294b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java @@ -59,6 +59,27 @@ public Integer maxChunkSize() { return maxChunkSize; } + @Override + public void validate() { + ValidationException validationException = new ValidationException(); + + if (maxChunkSize < MAX_CHUNK_SIZE_LOWER_LIMIT) { + validationException.addValidationError( + ChunkingSettingsOptions.MAX_CHUNK_SIZE + "[" + maxChunkSize + "] must be above " + MAX_CHUNK_SIZE_LOWER_LIMIT + ); + } + + if (sentenceOverlap > 1 || sentenceOverlap < 0) { + validationException.addValidationError( + ChunkingSettingsOptions.SENTENCE_OVERLAP + "[" + sentenceOverlap + "] must be either 0 or 1" + ); + } + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + } + @Override public Map asMap() { return Map.of( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java index 2b4e680d3fcc8..055df300bfd3e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java @@ -48,6 +48,27 @@ public WordBoundaryChunkingSettings(StreamInput in) throws IOException { overlap = in.readInt(); } + @Override + public void validate() { + ValidationException validationException = new ValidationException(); + + if (maxChunkSize < MAX_CHUNK_SIZE_LOWER_LIMIT) { + validationException.addValidationError( + ChunkingSettingsOptions.MAX_CHUNK_SIZE + "[" + maxChunkSize + "] must be above " + MAX_CHUNK_SIZE_LOWER_LIMIT + ); + } + + if (overlap > maxChunkSize / 2) { + validationException.addValidationError( + ChunkingSettingsOptions.OVERLAP + "[" + overlap + "] must be less than or equal to half of max chunk size" + ); + } + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + } + @Override public Map asMap() { return Map.of( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/ChunkScorerConfig.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/ChunkScorerConfig.java new file mode 100644 index 0000000000000..92c4eace01442 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/ChunkScorerConfig.java @@ -0,0 +1,100 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.rank.textsimilarity; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; +import org.elasticsearch.xpack.inference.chunking.SentenceBoundaryChunkingSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +public class ChunkScorerConfig implements Writeable { + + public final Integer size; + private final String inferenceText; + private final ChunkingSettings chunkingSettings; + + public static final int DEFAULT_CHUNK_SIZE = 300; + public static final int DEFAULT_SIZE = 1; + + public static ChunkingSettings createChunkingSettings(Integer chunkSize) { + int chunkSizeOrDefault = chunkSize != null ? chunkSize : DEFAULT_CHUNK_SIZE; + ChunkingSettings chunkingSettings = new SentenceBoundaryChunkingSettings(chunkSizeOrDefault, 0); + chunkingSettings.validate(); + return chunkingSettings; + } + + public static ChunkingSettings chunkingSettingsFromMap(Map map) { + + if (map == null || map.isEmpty()) { + return createChunkingSettings(DEFAULT_CHUNK_SIZE); + } + + if (map.size() == 1 && map.containsKey("max_chunk_size")) { + return createChunkingSettings((Integer) map.get("max_chunk_size")); + } + + return ChunkingSettingsBuilder.fromMap(map); + } + + public ChunkScorerConfig(StreamInput in) throws IOException { + this.size = in.readOptionalVInt(); + this.inferenceText = in.readString(); + Map chunkingSettingsMap = in.readGenericMap(); + this.chunkingSettings = ChunkingSettingsBuilder.fromMap(chunkingSettingsMap); + } + + public ChunkScorerConfig(Integer size, ChunkingSettings chunkingSettings) { + this(size, null, chunkingSettings); + } + + public ChunkScorerConfig(Integer size, String inferenceText, ChunkingSettings chunkingSettings) { + this.size = size; + this.inferenceText = inferenceText; + this.chunkingSettings = chunkingSettings; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalVInt(size); + out.writeString(inferenceText); + out.writeGenericMap(chunkingSettings.asMap()); + } + + public Integer size() { + return size; + } + + public String inferenceText() { + return inferenceText; + } + + public ChunkingSettings chunkingSettings() { + return chunkingSettings; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ChunkScorerConfig that = (ChunkScorerConfig) o; + return Objects.equals(size, that.size) + && Objects.equals(inferenceText, that.inferenceText) + && Objects.equals(chunkingSettings, that.chunkingSettings); + } + + @Override + public int hashCode() { + return Objects.hash(size, inferenceText, chunkingSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/SnippetConfig.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/SnippetConfig.java deleted file mode 100644 index f25ee40ca7ab1..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/SnippetConfig.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.rank.textsimilarity; - -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.index.query.QueryBuilder; - -import java.io.IOException; -import java.util.Objects; - -public class SnippetConfig implements Writeable { - - public final Integer numSnippets; - private final String inferenceText; - private final Integer tokenSizeLimit; - public final QueryBuilder snippetQueryBuilder; - - public static final int DEFAULT_NUM_SNIPPETS = 1; - - public SnippetConfig(StreamInput in) throws IOException { - this.numSnippets = in.readOptionalVInt(); - this.inferenceText = in.readString(); - this.tokenSizeLimit = in.readOptionalVInt(); - this.snippetQueryBuilder = in.readOptionalNamedWriteable(QueryBuilder.class); - } - - public SnippetConfig(Integer numSnippets) { - this(numSnippets, null, null); - } - - public SnippetConfig(Integer numSnippets, String inferenceText, Integer tokenSizeLimit) { - this(numSnippets, inferenceText, tokenSizeLimit, null); - } - - public SnippetConfig(Integer numSnippets, String inferenceText, Integer tokenSizeLimit, QueryBuilder snippetQueryBuilder) { - this.numSnippets = numSnippets; - this.inferenceText = inferenceText; - this.tokenSizeLimit = tokenSizeLimit; - this.snippetQueryBuilder = snippetQueryBuilder; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeOptionalVInt(numSnippets); - out.writeString(inferenceText); - out.writeOptionalVInt(tokenSizeLimit); - out.writeOptionalNamedWriteable(snippetQueryBuilder); - } - - public Integer numSnippets() { - return numSnippets; - } - - public String inferenceText() { - return inferenceText; - } - - public Integer tokenSizeLimit() { - return tokenSizeLimit; - } - - public QueryBuilder snippetQueryBuilder() { - return snippetQueryBuilder; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - SnippetConfig that = (SnippetConfig) o; - return Objects.equals(numSnippets, that.numSnippets) - && Objects.equals(inferenceText, that.inferenceText) - && Objects.equals(tokenSizeLimit, that.tokenSizeLimit) - && Objects.equals(snippetQueryBuilder, that.snippetQueryBuilder); - } - - @Override - public int hashCode() { - return Objects.hash(numSnippets, inferenceText, tokenSizeLimit, snippetQueryBuilder); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java index eed06a577df05..2a8af99721064 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java @@ -15,9 +15,6 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.index.query.MatchQueryBuilder; -import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.license.License; import org.elasticsearch.license.LicensedFeature; import org.elasticsearch.search.rank.RankBuilder; @@ -33,12 +30,12 @@ import java.util.List; import java.util.Objects; +import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.CHUNK_RESCORER_FIELD; import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.FAILURES_ALLOWED_FIELD; import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.FIELD_FIELD; import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.INFERENCE_ID_FIELD; import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.INFERENCE_TEXT_FIELD; import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.MIN_SCORE_FIELD; -import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.SNIPPETS_FIELD; /** * A {@code RankBuilder} that enables ranking with text similarity model inference. Supports parameters for configuring the inference call. @@ -47,11 +44,6 @@ public class TextSimilarityRankBuilder extends RankBuilder { public static final String NAME = "text_similarity_reranker"; - /** - * The default token size limit of the Elastic reranker is 512. - */ - private static final int DEFAULT_TOKEN_SIZE_LIMIT = 512; - public static final LicensedFeature.Momentary TEXT_SIMILARITY_RERANKER_FEATURE = LicensedFeature.momentary( null, "text-similarity-reranker", @@ -65,7 +57,7 @@ public class TextSimilarityRankBuilder extends RankBuilder { private final String field; private final Float minScore; private final boolean failuresAllowed; - private final SnippetConfig snippetConfig; + private final ChunkScorerConfig chunkScorerConfig; public TextSimilarityRankBuilder( String field, @@ -74,7 +66,7 @@ public TextSimilarityRankBuilder( int rankWindowSize, Float minScore, boolean failuresAllowed, - SnippetConfig snippetConfig + ChunkScorerConfig chunkScorerConfig ) { super(rankWindowSize); this.inferenceId = inferenceId; @@ -82,7 +74,7 @@ public TextSimilarityRankBuilder( this.field = field; this.minScore = minScore; this.failuresAllowed = failuresAllowed; - this.snippetConfig = snippetConfig; + this.chunkScorerConfig = chunkScorerConfig; } public TextSimilarityRankBuilder(StreamInput in) throws IOException { @@ -99,9 +91,9 @@ public TextSimilarityRankBuilder(StreamInput in) throws IOException { this.failuresAllowed = false; } if (in.getTransportVersion().supports(RERANK_SNIPPETS)) { - this.snippetConfig = in.readOptionalWriteable(SnippetConfig::new); + this.chunkScorerConfig = in.readOptionalWriteable(ChunkScorerConfig::new); } else { - this.snippetConfig = null; + this.chunkScorerConfig = null; } } @@ -127,7 +119,7 @@ public void doWriteTo(StreamOutput out) throws IOException { out.writeBoolean(failuresAllowed); } if (out.getTransportVersion().supports(RERANK_SNIPPETS)) { - out.writeOptionalWriteable(snippetConfig); + out.writeOptionalWriteable(chunkScorerConfig); } } @@ -144,53 +136,9 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio if (failuresAllowed) { builder.field(FAILURES_ALLOWED_FIELD.getPreferredName(), true); } - if (snippetConfig != null) { - builder.field(SNIPPETS_FIELD.getPreferredName(), snippetConfig); - } - } - - @Override - public RankBuilder rewrite(QueryRewriteContext queryRewriteContext) throws IOException { - TextSimilarityRankBuilder rewritten = this; - if (snippetConfig != null) { - QueryBuilder snippetQueryBuilder = snippetConfig.snippetQueryBuilder(); - if (snippetQueryBuilder == null) { - rewritten = new TextSimilarityRankBuilder( - field, - inferenceId, - inferenceText, - rankWindowSize(), - minScore, - failuresAllowed, - new SnippetConfig( - snippetConfig.numSnippets(), - snippetConfig.inferenceText(), - snippetConfig.tokenSizeLimit(), - new MatchQueryBuilder(field, inferenceText) - ) - ); - } else { - QueryBuilder rewrittenSnippetQueryBuilder = snippetQueryBuilder.rewrite(queryRewriteContext); - if (snippetQueryBuilder != rewrittenSnippetQueryBuilder) { - rewritten = new TextSimilarityRankBuilder( - field, - inferenceId, - inferenceText, - rankWindowSize(), - minScore, - failuresAllowed, - new SnippetConfig( - snippetConfig.numSnippets(), - snippetConfig.inferenceText(), - snippetConfig.tokenSizeLimit(), - rewrittenSnippetQueryBuilder - ) - ); - } - } + if (chunkScorerConfig != null) { + builder.field(CHUNK_RESCORER_FIELD.getPreferredName(), chunkScorerConfig); } - - return rewritten; } @Override @@ -237,7 +185,7 @@ public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int si @Override public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { - return new TextSimilarityRerankingRankFeaturePhaseRankShardContext(field, snippetConfig); + return new TextSimilarityRerankingRankFeaturePhaseRankShardContext(field, chunkScorerConfig); } @Override @@ -251,18 +199,12 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo inferenceText, minScore, failuresAllowed, - snippetConfig != null ? new SnippetConfig(snippetConfig.numSnippets, inferenceText, tokenSizeLimit(inferenceId)) : null + chunkScorerConfig != null + ? new ChunkScorerConfig(chunkScorerConfig.size, inferenceText, chunkScorerConfig.chunkingSettings()) + : null ); } - /** - * @return The token size limit to apply to this rerank context. - * TODO: This should be pulled from the inference endpoint when available, not hardcoded. - */ - public static Integer tokenSizeLimit(String inferenceId) { - return DEFAULT_TOKEN_SIZE_LIMIT; - } - public String field() { return field; } @@ -291,12 +233,12 @@ protected boolean doEquals(RankBuilder other) { && Objects.equals(field, that.field) && Objects.equals(minScore, that.minScore) && failuresAllowed == that.failuresAllowed - && Objects.equals(snippetConfig, that.snippetConfig); + && Objects.equals(chunkScorerConfig, that.chunkScorerConfig); } @Override protected int doHashCode() { - return Objects.hash(inferenceId, inferenceText, field, minScore, failuresAllowed, snippetConfig); + return Objects.hash(inferenceId, inferenceText, field, minScore, failuresAllowed, chunkScorerConfig); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java index 0a47db4d2a519..725443bc01e0d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java @@ -40,7 +40,7 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContext extends RankFe protected final String inferenceId; protected final String inferenceText; protected final Float minScore; - protected final SnippetConfig snippetConfig; + protected final ChunkScorerConfig chunkScorerConfig; public TextSimilarityRankFeaturePhaseRankCoordinatorContext( int size, @@ -51,14 +51,14 @@ public TextSimilarityRankFeaturePhaseRankCoordinatorContext( String inferenceText, Float minScore, boolean failuresAllowed, - @Nullable SnippetConfig snippetConfig + @Nullable ChunkScorerConfig chunkScorerConfig ) { super(size, from, rankWindowSize, failuresAllowed); this.client = client; this.inferenceId = inferenceId; this.inferenceText = inferenceText; this.minScore = minScore; - this.snippetConfig = snippetConfig; + this.chunkScorerConfig = chunkScorerConfig; } @Override @@ -80,8 +80,8 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener rankedDocs return scores; } - float[] extractScoresFromRankedSnippets(List rankedDocs, RankFeatureDoc[] featureDocs) { + float[] extractScoresFromRankedChunks(List rankedDocs, RankFeatureDoc[] featureDocs) { float[] scores = new float[featureDocs.length]; boolean[] hasScore = new boolean[featureDocs.length]; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java index 18bbbd8a2c134..74e8ff2bd4042 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java @@ -11,6 +11,7 @@ import org.elasticsearch.common.util.FeatureFlag; import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.search.builder.SearchSourceBuilder; @@ -26,6 +27,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.Objects; import static org.elasticsearch.search.rank.RankBuilder.DEFAULT_RANK_WINDOW_SIZE; @@ -50,8 +52,9 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder public static final ParseField INFERENCE_TEXT_FIELD = new ParseField("inference_text"); public static final ParseField FIELD_FIELD = new ParseField("field"); public static final ParseField FAILURES_ALLOWED_FIELD = new ParseField("allow_rerank_failures"); - public static final ParseField SNIPPETS_FIELD = new ParseField("snippets"); - public static final ParseField NUM_SNIPPETS_FIELD = new ParseField("num_snippets"); + public static final ParseField CHUNK_RESCORER_FIELD = new ParseField("chunk_rescorer"); + public static final ParseField CHUNK_SIZE_FIELD = new ParseField("size"); + public static final ParseField CHUNKING_SETTINGS_FIELD = new ParseField("chunking_settings"); public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(TextSimilarityRankBuilder.NAME, args -> { @@ -61,7 +64,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder String field = (String) args[3]; int rankWindowSize = args[4] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[4]; boolean failuresAllowed = args[5] != null && (Boolean) args[5]; - SnippetConfig snippets = (SnippetConfig) args[6]; + ChunkScorerConfig chunkScorerConfig = (ChunkScorerConfig) args[6]; return new TextSimilarityRankRetrieverBuilder( retrieverBuilder, @@ -70,18 +73,18 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder field, rankWindowSize, failuresAllowed, - snippets + chunkScorerConfig ); }); - private static final ConstructingObjectParser SNIPPETS_PARSER = new ConstructingObjectParser<>( - SNIPPETS_FIELD.getPreferredName(), - true, - args -> { - Integer numSnippets = (Integer) args[0]; - return new SnippetConfig(numSnippets); - } - ); + private static final ConstructingObjectParser CHUNK_SCORER_PARSER = + new ConstructingObjectParser<>(CHUNK_RESCORER_FIELD.getPreferredName(), true, args -> { + Integer size = (Integer) args[0]; + @SuppressWarnings("unchecked") + Map chunkingSettingsMap = (Map) args[1]; + ChunkingSettings chunkingSettings = ChunkScorerConfig.chunkingSettingsFromMap(chunkingSettingsMap); + return new ChunkScorerConfig(size, chunkingSettings); + }); static { PARSER.declareNamedObject(constructorArg(), (p, c, n) -> { @@ -94,9 +97,10 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder PARSER.declareString(constructorArg(), FIELD_FIELD); PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD); PARSER.declareBoolean(optionalConstructorArg(), FAILURES_ALLOWED_FIELD); - PARSER.declareObject(optionalConstructorArg(), SNIPPETS_PARSER, SNIPPETS_FIELD); + PARSER.declareObject(optionalConstructorArg(), CHUNK_SCORER_PARSER, CHUNK_RESCORER_FIELD); if (RERANK_SNIPPETS.isEnabled()) { - SNIPPETS_PARSER.declareInt(optionalConstructorArg(), NUM_SNIPPETS_FIELD); + CHUNK_SCORER_PARSER.declareInt(optionalConstructorArg(), CHUNK_SIZE_FIELD); + CHUNK_SCORER_PARSER.declareObjectOrNull(optionalConstructorArg(), (p, c) -> p.map(), null, CHUNKING_SETTINGS_FIELD); } RetrieverBuilder.declareBaseParserFields(PARSER); @@ -117,7 +121,7 @@ public static TextSimilarityRankRetrieverBuilder fromXContent( private final String inferenceText; private final String field; private final boolean failuresAllowed; - private final SnippetConfig snippets; + private final ChunkScorerConfig chunkScorerConfig; public TextSimilarityRankRetrieverBuilder( RetrieverBuilder retrieverBuilder, @@ -126,14 +130,14 @@ public TextSimilarityRankRetrieverBuilder( String field, int rankWindowSize, boolean failuresAllowed, - SnippetConfig snippets + ChunkScorerConfig chunkScorerConfig ) { super(List.of(RetrieverSource.from(retrieverBuilder)), rankWindowSize); this.inferenceId = inferenceId; this.inferenceText = inferenceText; this.field = field; this.failuresAllowed = failuresAllowed; - this.snippets = snippets; + this.chunkScorerConfig = chunkScorerConfig; } public TextSimilarityRankRetrieverBuilder( @@ -146,14 +150,14 @@ public TextSimilarityRankRetrieverBuilder( boolean failuresAllowed, String retrieverName, List preFilterQueryBuilders, - SnippetConfig snippets + ChunkScorerConfig chunkScorerConfig ) { super(retrieverSource, rankWindowSize); if (retrieverSource.size() != 1) { throw new IllegalArgumentException("[" + getName() + "] retriever should have exactly one inner retriever"); } - if (snippets != null && snippets.numSnippets() != null && snippets.numSnippets() < 1) { - throw new IllegalArgumentException("num_snippets must be greater than 0, was: " + snippets.numSnippets()); + if (chunkScorerConfig != null && chunkScorerConfig.size() != null && chunkScorerConfig.size() < 1) { + throw new IllegalArgumentException("size must be greater than 0, was: " + chunkScorerConfig.size()); } this.inferenceId = inferenceId; this.inferenceText = inferenceText; @@ -162,7 +166,7 @@ public TextSimilarityRankRetrieverBuilder( this.failuresAllowed = failuresAllowed; this.retrieverName = retrieverName; this.preFilterQueryBuilders = preFilterQueryBuilders; - this.snippets = snippets; + this.chunkScorerConfig = chunkScorerConfig; } @Override @@ -180,7 +184,7 @@ protected TextSimilarityRankRetrieverBuilder clone( failuresAllowed, retrieverName, newPreFilterQueryBuilders, - snippets + chunkScorerConfig ); } @@ -215,8 +219,8 @@ protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBu rankWindowSize, minScore, failuresAllowed, - snippets != null - ? new SnippetConfig(snippets.numSnippets, inferenceText, TextSimilarityRankBuilder.tokenSizeLimit(inferenceId)) + chunkScorerConfig != null + ? new ChunkScorerConfig(chunkScorerConfig.size, inferenceText, chunkScorerConfig.chunkingSettings()) : null ) ); @@ -246,10 +250,13 @@ protected void doToXContent(XContentBuilder builder, Params params) throws IOExc if (failuresAllowed) { builder.field(FAILURES_ALLOWED_FIELD.getPreferredName(), failuresAllowed); } - if (snippets != null) { - builder.startObject(SNIPPETS_FIELD.getPreferredName()); - if (snippets.numSnippets() != null) { - builder.field(NUM_SNIPPETS_FIELD.getPreferredName(), snippets.numSnippets()); + if (chunkScorerConfig != null) { + builder.startObject(CHUNK_RESCORER_FIELD.getPreferredName()); + if (chunkScorerConfig.size() != null) { + builder.field(CHUNK_SIZE_FIELD.getPreferredName(), chunkScorerConfig.size()); + } + if (chunkScorerConfig.chunkingSettings() != null) { + builder.field(CHUNKING_SETTINGS_FIELD.getPreferredName(), chunkScorerConfig.chunkingSettings().asMap()); } builder.endObject(); } @@ -265,11 +272,11 @@ public boolean doEquals(Object other) { && rankWindowSize == that.rankWindowSize && Objects.equals(minScore, that.minScore) && failuresAllowed == that.failuresAllowed - && Objects.equals(snippets, that.snippets); + && Objects.equals(chunkScorerConfig, that.chunkScorerConfig); } @Override public int doHashCode() { - return Objects.hash(inferenceId, inferenceText, field, rankWindowSize, minScore, failuresAllowed, snippets); + return Objects.hash(inferenceId, inferenceText, field, rankWindowSize, minScore, failuresAllowed, chunkScorerConfig); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRerankingRankFeaturePhaseRankShardContext.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRerankingRankFeaturePhaseRankShardContext.java index 66fb4a366a757..60adefd3493d2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRerankingRankFeaturePhaseRankShardContext.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRerankingRankFeaturePhaseRankShardContext.java @@ -8,38 +8,34 @@ package org.elasticsearch.xpack.inference.rank.textsimilarity; import org.elasticsearch.common.document.DocumentField; -import org.elasticsearch.common.logging.HeaderWarning; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; -import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder; -import org.elasticsearch.search.fetch.subphase.highlight.HighlightField; -import org.elasticsearch.search.fetch.subphase.highlight.SearchHighlightContext; -import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.rank.RankShardResult; import org.elasticsearch.search.rank.feature.RankFeatureDoc; import org.elasticsearch.search.rank.feature.RankFeatureShardResult; import org.elasticsearch.search.rank.rerank.RerankingRankFeaturePhaseRankShardContext; -import org.elasticsearch.xcontent.Text; +import org.elasticsearch.xpack.core.common.chunks.MemoryIndexChunkScorer; +import org.elasticsearch.xpack.inference.chunking.Chunker; +import org.elasticsearch.xpack.inference.chunking.ChunkerBuilder; import java.io.IOException; -import java.util.Arrays; import java.util.List; -import java.util.Map; -import static org.elasticsearch.xpack.inference.rank.textsimilarity.SnippetConfig.DEFAULT_NUM_SNIPPETS; +import static org.elasticsearch.xpack.inference.rank.textsimilarity.ChunkScorerConfig.DEFAULT_SIZE; public class TextSimilarityRerankingRankFeaturePhaseRankShardContext extends RerankingRankFeaturePhaseRankShardContext { - private final SnippetConfig snippetRankInput; + private final ChunkScorerConfig chunkScorerConfig; + private final ChunkingSettings chunkingSettings; + private final Chunker chunker; - // Rough approximation of token size vs. characters in highlight fragments. - // TODO: highlighter should be able to set fragment size by token not length - private static final int TOKEN_SIZE_LIMIT_MULTIPLIER = 5; - - public TextSimilarityRerankingRankFeaturePhaseRankShardContext(String field, @Nullable SnippetConfig snippetRankInput) { + public TextSimilarityRerankingRankFeaturePhaseRankShardContext(String field, @Nullable ChunkScorerConfig chunkScorerConfig) { super(field); - this.snippetRankInput = snippetRankInput; + this.chunkScorerConfig = chunkScorerConfig; + chunkingSettings = chunkScorerConfig != null ? chunkScorerConfig.chunkingSettings() : null; + chunker = chunkingSettings != null ? ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy()) : null; } @Override @@ -49,49 +45,34 @@ public RankShardResult doBuildRankFeatureShardResult(SearchHits hits, int shardI rankFeatureDocs[i] = new RankFeatureDoc(hits.getHits()[i].docId(), hits.getHits()[i].getScore(), shardId); SearchHit hit = hits.getHits()[i]; DocumentField docField = hit.field(field); - if (snippetRankInput == null && docField != null) { - rankFeatureDocs[i].featureData(List.of(docField.getValue().toString())); - } else { - Map highlightFields = hit.getHighlightFields(); - if (highlightFields != null && highlightFields.containsKey(field) && highlightFields.get(field).fragments().length > 0) { - List snippets = Arrays.stream(highlightFields.get(field).fragments()).map(Text::string).toList(); - rankFeatureDocs[i].featureData(snippets); - } else if (docField != null) { - // If we did not get highlighting results, backfill with the doc field value - // but pass in a warning because we are not reranking on snippets only + if (docField != null) { + if (chunkScorerConfig != null) { + int size = chunkScorerConfig.size() != null ? chunkScorerConfig.size() : DEFAULT_SIZE; + List chunkOffsets = chunker.chunk(docField.getValue().toString(), chunkingSettings); + List chunks = chunkOffsets.stream() + .map(offset -> { return docField.getValue().toString().substring(offset.start(), offset.end()); }) + .toList(); + + List bestChunks; + try { + MemoryIndexChunkScorer scorer = new MemoryIndexChunkScorer(); + List scoredChunks = scorer.scoreChunks( + chunks, + chunkScorerConfig.inferenceText(), + size + ); + bestChunks = scoredChunks.stream().map(MemoryIndexChunkScorer.ScoredChunk::content).limit(size).toList(); + } catch (IOException e) { + throw new IllegalStateException("Could not generate chunks for input to reranker", e); + } + rankFeatureDocs[i].featureData(bestChunks); + + } else { rankFeatureDocs[i].featureData(List.of(docField.getValue().toString())); - HeaderWarning.addWarning( - "Reranking on snippets requested, but no snippets were found for field [" + field + "]. Using field value instead." - ); } } } return new RankFeatureShardResult(rankFeatureDocs); } - @Override - public void prepareForFetch(SearchContext context) { - if (snippetRankInput != null) { - try { - HighlightBuilder highlightBuilder = new HighlightBuilder(); - highlightBuilder.highlightQuery(snippetRankInput.snippetQueryBuilder()); - // Stripping pre/post tags as they're not useful for snippet creation - highlightBuilder.field(field).preTags("").postTags(""); - // Return highest scoring fragments - highlightBuilder.order(HighlightBuilder.Order.SCORE); - int numSnippets = snippetRankInput.numSnippets() != null ? snippetRankInput.numSnippets() : DEFAULT_NUM_SNIPPETS; - highlightBuilder.numOfFragments(numSnippets); - // Rely on the model to determine the fragment size - int tokenSizeLimit = snippetRankInput.tokenSizeLimit(); - int fragmentSize = tokenSizeLimit * TOKEN_SIZE_LIMIT_MULTIPLIER; - highlightBuilder.fragmentSize(fragmentSize); - highlightBuilder.noMatchSize(fragmentSize); - SearchHighlightContext searchHighlightContext = highlightBuilder.build(context.getSearchExecutionContext()); - context.highlight(searchHighlightContext); - } catch (IOException e) { - throw new RuntimeException("Failed to generate snippet request", e); - } - } - } - } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContextTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContextTests.java index 27aa8b6fb5b5a..bf2beaa131e67 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContextTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContextTests.java @@ -39,7 +39,7 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContextTests extends E null ); - TextSimilarityRankFeaturePhaseRankCoordinatorContext withSnippets = new TextSimilarityRankFeaturePhaseRankCoordinatorContext( + TextSimilarityRankFeaturePhaseRankCoordinatorContext withChunks = new TextSimilarityRankFeaturePhaseRankCoordinatorContext( 10, 0, 100, @@ -48,7 +48,7 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContextTests extends E "some query", 0.0f, false, - new SnippetConfig(2, "some query", 10) + new ChunkScorerConfig(2, "some query", null) ); public void testComputeScores() { @@ -87,7 +87,7 @@ public void testExtractScoresFromRankedDocs() { assertArrayEquals(new float[] { 1.0f, 3.0f, 2.0f }, scores, 0.0f); } - public void testExtractScoresFromSingleSnippets() { + public void testExtractScoresFromSingleChunk() { List rankedDocs = List.of( new RankedDocsResults.RankedDoc(0, 1.0f, "text 1"), @@ -99,12 +99,12 @@ public void testExtractScoresFromSingleSnippets() { createRankFeatureDoc(1, 3.0f, 1, List.of("text 2")), createRankFeatureDoc(2, 2.0f, 0, List.of("text 3")) }; - float[] scores = withSnippets.extractScoresFromRankedSnippets(rankedDocs, featureDocs); - // Returned cores are from the snippet, not the whole text + float[] scores = withChunks.extractScoresFromRankedChunks(rankedDocs, featureDocs); + // Returned cores are from the chunk, not the whole text assertArrayEquals(new float[] { 1.0f, 2.5f, 1.5f }, scores, 0.0f); } - public void testExtractScoresFromMultipleSnippets() { + public void testExtractScoresFromMultipleChunks() { List rankedDocs = List.of( new RankedDocsResults.RankedDoc(0, 1.0f, "this is text 1"), @@ -119,8 +119,8 @@ public void testExtractScoresFromMultipleSnippets() { createRankFeatureDoc(1, 3.0f, 1, List.of("yet more text", "this is text 2")), createRankFeatureDoc(2, 2.0f, 0, List.of("this is text 3", "oh look, more text")) }; - float[] scores = withSnippets.extractScoresFromRankedSnippets(rankedDocs, featureDocs); - // Returned scores are from the best-ranking snippet, not the whole text + float[] scores = withChunks.extractScoresFromRankedChunks(rankedDocs, featureDocs); + // Returned scores are from the best-ranking chunk, not the whole text assertArrayEquals(new float[] { 2.5f, 3.0f, 2.0f }, scores, 0.0f); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java index c88ce1b65ee3d..cad05d791c9c6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java @@ -177,9 +177,9 @@ public ThrowingMockRequestActionBasedRankBuilder( Float minScore, boolean failuresAllowed, String throwingType, - SnippetConfig snippetConfig + ChunkScorerConfig chunkScorerConfig ) { - super(field, inferenceId, inferenceText, rankWindowSize, minScore, failuresAllowed, snippetConfig); + super(field, inferenceId, inferenceText, rankWindowSize, minScore, failuresAllowed, chunkScorerConfig); this.throwingRankBuilderType = AbstractRerankerIT.ThrowingRankBuilderType.valueOf(throwingType); } diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml index 3dd85ef9e8658..d971aad2bbc4b 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml @@ -515,14 +515,14 @@ setup: --- -"Text similarity reranker specifying number of snippets must be > 0": +"Text similarity reranker specifying number of rescore_chunks must be > 0": - requires: cluster_features: "text_similarity_reranker_snippets" - reason: snippets introduced in 9.2.0 + reason: rescore_chunks introduced in 9.2.0 - do: - catch: /num_snippets must be greater than 0/ + catch: /size must be greater than 0/ search: index: test-index body: @@ -538,18 +538,18 @@ setup: inference_id: my-rerank-model inference_text: "How often does the moon hide the sun?" field: inference_text_field - snippets: - num_snippets: 0 + chunk_rescorer: + size: 0 size: 10 - match: { status: 400 } --- -"Reranking based on snippets": +"Reranking based on rescore_chunks": - requires: cluster_features: "text_similarity_reranker_snippets" - reason: snippets introduced in 9.2.0 + reason: rescore_chunks introduced in 9.2.0 - do: search: @@ -569,8 +569,8 @@ setup: inference_id: my-rerank-model inference_text: "How often does the moon hide the sun?" field: text - snippets: - num_snippets: 2 + chunk_rescorer: + size: 2 size: 10 - match: { hits.total.value: 2 } @@ -580,11 +580,11 @@ setup: - match: { hits.hits.1._id: "doc_2" } --- -"Reranking based on snippets using defaults": +"Reranking based on rescore_chunks using defaults": - requires: cluster_features: "text_similarity_reranker_snippets" - reason: snippets introduced in 9.2.0 + reason: rescore_chunks introduced in 9.2.0 - do: search: @@ -603,7 +603,7 @@ setup: inference_id: my-rerank-model inference_text: "How often does the moon hide the sun?" field: text - snippets: { } + chunk_rescorer: { } size: 10 - match: { hits.total.value: 2 } @@ -613,11 +613,11 @@ setup: - match: { hits.hits.1._id: "doc_2" } --- -"Reranking based on snippets on a semantic_text field": +"Reranking based on rescore_chunks on a semantic_text field": - requires: cluster_features: "text_similarity_reranker_snippets" - reason: snippets introduced in 9.2.0 + reason: rescore_chunks introduced in 9.2.0 - do: search: @@ -637,8 +637,8 @@ setup: inference_id: my-rerank-model inference_text: "how often does the moon hide the sun?" field: semantic_text_field - snippets: - num_snippets: 2 + chunk_rescorer: + size: 2 size: 10 - match: { hits.total.value: 2 } @@ -648,11 +648,11 @@ setup: - match: { hits.hits.1._id: "doc_2" } --- -"Reranking based on snippets on a semantic_text field using defaults": +"Reranking based on rescore_chunks on a semantic_text field using defaults": - requires: cluster_features: "text_similarity_reranker_snippets" - reason: snippets introduced in 9.2.0 + reason: rescore_chunks introduced in 9.2.0 - do: search: @@ -672,7 +672,7 @@ setup: inference_id: my-rerank-model inference_text: "how often does the moon hide the sun?" field: semantic_text_field - snippets: { } + chunk_rescorer: { } size: 10 - match: { hits.total.value: 2 } @@ -682,38 +682,180 @@ setup: - match: { hits.hits.1._id: "doc_2" } --- -"Reranking based on snippets when highlighter doesn't return results": +"Reranking based on rescore_chunks on a semantic_text field specifying chunking settings": - requires: - test_runner_features: allowed_warnings cluster_features: "text_similarity_reranker_snippets" - reason: snippets introduced in 9.2.0 + reason: rescore_chunks introduced in 9.2.0 - do: - allowed_warnings: - - "Reranking on snippets requested, but no snippets were found for field [inference_text_field]. Using field value instead." search: index: test-index body: track_total_hits: true - fields: [ "text", "topic" ] + fields: [ "text", "semantic_text_field", "topic" ] retriever: text_similarity_reranker: retriever: standard: query: - term: - topic: "science" + match: + topic: + query: "science" rank_window_size: 10 inference_id: my-rerank-model - inference_text: "How often does the moon hide the sun?" - field: inference_text_field - snippets: - num_snippets: 2 + inference_text: "how often does the moon hide the sun?" + field: semantic_text_field + chunk_rescorer: + chunking_settings: + strategy: sentence + max_chunk_size: 20 + sentence_overlap: 0 size: 10 - match: { hits.total.value: 2 } - length: { hits.hits: 2 } - - match: { hits.hits.0._id: "doc_2" } - - match: { hits.hits.1._id: "doc_1" } + - match: { hits.hits.0._id: "doc_1" } + - match: { hits.hits.1._id: "doc_2" } + +--- +"Reranking based on rescore_chunks on a semantic_text field specifying chunking settings requires valid chunking settings": + + - requires: + cluster_features: "text_similarity_reranker_snippets" + reason: rescore_chunks introduced in 9.2.0 + + - do: + catch: /Invalid value/ + search: + index: test-index + body: + track_total_hits: true + fields: [ "text", "semantic_text_field", "topic" ] + retriever: + text_similarity_reranker: + retriever: + standard: + query: + match: + topic: + query: "science" + rank_window_size: 10 + inference_id: my-rerank-model + inference_text: "how often does the moon hide the sun?" + field: semantic_text_field + chunk_rescorer: + chunk_size: 20 + chunking_settings: + strategy: sentence + max_chunk_size: 10 + sentence_overlap: 20 + size: 10 + +--- +"Reranking based on rescore_chunks on a semantic_text field specifying chunk size": + + - requires: + cluster_features: "text_similarity_reranker_snippets" + reason: rescore_chunks introduced in 9.2.0 + + - do: + search: + index: test-index + body: + track_total_hits: true + fields: [ "text", "semantic_text_field", "topic" ] + retriever: + text_similarity_reranker: + retriever: + standard: + query: + match: + topic: + query: "science" + rank_window_size: 10 + inference_id: my-rerank-model + inference_text: "how often does the moon hide the sun?" + field: semantic_text_field + chunk_rescorer: + chunk_size: 20 + size: 10 + + - match: { hits.total.value: 2 } + - length: { hits.hits: 2 } + + - match: { hits.hits.0._id: "doc_1" } + - match: { hits.hits.1._id: "doc_2" } + +--- +"Reranking based on chunk_rescorer specifying only max chunk size will default remaining chunking settings": + + - requires: + cluster_features: "text_similarity_reranker_snippets" + reason: rescore_chunks introduced in 9.2.0 + + - do: + search: + index: test-index + body: + track_total_hits: true + fields: [ "text", "semantic_text_field", "topic" ] + retriever: + text_similarity_reranker: + retriever: + standard: + query: + match: + topic: + query: "science" + rank_window_size: 10 + inference_id: my-rerank-model + inference_text: "how often does the moon hide the sun?" + field: semantic_text_field + chunk_rescorer: + chunk_rescorer: 20 + chunking_settings: + max_chunk_size: 20 + size: 10 + + - match: { hits.total.value: 2 } + - length: { hits.hits: 2 } + + - match: { hits.hits.0._id: "doc_1" } + - match: { hits.hits.1._id: "doc_2" } + + +--- +"Reranking based on chunk_rescorer will send in first chunk if no text matches found": + + - requires: + cluster_features: "text_similarity_reranker_snippets" + reason: rescore_chunks introduced in 9.2.0 + + - do: + search: + index: test-index + body: + track_total_hits: true + fields: [ "text", "semantic_text_field", "topic" ] + retriever: + text_similarity_reranker: + retriever: + standard: + query: + match: + topic: + query: "science" + rank_window_size: 10 + inference_id: my-rerank-model + inference_text: "iamanonsensefieldthatshouldreturnnoresults" + field: semantic_text_field + chunk_rescorer: { } + size: 10 + + - match: { hits.total.value: 2 } + - length: { hits.hits: 2 } + + - match: { hits.hits.0._id: "doc_1" } + - match: { hits.hits.1._id: "doc_2" }