Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,7 @@ public float scoreBulk(
quantizeScoreBulk(q, bulkSize, scores);
in.readFloats(lowerIntervals, 0, bulkSize);
in.readFloats(upperIntervals, 0, bulkSize);
for (int i = 0; i < bulkSize; i++) {
targetComponentSums[i] = in.readInt();
}
in.readInts(targetComponentSums, 0, bulkSize);
in.readFloats(additionalCorrections, 0, bulkSize);
float maxScore = Float.NEGATIVE_INFINITY;
for (int i = 0; i < bulkSize; i++) {
Expand Down Expand Up @@ -326,9 +324,7 @@ public float scoreBulkOffsets(
quantizeScoreBulkOffsets(q, offsets, offsetsCount, scores, count);
in.readFloats(lowerIntervals, 0, count);
in.readFloats(upperIntervals, 0, count);
for (int i = 0; i < count; i++) {
targetComponentSums[i] = in.readInt();
}
in.readInts(targetComponentSums, 0, count);
in.readFloats(additionalCorrections, 0, count);
float maxScore = Float.NEGATIVE_INFINITY;
int offsetIndex = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,12 @@
import org.elasticsearch.simdvec.internal.IndexInputUtils;

import java.io.IOException;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.nio.ByteOrder;

import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
import static org.elasticsearch.simdvec.internal.Similarities.dotProductD1Q4;
import static org.elasticsearch.simdvec.internal.Similarities.dotProductD1Q4Bulk;
import static org.elasticsearch.simdvec.internal.Similarities.dotProductD1Q4BulkWithOffsets;

/** Panamized scorer for quantized vectors stored as a {@link MemorySegment}. */
final class MSBitToInt4ESNextOSQVectorsScorer extends MemorySegmentESNextOSQVectorsScorer.MemorySegmentScorer {
Expand All @@ -42,40 +38,16 @@ final class MSBitToInt4ESNextOSQVectorsScorer extends MemorySegmentESNextOSQVect
@Override
public long quantizeScore(byte[] q) throws IOException {
assert q.length == length * 4;
// 128 / 8 == 16
if (length >= 16) {
if (NATIVE_SUPPORTED) {
return nativeQuantizeScore(q);
} else if (PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) {
if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) {
return quantizeScore256(q);
} else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) {
return quantizeScore128(q);
}
if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) {
if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) {
return quantizeScore256(q);
} else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) {
return quantizeScore128(q);
}
}
return Long.MIN_VALUE;
}

private long nativeQuantizeScore(byte[] q) throws IOException {
return IndexInputUtils.withSlice(in, length, this::getScratch, segment -> nativeQuantizeScoreImpl(q, segment, length));
}

private static long nativeQuantizeScoreImpl(byte[] q, MemorySegment datasetMemorySegment, int length) {
final long qScore;
if (SUPPORTS_HEAP_SEGMENTS) {
var queryMemorySegment = MemorySegment.ofArray(q);
qScore = dotProductD1Q4(datasetMemorySegment, queryMemorySegment, length);
} else {
try (var arena = Arena.ofConfined()) {
var queryMemorySegment = arena.allocate(q.length, 32);
MemorySegment.copy(q, 0, queryMemorySegment, ValueLayout.JAVA_BYTE, 0, q.length);
qScore = dotProductD1Q4(datasetMemorySegment, queryMemorySegment, length);
}
}
return qScore;
}

private long quantizeScore256(byte[] q) throws IOException {
return IndexInputUtils.withSlice(in, length, this::getScratch, segment -> quantizeScore256Impl(q, segment, length));
}
Expand Down Expand Up @@ -214,44 +186,18 @@ private static long quantizeScore128Impl(byte[] q, MemorySegment memorySegment,
@Override
public boolean quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOException {
assert q.length == length * 4;
// 128 / 8 == 16
if (length >= 16) {
if (NATIVE_SUPPORTED) {
if (SUPPORTS_HEAP_SEGMENTS) {
var querySegment = MemorySegment.ofArray(q);
var scoresSegment = MemorySegment.ofArray(scores);
nativeQuantizeScoreBulk(querySegment, count, scoresSegment);
} else {
try (var arena = Arena.ofConfined()) {
var querySegment = arena.allocate(q.length, 32);
var scoresSegment = arena.allocate((long) scores.length * Float.BYTES, 32);
MemorySegment.copy(q, 0, querySegment, ValueLayout.JAVA_BYTE, 0, q.length);
nativeQuantizeScoreBulk(querySegment, count, scoresSegment);
MemorySegment.copy(scoresSegment, ValueLayout.JAVA_FLOAT, 0, scores, 0, scores.length);
}
}
if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) {
if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) {
quantizeScore256Bulk(q, count, scores);
return true;
} else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) {
quantizeScore128Bulk(q, count, scores);
return true;
} else if (PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) {
if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) {
quantizeScore256Bulk(q, count, scores);
return true;
} else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) {
quantizeScore128Bulk(q, count, scores);
return true;
}
}
}
return false;
}

private void nativeQuantizeScoreBulk(MemorySegment querySegment, int count, MemorySegment scoresSegment) throws IOException {
var datasetLengthInBytes = (long) length * count;
IndexInputUtils.withSlice(in, datasetLengthInBytes, this::getScratch, datasetSegment -> {
dotProductD1Q4Bulk(datasetSegment, querySegment, length, count, scoresSegment);
return null;
});
}

private void quantizeScore128Bulk(byte[] q, int count, float[] scores) throws IOException {
var datasetLengthInBytes = (long) length * count;
IndexInputUtils.withSlice(in, datasetLengthInBytes, this::getScratch, segment -> {
Expand Down Expand Up @@ -403,154 +349,67 @@ private static void quantizeScore256BulkImpl(byte[] q, MemorySegment memorySegme

@Override
public boolean quantizeScoreBulkOffsets(byte[] q, int[] offsets, int offsetsCount, float[] scores, int count) throws IOException {
assert q.length == length * 4;
if (NATIVE_SUPPORTED) {
if (SUPPORTS_HEAP_SEGMENTS) {
var querySegment = MemorySegment.ofArray(q);
var offsetsSegment = MemorySegment.ofArray(offsets);
var scoresSegment = MemorySegment.ofArray(scores);
nativeQuantizeScoreBulkOffsets(querySegment, offsetsSegment, scoresSegment, offsetsCount, count);
} else {
try (var arena = Arena.ofConfined()) {
var querySegment = arena.allocate(q.length, 32);
var offsetsSegment = arena.allocate((long) offsetsCount * Integer.BYTES, 32);
var scoresSegment = arena.allocate((long) scores.length * Float.BYTES, 32);
MemorySegment.copy(q, 0, querySegment, ValueLayout.JAVA_BYTE, 0, q.length);
MemorySegment.copy(offsets, 0, offsetsSegment, ValueLayout.JAVA_INT, 0, offsetsCount);
nativeQuantizeScoreBulkOffsets(querySegment, offsetsSegment, scoresSegment, offsetsCount, count);
MemorySegment.copy(scoresSegment, ValueLayout.JAVA_FLOAT, 0, scores, 0, scores.length);
}
}
repositionScoresMatchingOffsets(offsets, offsetsCount, scores);
return true;
}
return false;
}

private void nativeQuantizeScoreBulkOffsets(
MemorySegment querySegment,
MemorySegment offsetsSegment,
MemorySegment scoresSegment,
int offsetsCount,
int totalCount
) throws IOException {
var datasetLengthInBytes = (long) length * totalCount;
IndexInputUtils.withSlice(in, datasetLengthInBytes, this::getScratch, datasetSegment -> {
dotProductD1Q4BulkWithOffsets(datasetSegment, querySegment, length, length, offsetsSegment, offsetsCount, scoresSegment);
return null;
});
}

@Override
public float scoreBulk(
float scoreBulkOffsets(
byte[] q,
float queryLowerInterval,
float queryUpperInterval,
int queryComponentSum,
float queryAdditionalCorrection,
VectorSimilarityFunction similarityFunction,
float centroidDp,
int[] offsets,
int offsetsCount,
float[] scores,
int bulkSize
) throws IOException {
assert q.length == length * 4;
// 128 / 8 == 16
if (length >= 16) {
if (PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) {
if (NATIVE_SUPPORTED) {
if (SUPPORTS_HEAP_SEGMENTS) {
var querySegment = MemorySegment.ofArray(q);
var scoresSegment = MemorySegment.ofArray(scores);
nativeQuantizeScoreBulk(querySegment, bulkSize, scoresSegment);
return nativeApplyCorrectionsBulk(
queryLowerInterval,
queryUpperInterval,
queryComponentSum,
queryAdditionalCorrection,
similarityFunction,
centroidDp,
scoresSegment,
bulkSize
);
} else {
try (var arena = Arena.ofConfined()) {
var querySegment = arena.allocate(q.length, 32);
var scoresSegment = arena.allocate((long) scores.length * Float.BYTES, 32);
MemorySegment.copy(q, 0, querySegment, ValueLayout.JAVA_BYTE, 0, q.length);
nativeQuantizeScoreBulk(querySegment, bulkSize, scoresSegment);
var maxScore = nativeApplyCorrectionsBulk(
queryLowerInterval,
queryUpperInterval,
queryComponentSum,
queryAdditionalCorrection,
similarityFunction,
centroidDp,
scoresSegment,
bulkSize
);
MemorySegment.copy(scoresSegment, ValueLayout.JAVA_FLOAT, 0, scores, 0, scores.length);
return maxScore;
}
}
} else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) {
quantizeScore256Bulk(q, bulkSize, scores);
return applyCorrections256Bulk(
queryLowerInterval,
queryUpperInterval,
queryComponentSum,
queryAdditionalCorrection,
similarityFunction,
centroidDp,
scores,
bulkSize
);
} else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) {
quantizeScore128Bulk(q, bulkSize, scores);
return applyCorrections128Bulk(
queryLowerInterval,
queryUpperInterval,
queryComponentSum,
queryAdditionalCorrection,
similarityFunction,
centroidDp,
scores,
bulkSize
);
}
}
}
int count
) {
return Float.NEGATIVE_INFINITY;
}

private float nativeApplyCorrectionsBulk(
@Override
public float scoreBulk(
byte[] q,
float queryLowerInterval,
float queryUpperInterval,
int queryComponentSum,
float queryAdditionalCorrection,
VectorSimilarityFunction similarityFunction,
float centroidDp,
MemorySegment scoresSegment,
float[] scores,
int bulkSize
) throws IOException {
return IndexInputUtils.withSlice(
in,
16L * bulkSize,
this::getScratch,
seg -> ScoreCorrections.nativeApplyCorrectionsBulk(
similarityFunction,
seg,
bulkSize,
dimensions,
queryLowerInterval,
queryUpperInterval,
queryComponentSum,
queryAdditionalCorrection,
FOUR_BIT_SCALE,
ONE_BIT_SCALE,
centroidDp,
scoresSegment
)
);
assert q.length == length * 4;
if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) {
if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) {
quantizeScore256Bulk(q, bulkSize, scores);
return applyCorrections256Bulk(
queryLowerInterval,
queryUpperInterval,
queryComponentSum,
queryAdditionalCorrection,
similarityFunction,
centroidDp,
scores,
bulkSize
);
} else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) {
quantizeScore128Bulk(q, bulkSize, scores);
return applyCorrections128Bulk(
queryLowerInterval,
queryUpperInterval,
queryComponentSum,
queryAdditionalCorrection,
similarityFunction,
centroidDp,
scores,
bulkSize
);
}
}
return Float.NEGATIVE_INFINITY;
}

private float applyCorrections128Bulk(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
import org.elasticsearch.simdvec.internal.MemorySegmentES92Int7VectorsScorer;

import java.io.IOException;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;

import static org.elasticsearch.simdvec.internal.Similarities.dotProductI7uBulkWithOffsets;

Expand Down Expand Up @@ -44,23 +42,11 @@ boolean quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOExceptio
@Override
public boolean quantizeScoreBulkOffsets(byte[] q, int[] offsets, int offsetsCount, float[] scores, int count) throws IOException {
assert q.length == length;
if (NATIVE_SUPPORTED) {
if (SUPPORTS_HEAP_SEGMENTS) {
var querySegment = MemorySegment.ofArray(q);
var offsetsSegment = MemorySegment.ofArray(offsets);
var scoresSegment = MemorySegment.ofArray(scores);
nativeQuantizeScoreBulkOffsets(querySegment, offsetsSegment, scoresSegment, offsetsCount, count);
} else {
try (var arena = Arena.ofConfined()) {
var querySegment = arena.allocate(q.length, 32);
var offsetsSegment = arena.allocate((long) offsetsCount * Integer.BYTES, 32);
var scoresSegment = arena.allocate((long) scores.length * Float.BYTES, 32);
MemorySegment.copy(q, 0, querySegment, ValueLayout.JAVA_BYTE, 0, q.length);
MemorySegment.copy(offsets, 0, offsetsSegment, ValueLayout.JAVA_INT, 0, offsetsCount);
nativeQuantizeScoreBulkOffsets(querySegment, offsetsSegment, scoresSegment, offsetsCount, count);
MemorySegment.copy(scoresSegment, ValueLayout.JAVA_FLOAT, 0, scores, 0, scores.length);
}
}
if (NATIVE_SUPPORTED && SUPPORTS_HEAP_SEGMENTS) {
var querySegment = MemorySegment.ofArray(q);
var offsetsSegment = MemorySegment.ofArray(offsets);
var scoresSegment = MemorySegment.ofArray(scores);
nativeQuantizeScoreBulkOffsets(querySegment, offsetsSegment, scoresSegment, offsetsCount, count);
repositionScoresMatchingOffsets(offsets, offsetsCount, scores);
return true;
}
Expand All @@ -81,6 +67,23 @@ private void nativeQuantizeScoreBulkOffsets(
});
}

@Override
float scoreBulkOffsets(
byte[] q,
float queryLowerInterval,
float queryUpperInterval,
int queryComponentSum,
float queryAdditionalCorrection,
VectorSimilarityFunction similarityFunction,
float centroidDp,
int[] offsets,
int offsetsCount,
float[] scores,
int count
) {
return Float.NEGATIVE_INFINITY;
}

@Override
float scoreBulk(
byte[] q,
Expand Down
Loading
Loading