|
9 | 9 |
|
10 | 10 | package org.elasticsearch.simdvec; |
11 | 11 |
|
| 12 | +import org.apache.lucene.store.IndexInput; |
12 | 13 | import org.apache.lucene.util.BitUtil; |
13 | 14 | import org.apache.lucene.util.Constants; |
14 | 15 | import org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport; |
15 | 16 | import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider; |
16 | 17 |
|
| 18 | +import java.io.IOException; |
17 | 19 | import java.lang.invoke.MethodHandle; |
18 | 20 | import java.lang.invoke.MethodHandles; |
19 | 21 | import java.lang.invoke.MethodType; |
@@ -41,6 +43,10 @@ public class ESVectorUtil { |
41 | 43 |
|
42 | 44 | private static final ESVectorUtilSupport IMPL = ESVectorizationProvider.getInstance().getVectorUtilSupport(); |
43 | 45 |
|
| 46 | + public static ES91OSQVectorsScorer getES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException { |
| 47 | + return ESVectorizationProvider.getInstance().newES91OSQVectorsScorer(input, dimension); |
| 48 | + } |
| 49 | + |
44 | 50 | public static long ipByteBinByte(byte[] q, byte[] d) { |
45 | 51 | if (q.length != d.length * B_QUERY) { |
46 | 52 | throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + B_QUERY + " x " + d.length); |
@@ -211,4 +217,40 @@ public static void centerAndCalculateOSQStatsDp(float[] target, float[] centroid |
211 | 217 | assert stats.length == 6; |
212 | 218 | IMPL.centerAndCalculateOSQStatsDp(target, centroid, centered, stats); |
213 | 219 | } |
| 220 | + |
| 221 | + /** |
| 222 | + * Calculates the difference between two vectors and stores the result in a third vector. |
| 223 | + * @param v1 the first vector |
| 224 | + * @param v2 the second vector |
| 225 | + * @param result the result vector, must be the same length as the input vectors |
| 226 | + */ |
| 227 | + public static void subtract(float[] v1, float[] v2, float[] result) { |
| 228 | + if (v1.length != v2.length) { |
| 229 | + throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + v2.length); |
| 230 | + } |
| 231 | + if (result.length != v1.length) { |
| 232 | + throw new IllegalArgumentException("vector dimensions differ: " + result.length + "!=" + v1.length); |
| 233 | + } |
| 234 | + for (int i = 0; i < v1.length; i++) { |
| 235 | + result[i] = v1[i] - v2[i]; |
| 236 | + } |
| 237 | + } |
| 238 | + |
| 239 | + /** |
| 240 | + * calculates the spill-over score for a vector and a centroid, given its residual with |
| 241 | + * its actually nearest centroid |
| 242 | + * @param v1 the vector |
| 243 | + * @param centroid the centroid |
| 244 | + * @param originalResidual the residual with the actually nearest centroid |
| 245 | + * @return the spill-over score (soar) |
| 246 | + */ |
| 247 | + public static float soarResidual(float[] v1, float[] centroid, float[] originalResidual) { |
| 248 | + if (v1.length != centroid.length) { |
| 249 | + throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + centroid.length); |
| 250 | + } |
| 251 | + if (originalResidual.length != v1.length) { |
| 252 | + throw new IllegalArgumentException("vector dimensions differ: " + originalResidual.length + "!=" + v1.length); |
| 253 | + } |
| 254 | + return IMPL.soarResidual(v1, centroid, originalResidual); |
| 255 | + } |
214 | 256 | } |
0 commit comments