-
Couldn't load subscription status.
- Fork 25.6k
Add RerankRequestChunker #130485
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add RerankRequestChunker #130485
Changes from 13 commits
849147e
c41d54c
da4c939
004ca8f
5ec620a
4ff8eb0
ec78b87
9ef8917
24497ae
1fea365
8396214
8b97711
833ef02
77701e1
344e121
02c9d0a
d68bf09
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| pr: 130485 | ||
| summary: Add `RerankRequestChunker` | ||
| area: Machine Learning | ||
| type: enhancement | ||
| issues: [] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| 9168000 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1 @@ | ||
| index_request_include_tsid,9167000 | ||
| elastic_reranker_chunking_configuration,9168000 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,99 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License | ||
| * 2.0; you may not use this file except in compliance with the Elastic License | ||
| * 2.0. | ||
| */ | ||
|
|
||
| package org.elasticsearch.xpack.inference.chunking; | ||
|
|
||
| import com.ibm.icu.text.BreakIterator; | ||
|
|
||
| import org.elasticsearch.action.ActionListener; | ||
| import org.elasticsearch.inference.ChunkingSettings; | ||
| import org.elasticsearch.inference.InferenceServiceResults; | ||
| import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; | ||
|
|
||
| import java.util.ArrayList; | ||
| import java.util.HashSet; | ||
| import java.util.List; | ||
| import java.util.Set; | ||
|
|
||
| public class RerankRequestChunker { | ||
| private final List<String> inputs; | ||
| private final List<RerankChunks> rerankChunks; | ||
|
|
||
| public RerankRequestChunker(String query, List<String> inputs, Integer maxChunksPerDoc) { | ||
| this.inputs = inputs; | ||
| this.rerankChunks = chunk(inputs, buildChunkingSettingsForElasticRerank(query), maxChunksPerDoc); | ||
| } | ||
|
|
||
| private List<RerankChunks> chunk(List<String> inputs, ChunkingSettings chunkingSettings, Integer maxChunksPerDoc) { | ||
| var chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy()); | ||
| var chunks = new ArrayList<RerankChunks>(); | ||
| for (int i = 0; i < inputs.size(); i++) { | ||
| var chunksForInput = chunker.chunk(inputs.get(i), chunkingSettings); | ||
| if (maxChunksPerDoc != null && chunksForInput.size() > maxChunksPerDoc) { | ||
| chunksForInput = chunksForInput.subList(0, maxChunksPerDoc); | ||
| } | ||
|
|
||
| for (var chunk : chunksForInput) { | ||
| chunks.add(new RerankChunks(i, inputs.get(i).substring(chunk.start(), chunk.end()))); | ||
| } | ||
| } | ||
| return chunks; | ||
| } | ||
|
|
||
| public List<String> getChunkedInputs() { | ||
| List<String> chunkedInputs = new ArrayList<>(); | ||
| for (RerankChunks chunk : rerankChunks) { | ||
| chunkedInputs.add(chunk.chunkString()); | ||
| } | ||
|
|
||
| return chunkedInputs; | ||
| } | ||
|
|
||
| public ActionListener<InferenceServiceResults> parseChunkedRerankResultsListener(ActionListener<InferenceServiceResults> listener) { | ||
| return ActionListener.wrap(results -> { | ||
| if (results instanceof RankedDocsResults rankedDocsResults) { | ||
| listener.onResponse(parseRankedDocResultsForChunks(rankedDocsResults)); | ||
|
|
||
| } else { | ||
| listener.onFailure(new IllegalArgumentException("Expected RankedDocsResults but got: " + results.getClass())); | ||
| } | ||
|
|
||
| }, listener::onFailure); | ||
| } | ||
|
|
||
| private RankedDocsResults parseRankedDocResultsForChunks(RankedDocsResults rankedDocsResults) { | ||
| List<RankedDocsResults.RankedDoc> updatedRankedDocs = new ArrayList<>(); | ||
| Set<Integer> docIndicesSeen = new HashSet<>(); | ||
| for (RankedDocsResults.RankedDoc rankedDoc : rankedDocsResults.getRankedDocs()) { | ||
|
||
| int chunkIndex = rankedDoc.index(); | ||
| int docIndex = rerankChunks.get(chunkIndex).docIndex(); | ||
|
|
||
| if (docIndicesSeen.contains(docIndex) == false) { | ||
| // Create a ranked doc with the full input string and the index for the document instead of the chunk | ||
| RankedDocsResults.RankedDoc updatedRankedDoc = new RankedDocsResults.RankedDoc( | ||
| docIndex, | ||
| rankedDoc.relevanceScore(), | ||
| inputs.get(docIndex) | ||
| ); | ||
| updatedRankedDocs.add(updatedRankedDoc); | ||
| docIndicesSeen.add(docIndex); | ||
| } | ||
| } | ||
| updatedRankedDocs.sort((r1, r2) -> Float.compare(r2.relevanceScore(), r1.relevanceScore())); | ||
|
|
||
| return new RankedDocsResults(updatedRankedDocs); | ||
| } | ||
|
|
||
| public record RerankChunks(int docIndex, String chunkString) {}; | ||
|
|
||
| private ChunkingSettings buildChunkingSettingsForElasticRerank(String query) { | ||
| var wordIterator = BreakIterator.getWordInstance(); | ||
| wordIterator.setText(query); | ||
| var queryWordCount = ChunkerUtils.countWords(0, query.length(), wordIterator); | ||
| return ChunkingSettingsBuilder.buildChunkingSettingsForElasticRerank(queryWordCount); | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.