|
16 | 16 | */ |
17 | 17 | package org.apache.lucene.search; |
18 | 18 |
|
19 | | -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; |
20 | | - |
21 | 19 | import java.io.IOException; |
22 | 20 | import java.util.ArrayList; |
23 | | -import java.util.Arrays; |
24 | | -import java.util.Comparator; |
25 | 21 | import java.util.HashMap; |
26 | 22 | import java.util.Iterator; |
27 | 23 | import java.util.List; |
@@ -142,7 +138,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { |
142 | 138 | if (topK.scoreDocs.length == 0) { |
143 | 139 | return new MatchNoDocsQuery(); |
144 | 140 | } |
145 | | - return createRewrittenQuery(reader, topK, reentryCount); |
| 141 | + return DocAndScoreQuery.createDocAndScoreQuery(reader, topK, reentryCount); |
146 | 142 | } |
147 | 143 |
|
148 | 144 | private TopDocs runSearchTasks( |
@@ -398,46 +394,6 @@ public KnnCollector newCollector( |
398 | 394 | } |
399 | 395 | } |
400 | 396 |
|
401 | | - protected Query createRewrittenQuery(IndexReader reader, TopDocs topK, int reentryCount) { |
402 | | - int len = topK.scoreDocs.length; |
403 | | - assert len > 0; |
404 | | - float maxScore = topK.scoreDocs[0].score; |
405 | | - Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc)); |
406 | | - int[] docs = new int[len]; |
407 | | - float[] scores = new float[len]; |
408 | | - for (int i = 0; i < len; i++) { |
409 | | - docs[i] = topK.scoreDocs[i].doc; |
410 | | - scores[i] = topK.scoreDocs[i].score; |
411 | | - } |
412 | | - int[] segmentStarts = findSegmentStarts(reader.leaves(), docs); |
413 | | - return new DocAndScoreQuery( |
414 | | - docs, |
415 | | - scores, |
416 | | - maxScore, |
417 | | - segmentStarts, |
418 | | - topK.totalHits.value(), |
419 | | - reader.getContext().id(), |
420 | | - reentryCount); |
421 | | - } |
422 | | - |
423 | | - static int[] findSegmentStarts(List<LeafReaderContext> leaves, int[] docs) { |
424 | | - int[] starts = new int[leaves.size() + 1]; |
425 | | - starts[starts.length - 1] = docs.length; |
426 | | - if (starts.length == 2) { |
427 | | - return starts; |
428 | | - } |
429 | | - int resultIndex = 0; |
430 | | - for (int i = 1; i < starts.length - 1; i++) { |
431 | | - int upper = leaves.get(i).docBase; |
432 | | - resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper); |
433 | | - if (resultIndex < 0) { |
434 | | - resultIndex = -1 - resultIndex; |
435 | | - } |
436 | | - starts[i] = resultIndex; |
437 | | - } |
438 | | - return starts; |
439 | | - } |
440 | | - |
441 | 397 | @Override |
442 | 398 | public void visit(QueryVisitor visitor) { |
443 | 399 | if (visitor.acceptField(field)) { |
@@ -483,199 +439,6 @@ public Query getFilter() { |
483 | 439 | return filter; |
484 | 440 | } |
485 | 441 |
|
486 | | - /** Caches the results of a KnnVector search: a list of docs and their scores */ |
487 | | - static class DocAndScoreQuery extends Query { |
488 | | - |
489 | | - private final int[] docs; |
490 | | - private final float[] scores; |
491 | | - private final float maxScore; |
492 | | - private final int[] segmentStarts; |
493 | | - private final long visited; |
494 | | - private final Object contextIdentity; |
495 | | - private final int reentryCount; |
496 | | - |
497 | | - /** |
498 | | - * Constructor |
499 | | - * |
500 | | - * @param docs the global docids of documents that match, in ascending order |
501 | | - * @param scores the scores of the matching documents |
502 | | - * @param maxScore the max of those scores? why do we need to pass in? |
503 | | - * @param segmentStarts the indexes in docs and scores corresponding to the first matching |
504 | | - * document in each segment. If a segment has no matching documents, it should be assigned |
505 | | - * the index of the next segment that does. There should be a final entry that is always |
506 | | - * docs.length-1. |
507 | | - * @param visited the number of graph nodes that were visited, and for which vector distance |
508 | | - * scores were evaluated. |
509 | | - * @param contextIdentity an object identifying the reader context that was used to build this |
510 | | - * query |
511 | | - */ |
512 | | - DocAndScoreQuery( |
513 | | - int[] docs, |
514 | | - float[] scores, |
515 | | - float maxScore, |
516 | | - int[] segmentStarts, |
517 | | - long visited, |
518 | | - Object contextIdentity, |
519 | | - int reentryCount) { |
520 | | - this.docs = docs; |
521 | | - this.scores = scores; |
522 | | - this.maxScore = maxScore; |
523 | | - this.segmentStarts = segmentStarts; |
524 | | - this.visited = visited; |
525 | | - this.contextIdentity = contextIdentity; |
526 | | - this.reentryCount = reentryCount; |
527 | | - } |
528 | | - |
529 | | - /* |
530 | | - DocAndScoreQuery(DocAndScoreQuery other) { |
531 | | - this.docs = other.docs; |
532 | | - this.scores = other.scores; |
533 | | - this.maxScore = other.maxScore; |
534 | | - this.segmentStarts = other.segmentStarts; |
535 | | - this.visited = other.visited; |
536 | | - this.contextIdentity = other.contextIdentity; |
537 | | - this.reentryCount = other.reentryCount; |
538 | | - } |
539 | | - */ |
540 | | - |
541 | | - int reentryCount() { |
542 | | - return reentryCount; |
543 | | - } |
544 | | - |
545 | | - @Override |
546 | | - public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) |
547 | | - throws IOException { |
548 | | - if (searcher.getIndexReader().getContext().id() != contextIdentity) { |
549 | | - throw new IllegalStateException("This DocAndScore query was created by a different reader"); |
550 | | - } |
551 | | - return new Weight(this) { |
552 | | - @Override |
553 | | - public Explanation explain(LeafReaderContext context, int doc) { |
554 | | - int found = Arrays.binarySearch(docs, doc + context.docBase); |
555 | | - if (found < 0) { |
556 | | - return Explanation.noMatch("not in top " + docs.length + " docs"); |
557 | | - } |
558 | | - return Explanation.match(scores[found] * boost, "within top " + docs.length + " docs"); |
559 | | - } |
560 | | - |
561 | | - @Override |
562 | | - public int count(LeafReaderContext context) { |
563 | | - return segmentStarts[context.ord + 1] - segmentStarts[context.ord]; |
564 | | - } |
565 | | - |
566 | | - @Override |
567 | | - public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { |
568 | | - if (segmentStarts[context.ord] == segmentStarts[context.ord + 1]) { |
569 | | - return null; |
570 | | - } |
571 | | - final var scorer = |
572 | | - new Scorer() { |
573 | | - final int lower = segmentStarts[context.ord]; |
574 | | - final int upper = segmentStarts[context.ord + 1]; |
575 | | - int upTo = -1; |
576 | | - |
577 | | - @Override |
578 | | - public DocIdSetIterator iterator() { |
579 | | - return new DocIdSetIterator() { |
580 | | - @Override |
581 | | - public int docID() { |
582 | | - return docIdNoShadow(); |
583 | | - } |
584 | | - |
585 | | - @Override |
586 | | - public int nextDoc() { |
587 | | - if (upTo == -1) { |
588 | | - upTo = lower; |
589 | | - } else { |
590 | | - ++upTo; |
591 | | - } |
592 | | - return docIdNoShadow(); |
593 | | - } |
594 | | - |
595 | | - @Override |
596 | | - public int advance(int target) throws IOException { |
597 | | - return slowAdvance(target); |
598 | | - } |
599 | | - |
600 | | - @Override |
601 | | - public long cost() { |
602 | | - return upper - lower; |
603 | | - } |
604 | | - }; |
605 | | - } |
606 | | - |
607 | | - @Override |
608 | | - public float getMaxScore(int docId) { |
609 | | - return maxScore * boost; |
610 | | - } |
611 | | - |
612 | | - @Override |
613 | | - public float score() { |
614 | | - return scores[upTo] * boost; |
615 | | - } |
616 | | - |
617 | | - /** |
618 | | - * move the implementation of docID() into a differently-named method so we can call |
619 | | - * it from DocIDSetIterator.docID() even though this class is anonymous |
620 | | - * |
621 | | - * @return the current docid |
622 | | - */ |
623 | | - private int docIdNoShadow() { |
624 | | - if (upTo == -1) { |
625 | | - return -1; |
626 | | - } |
627 | | - if (upTo >= upper) { |
628 | | - return NO_MORE_DOCS; |
629 | | - } |
630 | | - return docs[upTo] - context.docBase; |
631 | | - } |
632 | | - |
633 | | - @Override |
634 | | - public int docID() { |
635 | | - return docIdNoShadow(); |
636 | | - } |
637 | | - }; |
638 | | - return new DefaultScorerSupplier(scorer); |
639 | | - } |
640 | | - |
641 | | - @Override |
642 | | - public boolean isCacheable(LeafReaderContext ctx) { |
643 | | - return true; |
644 | | - } |
645 | | - }; |
646 | | - } |
647 | | - |
648 | | - @Override |
649 | | - public String toString(String field) { |
650 | | - return "DocAndScoreQuery[" + docs[0] + ",...][" + scores[0] + ",...]," + maxScore; |
651 | | - } |
652 | | - |
653 | | - @Override |
654 | | - public void visit(QueryVisitor visitor) { |
655 | | - visitor.visitLeaf(this); |
656 | | - } |
657 | | - |
658 | | - public long visited() { |
659 | | - return visited; |
660 | | - } |
661 | | - |
662 | | - @Override |
663 | | - public boolean equals(Object obj) { |
664 | | - if (sameClassAs(obj) == false) { |
665 | | - return false; |
666 | | - } |
667 | | - return contextIdentity == ((DocAndScoreQuery) obj).contextIdentity |
668 | | - && Arrays.equals(docs, ((DocAndScoreQuery) obj).docs) |
669 | | - && Arrays.equals(scores, ((DocAndScoreQuery) obj).scores); |
670 | | - } |
671 | | - |
672 | | - @Override |
673 | | - public int hashCode() { |
674 | | - return Objects.hash( |
675 | | - classHash(), contextIdentity, Arrays.hashCode(docs), Arrays.hashCode(scores)); |
676 | | - } |
677 | | - } |
678 | | - |
679 | 442 | public KnnSearchStrategy getSearchStrategy() { |
680 | 443 | return searchStrategy; |
681 | 444 | } |
|
0 commit comments