Skip to content

Commit b02cd72

Browse files
committed
Adds accelerates optimized scalar quantization with vectorized functions
1 parent bc0e1d0 commit b02cd72

File tree

14 files changed

+619
-99
lines changed

14 files changed

+619
-99
lines changed

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,4 +144,71 @@ static int andBitCountLong(byte[] a, byte[] b) {
144144
}
145145
return distance;
146146
}
147+
148+
/**
149+
* Calculate the loss for optimized-scalar quantization for the given parameteres
150+
* @param target The vector being quantized, assumed to be centered
151+
* @param interval The interval for which to calculate the loss
152+
* @param points the quantization points
153+
* @param norm2 The norm squared of the target vector
154+
* @param lambda The lambda parameter for controlling anisotropic loss calculation
155+
* @return The loss for the given parameters
156+
*/
157+
public static float calculateOSQLoss(float[] target, float[] interval, int points, float norm2, float lambda) {
158+
assert interval.length == 2;
159+
float step = ((interval[1] - interval[0]) / (points - 1.0F));
160+
float invStep = 1f / step;
161+
return IMPL.calculateOSQLoss(target, interval, step, invStep, norm2, lambda);
162+
}
163+
164+
/**
165+
* Calculate the grid points for optimized-scalar quantization
166+
* @param target The vector being quantized, assumed to be centered
167+
* @param interval The interval for which to calculate the grid points
168+
* @param points the quantization points
169+
* @param pts The array to store the grid points, must be of length 5
170+
*/
171+
public static void calculateOSQGridPoints(float[] target, float[] interval, int points, float[] pts) {
172+
assert interval.length == 2;
173+
assert pts.length == 5;
174+
float invStep = (points - 1.0F) / (interval[1] - interval[0]);
175+
IMPL.calculateOSQGridPoints(target, interval, points, invStep, pts);
176+
}
177+
178+
/**
179+
* Center the target vector and calculate the optimized-scalar quantization statistics
180+
* @param target The vector being quantized
181+
* @param centroid The centroid of the target vector
182+
* @param centered The destination of the centered vector, will be overwritten
183+
* @param stats The array to store the statistics, must be of length 5
184+
*/
185+
public static void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats) {
186+
assert target.length == centroid.length;
187+
assert stats.length == 5;
188+
if (target.length != centroid.length) {
189+
throw new IllegalArgumentException("vector dimensions differ: " + target.length + "!=" + centroid.length);
190+
}
191+
if (centered.length != target.length) {
192+
throw new IllegalArgumentException("vector dimensions differ: " + centered.length + "!=" + target.length);
193+
}
194+
IMPL.centerAndCalculateOSQStatsEuclidean(target, centroid, centered, stats);
195+
}
196+
197+
/**
198+
* Center the target vector and calculate the optimized-scalar quantization statistics
199+
* @param target The vector being quantized
200+
* @param centroid The centroid of the target vector
201+
* @param centered The destination of the centered vector, will be overwritten
202+
* @param stats The array to store the statistics, must be of length 6
203+
*/
204+
public static void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats) {
205+
if (target.length != centroid.length) {
206+
throw new IllegalArgumentException("vector dimensions differ: " + target.length + "!=" + centroid.length);
207+
}
208+
if (centered.length != target.length) {
209+
throw new IllegalArgumentException("vector dimensions differ: " + centered.length + "!=" + target.length);
210+
}
211+
assert stats.length == 6;
212+
IMPL.centerAndCalculateOSQStatsDp(target, centroid, centered, stats);
213+
}
147214
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,100 @@ public float ipFloatByte(float[] q, byte[] d) {
4444
return ipFloatByteImpl(q, d);
4545
}
4646

47+
@Override
48+
public float calculateOSQLoss(float[] target, float[] interval, float step, float invStep, float norm2, float lambda) {
49+
float a = interval[0];
50+
float b = interval[1];
51+
float xe = 0f;
52+
float e = 0f;
53+
for (float xi : target) {
54+
// this is quantizing and then dequantizing the vector
55+
float xiq = fma(step, Math.round((Math.min(Math.max(xi, a), b) - a) * invStep), a);
56+
// how much does the de-quantized value differ from the original value
57+
float xiiq = xi - xiq;
58+
e = fma(xiiq, xiiq, e);
59+
xe = fma(xi, xiiq, xe);
60+
}
61+
return (1f - lambda) * xe * xe / norm2 + lambda * e;
62+
}
63+
64+
@Override
65+
public void calculateOSQGridPoints(float[] target, float[] interval, int points, float invStep, float[] pts) {
66+
float a = interval[0];
67+
float b = interval[1];
68+
float daa = 0;
69+
float dab = 0;
70+
float dbb = 0;
71+
float dax = 0;
72+
float dbx = 0;
73+
for (float v : target) {
74+
float k = Math.round((Math.min(Math.max(v, a), b) - a) * invStep);
75+
float s = k / (points - 1);
76+
float ms = 1f - s;
77+
daa = fma(ms, ms, daa);
78+
dab = fma(ms, s, dab);
79+
dbb = fma(s, s, dbb);
80+
dax = fma(ms, v, dax);
81+
dbx = fma(s, v, dbx);
82+
}
83+
pts[0] = daa;
84+
pts[1] = dab;
85+
pts[2] = dbb;
86+
pts[3] = dax;
87+
pts[4] = dbx;
88+
}
89+
90+
@Override
91+
public void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats) {
92+
float vecMean = 0;
93+
float vecVar = 0;
94+
float norm2 = 0;
95+
float min = Float.MAX_VALUE;
96+
float max = -Float.MAX_VALUE;
97+
for (int i = 0; i < target.length; i++) {
98+
centered[i] = target[i] - centroid[i];
99+
min = Math.min(min, centered[i]);
100+
max = Math.max(max, centered[i]);
101+
norm2 = fma(centered[i], centered[i], norm2);
102+
float delta = centered[i] - vecMean;
103+
vecMean += delta / (i + 1);
104+
float delta2 = centered[i] - vecMean;
105+
vecVar = fma(delta, delta2, vecVar);
106+
}
107+
stats[0] = vecMean;
108+
stats[1] = vecVar / target.length;
109+
stats[2] = norm2;
110+
stats[3] = min;
111+
stats[4] = max;
112+
}
113+
114+
@Override
115+
public void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats) {
116+
float vecMean = 0;
117+
float vecVar = 0;
118+
float norm2 = 0;
119+
float centroidDot = 0;
120+
float min = Float.MAX_VALUE;
121+
float max = -Float.MAX_VALUE;
122+
for (int i = 0; i < target.length; i++) {
123+
centroidDot = fma(target[i], centroid[i], centroidDot);
124+
centered[i] = target[i] - centroid[i];
125+
min = Math.min(min, centered[i]);
126+
max = Math.max(max, centered[i]);
127+
norm2 = fma(centered[i], centered[i], norm2);
128+
float delta = centered[i] - vecMean;
129+
vecMean += delta / (i + 1);
130+
float delta2 = centered[i] - vecMean;
131+
vecVar = fma(delta, delta2, vecVar);
132+
}
133+
stats[0] = vecMean;
134+
stats[1] = vecVar / target.length;
135+
stats[2] = norm2;
136+
stats[3] = min;
137+
stats[4] = max;
138+
stats[5] = centroidDot;
139+
}
140+
47141
public static int ipByteBitImpl(byte[] q, byte[] d) {
48142
return ipByteBitImpl(q, d, 0);
49143
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,12 @@ public interface ESVectorUtilSupport {
2020
float ipFloatBit(float[] q, byte[] d);
2121

2222
float ipFloatByte(float[] q, byte[] d);
23+
24+
float calculateOSQLoss(float[] target, float[] interval, float step, float invStep, float norm2, float lambda);
25+
26+
void calculateOSQGridPoints(float[] target, float[] interval, int points, float invStep, float[] pts);
27+
28+
void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats);
29+
30+
void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats);
2331
}

0 commit comments

Comments
 (0)