diff --git a/server/src/main/java/org/elasticsearch/search/internal/ExitableDirectoryReader.java b/server/src/main/java/org/elasticsearch/search/internal/ExitableDirectoryReader.java index 9c998eb920dc9..fa50317353556 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/ExitableDirectoryReader.java +++ b/server/src/main/java/org/elasticsearch/search/internal/ExitableDirectoryReader.java @@ -586,8 +586,9 @@ public VectorScorer scorer(byte[] bytes) throws IOException { if (scorer == null) { return null; } + DocIdSetIterator scorerIterator = scorer.iterator(); return new VectorScorer() { - private final DocIdSetIterator iterator = new ExitableDocSetIterator(scorer.iterator(), queryCancellation); + private final DocIdSetIterator iterator = exitableIterator(scorerIterator, queryCancellation); @Override public float score() throws IOException { @@ -637,8 +638,9 @@ public VectorScorer scorer(float[] target) throws IOException { if (scorer == null) { return null; } + DocIdSetIterator scorerIterator = scorer.iterator(); return new VectorScorer() { - private final DocIdSetIterator iterator = new ExitableDocSetIterator(scorer.iterator(), queryCancellation); + private final DocIdSetIterator iterator = exitableIterator(scorerIterator, queryCancellation); @Override public float score() throws IOException { @@ -663,6 +665,15 @@ public FloatVectorValues copy() throws IOException { } } + /** Wraps the iterator in an exitable iterator, specializing for KnnVectorValues.DocIndexIterator. */ + static DocIdSetIterator exitableIterator(DocIdSetIterator iterator, QueryCancellation queryCancellation) { + if (iterator instanceof KnnVectorValues.DocIndexIterator docIndexIterator) { + return createExitableIterator(docIndexIterator, queryCancellation); + } else { + return new ExitableDocSetIterator(iterator, queryCancellation); + } + } + private static KnnVectorValues.DocIndexIterator createExitableIterator( KnnVectorValues.DocIndexIterator delegate, QueryCancellation queryCancellation