Skip to content

Commit 9dd6d32

Browse files
authored
[8.x] Panama vector accelerated optimized scalar quantization (elastic#127118) (elastic#127269)
* Panama vector accelerated optimized scalar quantization (elastic#127118) * Adds accelerates optimized scalar quantization with vectorized functions * Adding benchmark * Update docs/changelog/127118.yaml * adjusting benchmark and delta (cherry picked from commit 059f91c) * fixing compilation * reverting unnecessary change
1 parent cdb569a commit 9dd6d32

File tree

17 files changed

+704
-99
lines changed

17 files changed

+704
-99
lines changed
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.benchmark.vector;
11+
12+
import org.apache.lucene.index.VectorSimilarityFunction;
13+
import org.elasticsearch.common.logging.LogConfigurator;
14+
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
15+
import org.openjdk.jmh.annotations.Benchmark;
16+
import org.openjdk.jmh.annotations.BenchmarkMode;
17+
import org.openjdk.jmh.annotations.Fork;
18+
import org.openjdk.jmh.annotations.Level;
19+
import org.openjdk.jmh.annotations.Measurement;
20+
import org.openjdk.jmh.annotations.Mode;
21+
import org.openjdk.jmh.annotations.OutputTimeUnit;
22+
import org.openjdk.jmh.annotations.Param;
23+
import org.openjdk.jmh.annotations.Scope;
24+
import org.openjdk.jmh.annotations.Setup;
25+
import org.openjdk.jmh.annotations.State;
26+
import org.openjdk.jmh.annotations.Warmup;
27+
28+
import java.util.concurrent.ThreadLocalRandom;
29+
import java.util.concurrent.TimeUnit;
30+
31+
@BenchmarkMode(Mode.Throughput)
32+
@OutputTimeUnit(TimeUnit.MILLISECONDS)
33+
@State(Scope.Benchmark)
34+
@Warmup(iterations = 3, time = 1)
35+
@Measurement(iterations = 5, time = 1)
36+
@Fork(value = 3)
37+
public class OptimizedScalarQuantizerBenchmark {
38+
static {
39+
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
40+
}
41+
@Param({ "384", "702", "1024" })
42+
int dims;
43+
44+
float[] vector;
45+
float[] centroid;
46+
byte[] destination;
47+
48+
@Param({ "1", "4", "7" })
49+
byte bits;
50+
51+
OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(VectorSimilarityFunction.DOT_PRODUCT);
52+
53+
@Setup(Level.Iteration)
54+
public void init() {
55+
ThreadLocalRandom random = ThreadLocalRandom.current();
56+
// random byte arrays for binary methods
57+
destination = new byte[dims];
58+
vector = new float[dims];
59+
centroid = new float[dims];
60+
for (int i = 0; i < dims; ++i) {
61+
vector[i] = random.nextFloat();
62+
centroid[i] = random.nextFloat();
63+
}
64+
}
65+
66+
@Benchmark
67+
public byte[] scalar() {
68+
osq.scalarQuantize(vector, destination, bits, centroid);
69+
return destination;
70+
}
71+
72+
@Benchmark
73+
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
74+
public byte[] vector() {
75+
osq.scalarQuantize(vector, destination, bits, centroid);
76+
return destination;
77+
}
78+
}

docs/changelog/127118.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 127118
2+
summary: Panama vector accelerated optimized scalar quantization
3+
area: Vector Search
4+
type: enhancement
5+
issues: []

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,4 +131,71 @@ static int andBitCountLong(byte[] a, byte[] b) {
131131
}
132132
return distance;
133133
}
134+
135+
/**
136+
* Calculate the loss for optimized-scalar quantization for the given parameteres
137+
* @param target The vector being quantized, assumed to be centered
138+
* @param interval The interval for which to calculate the loss
139+
* @param points the quantization points
140+
* @param norm2 The norm squared of the target vector
141+
* @param lambda The lambda parameter for controlling anisotropic loss calculation
142+
* @return The loss for the given parameters
143+
*/
144+
public static float calculateOSQLoss(float[] target, float[] interval, int points, float norm2, float lambda) {
145+
assert interval.length == 2;
146+
float step = ((interval[1] - interval[0]) / (points - 1.0F));
147+
float invStep = 1f / step;
148+
return IMPL.calculateOSQLoss(target, interval, step, invStep, norm2, lambda);
149+
}
150+
151+
/**
152+
* Calculate the grid points for optimized-scalar quantization
153+
* @param target The vector being quantized, assumed to be centered
154+
* @param interval The interval for which to calculate the grid points
155+
* @param points the quantization points
156+
* @param pts The array to store the grid points, must be of length 5
157+
*/
158+
public static void calculateOSQGridPoints(float[] target, float[] interval, int points, float[] pts) {
159+
assert interval.length == 2;
160+
assert pts.length == 5;
161+
float invStep = (points - 1.0F) / (interval[1] - interval[0]);
162+
IMPL.calculateOSQGridPoints(target, interval, points, invStep, pts);
163+
}
164+
165+
/**
166+
* Center the target vector and calculate the optimized-scalar quantization statistics
167+
* @param target The vector being quantized
168+
* @param centroid The centroid of the target vector
169+
* @param centered The destination of the centered vector, will be overwritten
170+
* @param stats The array to store the statistics, must be of length 5
171+
*/
172+
public static void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats) {
173+
assert target.length == centroid.length;
174+
assert stats.length == 5;
175+
if (target.length != centroid.length) {
176+
throw new IllegalArgumentException("vector dimensions differ: " + target.length + "!=" + centroid.length);
177+
}
178+
if (centered.length != target.length) {
179+
throw new IllegalArgumentException("vector dimensions differ: " + centered.length + "!=" + target.length);
180+
}
181+
IMPL.centerAndCalculateOSQStatsEuclidean(target, centroid, centered, stats);
182+
}
183+
184+
/**
185+
* Center the target vector and calculate the optimized-scalar quantization statistics
186+
* @param target The vector being quantized
187+
* @param centroid The centroid of the target vector
188+
* @param centered The destination of the centered vector, will be overwritten
189+
* @param stats The array to store the statistics, must be of length 6
190+
*/
191+
public static void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats) {
192+
if (target.length != centroid.length) {
193+
throw new IllegalArgumentException("vector dimensions differ: " + target.length + "!=" + centroid.length);
194+
}
195+
if (centered.length != target.length) {
196+
throw new IllegalArgumentException("vector dimensions differ: " + centered.length + "!=" + target.length);
197+
}
198+
assert stats.length == 6;
199+
IMPL.centerAndCalculateOSQStatsDp(target, centroid, centered, stats);
200+
}
134201
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,100 @@ public float ipFloatBit(float[] q, byte[] d) {
3939
return ipFloatBitImpl(q, d);
4040
}
4141

42+
@Override
43+
public float calculateOSQLoss(float[] target, float[] interval, float step, float invStep, float norm2, float lambda) {
44+
float a = interval[0];
45+
float b = interval[1];
46+
float xe = 0f;
47+
float e = 0f;
48+
for (float xi : target) {
49+
// this is quantizing and then dequantizing the vector
50+
float xiq = fma(step, Math.round((Math.min(Math.max(xi, a), b) - a) * invStep), a);
51+
// how much does the de-quantized value differ from the original value
52+
float xiiq = xi - xiq;
53+
e = fma(xiiq, xiiq, e);
54+
xe = fma(xi, xiiq, xe);
55+
}
56+
return (1f - lambda) * xe * xe / norm2 + lambda * e;
57+
}
58+
59+
@Override
60+
public void calculateOSQGridPoints(float[] target, float[] interval, int points, float invStep, float[] pts) {
61+
float a = interval[0];
62+
float b = interval[1];
63+
float daa = 0;
64+
float dab = 0;
65+
float dbb = 0;
66+
float dax = 0;
67+
float dbx = 0;
68+
for (float v : target) {
69+
float k = Math.round((Math.min(Math.max(v, a), b) - a) * invStep);
70+
float s = k / (points - 1);
71+
float ms = 1f - s;
72+
daa = fma(ms, ms, daa);
73+
dab = fma(ms, s, dab);
74+
dbb = fma(s, s, dbb);
75+
dax = fma(ms, v, dax);
76+
dbx = fma(s, v, dbx);
77+
}
78+
pts[0] = daa;
79+
pts[1] = dab;
80+
pts[2] = dbb;
81+
pts[3] = dax;
82+
pts[4] = dbx;
83+
}
84+
85+
@Override
86+
public void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats) {
87+
float vecMean = 0;
88+
float vecVar = 0;
89+
float norm2 = 0;
90+
float min = Float.MAX_VALUE;
91+
float max = -Float.MAX_VALUE;
92+
for (int i = 0; i < target.length; i++) {
93+
centered[i] = target[i] - centroid[i];
94+
min = Math.min(min, centered[i]);
95+
max = Math.max(max, centered[i]);
96+
norm2 = fma(centered[i], centered[i], norm2);
97+
float delta = centered[i] - vecMean;
98+
vecMean += delta / (i + 1);
99+
float delta2 = centered[i] - vecMean;
100+
vecVar = fma(delta, delta2, vecVar);
101+
}
102+
stats[0] = vecMean;
103+
stats[1] = vecVar / target.length;
104+
stats[2] = norm2;
105+
stats[3] = min;
106+
stats[4] = max;
107+
}
108+
109+
@Override
110+
public void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats) {
111+
float vecMean = 0;
112+
float vecVar = 0;
113+
float norm2 = 0;
114+
float centroidDot = 0;
115+
float min = Float.MAX_VALUE;
116+
float max = -Float.MAX_VALUE;
117+
for (int i = 0; i < target.length; i++) {
118+
centroidDot = fma(target[i], centroid[i], centroidDot);
119+
centered[i] = target[i] - centroid[i];
120+
min = Math.min(min, centered[i]);
121+
max = Math.max(max, centered[i]);
122+
norm2 = fma(centered[i], centered[i], norm2);
123+
float delta = centered[i] - vecMean;
124+
vecMean += delta / (i + 1);
125+
float delta2 = centered[i] - vecMean;
126+
vecVar = fma(delta, delta2, vecVar);
127+
}
128+
stats[0] = vecMean;
129+
stats[1] = vecVar / target.length;
130+
stats[2] = norm2;
131+
stats[3] = min;
132+
stats[4] = max;
133+
stats[5] = centroidDot;
134+
}
135+
42136
public static int ipByteBitImpl(byte[] q, byte[] d) {
43137
assert q.length == d.length * Byte.SIZE;
44138
int acc0 = 0;

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,12 @@ public interface ESVectorUtilSupport {
1818
int ipByteBit(byte[] q, byte[] d);
1919

2020
float ipFloatBit(float[] q, byte[] d);
21+
22+
float calculateOSQLoss(float[] target, float[] interval, float step, float invStep, float norm2, float lambda);
23+
24+
void calculateOSQGridPoints(float[] target, float[] interval, int points, float invStep, float[] pts);
25+
26+
void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats);
27+
28+
void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats);
2129
}

0 commit comments

Comments
 (0)