Skip to content

Commit 54002a4

Browse files
authored
Backport change to improve off-heap byte vector scoring at query time (#15010)
* Improve off-heap byte vector scoring at query time Cherry-pick of b3f4011 * Add CHANGES.txt entry --------- Co-authored-by: Kaival Parikh <kaivalp2000@gmail.com>
1 parent 3e42687 commit 54002a4

File tree

3 files changed

+122
-50
lines changed

3 files changed

+122
-50
lines changed

lucene/CHANGES.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ Optimizations
128128

129129
* GITHUB#14991: Refactor for loop at PointRangeQuery hot path. (Ge Song)
130130

131+
* GITHUB#15010: Improve off-heap KNN byte vector query performance in cases where indexing and search are performed by the same process. (Kaival Parikh)
132+
131133
Changes in Runtime Behavior
132134
---------------------
133135
* GITHUB#14823: Decrease TieredMergePolicy's default number of segments per

lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorer.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ abstract sealed class Lucene99MemorySegmentByteVectorScorer
3232

3333
final int vectorByteSize;
3434
final MemorySegmentAccessInput input;
35-
final MemorySegment query;
35+
final byte[] query;
3636
byte[] scratch;
3737

3838
/**
@@ -61,7 +61,7 @@ public static Optional<Lucene99MemorySegmentByteVectorScorer> create(
6161
super(values);
6262
this.input = input;
6363
this.vectorByteSize = values.getVectorByteLength();
64-
this.query = MemorySegment.ofArray(queryVector);
64+
this.query = queryVector;
6565
}
6666

6767
final MemorySegment getSegment(int ord) throws IOException {
@@ -113,7 +113,7 @@ public float score(int node) throws IOException {
113113
checkOrdinal(node);
114114
// divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len
115115
float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node));
116-
return 0.5f + raw / (float) (query.byteSize() * (1 << 15));
116+
return 0.5f + raw / (float) (query.length * (1 << 15));
117117
}
118118
}
119119

lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java

Lines changed: 117 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -309,46 +309,100 @@ private float squareDistanceBody(float[] a, float[] b, int limit) {
309309
// We also support 128 bit vectors, going 32 bits at a time.
310310
// This is slower but still faster than not vectorizing at all.
311311

312+
private interface ByteVectorLoader {
313+
int length();
314+
315+
ByteVector load(VectorSpecies<Byte> species, int index);
316+
317+
byte tail(int index);
318+
}
319+
320+
private record ArrayLoader(byte[] arr) implements ByteVectorLoader {
321+
@Override
322+
public int length() {
323+
return arr.length;
324+
}
325+
326+
@Override
327+
public ByteVector load(VectorSpecies<Byte> species, int index) {
328+
assert index + species.length() <= length();
329+
return ByteVector.fromArray(species, arr, index);
330+
}
331+
332+
@Override
333+
public byte tail(int index) {
334+
assert index <= length();
335+
return arr[index];
336+
}
337+
}
338+
339+
private record MemorySegmentLoader(MemorySegment segment) implements ByteVectorLoader {
340+
@Override
341+
public int length() {
342+
return Math.toIntExact(segment.byteSize());
343+
}
344+
345+
@Override
346+
public ByteVector load(VectorSpecies<Byte> species, int index) {
347+
assert index + species.length() <= length();
348+
return ByteVector.fromMemorySegment(species, segment, index, LITTLE_ENDIAN);
349+
}
350+
351+
@Override
352+
public byte tail(int index) {
353+
assert index <= length();
354+
return segment.get(JAVA_BYTE, index);
355+
}
356+
}
357+
312358
@Override
313359
public int dotProduct(byte[] a, byte[] b) {
314-
return dotProduct(MemorySegment.ofArray(a), MemorySegment.ofArray(b));
360+
return dotProductBody(new ArrayLoader(a), new ArrayLoader(b));
361+
}
362+
363+
public static int dotProduct(byte[] a, MemorySegment b) {
364+
return dotProductBody(new ArrayLoader(a), new MemorySegmentLoader(b));
315365
}
316366

317367
public static int dotProduct(MemorySegment a, MemorySegment b) {
318-
assert a.byteSize() == b.byteSize();
368+
return dotProductBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b));
369+
}
370+
371+
private static int dotProductBody(ByteVectorLoader a, ByteVectorLoader b) {
372+
assert a.length() == b.length();
319373
int i = 0;
320374
int res = 0;
321375

322376
// only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit
323377
// vectors (256-bit on intel to dodge performance landmines)
324-
if (a.byteSize() >= 16 && PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS) {
378+
if (a.length() >= 16 && PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS) {
325379
// compute vectorized dot product consistent with VPDPBUSD instruction
326380
if (VECTOR_BITSIZE >= 512) {
327-
i += BYTE_SPECIES.loopBound(a.byteSize());
381+
i += BYTE_SPECIES.loopBound(a.length());
328382
res += dotProductBody512(a, b, i);
329383
} else if (VECTOR_BITSIZE == 256) {
330-
i += BYTE_SPECIES.loopBound(a.byteSize());
384+
i += BYTE_SPECIES.loopBound(a.length());
331385
res += dotProductBody256(a, b, i);
332386
} else {
333387
// tricky: we don't have SPECIES_32, so we workaround with "overlapping read"
334-
i += ByteVector.SPECIES_64.loopBound(a.byteSize() - ByteVector.SPECIES_64.length());
388+
i += ByteVector.SPECIES_64.loopBound(a.length() - ByteVector.SPECIES_64.length());
335389
res += dotProductBody128(a, b, i);
336390
}
337391
}
338392

339393
// scalar tail
340-
for (; i < a.byteSize(); i++) {
341-
res += b.get(JAVA_BYTE, i) * a.get(JAVA_BYTE, i);
394+
for (; i < a.length(); i++) {
395+
res += a.tail(i) * b.tail(i);
342396
}
343397
return res;
344398
}
345399

346400
/** vectorized dot product body (512 bit vectors) */
347-
private static int dotProductBody512(MemorySegment a, MemorySegment b, int limit) {
401+
private static int dotProductBody512(ByteVectorLoader a, ByteVectorLoader b, int limit) {
348402
IntVector acc = IntVector.zero(INT_SPECIES);
349403
for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
350-
ByteVector va8 = ByteVector.fromMemorySegment(BYTE_SPECIES, a, i, LITTLE_ENDIAN);
351-
ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES, b, i, LITTLE_ENDIAN);
404+
ByteVector va8 = a.load(BYTE_SPECIES, i);
405+
ByteVector vb8 = b.load(BYTE_SPECIES, i);
352406

353407
// 16-bit multiply: avoid AVX-512 heavy multiply on zmm
354408
Vector<Short> va16 = va8.convertShape(B2S, SHORT_SPECIES, 0);
@@ -364,11 +418,11 @@ private static int dotProductBody512(MemorySegment a, MemorySegment b, int limit
364418
}
365419

366420
/** vectorized dot product body (256 bit vectors) */
367-
private static int dotProductBody256(MemorySegment a, MemorySegment b, int limit) {
421+
private static int dotProductBody256(ByteVectorLoader a, ByteVectorLoader b, int limit) {
368422
IntVector acc = IntVector.zero(IntVector.SPECIES_256);
369423
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length()) {
370-
ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN);
371-
ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN);
424+
ByteVector va8 = a.load(ByteVector.SPECIES_64, i);
425+
ByteVector vb8 = b.load(ByteVector.SPECIES_64, i);
372426

373427
// 32-bit multiply and add into accumulator
374428
Vector<Integer> va32 = va8.convertShape(B2I, IntVector.SPECIES_256, 0);
@@ -380,13 +434,13 @@ private static int dotProductBody256(MemorySegment a, MemorySegment b, int limit
380434
}
381435

382436
/** vectorized dot product body (128 bit vectors) */
383-
private static int dotProductBody128(MemorySegment a, MemorySegment b, int limit) {
437+
private static int dotProductBody128(ByteVectorLoader a, ByteVectorLoader b, int limit) {
384438
IntVector acc = IntVector.zero(IntVector.SPECIES_128);
385439
// 4 bytes at a time (re-loading half the vector each time!)
386440
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length() >> 1) {
387441
// load 8 bytes
388-
ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN);
389-
ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN);
442+
ByteVector va8 = a.load(ByteVector.SPECIES_64, i);
443+
ByteVector vb8 = b.load(ByteVector.SPECIES_64, i);
390444

391445
// process first "half" only: 16-bit multiply
392446
Vector<Short> va16 = va8.convert(B2S, 0);
@@ -578,28 +632,36 @@ private int int4DotProductBody128(byte[] a, byte[] b, int limit) {
578632

579633
@Override
580634
public float cosine(byte[] a, byte[] b) {
581-
return cosine(MemorySegment.ofArray(a), MemorySegment.ofArray(b));
635+
return cosineBody(new ArrayLoader(a), new ArrayLoader(b));
582636
}
583637

584638
public static float cosine(MemorySegment a, MemorySegment b) {
639+
return cosineBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b));
640+
}
641+
642+
public static float cosine(byte[] a, MemorySegment b) {
643+
return cosineBody(new ArrayLoader(a), new MemorySegmentLoader(b));
644+
}
645+
646+
private static float cosineBody(ByteVectorLoader a, ByteVectorLoader b) {
585647
int i = 0;
586648
int sum = 0;
587649
int norm1 = 0;
588650
int norm2 = 0;
589651

590652
// only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit
591653
// vectors (256-bit on intel to dodge performance landmines)
592-
if (a.byteSize() >= 16 && PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS) {
654+
if (a.length() >= 16 && PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS) {
593655
final float[] ret;
594656
if (VECTOR_BITSIZE >= 512) {
595-
i += BYTE_SPECIES.loopBound((int) a.byteSize());
657+
i += BYTE_SPECIES.loopBound(a.length());
596658
ret = cosineBody512(a, b, i);
597659
} else if (VECTOR_BITSIZE == 256) {
598-
i += BYTE_SPECIES.loopBound((int) a.byteSize());
660+
i += BYTE_SPECIES.loopBound(a.length());
599661
ret = cosineBody256(a, b, i);
600662
} else {
601663
// tricky: we don't have SPECIES_32, so we workaround with "overlapping read"
602-
i += ByteVector.SPECIES_64.loopBound(a.byteSize() - ByteVector.SPECIES_64.length());
664+
i += ByteVector.SPECIES_64.loopBound(a.length() - ByteVector.SPECIES_64.length());
603665
ret = cosineBody128(a, b, i);
604666
}
605667
sum += ret[0];
@@ -608,9 +670,9 @@ public static float cosine(MemorySegment a, MemorySegment b) {
608670
}
609671

610672
// scalar tail
611-
for (; i < a.byteSize(); i++) {
612-
byte elem1 = a.get(JAVA_BYTE, i);
613-
byte elem2 = b.get(JAVA_BYTE, i);
673+
for (; i < a.length(); i++) {
674+
byte elem1 = a.tail(i);
675+
byte elem2 = b.tail(i);
614676
sum += elem1 * elem2;
615677
norm1 += elem1 * elem1;
616678
norm2 += elem2 * elem2;
@@ -619,13 +681,13 @@ public static float cosine(MemorySegment a, MemorySegment b) {
619681
}
620682

621683
/** vectorized cosine body (512 bit vectors) */
622-
private static float[] cosineBody512(MemorySegment a, MemorySegment b, int limit) {
684+
private static float[] cosineBody512(ByteVectorLoader a, ByteVectorLoader b, int limit) {
623685
IntVector accSum = IntVector.zero(INT_SPECIES);
624686
IntVector accNorm1 = IntVector.zero(INT_SPECIES);
625687
IntVector accNorm2 = IntVector.zero(INT_SPECIES);
626688
for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
627-
ByteVector va8 = ByteVector.fromMemorySegment(BYTE_SPECIES, a, i, LITTLE_ENDIAN);
628-
ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES, b, i, LITTLE_ENDIAN);
689+
ByteVector va8 = a.load(BYTE_SPECIES, i);
690+
ByteVector vb8 = b.load(BYTE_SPECIES, i);
629691

630692
// 16-bit multiply: avoid AVX-512 heavy multiply on zmm
631693
Vector<Short> va16 = va8.convertShape(B2S, SHORT_SPECIES, 0);
@@ -649,13 +711,13 @@ private static float[] cosineBody512(MemorySegment a, MemorySegment b, int limit
649711
}
650712

651713
/** vectorized cosine body (256 bit vectors) */
652-
private static float[] cosineBody256(MemorySegment a, MemorySegment b, int limit) {
714+
private static float[] cosineBody256(ByteVectorLoader a, ByteVectorLoader b, int limit) {
653715
IntVector accSum = IntVector.zero(IntVector.SPECIES_256);
654716
IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_256);
655717
IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_256);
656718
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length()) {
657-
ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN);
658-
ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN);
719+
ByteVector va8 = a.load(ByteVector.SPECIES_64, i);
720+
ByteVector vb8 = b.load(ByteVector.SPECIES_64, i);
659721

660722
// 16-bit multiply, and add into accumulators
661723
Vector<Integer> va32 = va8.convertShape(B2I, IntVector.SPECIES_256, 0);
@@ -674,13 +736,13 @@ private static float[] cosineBody256(MemorySegment a, MemorySegment b, int limit
674736
}
675737

676738
/** vectorized cosine body (128 bit vectors) */
677-
private static float[] cosineBody128(MemorySegment a, MemorySegment b, int limit) {
739+
private static float[] cosineBody128(ByteVectorLoader a, ByteVectorLoader b, int limit) {
678740
IntVector accSum = IntVector.zero(IntVector.SPECIES_128);
679741
IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_128);
680742
IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_128);
681743
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length() >> 1) {
682-
ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN);
683-
ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN);
744+
ByteVector va8 = a.load(ByteVector.SPECIES_64, i);
745+
ByteVector vb8 = b.load(ByteVector.SPECIES_64, i);
684746

685747
// process first half only: 16-bit multiply
686748
Vector<Short> va16 = va8.convert(B2S, 0);
@@ -702,40 +764,48 @@ private static float[] cosineBody128(MemorySegment a, MemorySegment b, int limit
702764

703765
@Override
704766
public int squareDistance(byte[] a, byte[] b) {
705-
return squareDistance(MemorySegment.ofArray(a), MemorySegment.ofArray(b));
767+
return squareDistanceBody(new ArrayLoader(a), new ArrayLoader(b));
706768
}
707769

708770
public static int squareDistance(MemorySegment a, MemorySegment b) {
709-
assert a.byteSize() == b.byteSize();
771+
return squareDistanceBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b));
772+
}
773+
774+
public static int squareDistance(byte[] a, MemorySegment b) {
775+
return squareDistanceBody(new ArrayLoader(a), new MemorySegmentLoader(b));
776+
}
777+
778+
private static int squareDistanceBody(ByteVectorLoader a, ByteVectorLoader b) {
779+
assert a.length() == b.length();
710780
int i = 0;
711781
int res = 0;
712782

713783
// only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit
714784
// vectors (256-bit on intel to dodge performance landmines)
715-
if (a.byteSize() >= 16 && PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS) {
785+
if (a.length() >= 16 && PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS) {
716786
if (VECTOR_BITSIZE >= 256) {
717-
i += BYTE_SPECIES.loopBound((int) a.byteSize());
787+
i += BYTE_SPECIES.loopBound(a.length());
718788
res += squareDistanceBody256(a, b, i);
719789
} else {
720-
i += ByteVector.SPECIES_64.loopBound((int) a.byteSize());
790+
i += ByteVector.SPECIES_64.loopBound(a.length());
721791
res += squareDistanceBody128(a, b, i);
722792
}
723793
}
724794

725795
// scalar tail
726-
for (; i < a.byteSize(); i++) {
727-
int diff = a.get(JAVA_BYTE, i) - b.get(JAVA_BYTE, i);
796+
for (; i < a.length(); i++) {
797+
int diff = a.tail(i) - b.tail(i);
728798
res += diff * diff;
729799
}
730800
return res;
731801
}
732802

733803
/** vectorized square distance body (256+ bit vectors) */
734-
private static int squareDistanceBody256(MemorySegment a, MemorySegment b, int limit) {
804+
private static int squareDistanceBody256(ByteVectorLoader a, ByteVectorLoader b, int limit) {
735805
IntVector acc = IntVector.zero(INT_SPECIES);
736806
for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
737-
ByteVector va8 = ByteVector.fromMemorySegment(BYTE_SPECIES, a, i, LITTLE_ENDIAN);
738-
ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES, b, i, LITTLE_ENDIAN);
807+
ByteVector va8 = a.load(BYTE_SPECIES, i);
808+
ByteVector vb8 = b.load(BYTE_SPECIES, i);
739809

740810
// 32-bit sub, multiply, and add into accumulators
741811
// TODO: uses AVX-512 heavy multiply on zmm, should we just use 256-bit vectors on AVX-512?
@@ -749,14 +819,14 @@ private static int squareDistanceBody256(MemorySegment a, MemorySegment b, int l
749819
}
750820

751821
/** vectorized square distance body (128 bit vectors) */
752-
private static int squareDistanceBody128(MemorySegment a, MemorySegment b, int limit) {
822+
private static int squareDistanceBody128(ByteVectorLoader a, ByteVectorLoader b, int limit) {
753823
// 128-bit implementation, which must "split up" vectors due to widening conversions
754824
// it doesn't help to do the overlapping read trick, due to 32-bit multiply in the formula
755825
IntVector acc1 = IntVector.zero(IntVector.SPECIES_128);
756826
IntVector acc2 = IntVector.zero(IntVector.SPECIES_128);
757827
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length()) {
758-
ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN);
759-
ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN);
828+
ByteVector va8 = a.load(ByteVector.SPECIES_64, i);
829+
ByteVector vb8 = b.load(ByteVector.SPECIES_64, i);
760830

761831
// 16-bit sub
762832
Vector<Short> va16 = va8.convertShape(B2S, ShortVector.SPECIES_128, 0);

0 commit comments

Comments
 (0)