Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
import org.elasticsearch.common.logging.LogConfigurator;
import org.elasticsearch.simdvec.internal.vectorization.ES91OSQVectorsScorer;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.simdvec.internal.vectorization;
package org.elasticsearch.simdvec;

import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.IndexInput;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@

package org.elasticsearch.simdvec;

import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.Constants;
import org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport;
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;

import java.io.IOException;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
Expand Down Expand Up @@ -41,6 +43,10 @@ public class ESVectorUtil {

private static final ESVectorUtilSupport IMPL = ESVectorizationProvider.getInstance().getVectorUtilSupport();

public static ES91OSQVectorsScorer getES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException {
return ESVectorizationProvider.getInstance().newES91OSQVectorsScorer(input, dimension);
}

public static long ipByteBinByte(byte[] q, byte[] d) {
if (q.length != d.length * B_QUERY) {
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + B_QUERY + " x " + d.length);
Expand Down Expand Up @@ -211,4 +217,40 @@ public static void centerAndCalculateOSQStatsDp(float[] target, float[] centroid
assert stats.length == 6;
IMPL.centerAndCalculateOSQStatsDp(target, centroid, centered, stats);
}

/**
* Calculates the difference between two vectors and stores the result in a third vector.
* @param v1 the first vector
* @param v2 the second vector
* @param result the result vector, must be the same length as the input vectors
*/
public static void subtract(float[] v1, float[] v2, float[] result) {
if (v1.length != v2.length) {
throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + v2.length);
}
if (result.length != v1.length) {
throw new IllegalArgumentException("vector dimensions differ: " + result.length + "!=" + v1.length);
}
for (int i = 0; i < v1.length; i++) {
result[i] = v1[i] - v2[i];
}
}

/**
* calculates the spill-over score for a vector and a centroid, given its residual with
* its actually nearest centroid
* @param v1 the vector
* @param centroid the centroid
* @param originalResidual the residual with the actually nearest centroid
* @return the spill-over score (soar)
*/
public static float soarResidual(float[] v1, float[] centroid, float[] originalResidual) {
if (v1.length != centroid.length) {
throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + centroid.length);
}
if (originalResidual.length != v1.length) {
throw new IllegalArgumentException("vector dimensions differ: " + originalResidual.length + "!=" + v1.length);
}
return IMPL.soarResidual(v1, centroid, originalResidual);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,18 @@ public void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float
stats[5] = centroidDot;
}

@Override
public float soarResidual(float[] v1, float[] centroid, float[] originalResidual) {
assert v1.length == centroid.length;
assert v1.length == originalResidual.length;
float proj = 0;
for (int i = 0; i < v1.length; i++) {
float djk = v1[i] - centroid[i];
proj = fma(djk, originalResidual[i], proj);
}
return proj;
}

public static int ipByteBitImpl(byte[] q, byte[] d) {
return ipByteBitImpl(q, d, 0);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
package org.elasticsearch.simdvec.internal.vectorization;

import org.apache.lucene.store.IndexInput;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;

import java.io.IOException;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,7 @@ public interface ESVectorUtilSupport {
void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats);

void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats);

float soarResidual(float[] v1, float[] centroid, float[] originalResidual);

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
package org.elasticsearch.simdvec.internal.vectorization;

import org.apache.lucene.store.IndexInput;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;

import java.io.IOException;
import java.util.Objects;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.apache.lucene.util.Constants;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;

import java.io.IOException;
import java.util.Locale;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;

import java.io.IOException;
import java.lang.foreign.MemorySegment;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,49 @@ public float calculateOSQLoss(float[] target, float[] interval, float step, floa
return (1f - lambda) * xe * xe / norm2 + lambda * e;
}

@Override
public float soarResidual(float[] v1, float[] centroid, float[] originalResidual) {
assert v1.length == centroid.length;
assert v1.length == originalResidual.length;
float proj = 0;
int i = 0;
if (v1.length > 2 * FLOAT_SPECIES.length()) {
FloatVector projVec1 = FloatVector.zero(FLOAT_SPECIES);
FloatVector projVec2 = FloatVector.zero(FLOAT_SPECIES);
int unrolledLimit = FLOAT_SPECIES.loopBound(v1.length) - FLOAT_SPECIES.length();
for (; i < unrolledLimit; i += 2 * FLOAT_SPECIES.length()) {
// one
FloatVector v1Vec0 = FloatVector.fromArray(FLOAT_SPECIES, v1, i);
FloatVector centroidVec0 = FloatVector.fromArray(FLOAT_SPECIES, centroid, i);
FloatVector originalResidualVec0 = FloatVector.fromArray(FLOAT_SPECIES, originalResidual, i);
FloatVector djkVec0 = v1Vec0.sub(centroidVec0);
projVec1 = fma(djkVec0, originalResidualVec0, projVec1);

// two
FloatVector v1Vec1 = FloatVector.fromArray(FLOAT_SPECIES, v1, i + FLOAT_SPECIES.length());
FloatVector centroidVec1 = FloatVector.fromArray(FLOAT_SPECIES, centroid, i + FLOAT_SPECIES.length());
FloatVector originalResidualVec1 = FloatVector.fromArray(FLOAT_SPECIES, originalResidual, i + FLOAT_SPECIES.length());
FloatVector djkVec1 = v1Vec1.sub(centroidVec1);
projVec2 = fma(djkVec1, originalResidualVec1, projVec2);
}
// vector tail
for (; i < FLOAT_SPECIES.loopBound(v1.length); i += FLOAT_SPECIES.length()) {
FloatVector v1Vec = FloatVector.fromArray(FLOAT_SPECIES, v1, i);
FloatVector centroidVec = FloatVector.fromArray(FLOAT_SPECIES, centroid, i);
FloatVector originalResidualVec = FloatVector.fromArray(FLOAT_SPECIES, originalResidual, i);
FloatVector djkVec = v1Vec.sub(centroidVec);
projVec1 = fma(djkVec, originalResidualVec, projVec1);
}
proj += projVec1.add(projVec2).reduceLanes(ADD);
}
// tail
for (; i < v1.length; i++) {
float djk = v1[i] - centroid[i];
proj = fma(djk, originalResidual[i], proj);
}
return proj;
}

private static final VectorSpecies<Byte> BYTE_SPECIES_128 = ByteVector.SPECIES_128;
private static final VectorSpecies<Byte> BYTE_SPECIES_256 = ByteVector.SPECIES_256;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.MemorySegmentAccessInput;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;

import java.io.IOException;
import java.lang.foreign.MemorySegment;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,22 @@ public void testOsqGridPoints() {
}
}

public void testSoarOverspillScore() {
int size = random().nextInt(128, 512);
float deltaEps = 1e-5f * size;
var vector = new float[size];
var centroid = new float[size];
var preResidual = new float[size];
for (int i = 0; i < size; ++i) {
vector[i] = random().nextFloat();
centroid[i] = random().nextFloat();
preResidual[i] = random().nextFloat();
}
var expected = defaultedProvider.getVectorUtilSupport().soarResidual(vector, centroid, preResidual);
var result = defOrPanamaProvider.getVectorUtilSupport().soarResidual(vector, centroid, preResidual);
assertEquals(expected, result, deltaEps);
}

void testIpByteBinImpl(ToLongBiFunction<byte[], byte[]> ipByteBinFunc) {
int iterations = atLeast(50);
for (int i = 0; i < iterations; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.MMapDirectory;
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;

import static org.hamcrest.Matchers.lessThan;

Expand Down
3 changes: 2 additions & 1 deletion server/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,8 @@
org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat,
org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat,
org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat,
org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat;
org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat,
org.elasticsearch.index.codec.vectors.IVFVectorsFormat;

provides org.apache.lucene.codecs.Codec
with
Expand Down
Loading