Skip to content

Commit 1d9ab12

Browse files
authored
Optimize DirectIO prefetch for monotonically increasing access (elastic#136946)
1 parent f969d7d commit 1d9ab12

File tree

2 files changed

+155
-56
lines changed

2 files changed

+155
-56
lines changed

server/src/main/java/org/elasticsearch/index/store/AsyncDirectIOIndexInput.java

Lines changed: 56 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222

2323
import com.carrotsearch.hppc.IntArrayDeque;
2424
import com.carrotsearch.hppc.IntDeque;
25-
import com.carrotsearch.hppc.LongArrayDeque;
26-
import com.carrotsearch.hppc.LongDeque;
2725

2826
import org.apache.lucene.store.IndexInput;
2927
import org.elasticsearch.ExceptionsHelper;
@@ -43,6 +41,7 @@
4341
import java.util.ArrayList;
4442
import java.util.Arrays;
4543
import java.util.List;
44+
import java.util.Map;
4645
import java.util.Objects;
4746
import java.util.TreeMap;
4847
import java.util.concurrent.ExecutionException;
@@ -121,7 +120,7 @@ public AsyncDirectIOIndexInput(Path path, int blockSize, int bufferSize, int max
121120
super("DirectIOIndexInput(path=\"" + path + "\")");
122121
this.channel = FileChannel.open(path, StandardOpenOption.READ, getDirectOpenOption());
123122
this.blockSize = blockSize;
124-
this.prefetcher = new DirectIOPrefetcher(blockSize, this.channel, bufferSize, maxPrefetches, maxPrefetches * 16);
123+
this.prefetcher = new DirectIOPrefetcher(blockSize, this.channel, bufferSize, maxPrefetches);
125124
this.buffer = allocateBuffer(bufferSize, blockSize);
126125
this.isOpen = true;
127126
this.isClosable = true;
@@ -139,13 +138,7 @@ private AsyncDirectIOIndexInput(String description, AsyncDirectIOIndexInput othe
139138
this.buffer = allocateBuffer(bufferSize, other.blockSize);
140139
this.blockSize = other.blockSize;
141140
this.channel = other.channel;
142-
this.prefetcher = new DirectIOPrefetcher(
143-
this.blockSize,
144-
this.channel,
145-
bufferSize,
146-
other.prefetcher.maxConcurrentPrefetches,
147-
other.prefetcher.maxTotalPrefetches
148-
);
141+
this.prefetcher = new DirectIOPrefetcher(this.blockSize, this.channel, bufferSize, other.prefetcher.maxConcurrentPrefetches);
149142
this.isOpen = true;
150143
this.isClosable = false;
151144
this.length = length;
@@ -170,11 +163,22 @@ public void prefetch(long pos, long length) throws IOException {
170163
if (pos < 0 || length < 0 || pos + length > this.length) {
171164
throw new IllegalArgumentException("Invalid prefetch range: pos=" + pos + ", length=" + length + ", fileLength=" + this.length);
172165
}
173-
// check if our current buffer already contains the requested range
166+
167+
// align to prefetch buffer
174168
final long absPos = pos + offset;
175-
final long alignedPos = absPos - (absPos % blockSize);
176-
// we only prefetch into a single buffer, even if length exceeds buffer size
177-
// maybe we should improve this...
169+
long alignedPos = absPos - absPos % blockSize;
170+
171+
// check if our current buffer already contains the requested range
172+
if (alignedPos >= filePos && alignedPos < filePos + buffer.capacity()) {
173+
// The current buffer contains bytes of this request.
174+
// Adjust the position and length accordingly to skip the current buffer.
175+
alignedPos = filePos + buffer.capacity();
176+
length -= alignedPos - absPos;
177+
} else {
178+
// Add to the total length the bytes added by the alignment
179+
length += absPos - alignedPos;
180+
}
181+
// do the prefetch
178182
prefetcher.prefetch(alignedPos, length);
179183
}
180184

@@ -396,12 +400,16 @@ public IndexInput slice(String sliceDescription, long offset, long length) throw
396400
return slice;
397401
}
398402

403+
// pkg private for testing
404+
int prefetchSlots() {
405+
return prefetcher.posToSlot.size();
406+
}
407+
399408
/**
400409
* A simple prefetcher that uses virtual threads to prefetch data into direct byte buffers.
401410
*/
402411
private static class DirectIOPrefetcher implements Closeable {
403412
private final int maxConcurrentPrefetches;
404-
private final int maxTotalPrefetches;
405413
private final FileChannel channel;
406414
private final int blockSize;
407415
private final long[] prefetchPos;
@@ -411,10 +419,9 @@ private static class DirectIOPrefetcher implements Closeable {
411419
private final IntDeque slots;
412420
private final ByteBuffer[] prefetchBuffers;
413421
private final int prefetchBytesSize;
414-
private final LongDeque pendingPrefetches = new LongArrayDeque();
415422
private final ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor();
416423

417-
DirectIOPrefetcher(int blockSize, FileChannel channel, int prefetchBytesSize, int maxConcurrentPrefetches, int maxTotalPrefetches) {
424+
DirectIOPrefetcher(int blockSize, FileChannel channel, int prefetchBytesSize, int maxConcurrentPrefetches) {
418425
this.blockSize = blockSize;
419426
this.channel = channel;
420427
this.maxConcurrentPrefetches = maxConcurrentPrefetches;
@@ -428,48 +435,48 @@ private static class DirectIOPrefetcher implements Closeable {
428435
this.posToSlot = new TreeMap<>();
429436
this.prefetchBuffers = new ByteBuffer[maxConcurrentPrefetches];
430437
this.prefetchBytesSize = prefetchBytesSize;
431-
this.maxTotalPrefetches = maxTotalPrefetches;
432438
}
433439

434440
/**
435441
* Initiate prefetch of the given range. The range will be aligned to blockSize and
436442
* chopped up into chunks of prefetchBytesSize.
437-
* If there are not enough slots available, the prefetch request will be queued
438-
* until a slot becomes available. This throttling may occur if the number of
439-
* concurrent prefetches is exceeded, or if there is significant IO pressure.
443+
* If there are not enough slots available, the prefetch request will reuse the slot
444+
* with the lowest file pointer. If that slot is still being prefetched, the prefetch request
445+
* will be skipped.
440446
* @param pos the position to prefetch from, must be non-negative and within file length
441447
* @param length the length to prefetch, must be non-negative.
442448
*/
443449
void prefetch(long pos, long length) {
450+
assert pos % blockSize == 0 : "prefetch pos [" + pos + "] must be aligned to block size [" + blockSize + "]";
444451
// first determine how many slots we need given the length
445-
int numSlots = (int) Math.min(
446-
(length + prefetchBytesSize - 1) / prefetchBytesSize,
447-
maxTotalPrefetches - (this.posToSlot.size() + this.pendingPrefetches.size())
448-
);
449-
while (numSlots > 0 && (this.posToSlot.size() + this.pendingPrefetches.size()) < maxTotalPrefetches) {
450-
final int slot;
451-
Integer existingSlot = this.posToSlot.get(pos);
452-
if (existingSlot != null && prefetchThreads.get(existingSlot) != null) {
453-
// already being prefetched and hasn't been consumed.
454-
// return early
455-
return;
456-
}
457-
if (this.posToSlot.size() < maxConcurrentPrefetches && slots.isEmpty() == false) {
458-
slot = slots.removeFirst();
452+
while (length > 0) {
453+
Map.Entry<Long, Integer> floor = this.posToSlot.floorEntry(pos);
454+
if (floor == null || floor.getKey() + prefetchBytesSize <= pos) {
455+
// check if there are any slots available. If not we will reuse the one with the
456+
// lower file pointer.
457+
if (slots.isEmpty()) {
458+
assert this.posToSlot.size() == maxConcurrentPrefetches;
459+
final int oldestSlot = posToSlot.firstEntry().getValue();
460+
if (prefetchThreads.get(oldestSlot).isDone() == false) {
461+
// cannot reuse oldest slot. We are over-prefetching
462+
LOGGER.debug("could not prefetch pos [{}] with length [{}]", pos, length);
463+
return;
464+
}
465+
LOGGER.debug("prefetch on reused slot with pos [{}] with length [{}]", pos, length);
466+
clearSlot(oldestSlot);
467+
assert slots.isEmpty() == false;
468+
}
469+
final int slot = slots.removeFirst();
459470
posToSlot.put(pos, slot);
460471
prefetchPos[slot] = pos;
461-
} else {
462-
slot = -1;
463-
LOGGER.debug("queueing prefetch of pos [{}] with length [{}], waiting for open slot", pos, length);
464-
pendingPrefetches.addLast(pos);
465-
}
466-
if (slot != -1) {
467472
startPrefetch(pos, slot);
473+
length -= prefetchBytesSize;
474+
pos += prefetchBytesSize;
475+
} else {
476+
length -= floor.getKey() + prefetchBytesSize - pos;
477+
pos = floor.getKey() + prefetchBytesSize;
468478
}
469-
pos += prefetchBytesSize;
470-
numSlots--;
471479
}
472-
473480
}
474481

475482
/**
@@ -506,7 +513,7 @@ boolean readBytes(long pos, ByteBuffer slice, int delta) throws IOException {
506513
Thread.currentThread().interrupt();
507514
} finally {
508515
if (prefetchBuffer == null) {
509-
clearSlotAndMaybeStartPending(slot);
516+
clearSlot(slot);
510517
}
511518
}
512519
if (prefetchBuffer == null) {
@@ -518,22 +525,15 @@ boolean readBytes(long pos, ByteBuffer slice, int delta) throws IOException {
518525
slice.put(prefetchBuffer);
519526
slice.flip();
520527
slice.position(Math.toIntExact(pos - prefetchedPos) + delta);
521-
clearSlotAndMaybeStartPending(slot);
528+
clearSlot(slot);
522529
return true;
523530
}
524531

525-
void clearSlotAndMaybeStartPending(int slot) {
526-
assert prefetchThreads.get(slot) != null && prefetchThreads.get(slot).isDone();
532+
void clearSlot(int slot) {
533+
assert prefetchThreads.get(slot) != null;
527534
prefetchThreads.set(slot, null);
528535
posToSlot.remove(prefetchPos[slot]);
529-
if (pendingPrefetches.isEmpty()) {
530-
slots.addLast(slot);
531-
return;
532-
}
533-
final long req = pendingPrefetches.removeFirst();
534-
posToSlot.put(req, slot);
535-
prefetchPos[slot] = req;
536-
startPrefetch(req, slot);
536+
slots.addLast(slot);
537537
}
538538

539539
private boolean assertSlotsConsistent() {

server/src/test/java/org/elasticsearch/index/store/AsyncDirectIOIndexInputTests.java

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,22 @@
1212
import org.apache.lucene.store.Directory;
1313
import org.apache.lucene.store.IndexInput;
1414
import org.apache.lucene.store.NIOFSDirectory;
15+
import org.apache.lucene.util.hnsw.IntToIntFunction;
1516
import org.elasticsearch.core.SuppressForbidden;
1617
import org.elasticsearch.test.ESTestCase;
1718

1819
import java.io.IOException;
20+
import java.nio.ByteBuffer;
21+
import java.nio.ByteOrder;
1922
import java.nio.file.Files;
2023
import java.nio.file.Path;
2124
import java.util.ArrayList;
25+
import java.util.Arrays;
26+
import java.util.Collections;
2227
import java.util.List;
2328

29+
import static org.hamcrest.Matchers.lessThanOrEqualTo;
30+
2431
public class AsyncDirectIOIndexInputTests extends ESTestCase {
2532

2633
@SuppressForbidden(reason = "requires Files.getFileStore")
@@ -169,4 +176,96 @@ public void testWriteThenReadBytesConsistency() throws IOException {
169176
}
170177
}
171178

179+
public void testPrefetchGetsCleanUp() throws IOException {
180+
int numVectors = randomIntBetween(100, 1000);
181+
int numDimensions = randomIntBetween(100, 2048);
182+
Path path = createTempDir("testDirectIODirectory");
183+
byte[] bytes = new byte[numDimensions * Float.BYTES];
184+
ByteBuffer buffer = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN);
185+
float[][] vectors = new float[numVectors][numDimensions];
186+
try (Directory dir = new NIOFSDirectory(path)) {
187+
try (var output = dir.createOutput("test", org.apache.lucene.store.IOContext.DEFAULT)) {
188+
for (int i = 0; i < numVectors; i++) {
189+
random().nextBytes(bytes);
190+
output.writeBytes(bytes, bytes.length);
191+
buffer.asFloatBuffer().get(vectors[i]);
192+
}
193+
}
194+
195+
final int blockSize = getBlockSize(path);
196+
final int bufferSize = 8192;
197+
// fetch all
198+
try (AsyncDirectIOIndexInput actualInput = new AsyncDirectIOIndexInput(path.resolve("test"), blockSize, bufferSize, 64)) {
199+
assertPrefetchSlots(actualInput, numDimensions, numVectors, i -> i, vectors, bufferSize);
200+
}
201+
// fetch all in slice
202+
try (AsyncDirectIOIndexInput actualInput = new AsyncDirectIOIndexInput(path.resolve("test"), blockSize, bufferSize, 64)) {
203+
int start = randomIntBetween(0, numVectors - 1);
204+
float[][] vectorsSlice = Arrays.copyOfRange(vectors, start, numVectors);
205+
long sliceStart = (long) start * bytes.length;
206+
assertPrefetchSlots(
207+
(AsyncDirectIOIndexInput) actualInput.slice("slice", sliceStart, actualInput.length() - sliceStart),
208+
numDimensions,
209+
vectorsSlice.length,
210+
i -> i,
211+
vectorsSlice,
212+
bufferSize
213+
);
214+
}
215+
// random fetch
216+
List<Integer> tempList = new ArrayList<>(numVectors);
217+
for (int i = 0; i < numVectors; i++) {
218+
tempList.add(i);
219+
}
220+
Collections.shuffle(tempList, random());
221+
List<Integer> subList = tempList.subList(0, randomIntBetween(1, numVectors));
222+
Collections.sort(subList);
223+
try (AsyncDirectIOIndexInput actualInput = new AsyncDirectIOIndexInput(path.resolve("test"), blockSize, bufferSize, 64)) {
224+
assertPrefetchSlots(actualInput, numDimensions, subList.size(), subList::get, vectors, bufferSize);
225+
}
226+
}
227+
}
228+
229+
private static void assertPrefetchSlots(
230+
AsyncDirectIOIndexInput actualInput,
231+
int numDimensions,
232+
int numVectors,
233+
IntToIntFunction ords,
234+
float[][] vectors,
235+
int bufferSize
236+
) throws IOException {
237+
int prefetchSize = randomIntBetween(1, 64);
238+
float[] floats = new float[numDimensions];
239+
long bytesLength = (long) numDimensions * Float.BYTES;
240+
int limit = numVectors - prefetchSize + 1;
241+
int i = 0;
242+
for (; i < limit; i += prefetchSize) {
243+
int ord = ords.apply(i);
244+
for (int j = 0; j < prefetchSize; j++) {
245+
actualInput.prefetch((ord + j) * bytesLength, bytesLength);
246+
}
247+
// check we prefetch enough data. We need to add 1 because of the current buffer.
248+
assertThat(prefetchSize * bytesLength, lessThanOrEqualTo((long) (1 + actualInput.prefetchSlots()) * bufferSize));
249+
for (int j = 0; j < prefetchSize; j++) {
250+
actualInput.seek((ord + j) * bytesLength);
251+
actualInput.readFloats(floats, 0, floats.length);
252+
assertArrayEquals(vectors[ord + j], floats, 0.0f);
253+
}
254+
// check we have freed all the slots
255+
assertEquals(0, actualInput.prefetchSlots());
256+
}
257+
for (int k = i; k < numVectors; k++) {
258+
actualInput.prefetch(ords.apply(k) * bytesLength, bytesLength);
259+
}
260+
// check we prefetch enough data. We need to add 1 because of the current buffer.
261+
assertThat((numVectors - i) * bytesLength, lessThanOrEqualTo((long) (1 + actualInput.prefetchSlots()) * bufferSize));
262+
for (; i < numVectors; i++) {
263+
int ord = ords.apply(i);
264+
actualInput.seek(ord * bytesLength);
265+
actualInput.readFloats(floats, 0, floats.length);
266+
assertArrayEquals(vectors[ord], floats, 0.0f);
267+
}
268+
// check we have freed all the slots
269+
assertEquals(0, actualInput.prefetchSlots());
270+
}
172271
}

0 commit comments

Comments
 (0)