Skip to content

Commit b3f4011

Browse files
kaivalnpKaival Parikh
andauthored
Improve byte vector scoring at query time (on- vs. off heap) (#14874)
Co-authored-by: Kaival Parikh <[email protected]>
1 parent 46a4fee commit b3f4011

File tree

3 files changed

+122
-50
lines changed

3 files changed

+122
-50
lines changed

lucene/CHANGES.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ Optimizations
5555
* GITHUB#14011: Reduce allocation rate in HNSW concurrent merge. (Viliam Durina)
5656
* GITHUB#14022: Optimize DFS marking of connected components in HNSW by reducing stack depth, improving performance and reducing allocations. (Viswanath Kuchibhotla)
5757

58+
* GITHUB#14874: Improve off-heap KNN byte vector query performance in cases where indexing and search are performed by the same process. (Kaival Parikh)
59+
5860
Bug Fixes
5961
---------------------
6062
* GITHUB#14049: Randomize KNN codec params in RandomCodec. Fixes scalar quantization div-by-zero

lucene/core/src/java24/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorer.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ abstract sealed class Lucene99MemorySegmentByteVectorScorer
3232

3333
final int vectorByteSize;
3434
final MemorySegmentAccessInput input;
35-
final MemorySegment query;
35+
final byte[] query;
3636
byte[] scratch;
3737

3838
/**
@@ -61,7 +61,7 @@ public static Optional<Lucene99MemorySegmentByteVectorScorer> create(
6161
super(values);
6262
this.input = input;
6363
this.vectorByteSize = values.getVectorByteLength();
64-
this.query = MemorySegment.ofArray(queryVector);
64+
this.query = queryVector;
6565
}
6666

6767
final MemorySegment getSegment(int ord) throws IOException {
@@ -113,7 +113,7 @@ public float score(int node) throws IOException {
113113
checkOrdinal(node);
114114
// divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len
115115
float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node));
116-
return 0.5f + raw / (float) (query.byteSize() * (1 << 15));
116+
return 0.5f + raw / (float) (query.length * (1 << 15));
117117
}
118118
}
119119

lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java

Lines changed: 117 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -309,45 +309,99 @@ private float squareDistanceBody(float[] a, float[] b, int limit) {
309309
// We also support 128 bit vectors, going 32 bits at a time.
310310
// This is slower but still faster than not vectorizing at all.
311311

312+
private interface ByteVectorLoader {
313+
int length();
314+
315+
ByteVector load(VectorSpecies<Byte> species, int index);
316+
317+
byte tail(int index);
318+
}
319+
320+
private record ArrayLoader(byte[] arr) implements ByteVectorLoader {
321+
@Override
322+
public int length() {
323+
return arr.length;
324+
}
325+
326+
@Override
327+
public ByteVector load(VectorSpecies<Byte> species, int index) {
328+
assert index + species.length() <= length();
329+
return ByteVector.fromArray(species, arr, index);
330+
}
331+
332+
@Override
333+
public byte tail(int index) {
334+
assert index <= length();
335+
return arr[index];
336+
}
337+
}
338+
339+
private record MemorySegmentLoader(MemorySegment segment) implements ByteVectorLoader {
340+
@Override
341+
public int length() {
342+
return Math.toIntExact(segment.byteSize());
343+
}
344+
345+
@Override
346+
public ByteVector load(VectorSpecies<Byte> species, int index) {
347+
assert index + species.length() <= length();
348+
return ByteVector.fromMemorySegment(species, segment, index, LITTLE_ENDIAN);
349+
}
350+
351+
@Override
352+
public byte tail(int index) {
353+
assert index <= length();
354+
return segment.get(JAVA_BYTE, index);
355+
}
356+
}
357+
312358
@Override
313359
public int dotProduct(byte[] a, byte[] b) {
314-
return dotProduct(MemorySegment.ofArray(a), MemorySegment.ofArray(b));
360+
return dotProductBody(new ArrayLoader(a), new ArrayLoader(b));
361+
}
362+
363+
public static int dotProduct(byte[] a, MemorySegment b) {
364+
return dotProductBody(new ArrayLoader(a), new MemorySegmentLoader(b));
315365
}
316366

317367
public static int dotProduct(MemorySegment a, MemorySegment b) {
318-
assert a.byteSize() == b.byteSize();
368+
return dotProductBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b));
369+
}
370+
371+
private static int dotProductBody(ByteVectorLoader a, ByteVectorLoader b) {
372+
assert a.length() == b.length();
319373
int i = 0;
320374
int res = 0;
321375

322376
// only vectorize if we'll at least enter the loop a single time
323-
if (a.byteSize() >= 16) {
377+
if (a.length() >= 16) {
324378
// compute vectorized dot product consistent with VPDPBUSD instruction
325379
if (VECTOR_BITSIZE >= 512) {
326-
i += BYTE_SPECIES.loopBound(a.byteSize());
380+
i += BYTE_SPECIES.loopBound(a.length());
327381
res += dotProductBody512(a, b, i);
328382
} else if (VECTOR_BITSIZE == 256) {
329-
i += BYTE_SPECIES.loopBound(a.byteSize());
383+
i += BYTE_SPECIES.loopBound(a.length());
330384
res += dotProductBody256(a, b, i);
331385
} else {
332386
// tricky: we don't have SPECIES_32, so we workaround with "overlapping read"
333-
i += ByteVector.SPECIES_64.loopBound(a.byteSize() - ByteVector.SPECIES_64.length());
387+
i += ByteVector.SPECIES_64.loopBound(a.length() - ByteVector.SPECIES_64.length());
334388
res += dotProductBody128(a, b, i);
335389
}
336390
}
337391

338392
// scalar tail
339-
for (; i < a.byteSize(); i++) {
340-
res += b.get(JAVA_BYTE, i) * a.get(JAVA_BYTE, i);
393+
for (; i < a.length(); i++) {
394+
res += a.tail(i) * b.tail(i);
341395
}
342396
return res;
343397
}
344398

345399
/** vectorized dot product body (512 bit vectors) */
346-
private static int dotProductBody512(MemorySegment a, MemorySegment b, int limit) {
400+
private static int dotProductBody512(ByteVectorLoader a, ByteVectorLoader b, int limit) {
347401
IntVector acc = IntVector.zero(INT_SPECIES);
348402
for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
349-
ByteVector va8 = ByteVector.fromMemorySegment(BYTE_SPECIES, a, i, LITTLE_ENDIAN);
350-
ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES, b, i, LITTLE_ENDIAN);
403+
ByteVector va8 = a.load(BYTE_SPECIES, i);
404+
ByteVector vb8 = b.load(BYTE_SPECIES, i);
351405

352406
// 16-bit multiply: avoid AVX-512 heavy multiply on zmm
353407
Vector<Short> va16 = va8.convertShape(B2S, SHORT_SPECIES, 0);
@@ -363,11 +417,11 @@ private static int dotProductBody512(MemorySegment a, MemorySegment b, int limit
363417
}
364418

365419
/** vectorized dot product body (256 bit vectors) */
366-
private static int dotProductBody256(MemorySegment a, MemorySegment b, int limit) {
420+
private static int dotProductBody256(ByteVectorLoader a, ByteVectorLoader b, int limit) {
367421
IntVector acc = IntVector.zero(IntVector.SPECIES_256);
368422
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length()) {
369-
ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN);
370-
ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN);
423+
ByteVector va8 = a.load(ByteVector.SPECIES_64, i);
424+
ByteVector vb8 = b.load(ByteVector.SPECIES_64, i);
371425

372426
// 32-bit multiply and add into accumulator
373427
Vector<Integer> va32 = va8.convertShape(B2I, IntVector.SPECIES_256, 0);
@@ -379,13 +433,13 @@ private static int dotProductBody256(MemorySegment a, MemorySegment b, int limit
379433
}
380434

381435
/** vectorized dot product body (128 bit vectors) */
382-
private static int dotProductBody128(MemorySegment a, MemorySegment b, int limit) {
436+
private static int dotProductBody128(ByteVectorLoader a, ByteVectorLoader b, int limit) {
383437
IntVector acc = IntVector.zero(IntVector.SPECIES_128);
384438
// 4 bytes at a time (re-loading half the vector each time!)
385439
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length() >> 1) {
386440
// load 8 bytes
387-
ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN);
388-
ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN);
441+
ByteVector va8 = a.load(ByteVector.SPECIES_64, i);
442+
ByteVector vb8 = b.load(ByteVector.SPECIES_64, i);
389443

390444
// process first "half" only: 16-bit multiply
391445
Vector<Short> va16 = va8.convert(B2S, 0);
@@ -577,27 +631,35 @@ private int int4DotProductBody128(byte[] a, byte[] b, int limit) {
577631

578632
@Override
579633
public float cosine(byte[] a, byte[] b) {
580-
return cosine(MemorySegment.ofArray(a), MemorySegment.ofArray(b));
634+
return cosineBody(new ArrayLoader(a), new ArrayLoader(b));
581635
}
582636

583637
public static float cosine(MemorySegment a, MemorySegment b) {
638+
return cosineBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b));
639+
}
640+
641+
public static float cosine(byte[] a, MemorySegment b) {
642+
return cosineBody(new ArrayLoader(a), new MemorySegmentLoader(b));
643+
}
644+
645+
private static float cosineBody(ByteVectorLoader a, ByteVectorLoader b) {
584646
int i = 0;
585647
int sum = 0;
586648
int norm1 = 0;
587649
int norm2 = 0;
588650

589651
// only vectorize if we'll at least enter the loop a single time
590-
if (a.byteSize() >= 16) {
652+
if (a.length() >= 16) {
591653
final float[] ret;
592654
if (VECTOR_BITSIZE >= 512) {
593-
i += BYTE_SPECIES.loopBound((int) a.byteSize());
655+
i += BYTE_SPECIES.loopBound(a.length());
594656
ret = cosineBody512(a, b, i);
595657
} else if (VECTOR_BITSIZE == 256) {
596-
i += BYTE_SPECIES.loopBound((int) a.byteSize());
658+
i += BYTE_SPECIES.loopBound(a.length());
597659
ret = cosineBody256(a, b, i);
598660
} else {
599661
// tricky: we don't have SPECIES_32, so we workaround with "overlapping read"
600-
i += ByteVector.SPECIES_64.loopBound(a.byteSize() - ByteVector.SPECIES_64.length());
662+
i += ByteVector.SPECIES_64.loopBound(a.length() - ByteVector.SPECIES_64.length());
601663
ret = cosineBody128(a, b, i);
602664
}
603665
sum += ret[0];
@@ -606,9 +668,9 @@ public static float cosine(MemorySegment a, MemorySegment b) {
606668
}
607669

608670
// scalar tail
609-
for (; i < a.byteSize(); i++) {
610-
byte elem1 = a.get(JAVA_BYTE, i);
611-
byte elem2 = b.get(JAVA_BYTE, i);
671+
for (; i < a.length(); i++) {
672+
byte elem1 = a.tail(i);
673+
byte elem2 = b.tail(i);
612674
sum += elem1 * elem2;
613675
norm1 += elem1 * elem1;
614676
norm2 += elem2 * elem2;
@@ -617,13 +679,13 @@ public static float cosine(MemorySegment a, MemorySegment b) {
617679
}
618680

619681
/** vectorized cosine body (512 bit vectors) */
620-
private static float[] cosineBody512(MemorySegment a, MemorySegment b, int limit) {
682+
private static float[] cosineBody512(ByteVectorLoader a, ByteVectorLoader b, int limit) {
621683
IntVector accSum = IntVector.zero(INT_SPECIES);
622684
IntVector accNorm1 = IntVector.zero(INT_SPECIES);
623685
IntVector accNorm2 = IntVector.zero(INT_SPECIES);
624686
for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
625-
ByteVector va8 = ByteVector.fromMemorySegment(BYTE_SPECIES, a, i, LITTLE_ENDIAN);
626-
ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES, b, i, LITTLE_ENDIAN);
687+
ByteVector va8 = a.load(BYTE_SPECIES, i);
688+
ByteVector vb8 = b.load(BYTE_SPECIES, i);
627689

628690
// 16-bit multiply: avoid AVX-512 heavy multiply on zmm
629691
Vector<Short> va16 = va8.convertShape(B2S, SHORT_SPECIES, 0);
@@ -647,13 +709,13 @@ private static float[] cosineBody512(MemorySegment a, MemorySegment b, int limit
647709
}
648710

649711
/** vectorized cosine body (256 bit vectors) */
650-
private static float[] cosineBody256(MemorySegment a, MemorySegment b, int limit) {
712+
private static float[] cosineBody256(ByteVectorLoader a, ByteVectorLoader b, int limit) {
651713
IntVector accSum = IntVector.zero(IntVector.SPECIES_256);
652714
IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_256);
653715
IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_256);
654716
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length()) {
655-
ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN);
656-
ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN);
717+
ByteVector va8 = a.load(ByteVector.SPECIES_64, i);
718+
ByteVector vb8 = b.load(ByteVector.SPECIES_64, i);
657719

658720
// 16-bit multiply, and add into accumulators
659721
Vector<Integer> va32 = va8.convertShape(B2I, IntVector.SPECIES_256, 0);
@@ -672,13 +734,13 @@ private static float[] cosineBody256(MemorySegment a, MemorySegment b, int limit
672734
}
673735

674736
/** vectorized cosine body (128 bit vectors) */
675-
private static float[] cosineBody128(MemorySegment a, MemorySegment b, int limit) {
737+
private static float[] cosineBody128(ByteVectorLoader a, ByteVectorLoader b, int limit) {
676738
IntVector accSum = IntVector.zero(IntVector.SPECIES_128);
677739
IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_128);
678740
IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_128);
679741
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length() >> 1) {
680-
ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN);
681-
ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN);
742+
ByteVector va8 = a.load(ByteVector.SPECIES_64, i);
743+
ByteVector vb8 = b.load(ByteVector.SPECIES_64, i);
682744

683745
// process first half only: 16-bit multiply
684746
Vector<Short> va16 = va8.convert(B2S, 0);
@@ -700,39 +762,47 @@ private static float[] cosineBody128(MemorySegment a, MemorySegment b, int limit
700762

701763
@Override
702764
public int squareDistance(byte[] a, byte[] b) {
703-
return squareDistance(MemorySegment.ofArray(a), MemorySegment.ofArray(b));
765+
return squareDistanceBody(new ArrayLoader(a), new ArrayLoader(b));
704766
}
705767

706768
public static int squareDistance(MemorySegment a, MemorySegment b) {
707-
assert a.byteSize() == b.byteSize();
769+
return squareDistanceBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b));
770+
}
771+
772+
public static int squareDistance(byte[] a, MemorySegment b) {
773+
return squareDistanceBody(new ArrayLoader(a), new MemorySegmentLoader(b));
774+
}
775+
776+
private static int squareDistanceBody(ByteVectorLoader a, ByteVectorLoader b) {
777+
assert a.length() == b.length();
708778
int i = 0;
709779
int res = 0;
710780

711781
// only vectorize if we'll at least enter the loop a single time
712-
if (a.byteSize() >= 16) {
782+
if (a.length() >= 16) {
713783
if (VECTOR_BITSIZE >= 256) {
714-
i += BYTE_SPECIES.loopBound((int) a.byteSize());
784+
i += BYTE_SPECIES.loopBound(a.length());
715785
res += squareDistanceBody256(a, b, i);
716786
} else {
717-
i += ByteVector.SPECIES_64.loopBound((int) a.byteSize());
787+
i += ByteVector.SPECIES_64.loopBound(a.length());
718788
res += squareDistanceBody128(a, b, i);
719789
}
720790
}
721791

722792
// scalar tail
723-
for (; i < a.byteSize(); i++) {
724-
int diff = a.get(JAVA_BYTE, i) - b.get(JAVA_BYTE, i);
793+
for (; i < a.length(); i++) {
794+
int diff = a.tail(i) - b.tail(i);
725795
res += diff * diff;
726796
}
727797
return res;
728798
}
729799

730800
/** vectorized square distance body (256+ bit vectors) */
731-
private static int squareDistanceBody256(MemorySegment a, MemorySegment b, int limit) {
801+
private static int squareDistanceBody256(ByteVectorLoader a, ByteVectorLoader b, int limit) {
732802
IntVector acc = IntVector.zero(INT_SPECIES);
733803
for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
734-
ByteVector va8 = ByteVector.fromMemorySegment(BYTE_SPECIES, a, i, LITTLE_ENDIAN);
735-
ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES, b, i, LITTLE_ENDIAN);
804+
ByteVector va8 = a.load(BYTE_SPECIES, i);
805+
ByteVector vb8 = b.load(BYTE_SPECIES, i);
736806

737807
// 32-bit sub, multiply, and add into accumulators
738808
// TODO: uses AVX-512 heavy multiply on zmm, should we just use 256-bit vectors on AVX-512?
@@ -746,14 +816,14 @@ private static int squareDistanceBody256(MemorySegment a, MemorySegment b, int l
746816
}
747817

748818
/** vectorized square distance body (128 bit vectors) */
749-
private static int squareDistanceBody128(MemorySegment a, MemorySegment b, int limit) {
819+
private static int squareDistanceBody128(ByteVectorLoader a, ByteVectorLoader b, int limit) {
750820
// 128-bit implementation, which must "split up" vectors due to widening conversions
751821
// it doesn't help to do the overlapping read trick, due to 32-bit multiply in the formula
752822
IntVector acc1 = IntVector.zero(IntVector.SPECIES_128);
753823
IntVector acc2 = IntVector.zero(IntVector.SPECIES_128);
754824
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length()) {
755-
ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN);
756-
ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN);
825+
ByteVector va8 = a.load(ByteVector.SPECIES_64, i);
826+
ByteVector vb8 = b.load(ByteVector.SPECIES_64, i);
757827

758828
// 16-bit sub
759829
Vector<Short> va16 = va8.convertShape(B2S, ShortVector.SPECIES_128, 0);

0 commit comments

Comments
 (0)