Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.benchmark.vector;

import org.apache.lucene.index.VectorSimilarityFunction;
import org.elasticsearch.common.logging.LogConfigurator;
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;

import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;

@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Benchmark)
@Warmup(iterations = 3, time = 1)
@Measurement(iterations = 5, time = 1)
@Fork(value = 3)
public class OptimizedScalarQuantizerBenchmark {
static {
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
}
@Param({ "384", "702", "1024" })
int dims;

float[] vector;
float[] centroid;
byte[] destination;

@Param({ "1", "4", "7" })
byte bits;

OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(VectorSimilarityFunction.DOT_PRODUCT);

@Setup(Level.Iteration)
public void init() {
ThreadLocalRandom random = ThreadLocalRandom.current();
// random byte arrays for binary methods
destination = new byte[dims];
vector = new float[dims];
centroid = new float[dims];
for (int i = 0; i < dims; ++i) {
vector[i] = random.nextFloat();
centroid[i] = random.nextFloat();
}
}

@Benchmark
public byte[] scalar() {
osq.scalarQuantize(vector, destination, bits, centroid);
return destination;
}

@Benchmark
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public byte[] vector() {
osq.scalarQuantize(vector, destination, bits, centroid);
return destination;
}
}
5 changes: 5 additions & 0 deletions docs/changelog/127118.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 127118
summary: Panama vector accelerated optimized scalar quantization
area: Vector Search
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,71 @@ static int andBitCountLong(byte[] a, byte[] b) {
}
return distance;
}

/**
* Calculate the loss for optimized-scalar quantization for the given parameteres
* @param target The vector being quantized, assumed to be centered
* @param interval The interval for which to calculate the loss
* @param points the quantization points
* @param norm2 The norm squared of the target vector
* @param lambda The lambda parameter for controlling anisotropic loss calculation
* @return The loss for the given parameters
*/
public static float calculateOSQLoss(float[] target, float[] interval, int points, float norm2, float lambda) {
assert interval.length == 2;
float step = ((interval[1] - interval[0]) / (points - 1.0F));
float invStep = 1f / step;
return IMPL.calculateOSQLoss(target, interval, step, invStep, norm2, lambda);
}

/**
* Calculate the grid points for optimized-scalar quantization
* @param target The vector being quantized, assumed to be centered
* @param interval The interval for which to calculate the grid points
* @param points the quantization points
* @param pts The array to store the grid points, must be of length 5
*/
public static void calculateOSQGridPoints(float[] target, float[] interval, int points, float[] pts) {
assert interval.length == 2;
assert pts.length == 5;
float invStep = (points - 1.0F) / (interval[1] - interval[0]);
IMPL.calculateOSQGridPoints(target, interval, points, invStep, pts);
}

/**
* Center the target vector and calculate the optimized-scalar quantization statistics
* @param target The vector being quantized
* @param centroid The centroid of the target vector
* @param centered The destination of the centered vector, will be overwritten
* @param stats The array to store the statistics, must be of length 5
*/
public static void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats) {
assert target.length == centroid.length;
assert stats.length == 5;
if (target.length != centroid.length) {
throw new IllegalArgumentException("vector dimensions differ: " + target.length + "!=" + centroid.length);
}
if (centered.length != target.length) {
throw new IllegalArgumentException("vector dimensions differ: " + centered.length + "!=" + target.length);
}
IMPL.centerAndCalculateOSQStatsEuclidean(target, centroid, centered, stats);
}

/**
* Center the target vector and calculate the optimized-scalar quantization statistics
* @param target The vector being quantized
* @param centroid The centroid of the target vector
* @param centered The destination of the centered vector, will be overwritten
* @param stats The array to store the statistics, must be of length 6
*/
public static void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats) {
if (target.length != centroid.length) {
throw new IllegalArgumentException("vector dimensions differ: " + target.length + "!=" + centroid.length);
}
if (centered.length != target.length) {
throw new IllegalArgumentException("vector dimensions differ: " + centered.length + "!=" + target.length);
}
assert stats.length == 6;
IMPL.centerAndCalculateOSQStatsDp(target, centroid, centered, stats);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,100 @@ public float ipFloatByte(float[] q, byte[] d) {
return ipFloatByteImpl(q, d);
}

@Override
public float calculateOSQLoss(float[] target, float[] interval, float step, float invStep, float norm2, float lambda) {
float a = interval[0];
float b = interval[1];
float xe = 0f;
float e = 0f;
for (float xi : target) {
// this is quantizing and then dequantizing the vector
float xiq = fma(step, Math.round((Math.min(Math.max(xi, a), b) - a) * invStep), a);
// how much does the de-quantized value differ from the original value
float xiiq = xi - xiq;
e = fma(xiiq, xiiq, e);
xe = fma(xi, xiiq, xe);
}
return (1f - lambda) * xe * xe / norm2 + lambda * e;
}

@Override
public void calculateOSQGridPoints(float[] target, float[] interval, int points, float invStep, float[] pts) {
float a = interval[0];
float b = interval[1];
float daa = 0;
float dab = 0;
float dbb = 0;
float dax = 0;
float dbx = 0;
for (float v : target) {
float k = Math.round((Math.min(Math.max(v, a), b) - a) * invStep);
float s = k / (points - 1);
float ms = 1f - s;
daa = fma(ms, ms, daa);
dab = fma(ms, s, dab);
dbb = fma(s, s, dbb);
dax = fma(ms, v, dax);
dbx = fma(s, v, dbx);
}
pts[0] = daa;
pts[1] = dab;
pts[2] = dbb;
pts[3] = dax;
pts[4] = dbx;
}

@Override
public void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats) {
float vecMean = 0;
float vecVar = 0;
float norm2 = 0;
float min = Float.MAX_VALUE;
float max = -Float.MAX_VALUE;
for (int i = 0; i < target.length; i++) {
centered[i] = target[i] - centroid[i];
min = Math.min(min, centered[i]);
max = Math.max(max, centered[i]);
norm2 = fma(centered[i], centered[i], norm2);
float delta = centered[i] - vecMean;
vecMean += delta / (i + 1);
float delta2 = centered[i] - vecMean;
vecVar = fma(delta, delta2, vecVar);
}
stats[0] = vecMean;
stats[1] = vecVar / target.length;
stats[2] = norm2;
stats[3] = min;
stats[4] = max;
}

@Override
public void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats) {
float vecMean = 0;
float vecVar = 0;
float norm2 = 0;
float centroidDot = 0;
float min = Float.MAX_VALUE;
float max = -Float.MAX_VALUE;
for (int i = 0; i < target.length; i++) {
centroidDot = fma(target[i], centroid[i], centroidDot);
centered[i] = target[i] - centroid[i];
min = Math.min(min, centered[i]);
max = Math.max(max, centered[i]);
norm2 = fma(centered[i], centered[i], norm2);
float delta = centered[i] - vecMean;
vecMean += delta / (i + 1);
float delta2 = centered[i] - vecMean;
vecVar = fma(delta, delta2, vecVar);
}
stats[0] = vecMean;
stats[1] = vecVar / target.length;
stats[2] = norm2;
stats[3] = min;
stats[4] = max;
stats[5] = centroidDot;
}

public static int ipByteBitImpl(byte[] q, byte[] d) {
return ipByteBitImpl(q, d, 0);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,12 @@ public interface ESVectorUtilSupport {
float ipFloatBit(float[] q, byte[] d);

float ipFloatByte(float[] q, byte[] d);

float calculateOSQLoss(float[] target, float[] interval, float step, float invStep, float norm2, float lambda);

void calculateOSQGridPoints(float[] target, float[] interval, int points, float invStep, float[] pts);

void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats);

void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats);
}
Loading
Loading