Skip to content

Commit f245bed

Browse files
authored
Refactor GroupVIntUtil functional interface lambda which does not inline correctly in MemorySegmentIndexInput (#15089)
Refactor GroupVIntUtil functional interface lambda which does not inline correctly in MemorySegmentIndexInput. This (hopefully) fixes #15079
1 parent c3f3db6 commit f245bed

File tree

5 files changed

+69
-28
lines changed

5 files changed

+69
-28
lines changed

lucene/CHANGES.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,9 @@ Improvements
166166

167167
* GITHUB#14816: Expose search strategy in KNN query
168168

169+
* GITHUB#15089, GITHUB#15079: Refactor GroupVIntUtil functional interface lambda
170+
which does not inline correctly in MemorySegmentIndexInput (Uwe Schindler, Robert Muir)
171+
169172
Optimizations
170173
---------------------
171174
* GITHUB#14932: Switched to GroupVarInt Encoding for HNSW Graph edges, added backwards compatibility (Akira Lonske)

lucene/core/src/java/org/apache/lucene/store/BufferedIndexInput.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,13 @@ public final int readInt() throws IOException {
154154
public void readGroupVInt(int[] dst, int offset) throws IOException {
155155
final int len =
156156
GroupVIntUtil.readGroupVInt(
157-
this, buffer.remaining(), p -> buffer.getInt((int) p), buffer.position(), dst, offset);
157+
this,
158+
buffer.remaining(),
159+
GroupVIntUtil.VH_BUFFER_GET_INT,
160+
buffer,
161+
buffer.position(),
162+
dst,
163+
offset);
158164
if (len > 0) {
159165
buffer.position(buffer.position() + len);
160166
}

lucene/core/src/java/org/apache/lucene/store/ByteBuffersDataInput.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
*/
3939
public final class ByteBuffersDataInput extends DataInput
4040
implements Accountable, RandomAccessInput {
41+
4142
private final ByteBuffer[] blocks;
4243
private final FloatBuffer[] floatBuffers;
4344
private final LongBuffer[] longBuffers;
@@ -215,7 +216,8 @@ public void readGroupVInt(int[] dst, int offset) throws IOException {
215216
GroupVIntUtil.readGroupVInt(
216217
this,
217218
block.limit() - blockOffset,
218-
p -> block.getInt((int) p),
219+
GroupVIntUtil.VH_BUFFER_GET_INT,
220+
block,
219221
blockOffset,
220222
dst,
221223
offset);

lucene/core/src/java/org/apache/lucene/store/MemorySegmentIndexInput.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
import java.lang.foreign.Arena;
2424
import java.lang.foreign.MemorySegment;
2525
import java.lang.foreign.ValueLayout;
26+
import java.lang.invoke.MethodHandles;
27+
import java.lang.invoke.MethodType;
28+
import java.lang.invoke.VarHandle;
2629
import java.nio.ByteOrder;
2730
import java.util.Arrays;
2831
import java.util.Objects;
@@ -52,6 +55,13 @@ abstract class MemorySegmentIndexInput extends IndexInput implements MemorySegme
5255
ValueLayout.JAVA_FLOAT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);
5356
private static final Optional<NativeAccess> NATIVE_ACCESS = NativeAccess.getImplementation();
5457

58+
private static final VarHandle VH_MEMSEG_GET_INT =
59+
MethodHandles.filterCoordinates(
60+
LAYOUT_LE_INT.varHandle(),
61+
0,
62+
MethodHandles.identity(Object.class)
63+
.asType(MethodType.methodType(MemorySegment.class, Object.class)));
64+
5565
final long length;
5666
final long chunkSizeMask;
5767
final int chunkSizePower;
@@ -454,7 +464,8 @@ public void readGroupVInt(int[] dst, int offset) throws IOException {
454464
GroupVIntUtil.readGroupVInt(
455465
this,
456466
curSegment.byteSize() - curPosition,
457-
p -> curSegment.get(LAYOUT_LE_INT, p),
467+
VH_MEMSEG_GET_INT,
468+
curSegment,
458469
curPosition,
459470
dst,
460471
offset);

lucene/core/src/java/org/apache/lucene/util/GroupVIntUtil.java

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@
1717
package org.apache.lucene.util;
1818

1919
import java.io.IOException;
20+
import java.lang.invoke.MethodHandles;
21+
import java.lang.invoke.MethodType;
22+
import java.lang.invoke.VarHandle;
23+
import java.nio.ByteBuffer;
24+
import java.nio.ByteOrder;
2025
import org.apache.lucene.store.DataInput;
2126
import org.apache.lucene.store.DataOutput;
2227

@@ -33,6 +38,23 @@ public final class GroupVIntUtil {
3338
private static final long[] LONG_MASKS = new long[] {0xFFL, 0xFFFFL, 0xFFFFFFL, 0xFFFFFFFFL};
3439
private static final int[] INT_MASKS = new int[] {0xFF, 0xFFFF, 0xFFFFFF, ~0};
3540

41+
/**
42+
* A {@link VarHandle} which allows to read ints from a {@link ByteBuffer} using {@code long}
43+
* offsets. The handle can be used with the {@code readGroupVInt()} methods taking a {@code
44+
* VarHandle} and {@code ByteBuffer} storage parameter.
45+
*
46+
* @see #readGroupVInt(DataInput, long, VarHandle, Object, long, int[], int)
47+
* @see #readGroupVInt(DataInput, long, VarHandle, Object, long, long[], int)
48+
*/
49+
public static final VarHandle VH_BUFFER_GET_INT =
50+
MethodHandles.filterCoordinates(
51+
MethodHandles.byteBufferViewVarHandle(int[].class, ByteOrder.LITTLE_ENDIAN),
52+
0,
53+
MethodHandles.identity(Object.class)
54+
.asType(MethodType.methodType(ByteBuffer.class, Object.class)),
55+
MethodHandles.explicitCastArguments(
56+
MethodHandles.identity(long.class), MethodType.methodType(int.class, long.class)));
57+
3658
/**
3759
* Read all the group varints, including the tail vints. we need a long[] because this is what
3860
* postings are using, all longs are actually required to be integers.
@@ -126,21 +148,15 @@ private static int readIntInGroup(DataInput in, int numBytesMinus1) throws IOExc
126148
}
127149

128150
/**
129-
* Provides an abstraction for read int values, so that decoding logic can be reused in different
130-
* DataInput.
131-
*/
132-
@FunctionalInterface
133-
public static interface IntReader {
134-
int read(long v);
135-
}
136-
137-
/**
138-
* Faster implementation of read single group, It read values from the buffer that would not cross
139-
* boundaries.
151+
* Faster implementation of read single group, It read values from {@link VarHandle} that would
152+
* not cross boundaries.
140153
*
141154
* @param in the input to use to read data.
142155
* @param remaining the number of remaining bytes allowed to read for current block/segment.
143-
* @param reader the supplier of read int.
156+
* @param vh the varhandle which has the coordinates {@code (Object, long)}. The first coordinate
157+
* must accept the {@code storage} parameter, the second coordinate must be the long offset.
158+
* @param storage the reference to the backing storage (e.g., one of {@code byte[], ByteBuffer,
159+
* MemorySegment})
144160
* @param pos the start pos to read from the reader.
145161
* @param dst the array to read ints into.
146162
* @param offset the offset in the array to start storing ints.
@@ -149,7 +165,7 @@ public static interface IntReader {
149165
* #MAX_LENGTH_PER_GROUP}
150166
*/
151167
public static int readGroupVInt(
152-
DataInput in, long remaining, IntReader reader, long pos, long[] dst, int offset)
168+
DataInput in, long remaining, VarHandle vh, Object storage, long pos, long[] dst, int offset)
153169
throws IOException {
154170
if (remaining < MAX_LENGTH_PER_GROUP) {
155171
readGroupVInt(in, dst, offset);
@@ -163,24 +179,27 @@ public static int readGroupVInt(
163179
final int n4Minus1 = flag & 0x03;
164180

165181
// This code path has fewer conditionals and tends to be significantly faster in benchmarks
166-
dst[offset] = reader.read(pos) & LONG_MASKS[n1Minus1];
182+
dst[offset] = (int) vh.get(storage, pos) & LONG_MASKS[n1Minus1];
167183
pos += 1 + n1Minus1;
168-
dst[offset + 1] = reader.read(pos) & LONG_MASKS[n2Minus1];
184+
dst[offset + 1] = (int) vh.get(storage, pos) & LONG_MASKS[n2Minus1];
169185
pos += 1 + n2Minus1;
170-
dst[offset + 2] = reader.read(pos) & LONG_MASKS[n3Minus1];
186+
dst[offset + 2] = (int) vh.get(storage, pos) & LONG_MASKS[n3Minus1];
171187
pos += 1 + n3Minus1;
172-
dst[offset + 3] = reader.read(pos) & LONG_MASKS[n4Minus1];
188+
dst[offset + 3] = (int) vh.get(storage, pos) & LONG_MASKS[n4Minus1];
173189
pos += 1 + n4Minus1;
174190
return (int) (pos - posStart);
175191
}
176192

177193
/**
178-
* Faster implementation of read single group, It read values from the buffer that would not cross
179-
* boundaries.
194+
* Faster implementation of read single group, It read values from a {@link VarHandle} that would
195+
* not cross boundaries.
180196
*
181197
* @param in the input to use to read data.
182198
* @param remaining the number of remaining bytes allowed to read for current block/segment.
183-
* @param reader the supplier of read int.
199+
* @param vh the varhandle which has the coordinates {@code (Object, long)}. The first coordinate
200+
* must accept the {@code storage} parameter, the second coordinate must be the long offset.
201+
* @param storage the reference to the backing storage (e.g., one of {@code byte[], ByteBuffer,
202+
* MemorySegment})
184203
* @param pos the start pos to read from the reader.
185204
* @param dst the array to read ints into.
186205
* @param offset the offset in the array to start storing ints.
@@ -189,7 +208,7 @@ public static int readGroupVInt(
189208
* #MAX_LENGTH_PER_GROUP}
190209
*/
191210
public static int readGroupVInt(
192-
DataInput in, long remaining, IntReader reader, long pos, int[] dst, int offset)
211+
DataInput in, long remaining, VarHandle vh, Object storage, long pos, int[] dst, int offset)
193212
throws IOException {
194213
if (remaining < MAX_LENGTH_PER_GROUP) {
195214
readGroupVInt(in, dst, offset);
@@ -203,13 +222,13 @@ public static int readGroupVInt(
203222
final int n4Minus1 = flag & 0x03;
204223

205224
// This code path has fewer conditionals and tends to be significantly faster in benchmarks
206-
dst[offset] = reader.read(pos) & INT_MASKS[n1Minus1];
225+
dst[offset] = (int) vh.get(storage, pos) & INT_MASKS[n1Minus1];
207226
pos += 1 + n1Minus1;
208-
dst[offset + 1] = reader.read(pos) & INT_MASKS[n2Minus1];
227+
dst[offset + 1] = (int) vh.get(storage, pos) & INT_MASKS[n2Minus1];
209228
pos += 1 + n2Minus1;
210-
dst[offset + 2] = reader.read(pos) & INT_MASKS[n3Minus1];
229+
dst[offset + 2] = (int) vh.get(storage, pos) & INT_MASKS[n3Minus1];
211230
pos += 1 + n3Minus1;
212-
dst[offset + 3] = reader.read(pos) & INT_MASKS[n4Minus1];
231+
dst[offset + 3] = (int) vh.get(storage, pos) & INT_MASKS[n4Minus1];
213232
pos += 1 + n4Minus1;
214233
return (int) (pos - posStart);
215234
}

0 commit comments

Comments
 (0)