Skip to content

Commit 059f91c

Browse files
authored
Panama vector accelerated optimized scalar quantization (#127118)
* Adds accelerates optimized scalar quantization with vectorized functions * Adding benchmark * Update docs/changelog/127118.yaml * adjusting benchmark and delta
1 parent ad0fe78 commit 059f91c

File tree

16 files changed

+702
-99
lines changed

16 files changed

+702
-99
lines changed
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.benchmark.vector;
11+
12+
import org.apache.lucene.index.VectorSimilarityFunction;
13+
import org.elasticsearch.common.logging.LogConfigurator;
14+
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
15+
import org.openjdk.jmh.annotations.Benchmark;
16+
import org.openjdk.jmh.annotations.BenchmarkMode;
17+
import org.openjdk.jmh.annotations.Fork;
18+
import org.openjdk.jmh.annotations.Level;
19+
import org.openjdk.jmh.annotations.Measurement;
20+
import org.openjdk.jmh.annotations.Mode;
21+
import org.openjdk.jmh.annotations.OutputTimeUnit;
22+
import org.openjdk.jmh.annotations.Param;
23+
import org.openjdk.jmh.annotations.Scope;
24+
import org.openjdk.jmh.annotations.Setup;
25+
import org.openjdk.jmh.annotations.State;
26+
import org.openjdk.jmh.annotations.Warmup;
27+
28+
import java.util.concurrent.ThreadLocalRandom;
29+
import java.util.concurrent.TimeUnit;
30+
31+
@BenchmarkMode(Mode.Throughput)
32+
@OutputTimeUnit(TimeUnit.MILLISECONDS)
33+
@State(Scope.Benchmark)
34+
@Warmup(iterations = 3, time = 1)
35+
@Measurement(iterations = 5, time = 1)
36+
@Fork(value = 3)
37+
public class OptimizedScalarQuantizerBenchmark {
38+
static {
39+
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
40+
}
41+
@Param({ "384", "702", "1024" })
42+
int dims;
43+
44+
float[] vector;
45+
float[] centroid;
46+
byte[] destination;
47+
48+
@Param({ "1", "4", "7" })
49+
byte bits;
50+
51+
OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(VectorSimilarityFunction.DOT_PRODUCT);
52+
53+
@Setup(Level.Iteration)
54+
public void init() {
55+
ThreadLocalRandom random = ThreadLocalRandom.current();
56+
// random byte arrays for binary methods
57+
destination = new byte[dims];
58+
vector = new float[dims];
59+
centroid = new float[dims];
60+
for (int i = 0; i < dims; ++i) {
61+
vector[i] = random.nextFloat();
62+
centroid[i] = random.nextFloat();
63+
}
64+
}
65+
66+
@Benchmark
67+
public byte[] scalar() {
68+
osq.scalarQuantize(vector, destination, bits, centroid);
69+
return destination;
70+
}
71+
72+
@Benchmark
73+
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
74+
public byte[] vector() {
75+
osq.scalarQuantize(vector, destination, bits, centroid);
76+
return destination;
77+
}
78+
}

docs/changelog/127118.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 127118
2+
summary: Panama vector accelerated optimized scalar quantization
3+
area: Vector Search
4+
type: enhancement
5+
issues: []

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,4 +144,71 @@ static int andBitCountLong(byte[] a, byte[] b) {
144144
}
145145
return distance;
146146
}
147+
148+
/**
149+
* Calculate the loss for optimized-scalar quantization for the given parameteres
150+
* @param target The vector being quantized, assumed to be centered
151+
* @param interval The interval for which to calculate the loss
152+
* @param points the quantization points
153+
* @param norm2 The norm squared of the target vector
154+
* @param lambda The lambda parameter for controlling anisotropic loss calculation
155+
* @return The loss for the given parameters
156+
*/
157+
public static float calculateOSQLoss(float[] target, float[] interval, int points, float norm2, float lambda) {
158+
assert interval.length == 2;
159+
float step = ((interval[1] - interval[0]) / (points - 1.0F));
160+
float invStep = 1f / step;
161+
return IMPL.calculateOSQLoss(target, interval, step, invStep, norm2, lambda);
162+
}
163+
164+
/**
165+
* Calculate the grid points for optimized-scalar quantization
166+
* @param target The vector being quantized, assumed to be centered
167+
* @param interval The interval for which to calculate the grid points
168+
* @param points the quantization points
169+
* @param pts The array to store the grid points, must be of length 5
170+
*/
171+
public static void calculateOSQGridPoints(float[] target, float[] interval, int points, float[] pts) {
172+
assert interval.length == 2;
173+
assert pts.length == 5;
174+
float invStep = (points - 1.0F) / (interval[1] - interval[0]);
175+
IMPL.calculateOSQGridPoints(target, interval, points, invStep, pts);
176+
}
177+
178+
/**
179+
* Center the target vector and calculate the optimized-scalar quantization statistics
180+
* @param target The vector being quantized
181+
* @param centroid The centroid of the target vector
182+
* @param centered The destination of the centered vector, will be overwritten
183+
* @param stats The array to store the statistics, must be of length 5
184+
*/
185+
public static void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats) {
186+
assert target.length == centroid.length;
187+
assert stats.length == 5;
188+
if (target.length != centroid.length) {
189+
throw new IllegalArgumentException("vector dimensions differ: " + target.length + "!=" + centroid.length);
190+
}
191+
if (centered.length != target.length) {
192+
throw new IllegalArgumentException("vector dimensions differ: " + centered.length + "!=" + target.length);
193+
}
194+
IMPL.centerAndCalculateOSQStatsEuclidean(target, centroid, centered, stats);
195+
}
196+
197+
/**
198+
* Center the target vector and calculate the optimized-scalar quantization statistics
199+
* @param target The vector being quantized
200+
* @param centroid The centroid of the target vector
201+
* @param centered The destination of the centered vector, will be overwritten
202+
* @param stats The array to store the statistics, must be of length 6
203+
*/
204+
public static void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats) {
205+
if (target.length != centroid.length) {
206+
throw new IllegalArgumentException("vector dimensions differ: " + target.length + "!=" + centroid.length);
207+
}
208+
if (centered.length != target.length) {
209+
throw new IllegalArgumentException("vector dimensions differ: " + centered.length + "!=" + target.length);
210+
}
211+
assert stats.length == 6;
212+
IMPL.centerAndCalculateOSQStatsDp(target, centroid, centered, stats);
213+
}
147214
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,100 @@ public float ipFloatByte(float[] q, byte[] d) {
4444
return ipFloatByteImpl(q, d);
4545
}
4646

47+
@Override
48+
public float calculateOSQLoss(float[] target, float[] interval, float step, float invStep, float norm2, float lambda) {
49+
float a = interval[0];
50+
float b = interval[1];
51+
float xe = 0f;
52+
float e = 0f;
53+
for (float xi : target) {
54+
// this is quantizing and then dequantizing the vector
55+
float xiq = fma(step, Math.round((Math.min(Math.max(xi, a), b) - a) * invStep), a);
56+
// how much does the de-quantized value differ from the original value
57+
float xiiq = xi - xiq;
58+
e = fma(xiiq, xiiq, e);
59+
xe = fma(xi, xiiq, xe);
60+
}
61+
return (1f - lambda) * xe * xe / norm2 + lambda * e;
62+
}
63+
64+
@Override
65+
public void calculateOSQGridPoints(float[] target, float[] interval, int points, float invStep, float[] pts) {
66+
float a = interval[0];
67+
float b = interval[1];
68+
float daa = 0;
69+
float dab = 0;
70+
float dbb = 0;
71+
float dax = 0;
72+
float dbx = 0;
73+
for (float v : target) {
74+
float k = Math.round((Math.min(Math.max(v, a), b) - a) * invStep);
75+
float s = k / (points - 1);
76+
float ms = 1f - s;
77+
daa = fma(ms, ms, daa);
78+
dab = fma(ms, s, dab);
79+
dbb = fma(s, s, dbb);
80+
dax = fma(ms, v, dax);
81+
dbx = fma(s, v, dbx);
82+
}
83+
pts[0] = daa;
84+
pts[1] = dab;
85+
pts[2] = dbb;
86+
pts[3] = dax;
87+
pts[4] = dbx;
88+
}
89+
90+
@Override
91+
public void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats) {
92+
float vecMean = 0;
93+
float vecVar = 0;
94+
float norm2 = 0;
95+
float min = Float.MAX_VALUE;
96+
float max = -Float.MAX_VALUE;
97+
for (int i = 0; i < target.length; i++) {
98+
centered[i] = target[i] - centroid[i];
99+
min = Math.min(min, centered[i]);
100+
max = Math.max(max, centered[i]);
101+
norm2 = fma(centered[i], centered[i], norm2);
102+
float delta = centered[i] - vecMean;
103+
vecMean += delta / (i + 1);
104+
float delta2 = centered[i] - vecMean;
105+
vecVar = fma(delta, delta2, vecVar);
106+
}
107+
stats[0] = vecMean;
108+
stats[1] = vecVar / target.length;
109+
stats[2] = norm2;
110+
stats[3] = min;
111+
stats[4] = max;
112+
}
113+
114+
@Override
115+
public void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats) {
116+
float vecMean = 0;
117+
float vecVar = 0;
118+
float norm2 = 0;
119+
float centroidDot = 0;
120+
float min = Float.MAX_VALUE;
121+
float max = -Float.MAX_VALUE;
122+
for (int i = 0; i < target.length; i++) {
123+
centroidDot = fma(target[i], centroid[i], centroidDot);
124+
centered[i] = target[i] - centroid[i];
125+
min = Math.min(min, centered[i]);
126+
max = Math.max(max, centered[i]);
127+
norm2 = fma(centered[i], centered[i], norm2);
128+
float delta = centered[i] - vecMean;
129+
vecMean += delta / (i + 1);
130+
float delta2 = centered[i] - vecMean;
131+
vecVar = fma(delta, delta2, vecVar);
132+
}
133+
stats[0] = vecMean;
134+
stats[1] = vecVar / target.length;
135+
stats[2] = norm2;
136+
stats[3] = min;
137+
stats[4] = max;
138+
stats[5] = centroidDot;
139+
}
140+
47141
public static int ipByteBitImpl(byte[] q, byte[] d) {
48142
return ipByteBitImpl(q, d, 0);
49143
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,12 @@ public interface ESVectorUtilSupport {
2020
float ipFloatBit(float[] q, byte[] d);
2121

2222
float ipFloatByte(float[] q, byte[] d);
23+
24+
float calculateOSQLoss(float[] target, float[] interval, float step, float invStep, float norm2, float lambda);
25+
26+
void calculateOSQGridPoints(float[] target, float[] interval, int points, float invStep, float[] pts);
27+
28+
void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats);
29+
30+
void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats);
2331
}

0 commit comments

Comments
 (0)