diff --git a/server/src/main/java/org/elasticsearch/index/store/AsyncDirectIOIndexInput.java b/server/src/main/java/org/elasticsearch/index/store/AsyncDirectIOIndexInput.java index 71f3335abef34..5af1ddaf5cc44 100644 --- a/server/src/main/java/org/elasticsearch/index/store/AsyncDirectIOIndexInput.java +++ b/server/src/main/java/org/elasticsearch/index/store/AsyncDirectIOIndexInput.java @@ -22,8 +22,6 @@ import com.carrotsearch.hppc.IntArrayDeque; import com.carrotsearch.hppc.IntDeque; -import com.carrotsearch.hppc.LongArrayDeque; -import com.carrotsearch.hppc.LongDeque; import org.apache.lucene.store.IndexInput; import org.elasticsearch.ExceptionsHelper; @@ -43,6 +41,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.TreeMap; import java.util.concurrent.ExecutionException; @@ -121,7 +120,7 @@ public AsyncDirectIOIndexInput(Path path, int blockSize, int bufferSize, int max super("DirectIOIndexInput(path=\"" + path + "\")"); this.channel = FileChannel.open(path, StandardOpenOption.READ, getDirectOpenOption()); this.blockSize = blockSize; - this.prefetcher = new DirectIOPrefetcher(blockSize, this.channel, bufferSize, maxPrefetches, maxPrefetches * 16); + this.prefetcher = new DirectIOPrefetcher(blockSize, this.channel, bufferSize, maxPrefetches); this.buffer = allocateBuffer(bufferSize, blockSize); this.isOpen = true; this.isClosable = true; @@ -139,13 +138,7 @@ private AsyncDirectIOIndexInput(String description, AsyncDirectIOIndexInput othe this.buffer = allocateBuffer(bufferSize, other.blockSize); this.blockSize = other.blockSize; this.channel = other.channel; - this.prefetcher = new DirectIOPrefetcher( - this.blockSize, - this.channel, - bufferSize, - other.prefetcher.maxConcurrentPrefetches, - other.prefetcher.maxTotalPrefetches - ); + this.prefetcher = new DirectIOPrefetcher(this.blockSize, this.channel, bufferSize, other.prefetcher.maxConcurrentPrefetches); this.isOpen = true; this.isClosable = false; this.length = length; @@ -170,11 +163,22 @@ public void prefetch(long pos, long length) throws IOException { if (pos < 0 || length < 0 || pos + length > this.length) { throw new IllegalArgumentException("Invalid prefetch range: pos=" + pos + ", length=" + length + ", fileLength=" + this.length); } - // check if our current buffer already contains the requested range + + // align to prefetch buffer final long absPos = pos + offset; - final long alignedPos = absPos - (absPos % blockSize); - // we only prefetch into a single buffer, even if length exceeds buffer size - // maybe we should improve this... + long alignedPos = absPos - absPos % blockSize; + + // check if our current buffer already contains the requested range + if (alignedPos >= filePos && alignedPos < filePos + buffer.capacity()) { + // The current buffer contains bytes of this request. + // Adjust the position and length accordingly to skip the current buffer. + alignedPos = filePos + buffer.capacity(); + length -= alignedPos - absPos; + } else { + // Add to the total length the bytes added by the alignment + length += absPos - alignedPos; + } + // do the prefetch prefetcher.prefetch(alignedPos, length); } @@ -396,12 +400,16 @@ public IndexInput slice(String sliceDescription, long offset, long length) throw return slice; } + // pkg private for testing + int prefetchSlots() { + return prefetcher.posToSlot.size(); + } + /** * A simple prefetcher that uses virtual threads to prefetch data into direct byte buffers. */ private static class DirectIOPrefetcher implements Closeable { private final int maxConcurrentPrefetches; - private final int maxTotalPrefetches; private final FileChannel channel; private final int blockSize; private final long[] prefetchPos; @@ -411,10 +419,9 @@ private static class DirectIOPrefetcher implements Closeable { private final IntDeque slots; private final ByteBuffer[] prefetchBuffers; private final int prefetchBytesSize; - private final LongDeque pendingPrefetches = new LongArrayDeque(); private final ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor(); - DirectIOPrefetcher(int blockSize, FileChannel channel, int prefetchBytesSize, int maxConcurrentPrefetches, int maxTotalPrefetches) { + DirectIOPrefetcher(int blockSize, FileChannel channel, int prefetchBytesSize, int maxConcurrentPrefetches) { this.blockSize = blockSize; this.channel = channel; this.maxConcurrentPrefetches = maxConcurrentPrefetches; @@ -428,48 +435,48 @@ private static class DirectIOPrefetcher implements Closeable { this.posToSlot = new TreeMap<>(); this.prefetchBuffers = new ByteBuffer[maxConcurrentPrefetches]; this.prefetchBytesSize = prefetchBytesSize; - this.maxTotalPrefetches = maxTotalPrefetches; } /** * Initiate prefetch of the given range. The range will be aligned to blockSize and * chopped up into chunks of prefetchBytesSize. - * If there are not enough slots available, the prefetch request will be queued - * until a slot becomes available. This throttling may occur if the number of - * concurrent prefetches is exceeded, or if there is significant IO pressure. + * If there are not enough slots available, the prefetch request will reuse the slot + * with the lowest file pointer. If that slot is still being prefetched, the prefetch request + * will be skipped. * @param pos the position to prefetch from, must be non-negative and within file length * @param length the length to prefetch, must be non-negative. */ void prefetch(long pos, long length) { + assert pos % blockSize == 0 : "prefetch pos [" + pos + "] must be aligned to block size [" + blockSize + "]"; // first determine how many slots we need given the length - int numSlots = (int) Math.min( - (length + prefetchBytesSize - 1) / prefetchBytesSize, - maxTotalPrefetches - (this.posToSlot.size() + this.pendingPrefetches.size()) - ); - while (numSlots > 0 && (this.posToSlot.size() + this.pendingPrefetches.size()) < maxTotalPrefetches) { - final int slot; - Integer existingSlot = this.posToSlot.get(pos); - if (existingSlot != null && prefetchThreads.get(existingSlot) != null) { - // already being prefetched and hasn't been consumed. - // return early - return; - } - if (this.posToSlot.size() < maxConcurrentPrefetches && slots.isEmpty() == false) { - slot = slots.removeFirst(); + while (length > 0) { + Map.Entry floor = this.posToSlot.floorEntry(pos); + if (floor == null || floor.getKey() + prefetchBytesSize <= pos) { + // check if there are any slots available. If not we will reuse the one with the + // lower file pointer. + if (slots.isEmpty()) { + assert this.posToSlot.size() == maxConcurrentPrefetches; + final int oldestSlot = posToSlot.firstEntry().getValue(); + if (prefetchThreads.get(oldestSlot).isDone() == false) { + // cannot reuse oldest slot. We are over-prefetching + LOGGER.debug("could not prefetch pos [{}] with length [{}]", pos, length); + return; + } + LOGGER.debug("prefetch on reused slot with pos [{}] with length [{}]", pos, length); + clearSlot(oldestSlot); + assert slots.isEmpty() == false; + } + final int slot = slots.removeFirst(); posToSlot.put(pos, slot); prefetchPos[slot] = pos; - } else { - slot = -1; - LOGGER.debug("queueing prefetch of pos [{}] with length [{}], waiting for open slot", pos, length); - pendingPrefetches.addLast(pos); - } - if (slot != -1) { startPrefetch(pos, slot); + length -= prefetchBytesSize; + pos += prefetchBytesSize; + } else { + length -= floor.getKey() + prefetchBytesSize - pos; + pos = floor.getKey() + prefetchBytesSize; } - pos += prefetchBytesSize; - numSlots--; } - } /** @@ -506,7 +513,7 @@ boolean readBytes(long pos, ByteBuffer slice, int delta) throws IOException { Thread.currentThread().interrupt(); } finally { if (prefetchBuffer == null) { - clearSlotAndMaybeStartPending(slot); + clearSlot(slot); } } if (prefetchBuffer == null) { @@ -518,22 +525,15 @@ boolean readBytes(long pos, ByteBuffer slice, int delta) throws IOException { slice.put(prefetchBuffer); slice.flip(); slice.position(Math.toIntExact(pos - prefetchedPos) + delta); - clearSlotAndMaybeStartPending(slot); + clearSlot(slot); return true; } - void clearSlotAndMaybeStartPending(int slot) { - assert prefetchThreads.get(slot) != null && prefetchThreads.get(slot).isDone(); + void clearSlot(int slot) { + assert prefetchThreads.get(slot) != null; prefetchThreads.set(slot, null); posToSlot.remove(prefetchPos[slot]); - if (pendingPrefetches.isEmpty()) { - slots.addLast(slot); - return; - } - final long req = pendingPrefetches.removeFirst(); - posToSlot.put(req, slot); - prefetchPos[slot] = req; - startPrefetch(req, slot); + slots.addLast(slot); } private boolean assertSlotsConsistent() { diff --git a/server/src/test/java/org/elasticsearch/index/store/AsyncDirectIOIndexInputTests.java b/server/src/test/java/org/elasticsearch/index/store/AsyncDirectIOIndexInputTests.java index a9beebf16e37e..8189d834f727b 100644 --- a/server/src/test/java/org/elasticsearch/index/store/AsyncDirectIOIndexInputTests.java +++ b/server/src/test/java/org/elasticsearch/index/store/AsyncDirectIOIndexInputTests.java @@ -12,15 +12,22 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.NIOFSDirectory; +import org.apache.lucene.util.hnsw.IntToIntFunction; import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.test.ESTestCase; import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.nio.file.Files; import java.nio.file.Path; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; import java.util.List; +import static org.hamcrest.Matchers.lessThanOrEqualTo; + public class AsyncDirectIOIndexInputTests extends ESTestCase { @SuppressForbidden(reason = "requires Files.getFileStore") @@ -169,4 +176,96 @@ public void testWriteThenReadBytesConsistency() throws IOException { } } + public void testPrefetchGetsCleanUp() throws IOException { + int numVectors = randomIntBetween(100, 1000); + int numDimensions = randomIntBetween(100, 2048); + Path path = createTempDir("testDirectIODirectory"); + byte[] bytes = new byte[numDimensions * Float.BYTES]; + ByteBuffer buffer = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN); + float[][] vectors = new float[numVectors][numDimensions]; + try (Directory dir = new NIOFSDirectory(path)) { + try (var output = dir.createOutput("test", org.apache.lucene.store.IOContext.DEFAULT)) { + for (int i = 0; i < numVectors; i++) { + random().nextBytes(bytes); + output.writeBytes(bytes, bytes.length); + buffer.asFloatBuffer().get(vectors[i]); + } + } + + final int blockSize = getBlockSize(path); + final int bufferSize = 8192; + // fetch all + try (AsyncDirectIOIndexInput actualInput = new AsyncDirectIOIndexInput(path.resolve("test"), blockSize, bufferSize, 64)) { + assertPrefetchSlots(actualInput, numDimensions, numVectors, i -> i, vectors, bufferSize); + } + // fetch all in slice + try (AsyncDirectIOIndexInput actualInput = new AsyncDirectIOIndexInput(path.resolve("test"), blockSize, bufferSize, 64)) { + int start = randomIntBetween(0, numVectors - 1); + float[][] vectorsSlice = Arrays.copyOfRange(vectors, start, numVectors); + long sliceStart = (long) start * bytes.length; + assertPrefetchSlots( + (AsyncDirectIOIndexInput) actualInput.slice("slice", sliceStart, actualInput.length() - sliceStart), + numDimensions, + vectorsSlice.length, + i -> i, + vectorsSlice, + bufferSize + ); + } + // random fetch + List tempList = new ArrayList<>(numVectors); + for (int i = 0; i < numVectors; i++) { + tempList.add(i); + } + Collections.shuffle(tempList, random()); + List subList = tempList.subList(0, randomIntBetween(1, numVectors)); + Collections.sort(subList); + try (AsyncDirectIOIndexInput actualInput = new AsyncDirectIOIndexInput(path.resolve("test"), blockSize, bufferSize, 64)) { + assertPrefetchSlots(actualInput, numDimensions, subList.size(), subList::get, vectors, bufferSize); + } + } + } + + private static void assertPrefetchSlots( + AsyncDirectIOIndexInput actualInput, + int numDimensions, + int numVectors, + IntToIntFunction ords, + float[][] vectors, + int bufferSize + ) throws IOException { + int prefetchSize = randomIntBetween(1, 64); + float[] floats = new float[numDimensions]; + long bytesLength = (long) numDimensions * Float.BYTES; + int limit = numVectors - prefetchSize + 1; + int i = 0; + for (; i < limit; i += prefetchSize) { + int ord = ords.apply(i); + for (int j = 0; j < prefetchSize; j++) { + actualInput.prefetch((ord + j) * bytesLength, bytesLength); + } + // check we prefetch enough data. We need to add 1 because of the current buffer. + assertThat(prefetchSize * bytesLength, lessThanOrEqualTo((long) (1 + actualInput.prefetchSlots()) * bufferSize)); + for (int j = 0; j < prefetchSize; j++) { + actualInput.seek((ord + j) * bytesLength); + actualInput.readFloats(floats, 0, floats.length); + assertArrayEquals(vectors[ord + j], floats, 0.0f); + } + // check we have freed all the slots + assertEquals(0, actualInput.prefetchSlots()); + } + for (int k = i; k < numVectors; k++) { + actualInput.prefetch(ords.apply(k) * bytesLength, bytesLength); + } + // check we prefetch enough data. We need to add 1 because of the current buffer. + assertThat((numVectors - i) * bytesLength, lessThanOrEqualTo((long) (1 + actualInput.prefetchSlots()) * bufferSize)); + for (; i < numVectors; i++) { + int ord = ords.apply(i); + actualInput.seek(ord * bytesLength); + actualInput.readFloats(floats, 0, floats.length); + assertArrayEquals(vectors[ord], floats, 0.0f); + } + // check we have freed all the slots + assertEquals(0, actualInput.prefetchSlots()); + } }