Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ Optimizations
* GITHUB#15151: Use `SimScorer#score` bulk API to compute impact scores per
block of postings. (Adrien Grand)

* GITHUB#14863: Perform scoring for 4 and 7 bit quantized vectors off-heap. (Kaival Parikh)

Bug Fixes
---------------------
* GITHUB#14161: PointInSetQuery's constructor now throws IllegalArgumentException
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,13 @@ static void compressBytes(byte[] raw, byte[] compressed) {
private byte[] bytesA;
private byte[] bytesB;
private byte[] halfBytesA;
private byte[] halfBytesAPacked;
private byte[] halfBytesB;
private byte[] halfBytesBPacked;
private float[] floatsA;
private float[] floatsB;
private int expectedhalfByteDotProduct;
private int expectedHalfByteDotProduct;
private int expectedHalfByteSquareDistance;

@Param({"1", "128", "207", "256", "300", "512", "702", "1024"})
int size;
Expand All @@ -74,16 +76,23 @@ public void init() {
random.nextBytes(bytesB);
// random half byte arrays for binary methods
// this means that all values must be between 0 and 15
expectedhalfByteDotProduct = 0;
expectedHalfByteDotProduct = 0;
expectedHalfByteSquareDistance = 0;
halfBytesA = new byte[size];
halfBytesB = new byte[size];
for (int i = 0; i < size; ++i) {
halfBytesA[i] = (byte) random.nextInt(16);
halfBytesB[i] = (byte) random.nextInt(16);
expectedhalfByteDotProduct += halfBytesA[i] * halfBytesB[i];
expectedHalfByteDotProduct += halfBytesA[i] * halfBytesB[i];

int diff = halfBytesA[i] - halfBytesB[i];
expectedHalfByteSquareDistance += diff * diff;
}
// pack the half byte arrays
if (size % 2 == 0) {
halfBytesAPacked = new byte[(size + 1) >> 1];
compressBytes(halfBytesA, halfBytesAPacked);

halfBytesBPacked = new byte[(size + 1) >> 1];
compressBytes(halfBytesB, halfBytesBPacked);
}
Expand All @@ -108,6 +117,74 @@ public float binaryCosineVector() {
return VectorUtil.cosine(bytesA, bytesB);
}

@Benchmark
public int binarySquareScalar() {
return VectorUtil.squareDistance(bytesA, bytesB);
}

@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public int binarySquareVector() {
return VectorUtil.squareDistance(bytesA, bytesB);
}

@Benchmark
public int binaryHalfByteSquareScalar() {
int v = VectorUtil.int4SquareDistance(halfBytesA, halfBytesB);
if (v != expectedHalfByteSquareDistance) {
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
}
return v;
}

@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public int binaryHalfByteSquareVector() {
int v = VectorUtil.int4SquareDistance(halfBytesA, halfBytesB);
if (v != expectedHalfByteSquareDistance) {
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
}
return v;
}

@Benchmark
public int binaryHalfByteSquareSinglePackedScalar() {
int v = VectorUtil.int4SquareDistanceSinglePacked(halfBytesA, halfBytesBPacked);
if (v != expectedHalfByteSquareDistance) {
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
}
return v;
}

@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public int binaryHalfByteSquareSinglePackedVector() {
int v = VectorUtil.int4SquareDistanceSinglePacked(halfBytesA, halfBytesBPacked);
if (v != expectedHalfByteSquareDistance) {
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
}
return v;
}

@Benchmark
public int binaryHalfByteSquareBothPackedScalar() {
int v = VectorUtil.int4SquareDistanceBothPacked(halfBytesAPacked, halfBytesBPacked);
if (v != expectedHalfByteSquareDistance) {
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
}
return v;
}

@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public int binaryHalfByteSquareBothPackedVector() {
int v = VectorUtil.int4SquareDistanceBothPacked(halfBytesAPacked, halfBytesBPacked);
if (v != expectedHalfByteSquareDistance) {
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
}
return v;
}

@Benchmark
public int binaryDotProductScalar() {
return VectorUtil.dotProduct(bytesA, bytesB);
Expand All @@ -131,14 +208,22 @@ public int binaryDotProductUint8Vector() {
}

@Benchmark
public int binarySquareScalar() {
return VectorUtil.squareDistance(bytesA, bytesB);
public int binaryHalfByteDotProductScalar() {
int v = VectorUtil.int4DotProduct(halfBytesA, halfBytesB);
if (v != expectedHalfByteDotProduct) {
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
}
return v;
}

@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public int binarySquareVector() {
return VectorUtil.squareDistance(bytesA, bytesB);
public int binaryHalfByteDotProductVector() {
int v = VectorUtil.int4DotProduct(halfBytesA, halfBytesB);
if (v != expectedHalfByteDotProduct) {
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
}
return v;
}

@Benchmark
Expand All @@ -153,37 +238,39 @@ public int binarySquareUint8Vector() {
}

@Benchmark
public int binaryHalfByteScalar() {
return VectorUtil.int4DotProduct(halfBytesA, halfBytesB);
public int binaryHalfByteDotProductSinglePackedScalar() {
int v = VectorUtil.int4DotProductSinglePacked(halfBytesA, halfBytesBPacked);
if (v != expectedHalfByteDotProduct) {
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
}
return v;
}

@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public int binaryHalfByteVector() {
return VectorUtil.int4DotProduct(halfBytesA, halfBytesB);
public int binaryHalfByteDotProductSinglePackedVector() {
int v = VectorUtil.int4DotProductSinglePacked(halfBytesA, halfBytesBPacked);
if (v != expectedHalfByteDotProduct) {
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
}
return v;
}

@Benchmark
public int binaryHalfByteScalarPacked() {
if (size % 2 != 0) {
throw new RuntimeException("Size must be even for this benchmark");
}
int v = VectorUtil.int4DotProductPacked(halfBytesA, halfBytesBPacked);
if (v != expectedhalfByteDotProduct) {
throw new RuntimeException("Expected " + expectedhalfByteDotProduct + " but got " + v);
public int binaryHalfByteDotProductBothPackedScalar() {
int v = VectorUtil.int4DotProductBothPacked(halfBytesAPacked, halfBytesBPacked);
if (v != expectedHalfByteDotProduct) {
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
}
return v;
}

@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public int binaryHalfByteVectorPacked() {
if (size % 2 != 0) {
throw new RuntimeException("Size must be even for this benchmark");
}
int v = VectorUtil.int4DotProductPacked(halfBytesA, halfBytesBPacked);
if (v != expectedhalfByteDotProduct) {
throw new RuntimeException("Expected " + expectedhalfByteDotProduct + " but got " + v);
public int binaryHalfByteDotProductBothPackedVector() {
int v = VectorUtil.int4DotProductBothPacked(halfBytesAPacked, halfBytesBPacked);
if (v != expectedHalfByteDotProduct) {
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
}
return v;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,8 @@ private FlatVectorScorerUtil() {}
public static FlatVectorsScorer getLucene99FlatVectorsScorer() {
return IMPL.getLucene99FlatVectorsScorer();
}

public static FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() {
return IMPL.getLucene99ScalarQuantizedVectorsScorer();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.FloatToFloatFunction;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
Expand Down Expand Up @@ -245,7 +246,7 @@ public float score(int vectorOrdinal) throws IOException {
values.getSlice().seek((long) vectorOrdinal * (values.getVectorByteLength() + Float.BYTES));
values.getSlice().readBytes(compressedVector, 0, compressedVector.length);
float vectorOffset = values.getScoreCorrectionConstant(vectorOrdinal);
int dotProduct = VectorUtil.int4DotProductPacked(targetBytes, compressedVector);
int dotProduct = VectorUtil.int4DotProductSinglePacked(targetBytes, compressedVector);
// For the current implementation of scalar quantization, all dotproducts should
// be >= 0;
assert dotProduct >= 0;
Expand Down Expand Up @@ -301,11 +302,6 @@ public void setScoringOrdinal(int node) throws IOException {
}
}

@FunctionalInterface
private interface FloatToFloatFunction {
float apply(float f);
}

private static final class ScalarQuantizedRandomVectorScorerSupplier
implements RandomVectorScorerSupplier {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
package org.apache.lucene.codecs.lucene99;

import java.io.IOException;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
Expand Down Expand Up @@ -68,7 +68,7 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {

final byte bits;
final boolean compress;
final Lucene99ScalarQuantizedVectorScorer flatVectorScorer;
final FlatVectorsScorer flatVectorScorer;

/** Constructs a format using default graph construction parameters */
public Lucene99ScalarQuantizedVectorsFormat() {
Expand Down Expand Up @@ -115,8 +115,7 @@ public Lucene99ScalarQuantizedVectorsFormat(
this.bits = (byte) bits;
this.confidenceInterval = confidenceInterval;
this.compress = compress;
this.flatVectorScorer =
new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE);
this.flatVectorScorer = FlatVectorScorerUtil.getLucene99ScalarQuantizedVectorsScorer();
}

public static float calculateDefaultConfidenceInterval(int vectorDimension) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,24 +164,35 @@ public int uint8DotProduct(byte[] a, byte[] b) {
}

@Override
public int int4DotProduct(byte[] a, boolean apacked, byte[] b, boolean bpacked) {
assert (apacked && bpacked) == false;
if (apacked || bpacked) {
byte[] packed = apacked ? a : b;
byte[] unpacked = apacked ? b : a;
int total = 0;
for (int i = 0; i < packed.length; i++) {
byte packedByte = packed[i];
byte unpacked1 = unpacked[i];
byte unpacked2 = unpacked[i + packed.length];
total += (packedByte & 0x0F) * unpacked2;
total += ((packedByte & 0xFF) >> 4) * unpacked1;
}
return total;
}
public int int4DotProduct(byte[] a, byte[] b) {
return dotProduct(a, b);
}

@Override
public int int4DotProductSinglePacked(byte[] unpacked, byte[] packed) {
int total = 0;
for (int i = 0; i < packed.length; i++) {
byte packedByte = packed[i];
byte unpacked1 = unpacked[i];
byte unpacked2 = unpacked[i + packed.length];
total += (packedByte & 0x0F) * unpacked2;
total += ((packedByte & 0xFF) >> 4) * unpacked1;
}
return total;
}

@Override
public int int4DotProductBothPacked(byte[] a, byte[] b) {
int total = 0;
for (int i = 0; i < a.length; i++) {
byte aByte = a[i];
byte bByte = b[i];
total += (aByte & 0x0F) * (bByte & 0x0F);
total += ((aByte & 0xFF) >> 4) * ((bByte & 0xFF) >> 4);
}
return total;
}

@Override
public float cosine(byte[] a, byte[] b) {
// Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14.
Expand Down Expand Up @@ -210,6 +221,42 @@ public int squareDistance(byte[] a, byte[] b) {
return squareSum;
}

@Override
public int int4SquareDistance(byte[] a, byte[] b) {
return squareDistance(a, b);
}

@Override
public int int4SquareDistanceSinglePacked(byte[] unpacked, byte[] packed) {
int total = 0;
for (int i = 0; i < packed.length; i++) {
byte packedByte = packed[i];
byte unpacked1 = unpacked[i];
byte unpacked2 = unpacked[i + packed.length];

int diff1 = (packedByte & 0x0F) - unpacked2;
int diff2 = ((packedByte & 0xFF) >> 4) - unpacked1;

total += diff1 * diff1 + diff2 * diff2;
}
return total;
}

@Override
public int int4SquareDistanceBothPacked(byte[] a, byte[] b) {
int total = 0;
for (int i = 0; i < a.length; i++) {
byte aByte = a[i];
byte bByte = b[i];

int diff1 = (aByte & 0x0F) - (bByte & 0x0F);
int diff2 = ((aByte & 0xFF) >> 4) - ((bByte & 0xFF) >> 4);

total += diff1 * diff1 + diff2 * diff2;
}
return total;
}

@Override
public int uint8SquareDistance(byte[] a, byte[] b) {
// Note: this will not overflow if dim < 2^16, since max(ubyte * ubyte) = 2^16.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorScorer;
import org.apache.lucene.store.IndexInput;

/** Default provider returning scalar implementations. */
Expand All @@ -40,6 +41,11 @@ public FlatVectorsScorer getLucene99FlatVectorsScorer() {
return DefaultFlatVectorScorer.INSTANCE;
}

@Override
public FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() {
return new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE);
}

@Override
public PostingDecodingUtil newPostingDecodingUtil(IndexInput input) {
return new PostingDecodingUtil(input);
Expand Down
Loading
Loading