Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@

import com.carrotsearch.hppc.IntArrayDeque;
import com.carrotsearch.hppc.IntDeque;
import com.carrotsearch.hppc.LongArrayDeque;
import com.carrotsearch.hppc.LongDeque;

import org.apache.lucene.store.IndexInput;
import org.elasticsearch.ExceptionsHelper;
Expand All @@ -43,6 +41,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.TreeMap;
import java.util.concurrent.ExecutionException;
Expand Down Expand Up @@ -121,7 +120,7 @@ public AsyncDirectIOIndexInput(Path path, int blockSize, int bufferSize, int max
super("DirectIOIndexInput(path=\"" + path + "\")");
this.channel = FileChannel.open(path, StandardOpenOption.READ, getDirectOpenOption());
this.blockSize = blockSize;
this.prefetcher = new DirectIOPrefetcher(blockSize, this.channel, bufferSize, maxPrefetches, maxPrefetches * 16);
this.prefetcher = new DirectIOPrefetcher(blockSize, this.channel, bufferSize, maxPrefetches);
this.buffer = allocateBuffer(bufferSize, blockSize);
this.isOpen = true;
this.isClosable = true;
Expand All @@ -139,13 +138,7 @@ private AsyncDirectIOIndexInput(String description, AsyncDirectIOIndexInput othe
this.buffer = allocateBuffer(bufferSize, other.blockSize);
this.blockSize = other.blockSize;
this.channel = other.channel;
this.prefetcher = new DirectIOPrefetcher(
this.blockSize,
this.channel,
bufferSize,
other.prefetcher.maxConcurrentPrefetches,
other.prefetcher.maxTotalPrefetches
);
this.prefetcher = new DirectIOPrefetcher(this.blockSize, this.channel, bufferSize, other.prefetcher.maxConcurrentPrefetches);
this.isOpen = true;
this.isClosable = false;
this.length = length;
Expand All @@ -170,11 +163,22 @@ public void prefetch(long pos, long length) throws IOException {
if (pos < 0 || length < 0 || pos + length > this.length) {
throw new IllegalArgumentException("Invalid prefetch range: pos=" + pos + ", length=" + length + ", fileLength=" + this.length);
}
// check if our current buffer already contains the requested range

// align to prefetch buffer
final long absPos = pos + offset;
final long alignedPos = absPos - (absPos % blockSize);
// we only prefetch into a single buffer, even if length exceeds buffer size
// maybe we should improve this...
long alignedPos = absPos - absPos % blockSize;

// check if our current buffer already contains the requested range
if (alignedPos >= filePos && alignedPos < filePos + buffer.capacity()) {
// The current buffer contains bytes of this request.
// Adjust the position and length accordingly to skip the current buffer.
alignedPos = filePos + buffer.capacity();
length -= alignedPos - absPos;
} else {
// Add to the total length the bytes added by the alignment
length += absPos - alignedPos;
}
// do the prefetch
prefetcher.prefetch(alignedPos, length);
}

Expand Down Expand Up @@ -396,12 +400,16 @@ public IndexInput slice(String sliceDescription, long offset, long length) throw
return slice;
}

// pkg private for testing
int prefetchSlots() {
return prefetcher.posToSlot.size();
}

/**
* A simple prefetcher that uses virtual threads to prefetch data into direct byte buffers.
*/
private static class DirectIOPrefetcher implements Closeable {
private final int maxConcurrentPrefetches;
private final int maxTotalPrefetches;
private final FileChannel channel;
private final int blockSize;
private final long[] prefetchPos;
Expand All @@ -411,10 +419,9 @@ private static class DirectIOPrefetcher implements Closeable {
private final IntDeque slots;
private final ByteBuffer[] prefetchBuffers;
private final int prefetchBytesSize;
private final LongDeque pendingPrefetches = new LongArrayDeque();
private final ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor();

DirectIOPrefetcher(int blockSize, FileChannel channel, int prefetchBytesSize, int maxConcurrentPrefetches, int maxTotalPrefetches) {
DirectIOPrefetcher(int blockSize, FileChannel channel, int prefetchBytesSize, int maxConcurrentPrefetches) {
this.blockSize = blockSize;
this.channel = channel;
this.maxConcurrentPrefetches = maxConcurrentPrefetches;
Expand All @@ -428,48 +435,48 @@ private static class DirectIOPrefetcher implements Closeable {
this.posToSlot = new TreeMap<>();
this.prefetchBuffers = new ByteBuffer[maxConcurrentPrefetches];
this.prefetchBytesSize = prefetchBytesSize;
this.maxTotalPrefetches = maxTotalPrefetches;
}

/**
* Initiate prefetch of the given range. The range will be aligned to blockSize and
* chopped up into chunks of prefetchBytesSize.
* If there are not enough slots available, the prefetch request will be queued
* until a slot becomes available. This throttling may occur if the number of
* concurrent prefetches is exceeded, or if there is significant IO pressure.
* If there are not enough slots available, the prefetch request will reuse the slot
* with the lowest file pointer. If that slot is still being prefetched, the prefetch request
* will be skipped.
* @param pos the position to prefetch from, must be non-negative and within file length
* @param length the length to prefetch, must be non-negative.
*/
void prefetch(long pos, long length) {
assert pos % blockSize == 0 : "prefetch pos [" + pos + "] must be aligned to block size [" + blockSize + "]";
// first determine how many slots we need given the length
int numSlots = (int) Math.min(
(length + prefetchBytesSize - 1) / prefetchBytesSize,
maxTotalPrefetches - (this.posToSlot.size() + this.pendingPrefetches.size())
);
while (numSlots > 0 && (this.posToSlot.size() + this.pendingPrefetches.size()) < maxTotalPrefetches) {
final int slot;
Integer existingSlot = this.posToSlot.get(pos);
if (existingSlot != null && prefetchThreads.get(existingSlot) != null) {
// already being prefetched and hasn't been consumed.
// return early
return;
}
if (this.posToSlot.size() < maxConcurrentPrefetches && slots.isEmpty() == false) {
slot = slots.removeFirst();
while (length > 0) {
Map.Entry<Long, Integer> floor = this.posToSlot.floorEntry(pos);
if (floor == null || floor.getKey() + prefetchBytesSize <= pos) {
// check if there are any slots available. If not we will reuse the one with the
// lower file pointer.
if (slots.isEmpty()) {
assert this.posToSlot.size() == maxConcurrentPrefetches;
final int oldestSlot = posToSlot.firstEntry().getValue();
if (prefetchThreads.get(oldestSlot).isDone() == false) {
// cannot reuse oldest slot. We are over-prefetching
LOGGER.debug("could not prefetch pos [{}] with length [{}]", pos, length);
return;
}
LOGGER.debug("prefetch on reused slot with pos [{}] with length [{}]", pos, length);
clearSlot(oldestSlot);
assert slots.isEmpty() == false;
}
final int slot = slots.removeFirst();
posToSlot.put(pos, slot);
prefetchPos[slot] = pos;
} else {
slot = -1;
LOGGER.debug("queueing prefetch of pos [{}] with length [{}], waiting for open slot", pos, length);
pendingPrefetches.addLast(pos);
}
if (slot != -1) {
startPrefetch(pos, slot);
length -= prefetchBytesSize;
pos += prefetchBytesSize;
} else {
length -= floor.getKey() + prefetchBytesSize - pos;
pos = floor.getKey() + prefetchBytesSize;
}
pos += prefetchBytesSize;
numSlots--;
}

}

/**
Expand Down Expand Up @@ -506,7 +513,7 @@ boolean readBytes(long pos, ByteBuffer slice, int delta) throws IOException {
Thread.currentThread().interrupt();
} finally {
if (prefetchBuffer == null) {
clearSlotAndMaybeStartPending(slot);
clearSlot(slot);
}
}
if (prefetchBuffer == null) {
Expand All @@ -518,22 +525,15 @@ boolean readBytes(long pos, ByteBuffer slice, int delta) throws IOException {
slice.put(prefetchBuffer);
slice.flip();
slice.position(Math.toIntExact(pos - prefetchedPos) + delta);
clearSlotAndMaybeStartPending(slot);
clearSlot(slot);
return true;
}

void clearSlotAndMaybeStartPending(int slot) {
assert prefetchThreads.get(slot) != null && prefetchThreads.get(slot).isDone();
void clearSlot(int slot) {
assert prefetchThreads.get(slot) != null;
prefetchThreads.set(slot, null);
posToSlot.remove(prefetchPos[slot]);
if (pendingPrefetches.isEmpty()) {
slots.addLast(slot);
return;
}
final long req = pendingPrefetches.removeFirst();
posToSlot.put(req, slot);
prefetchPos[slot] = req;
startPrefetch(req, slot);
slots.addLast(slot);
}

private boolean assertSlotsConsistent() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,22 @@
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.NIOFSDirectory;
import org.apache.lucene.util.hnsw.IntToIntFunction;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.test.ESTestCase;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import static org.hamcrest.Matchers.lessThanOrEqualTo;

public class AsyncDirectIOIndexInputTests extends ESTestCase {

@SuppressForbidden(reason = "requires Files.getFileStore")
Expand Down Expand Up @@ -169,4 +176,96 @@ public void testWriteThenReadBytesConsistency() throws IOException {
}
}

public void testPrefetchGetsCleanUp() throws IOException {
int numVectors = randomIntBetween(100, 1000);
int numDimensions = randomIntBetween(100, 2048);
Path path = createTempDir("testDirectIODirectory");
byte[] bytes = new byte[numDimensions * Float.BYTES];
ByteBuffer buffer = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN);
float[][] vectors = new float[numVectors][numDimensions];
try (Directory dir = new NIOFSDirectory(path)) {
try (var output = dir.createOutput("test", org.apache.lucene.store.IOContext.DEFAULT)) {
for (int i = 0; i < numVectors; i++) {
random().nextBytes(bytes);
output.writeBytes(bytes, bytes.length);
buffer.asFloatBuffer().get(vectors[i]);
}
}

final int blockSize = getBlockSize(path);
final int bufferSize = 8192;
// fetch all
try (AsyncDirectIOIndexInput actualInput = new AsyncDirectIOIndexInput(path.resolve("test"), blockSize, bufferSize, 64)) {
assertPrefetchSlots(actualInput, numDimensions, numVectors, i -> i, vectors, bufferSize);
}
// fetch all in slice
try (AsyncDirectIOIndexInput actualInput = new AsyncDirectIOIndexInput(path.resolve("test"), blockSize, bufferSize, 64)) {
int start = randomIntBetween(0, numVectors - 1);
float[][] vectorsSlice = Arrays.copyOfRange(vectors, start, numVectors);
long sliceStart = (long) start * bytes.length;
assertPrefetchSlots(
(AsyncDirectIOIndexInput) actualInput.slice("slice", sliceStart, actualInput.length() - sliceStart),
numDimensions,
vectorsSlice.length,
i -> i,
vectorsSlice,
bufferSize
);
}
// random fetch
List<Integer> tempList = new ArrayList<>(numVectors);
for (int i = 0; i < numVectors; i++) {
tempList.add(i);
}
Collections.shuffle(tempList, random());
List<Integer> subList = tempList.subList(0, randomIntBetween(1, numVectors));
Collections.sort(subList);
try (AsyncDirectIOIndexInput actualInput = new AsyncDirectIOIndexInput(path.resolve("test"), blockSize, bufferSize, 64)) {
assertPrefetchSlots(actualInput, numDimensions, subList.size(), subList::get, vectors, bufferSize);
}
}
}

private static void assertPrefetchSlots(
AsyncDirectIOIndexInput actualInput,
int numDimensions,
int numVectors,
IntToIntFunction ords,
float[][] vectors,
int bufferSize
) throws IOException {
int prefetchSize = randomIntBetween(1, 64);
float[] floats = new float[numDimensions];
long bytesLength = (long) numDimensions * Float.BYTES;
int limit = numVectors - prefetchSize + 1;
int i = 0;
for (; i < limit; i += prefetchSize) {
int ord = ords.apply(i);
for (int j = 0; j < prefetchSize; j++) {
actualInput.prefetch((ord + j) * bytesLength, bytesLength);
}
// check we prefetch enough data. We need to add 1 because of the current buffer.
assertThat(prefetchSize * bytesLength, lessThanOrEqualTo((long) (1 + actualInput.prefetchSlots()) * bufferSize));
for (int j = 0; j < prefetchSize; j++) {
actualInput.seek((ord + j) * bytesLength);
actualInput.readFloats(floats, 0, floats.length);
assertArrayEquals(vectors[ord + j], floats, 0.0f);
}
// check we have freed all the slots
assertEquals(0, actualInput.prefetchSlots());
}
for (int k = i; k < numVectors; k++) {
actualInput.prefetch(ords.apply(k) * bytesLength, bytesLength);
}
// check we prefetch enough data. We need to add 1 because of the current buffer.
assertThat((numVectors - i) * bytesLength, lessThanOrEqualTo((long) (1 + actualInput.prefetchSlots()) * bufferSize));
for (; i < numVectors; i++) {
int ord = ords.apply(i);
actualInput.seek(ord * bytesLength);
actualInput.readFloats(floats, 0, floats.length);
assertArrayEquals(vectors[ord], floats, 0.0f);
}
// check we have freed all the slots
assertEquals(0, actualInput.prefetchSlots());
}
}