Skip to content

Commit 86ffba6

Browse files
authored
Fix native implementations of PQ assembleAndSum and pqDecodedCosineSimilarity (#420)
* Fix native implementations of PQ assembleAndSum and pqDecodedCosineSimilarity * Avoid additional slice by passing baseOffsetsOffset across native boundaries for assemble and sum functions
1 parent 06ac8aa commit 86ffba6

File tree

4 files changed

+32
-22
lines changed

4 files changed

+32
-22
lines changed

jvector-native/src/main/c/jvector_simd.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,12 +289,13 @@ float euclidean_f32(int preferred_size, const float* a, int aoffset, const float
289289
: euclidean_f32_256(a, aoffset, b, boffset, length);
290290
}
291291

292-
float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned char* baseOffsets, int baseOffsetsLength) {
292+
float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned char* baseOffsets, int baseOffsetsOffset, int baseOffsetsLength) {
293293
__m512 sum = _mm512_setzero_ps();
294294
int i = 0;
295295
int limit = baseOffsetsLength - (baseOffsetsLength % 16);
296296
__m512i indexRegister = initialIndexRegister;
297297
__m512i dataBaseVec = _mm512_set1_epi32(dataBase);
298+
baseOffsets = baseOffsets + baseOffsetsOffset;
298299

299300
for (; i < limit; i += 16) {
300301
__m128i baseOffsetsRaw = _mm_loadu_si128((__m128i *)(baseOffsets + i));
@@ -319,13 +320,14 @@ float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned c
319320
return res;
320321
}
321322

322-
float pq_decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude) {
323+
float pq_decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsOffset, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude) {
323324
__m512 sum = _mm512_setzero_ps();
324325
__m512 vaMagnitude = _mm512_setzero_ps();
325326
int i = 0;
326327
int limit = baseOffsetsLength - (baseOffsetsLength % 16);
327328
__m512i indexRegister = initialIndexRegister;
328329
__m512i scale = _mm512_set1_epi32(clusterCount);
330+
baseOffsets = baseOffsets + baseOffsetsOffset;
329331

330332

331333
for (; i < limit; i += 16) {

jvector-native/src/main/c/jvector_simd.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ float euclidean_f32(int preferred_size, const float* a, int aoffset, const float
2828
void bulk_quantized_shuffle_dot_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartials, float delta, float minDistance, float* results);
2929
void bulk_quantized_shuffle_euclidean_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartials, float delta, float minDistance, float* results);
3030
void bulk_quantized_shuffle_cosine_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartialSums, float sumDelta, float minDistance, const char* quantizedPartialMagnitudes, float magnitudeDelta, float minMagnitude, float queryMagnitudeSquared, float* results);
31-
float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned char* baseOffsets, int baseOffsetsLength);
32-
float pq_decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude);
31+
float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned char* baseOffsets, int baseOffsetsOffset, int baseOffsetsLength);
32+
float pq_decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsOffset, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude);
3333
void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums);
3434
void calculate_partial_sums_euclidean_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums);
3535
void calculate_partial_sums_best_dot_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums, float* partialBestDistances);

jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,16 +117,15 @@ public void minInPlace(VectorFloat<?> v1, VectorFloat<?> v2) {
117117

118118
@Override
119119
public float assembleAndSum(VectorFloat<?> data, int dataBase, ByteSequence<?> baseOffsets) {
120-
assert baseOffsets.offset() == 0 : "Base offsets are expected to have an offset of 0. Found: " + baseOffsets.offset();
121-
return NativeSimdOps.assemble_and_sum_f32_512(((MemorySegmentVectorFloat)data).get(), dataBase, ((MemorySegmentByteSequence)baseOffsets).get(), baseOffsets.length());
120+
return assembleAndSum(data, dataBase, baseOffsets, 0, baseOffsets.length());
122121
}
123122

124123
@Override
125124
public float assembleAndSum(VectorFloat<?> data, int dataBase, ByteSequence<?> baseOffsets, int baseOffsetsOffset, int baseOffsetsLength)
126125
{
127-
assert baseOffsetsOffset == 0;
128-
assert baseOffsetsLength == baseOffsets.length();
129-
return assembleAndSum(data, dataBase, baseOffsets);
126+
assert baseOffsets.offset() == 0 : "Base offsets are expected to have an offset of 0. Found: " + baseOffsets.offset();
127+
// baseOffsets is a pointer into a PQ chunk - we need to index into it by baseOffsetsOffset and provide baseOffsetsLength to the native code
128+
return NativeSimdOps.assemble_and_sum_f32_512(((MemorySegmentVectorFloat) data).get(), dataBase, ((MemorySegmentByteSequence) baseOffsets).get(), baseOffsetsOffset, baseOffsetsLength);
130129
}
131130

132131
@Override
@@ -189,9 +188,16 @@ public void bulkShuffleQuantizedSimilarityCosine(ByteSequence<?> shuffles, int c
189188

190189
@Override
191190
public float pqDecodedCosineSimilarity(ByteSequence<?> encoded, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude)
191+
{
192+
return pqDecodedCosineSimilarity(encoded, 0, encoded.length(), clusterCount, partialSums, aMagnitude, bMagnitude);
193+
}
194+
195+
@Override
196+
public float pqDecodedCosineSimilarity(ByteSequence<?> encoded, int encodedOffset, int encodedLength, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude)
192197
{
193198
assert encoded.offset() == 0 : "Bulk shuffle shuffles are expected to have an offset of 0. Found: " + encoded.offset();
194-
return NativeSimdOps.pq_decoded_cosine_similarity_f32_512(((MemorySegmentByteSequence) encoded).get(), encoded.length(), clusterCount, ((MemorySegmentVectorFloat) partialSums).get(), ((MemorySegmentVectorFloat) aMagnitude).get(), bMagnitude);
199+
// encoded is a pointer into a PQ chunk - we need to index into it by encodedOffset and provide encodedLength to the native code
200+
return NativeSimdOps.pq_decoded_cosine_similarity_f32_512(((MemorySegmentByteSequence) encoded).get(), encodedOffset, encodedLength, clusterCount, ((MemorySegmentVectorFloat) partialSums).get(), ((MemorySegmentVectorFloat) aMagnitude).get(), bMagnitude);
195201
}
196202

197203
@Override

jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,7 @@ private static class assemble_and_sum_f32_512 {
408408
NativeSimdOps.C_POINTER,
409409
NativeSimdOps.C_INT,
410410
NativeSimdOps.C_POINTER,
411+
NativeSimdOps.C_INT,
411412
NativeSimdOps.C_INT
412413
);
413414

@@ -419,7 +420,7 @@ private static class assemble_and_sum_f32_512 {
419420
/**
420421
* Function descriptor for:
421422
* {@snippet lang=c :
422-
* float assemble_and_sum_f32_512(const float *data, int dataBase, const unsigned char *baseOffsets, int baseOffsetsLength)
423+
* float assemble_and_sum_f32_512(const float *data, int dataBase, const unsigned char *baseOffsets, int baseOffsetsOffset, int baseOffsetsLength)
423424
* }
424425
*/
425426
public static FunctionDescriptor assemble_and_sum_f32_512$descriptor() {
@@ -429,24 +430,24 @@ private static class assemble_and_sum_f32_512 {
429430
/**
430431
* Downcall method handle for:
431432
* {@snippet lang=c :
432-
* float assemble_and_sum_f32_512(const float *data, int dataBase, const unsigned char *baseOffsets, int baseOffsetsLength)
433+
* float assemble_and_sum_f32_512(const float *data, int dataBase, const unsigned char *baseOffsets, int baseOffsetsOffset, int baseOffsetsLength)
433434
* }
434435
*/
435436
public static MethodHandle assemble_and_sum_f32_512$handle() {
436437
return assemble_and_sum_f32_512.HANDLE;
437438
}
438439
/**
439440
* {@snippet lang=c :
440-
* float assemble_and_sum_f32_512(const float *data, int dataBase, const unsigned char *baseOffsets, int baseOffsetsLength)
441+
* float assemble_and_sum_f32_512(const float *data, int dataBase, const unsigned char *baseOffsets, int baseOffsetsOffset, int baseOffsetsLength)
441442
* }
442443
*/
443-
public static float assemble_and_sum_f32_512(MemorySegment data, int dataBase, MemorySegment baseOffsets, int baseOffsetsLength) {
444+
public static float assemble_and_sum_f32_512(MemorySegment data, int dataBase, MemorySegment baseOffsets, int baseOffsetsOffset, int baseOffsetsLength) {
444445
var mh$ = assemble_and_sum_f32_512.HANDLE;
445446
try {
446447
if (TRACE_DOWNCALLS) {
447-
traceDowncall("assemble_and_sum_f32_512", data, dataBase, baseOffsets, baseOffsetsLength);
448+
traceDowncall("assemble_and_sum_f32_512", data, dataBase, baseOffsets, baseOffsetsOffset, baseOffsetsLength);
448449
}
449-
return (float)mh$.invokeExact(data, dataBase, baseOffsets, baseOffsetsLength);
450+
return (float)mh$.invokeExact(data, dataBase, baseOffsets, baseOffsetsOffset, baseOffsetsLength);
450451
} catch (Throwable ex$) {
451452
throw new AssertionError("should not reach here", ex$);
452453
}
@@ -458,6 +459,7 @@ private static class pq_decoded_cosine_similarity_f32_512 {
458459
NativeSimdOps.C_POINTER,
459460
NativeSimdOps.C_INT,
460461
NativeSimdOps.C_INT,
462+
NativeSimdOps.C_INT,
461463
NativeSimdOps.C_POINTER,
462464
NativeSimdOps.C_POINTER,
463465
NativeSimdOps.C_FLOAT
@@ -471,7 +473,7 @@ private static class pq_decoded_cosine_similarity_f32_512 {
471473
/**
472474
* Function descriptor for:
473475
* {@snippet lang=c :
474-
* float pq_decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude)
476+
* float pq_decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsOffset, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude)
475477
* }
476478
*/
477479
public static FunctionDescriptor pq_decoded_cosine_similarity_f32_512$descriptor() {
@@ -481,24 +483,24 @@ private static class pq_decoded_cosine_similarity_f32_512 {
481483
/**
482484
* Downcall method handle for:
483485
* {@snippet lang=c :
484-
* float pq_decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude)
486+
* float pq_decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsOffset, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude)
485487
* }
486488
*/
487489
public static MethodHandle pq_decoded_cosine_similarity_f32_512$handle() {
488490
return pq_decoded_cosine_similarity_f32_512.HANDLE;
489491
}
490492
/**
491493
* {@snippet lang=c :
492-
* float pq_decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude)
494+
* float pq_decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsOffset, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude)
493495
* }
494496
*/
495-
public static float pq_decoded_cosine_similarity_f32_512(MemorySegment baseOffsets, int baseOffsetsLength, int clusterCount, MemorySegment partialSums, MemorySegment aMagnitude, float bMagnitude) {
497+
public static float pq_decoded_cosine_similarity_f32_512(MemorySegment baseOffsets, int baseOffsetsOffset, int baseOffsetsLength, int clusterCount, MemorySegment partialSums, MemorySegment aMagnitude, float bMagnitude) {
496498
var mh$ = pq_decoded_cosine_similarity_f32_512.HANDLE;
497499
try {
498500
if (TRACE_DOWNCALLS) {
499-
traceDowncall("pq_decoded_cosine_similarity_f32_512", baseOffsets, baseOffsetsLength, clusterCount, partialSums, aMagnitude, bMagnitude);
501+
traceDowncall("pq_decoded_cosine_similarity_f32_512", baseOffsets, baseOffsetsOffset, baseOffsetsLength, clusterCount, partialSums, aMagnitude, bMagnitude);
500502
}
501-
return (float)mh$.invokeExact(baseOffsets, baseOffsetsLength, clusterCount, partialSums, aMagnitude, bMagnitude);
503+
return (float)mh$.invokeExact(baseOffsets, baseOffsetsOffset, baseOffsetsLength, clusterCount, partialSums, aMagnitude, bMagnitude);
502504
} catch (Throwable ex$) {
503505
throw new AssertionError("should not reach here", ex$);
504506
}

0 commit comments

Comments
 (0)