|
| 1 | +/* |
| 2 | + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one |
| 3 | + * or more contributor license agreements. Licensed under the Elastic License |
| 4 | + * 2.0; you may not use this file except in compliance with the Elastic License |
| 5 | + * 2.0. |
| 6 | + */ |
| 7 | + |
| 8 | +package org.elasticsearch.compute.lucene; |
| 9 | + |
| 10 | +import org.apache.lucene.index.LeafReaderContext; |
| 11 | +import org.apache.lucene.search.BulkScorer; |
| 12 | +import org.apache.lucene.search.IndexSearcher; |
| 13 | +import org.apache.lucene.search.LeafCollector; |
| 14 | +import org.apache.lucene.search.Query; |
| 15 | +import org.apache.lucene.search.Scorable; |
| 16 | +import org.apache.lucene.search.ScoreMode; |
| 17 | +import org.apache.lucene.search.Scorer; |
| 18 | +import org.apache.lucene.search.Weight; |
| 19 | +import org.apache.lucene.util.ArrayUtil; |
| 20 | +import org.apache.lucene.util.Bits; |
| 21 | +import org.elasticsearch.compute.data.Block; |
| 22 | +import org.elasticsearch.compute.data.BlockFactory; |
| 23 | +import org.elasticsearch.compute.data.BooleanVector; |
| 24 | +import org.elasticsearch.compute.data.DocBlock; |
| 25 | +import org.elasticsearch.compute.data.DocVector; |
| 26 | +import org.elasticsearch.compute.data.DoubleVector; |
| 27 | +import org.elasticsearch.compute.data.IntVector; |
| 28 | +import org.elasticsearch.compute.data.Page; |
| 29 | +import org.elasticsearch.compute.operator.DriverContext; |
| 30 | +import org.elasticsearch.compute.operator.EvalOperator; |
| 31 | +import org.elasticsearch.core.Releasable; |
| 32 | +import org.elasticsearch.core.Releasables; |
| 33 | + |
| 34 | +import java.io.IOException; |
| 35 | +import java.io.UncheckedIOException; |
| 36 | + |
| 37 | +/** |
| 38 | + * {@link EvalOperator.ExpressionEvaluator} to run a Lucene {@link Query} during |
| 39 | + * the compute engine's normal execution, yielding matches/does not match into |
| 40 | + * a {@link BooleanVector}. It's much faster to push these to the |
| 41 | + * {@link LuceneSourceOperator} or the like, but sometimes this isn't possible. So |
| 42 | + * this evaluator is here to save the day. |
| 43 | + */ |
| 44 | +public class LuceneQueryScoreEvaluator implements EvalOperator.ExpressionEvaluator { |
| 45 | + |
| 46 | + public static final double NO_MATCH_SCORE = 0.0; |
| 47 | + |
| 48 | + public record ShardConfig(Query query, IndexSearcher searcher) {} |
| 49 | + |
| 50 | + private final BlockFactory blockFactory; |
| 51 | + private final ShardConfig[] shards; |
| 52 | + |
| 53 | + private ShardState[] perShardState = EMPTY_SHARD_STATES; |
| 54 | + |
| 55 | + public LuceneQueryScoreEvaluator(BlockFactory blockFactory, ShardConfig[] shards) { |
| 56 | + this.blockFactory = blockFactory; |
| 57 | + this.shards = shards; |
| 58 | + } |
| 59 | + |
| 60 | + @Override |
| 61 | + public Block eval(Page page) { |
| 62 | + // Lucene based operators retrieve DocVectors as first block |
| 63 | + Block block = page.getBlock(0); |
| 64 | + assert block instanceof DocBlock : "LuceneQueryExpressionEvaluator expects DocBlock as input"; |
| 65 | + DocVector docs = (DocVector) block.asVector(); |
| 66 | + try { |
| 67 | + if (docs.singleSegmentNonDecreasing()) { |
| 68 | + return evalSingleSegmentNonDecreasing(docs).asBlock(); |
| 69 | + } else { |
| 70 | + return evalSlow(docs).asBlock(); |
| 71 | + } |
| 72 | + } catch (IOException e) { |
| 73 | + throw new UncheckedIOException(e); |
| 74 | + } |
| 75 | + } |
| 76 | + |
| 77 | + /** |
| 78 | + * Evaluate {@link DocVector#singleSegmentNonDecreasing()} documents. |
| 79 | + * <p> |
| 80 | + * ESQL receives documents in DocVector, and they can be in one of two |
| 81 | + * states. Either the DocVector contains documents from a single segment |
| 82 | + * non-decreasing order, or it doesn't. that first case is much more like |
| 83 | + * how Lucene likes to process documents. and it's much more common. So we |
| 84 | + * optimize for it. |
| 85 | + * <p> |
| 86 | + * Vectors that are {@link DocVector#singleSegmentNonDecreasing()} |
| 87 | + * represent many documents from a single Lucene segment. In Elasticsearch |
| 88 | + * terms that's a segment in a single shard. And the document ids are in |
| 89 | + * non-decreasing order. Probably just {@code 0, 1, 2, 3, 4, 5...}. |
| 90 | + * But maybe something like {@code 0, 5, 6, 10, 10, 10}. Both of those are |
| 91 | + * very like how lucene "natively" processes documents and this optimizes |
| 92 | + * those accesses. |
| 93 | + * </p> |
| 94 | + * <p> |
| 95 | + * If the documents are literally {@code 0, 1, ... n} then we use |
| 96 | + * {@link BulkScorer#score(LeafCollector, Bits, int, int)} which processes |
| 97 | + * a whole range. This should be quite common in the case where we don't |
| 98 | + * have deleted documents because that's the order that |
| 99 | + * {@link LuceneSourceOperator} produces them. |
| 100 | + * </p> |
| 101 | + * <p> |
| 102 | + * If there are gaps in the sequence we use {@link Scorer} calls to |
| 103 | + * score the sequence. This'll be less fast but isn't going be particularly |
| 104 | + * common. |
| 105 | + * </p> |
| 106 | + */ |
| 107 | + private DoubleVector evalSingleSegmentNonDecreasing(DocVector docs) throws IOException { |
| 108 | + ShardState shardState = shardState(docs.shards().getInt(0)); |
| 109 | + SegmentState segmentState = shardState.segmentState(docs.segments().getInt(0)); |
| 110 | + int min = docs.docs().getInt(0); |
| 111 | + int max = docs.docs().getInt(docs.getPositionCount() - 1); |
| 112 | + int length = max - min + 1; |
| 113 | + if (length == docs.getPositionCount() && length > 1) { |
| 114 | + return segmentState.scoreDense(min, max); |
| 115 | + } |
| 116 | + return segmentState.scoreSparse(docs.docs()); |
| 117 | + } |
| 118 | + |
| 119 | + /** |
| 120 | + * Evaluate non-{@link DocVector#singleSegmentNonDecreasing()} documents. See |
| 121 | + * {@link #evalSingleSegmentNonDecreasing} for the meaning of |
| 122 | + * {@link DocVector#singleSegmentNonDecreasing()} and how we can efficiently |
| 123 | + * evaluate those segments. |
| 124 | + * <p> |
| 125 | + * This processes the worst case blocks of documents. These can be from any |
| 126 | + * number of shards and any number of segments and in any order. We do this |
| 127 | + * by iterating the docs in {@code shard ASC, segment ASC, doc ASC} order. |
| 128 | + * So, that's segment by segment, docs ascending. We build a boolean block |
| 129 | + * out of that. Then we <strong>sort</strong> that to put the booleans in |
| 130 | + * the order that the {@link DocVector} came in. |
| 131 | + * </p> |
| 132 | + */ |
| 133 | + private DoubleVector evalSlow(DocVector docs) throws IOException { |
| 134 | + int[] map = docs.shardSegmentDocMapForwards(); |
| 135 | + // Clear any state flags from the previous run |
| 136 | + int prevShard = -1; |
| 137 | + int prevSegment = -1; |
| 138 | + SegmentState segmentState = null; |
| 139 | + try (DoubleVector.Builder builder = blockFactory.newDoubleVectorFixedBuilder(docs.getPositionCount())) { |
| 140 | + for (int i = 0; i < docs.getPositionCount(); i++) { |
| 141 | + int shard = docs.shards().getInt(docs.shards().getInt(map[i])); |
| 142 | + int segment = docs.segments().getInt(map[i]); |
| 143 | + if (segmentState == null || prevShard != shard || prevSegment != segment) { |
| 144 | + segmentState = shardState(shard).segmentState(segment); |
| 145 | + segmentState.initScorer(docs.docs().getInt(map[i])); |
| 146 | + prevShard = shard; |
| 147 | + prevSegment = segment; |
| 148 | + } |
| 149 | + if (segmentState.noMatch) { |
| 150 | + builder.appendDouble(NO_MATCH_SCORE); |
| 151 | + } else { |
| 152 | + segmentState.scoreSingleDocWithScorer(builder, docs.docs().getInt(map[i])); |
| 153 | + } |
| 154 | + } |
| 155 | + try (DoubleVector outOfOrder = builder.build()) { |
| 156 | + return outOfOrder.filter(docs.shardSegmentDocMapBackwards()); |
| 157 | + } |
| 158 | + } |
| 159 | + } |
| 160 | + |
| 161 | + @Override |
| 162 | + public void close() { |
| 163 | + |
| 164 | + } |
| 165 | + |
| 166 | + private ShardState shardState(int shard) throws IOException { |
| 167 | + if (shard >= perShardState.length) { |
| 168 | + perShardState = ArrayUtil.grow(perShardState, shard + 1); |
| 169 | + } else if (perShardState[shard] != null) { |
| 170 | + return perShardState[shard]; |
| 171 | + } |
| 172 | + perShardState[shard] = new ShardState(shards[shard]); |
| 173 | + return perShardState[shard]; |
| 174 | + } |
| 175 | + |
| 176 | + private class ShardState { |
| 177 | + private final Weight weight; |
| 178 | + private final IndexSearcher searcher; |
| 179 | + private SegmentState[] perSegmentState = EMPTY_SEGMENT_STATES; |
| 180 | + |
| 181 | + ShardState(ShardConfig config) throws IOException { |
| 182 | + weight = config.searcher.createWeight(config.query, ScoreMode.COMPLETE, 1.0f); |
| 183 | + searcher = config.searcher; |
| 184 | + } |
| 185 | + |
| 186 | + SegmentState segmentState(int segment) throws IOException { |
| 187 | + if (segment >= perSegmentState.length) { |
| 188 | + perSegmentState = ArrayUtil.grow(perSegmentState, segment + 1); |
| 189 | + } else if (perSegmentState[segment] != null) { |
| 190 | + return perSegmentState[segment]; |
| 191 | + } |
| 192 | + perSegmentState[segment] = new SegmentState(weight, searcher.getLeafContexts().get(segment)); |
| 193 | + return perSegmentState[segment]; |
| 194 | + } |
| 195 | + } |
| 196 | + |
| 197 | + private class SegmentState { |
| 198 | + private final Weight weight; |
| 199 | + private final LeafReaderContext ctx; |
| 200 | + |
| 201 | + /** |
| 202 | + * Lazily initialed {@link Scorer} for this. {@code null} here means uninitialized |
| 203 | + * <strong>or</strong> that {@link #noMatch} is true. |
| 204 | + */ |
| 205 | + private Scorer scorer; |
| 206 | + |
| 207 | + /** |
| 208 | + * Thread that initialized the {@link #scorer}. |
| 209 | + */ |
| 210 | + private Thread scorerThread; |
| 211 | + |
| 212 | + /** |
| 213 | + * Lazily initialed {@link BulkScorer} for this. {@code null} here means uninitialized |
| 214 | + * <strong>or</strong> that {@link #noMatch} is true. |
| 215 | + */ |
| 216 | + private BulkScorer bulkScorer; |
| 217 | + |
| 218 | + /** |
| 219 | + * Thread that initialized the {@link #bulkScorer}. |
| 220 | + */ |
| 221 | + private Thread bulkScorerThread; |
| 222 | + |
| 223 | + /** |
| 224 | + * Set to {@code true} if, in the process of building a {@link Scorer} or {@link BulkScorer}, |
| 225 | + * the {@link Weight} tells us there aren't any matches. |
| 226 | + */ |
| 227 | + private boolean noMatch; |
| 228 | + |
| 229 | + private SegmentState(Weight weight, LeafReaderContext ctx) { |
| 230 | + this.weight = weight; |
| 231 | + this.ctx = ctx; |
| 232 | + } |
| 233 | + |
| 234 | + /** |
| 235 | + * Score a range using the {@link BulkScorer}. This should be faster |
| 236 | + * than using {@link #scoreSparse} for dense doc ids. |
| 237 | + */ |
| 238 | + DoubleVector scoreDense(int min, int max) throws IOException { |
| 239 | + int length = max - min + 1; |
| 240 | + if (noMatch) { |
| 241 | + return blockFactory.newConstantDoubleVector(NO_MATCH_SCORE, length); |
| 242 | + } |
| 243 | + if (bulkScorer == null || // The bulkScorer wasn't initialized |
| 244 | + Thread.currentThread() != bulkScorerThread // The bulkScorer was initialized on a different thread |
| 245 | + ) { |
| 246 | + bulkScorerThread = Thread.currentThread(); |
| 247 | + bulkScorer = weight.bulkScorer(ctx); |
| 248 | + if (bulkScorer == null) { |
| 249 | + noMatch = true; |
| 250 | + return blockFactory.newConstantDoubleVector(NO_MATCH_SCORE, length); |
| 251 | + } |
| 252 | + } |
| 253 | + try (DenseCollector collector = new DenseCollector(blockFactory, min, max)) { |
| 254 | + bulkScorer.score(collector, ctx.reader().getLiveDocs(), min, max + 1); |
| 255 | + return collector.build(); |
| 256 | + } |
| 257 | + } |
| 258 | + |
| 259 | + /** |
| 260 | + * Score a vector of doc ids using {@link Scorer}. If you have a dense range of |
| 261 | + * doc ids it'd be faster to use {@link #scoreDense}. |
| 262 | + */ |
| 263 | + DoubleVector scoreSparse(IntVector docs) throws IOException { |
| 264 | + initScorer(docs.getInt(0)); |
| 265 | + if (noMatch) { |
| 266 | + return blockFactory.newConstantDoubleVector(NO_MATCH_SCORE, docs.getPositionCount()); |
| 267 | + } |
| 268 | + try (DoubleVector.Builder builder = blockFactory.newDoubleVectorFixedBuilder(docs.getPositionCount())) { |
| 269 | + for (int i = 0; i < docs.getPositionCount(); i++) { |
| 270 | + scoreSingleDocWithScorer(builder, docs.getInt(i)); |
| 271 | + } |
| 272 | + return builder.build(); |
| 273 | + } |
| 274 | + } |
| 275 | + |
| 276 | + private void initScorer(int minDocId) throws IOException { |
| 277 | + if (noMatch) { |
| 278 | + return; |
| 279 | + } |
| 280 | + if (scorer == null || // Scorer not initialized |
| 281 | + scorerThread != Thread.currentThread() || // Scorer initialized on a different thread |
| 282 | + scorer.iterator().docID() > minDocId // The previous block came "after" this one |
| 283 | + ) { |
| 284 | + scorerThread = Thread.currentThread(); |
| 285 | + scorer = weight.scorer(ctx); |
| 286 | + if (scorer == null) { |
| 287 | + noMatch = true; |
| 288 | + } |
| 289 | + } |
| 290 | + } |
| 291 | + |
| 292 | + private void scoreSingleDocWithScorer(DoubleVector.Builder builder, int doc) throws IOException { |
| 293 | + if (scorer.iterator().docID() == doc) { |
| 294 | + builder.appendDouble(scorer.score()); |
| 295 | + } else if (scorer.iterator().docID() > doc) { |
| 296 | + builder.appendDouble(NO_MATCH_SCORE); |
| 297 | + } else { |
| 298 | + builder.appendDouble(scorer.iterator().advance(doc) == doc ? scorer.score() : NO_MATCH_SCORE); |
| 299 | + } |
| 300 | + } |
| 301 | + } |
| 302 | + |
| 303 | + private static final ShardState[] EMPTY_SHARD_STATES = new ShardState[0]; |
| 304 | + private static final SegmentState[] EMPTY_SEGMENT_STATES = new SegmentState[0]; |
| 305 | + |
| 306 | + /** |
| 307 | + * Collects matching information for dense range of doc ids. This assumes that |
| 308 | + * doc ids are sent to {@link LeafCollector#collect(int)} in ascending order |
| 309 | + * which isn't documented, but @jpountz swears is true. |
| 310 | + */ |
| 311 | + static class DenseCollector implements LeafCollector, Releasable { |
| 312 | + private final DoubleVector.FixedBuilder builder; |
| 313 | + private final int max; |
| 314 | + private Scorable scorer; |
| 315 | + |
| 316 | + int next; |
| 317 | + |
| 318 | + DenseCollector(BlockFactory blockFactory, int min, int max) { |
| 319 | + this.builder = blockFactory.newDoubleVectorFixedBuilder(max - min + 1); |
| 320 | + this.max = max; |
| 321 | + next = min; |
| 322 | + } |
| 323 | + |
| 324 | + @Override |
| 325 | + public void setScorer(Scorable scorable) { |
| 326 | + this.scorer = scorable; |
| 327 | + } |
| 328 | + |
| 329 | + @Override |
| 330 | + public void collect(int doc) throws IOException { |
| 331 | + while (next++ < doc) { |
| 332 | + builder.appendDouble(NO_MATCH_SCORE); |
| 333 | + } |
| 334 | + builder.appendDouble(scorer.score()); |
| 335 | + } |
| 336 | + |
| 337 | + public DoubleVector build() { |
| 338 | + return builder.build(); |
| 339 | + } |
| 340 | + |
| 341 | + @Override |
| 342 | + public void finish() { |
| 343 | + while (next++ <= max) { |
| 344 | + builder.appendDouble(NO_MATCH_SCORE); |
| 345 | + } |
| 346 | + } |
| 347 | + |
| 348 | + @Override |
| 349 | + public void close() { |
| 350 | + Releasables.closeExpectNoException(builder); |
| 351 | + } |
| 352 | + } |
| 353 | + |
| 354 | + public static class Factory implements EvalOperator.ExpressionEvaluator.Factory { |
| 355 | + private final ShardConfig[] shardConfigs; |
| 356 | + |
| 357 | + public Factory(ShardConfig[] shardConfigs) { |
| 358 | + this.shardConfigs = shardConfigs; |
| 359 | + } |
| 360 | + |
| 361 | + @Override |
| 362 | + public EvalOperator.ExpressionEvaluator get(DriverContext context) { |
| 363 | + return new LuceneQueryScoreEvaluator(context.blockFactory(), shardConfigs); |
| 364 | + } |
| 365 | + } |
| 366 | +} |
0 commit comments