|
| 1 | +/* |
| 2 | + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one |
| 3 | + * or more contributor license agreements. Licensed under the Elastic License |
| 4 | + * 2.0; you may not use this file except in compliance with the Elastic License |
| 5 | + * 2.0. |
| 6 | + */ |
| 7 | + |
| 8 | +package org.elasticsearch.xpack.inference.chunking; |
| 9 | + |
| 10 | +import org.elasticsearch.action.ActionListener; |
| 11 | +import org.elasticsearch.inference.ChunkingSettings; |
| 12 | +import org.elasticsearch.xpack.core.inference.action.InferenceAction; |
| 13 | +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; |
| 14 | + |
| 15 | +import java.util.ArrayList; |
| 16 | +import java.util.HashMap; |
| 17 | +import java.util.List; |
| 18 | +import java.util.Map; |
| 19 | + |
| 20 | +public class RerankRequestChunker { |
| 21 | + |
| 22 | + private final ChunkingSettings chunkingSettings; |
| 23 | + private final List<String> inputs; |
| 24 | + private final Map<Integer, RerankChunks> rerankChunks; |
| 25 | + |
| 26 | + public RerankRequestChunker(List<String> inputs) { |
| 27 | + // TODO: Make chunking settings dependent on the model being used. |
| 28 | + // There may be a way to do this dynamically knowing the max token size for the model/service and query size |
| 29 | + // instead of hardcoding it ona model/service basis. |
| 30 | + this.chunkingSettings = new WordBoundaryChunkingSettings(100, 10); |
| 31 | + this.inputs = inputs; |
| 32 | + this.rerankChunks = chunk(inputs, chunkingSettings); |
| 33 | + } |
| 34 | + |
| 35 | + private Map<Integer, RerankChunks> chunk(List<String> inputs, ChunkingSettings chunkingSettings) { |
| 36 | + var chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy()); |
| 37 | + var chunks = new HashMap<Integer, RerankChunks>(); |
| 38 | + var chunkIndex = 0; |
| 39 | + for (int i = 0; i < inputs.size(); i++) { |
| 40 | + var chunksForInput = chunker.chunk(inputs.get(i), chunkingSettings); |
| 41 | + for (var chunk : chunksForInput) { |
| 42 | + chunks.put(chunkIndex, new RerankChunks(i, inputs.get(i).substring(chunk.start(), chunk.end()))); |
| 43 | + chunkIndex++; |
| 44 | + } |
| 45 | + } |
| 46 | + return chunks; |
| 47 | + } |
| 48 | + |
| 49 | + public List<String> getChunkedInputs() { |
| 50 | + List<String> chunkedInputs = new ArrayList<>(); |
| 51 | + for (RerankChunks chunk : rerankChunks.values()) { |
| 52 | + chunkedInputs.add(chunk.chunkString()); |
| 53 | + } |
| 54 | + // TODO: Score the inputs here and only return the top N chunks for each document |
| 55 | + return chunkedInputs; |
| 56 | + } |
| 57 | + |
| 58 | + public ActionListener<InferenceAction.Response> parseChunkedRerankResultsListener(ActionListener<InferenceAction.Response> listener) { |
| 59 | + return ActionListener.wrap(results -> { |
| 60 | + if (results.getResults() instanceof RankedDocsResults rankedDocsResults) { |
| 61 | + listener.onResponse(new InferenceAction.Response(parseRankedDocResultsForChunks(rankedDocsResults))); |
| 62 | + // TODO: Figure out if the above correctly creates the response or if it loses any info |
| 63 | + |
| 64 | + } else { |
| 65 | + listener.onFailure(new IllegalArgumentException("Expected RankedDocsResults but got: " + results.getClass())); |
| 66 | + } |
| 67 | + |
| 68 | + }, listener::onFailure); |
| 69 | + } |
| 70 | + |
| 71 | + private RankedDocsResults parseRankedDocResultsForChunks(RankedDocsResults rankedDocsResults) { |
| 72 | + Map<Integer, RankedDocsResults.RankedDoc> bestRankedDocResultPerDoc = new HashMap<>(); |
| 73 | + for (var rankedDoc : rankedDocsResults.getRankedDocs()) { |
| 74 | + int chunkIndex = rankedDoc.index(); |
| 75 | + int docIndex = rerankChunks.get(chunkIndex).docIndex(); |
| 76 | + if (bestRankedDocResultPerDoc.containsKey(docIndex)) { |
| 77 | + RankedDocsResults.RankedDoc existingDoc = bestRankedDocResultPerDoc.get(docIndex); |
| 78 | + if (rankedDoc.relevanceScore() > existingDoc.relevanceScore()) { |
| 79 | + bestRankedDocResultPerDoc.put( |
| 80 | + docIndex, |
| 81 | + new RankedDocsResults.RankedDoc(docIndex, rankedDoc.relevanceScore(), inputs.get(docIndex)) |
| 82 | + ); |
| 83 | + } |
| 84 | + } else { |
| 85 | + bestRankedDocResultPerDoc.put( |
| 86 | + docIndex, |
| 87 | + new RankedDocsResults.RankedDoc(docIndex, rankedDoc.relevanceScore(), inputs.get(docIndex)) |
| 88 | + ); |
| 89 | + } |
| 90 | + } |
| 91 | + var bestRankedDocResultPerDocList = new ArrayList<>(bestRankedDocResultPerDoc.values()); |
| 92 | + bestRankedDocResultPerDocList.sort( |
| 93 | + (RankedDocsResults.RankedDoc d1, RankedDocsResults.RankedDoc d2) -> Float.compare(d2.relevanceScore(), d1.relevanceScore()) |
| 94 | + ); |
| 95 | + return new RankedDocsResults(bestRankedDocResultPerDocList); |
| 96 | + } |
| 97 | + |
| 98 | + public record RerankChunks(int docIndex, String chunkString) {}; |
| 99 | +} |
0 commit comments