Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
package org.elasticsearch.simdvec.internal.vectorization;

import jdk.incubator.vector.ByteVector;
import jdk.incubator.vector.FloatVector;
import jdk.incubator.vector.IntVector;
import jdk.incubator.vector.LongVector;
import jdk.incubator.vector.VectorOperators;
Expand Down Expand Up @@ -60,6 +61,9 @@ public float ipFloatBit(float[] q, byte[] d) {

@Override
public float ipFloatByte(float[] q, byte[] d) {
if (BYTE_FOR_FLOAT_SPECIES != null && q.length >= FLOAT_SPECIES.length()) {
return ipFloatByteImpl(q, d);
}
return DefaultESVectorUtilSupport.ipFloatByteImpl(q, d);
}

Expand Down Expand Up @@ -165,4 +169,40 @@ public static long ipByteBin128(byte[] q, byte[] d) {
}
return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
}

private static final VectorSpecies<Float> FLOAT_SPECIES = FloatVector.SPECIES_PREFERRED;
private static final VectorSpecies<Byte> BYTE_FOR_FLOAT_SPECIES;

static {
VectorSpecies<Byte> byteForFloat;
try {
// calculate vector size to convert from single bytes to 4-byte floats
byteForFloat = VectorSpecies.of(byte.class, VectorShape.forBitSize(FLOAT_SPECIES.vectorBitSize() / 4));
} catch (IllegalArgumentException e) {
// can't get a byte vector size small enough, just use default impl
byteForFloat = null;
}
BYTE_FOR_FLOAT_SPECIES = byteForFloat;
}

public static float ipFloatByteImpl(float[] q, byte[] d) {
assert BYTE_FOR_FLOAT_SPECIES != null;
float sum = 0;
int i = 0;

int limit = FLOAT_SPECIES.loopBound(q.length);
for (; i < limit; i += FLOAT_SPECIES.length()) {
FloatVector qv = FloatVector.fromArray(FLOAT_SPECIES, q, i);
ByteVector bv = ByteVector.fromArray(BYTE_FOR_FLOAT_SPECIES, d, i);
// no separate parts needed for the cast, as we've used a byte vector size 1/4th the float vector size
sum += qv.mul(bv.castShape(qv.species(), 0)).reduceLanes(VectorOperators.ADD);
}

// handle the tail
for (; i < q.length; i++) {
sum += q[i] * d[i];
}

return sum;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;

import java.util.Arrays;
import java.util.function.ToDoubleBiFunction;
import java.util.function.ToLongBiFunction;

import static org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport.B_QUERY;

Expand Down Expand Up @@ -40,8 +42,16 @@ public void testIpFloatBit() {
}

public void testIpFloatByte() {
float[] q = new float[16];
byte[] d = new byte[16];
testIpFloatByteImpl(ESVectorUtil::ipFloatByte);
testIpFloatByteImpl(defaultedProvider.getVectorUtilSupport()::ipFloatByte);
testIpFloatByteImpl(defOrPanamaProvider.getVectorUtilSupport()::ipFloatByte);
}

private void testIpFloatByteImpl(ToDoubleBiFunction<float[], byte[]> impl) {
int vectorSize = randomIntBetween(1, 1024);

float[] q = new float[vectorSize];
byte[] d = new byte[vectorSize];
for (int i = 0; i < q.length; i++) {
q[i] = random().nextFloat();
}
Expand All @@ -51,7 +61,7 @@ public void testIpFloatByte() {
for (int i = 0; i < q.length; i++) {
expected += q[i] * d[i];
}
assertEquals(expected, ESVectorUtil.ipFloatByte(q, d), 1e-6);
assertEquals(expected, impl.applyAsDouble(q, d), 1e-2);
}

public void testBitAndCount() {
Expand All @@ -74,65 +84,57 @@ public void testBasicIpByteBin() {
testBasicIpByteBinImpl(defOrPanamaProvider.getVectorUtilSupport()::ipByteBinByte);
}

interface IpByteBin {
long apply(byte[] q, byte[] d);
}

interface BitOps {
long apply(byte[] q, byte[] d);
}

void testBasicBitAndImpl(BitOps bitAnd) {
assertEquals(0, bitAnd.apply(new byte[] { 0 }, new byte[] { 0 }));
assertEquals(0, bitAnd.apply(new byte[] { 1 }, new byte[] { 0 }));
assertEquals(0, bitAnd.apply(new byte[] { 0 }, new byte[] { 1 }));
assertEquals(1, bitAnd.apply(new byte[] { 1 }, new byte[] { 1 }));
void testBasicBitAndImpl(ToLongBiFunction<byte[], byte[]> bitAnd) {
assertEquals(0, bitAnd.applyAsLong(new byte[] { 0 }, new byte[] { 0 }));
assertEquals(0, bitAnd.applyAsLong(new byte[] { 1 }, new byte[] { 0 }));
assertEquals(0, bitAnd.applyAsLong(new byte[] { 0 }, new byte[] { 1 }));
assertEquals(1, bitAnd.applyAsLong(new byte[] { 1 }, new byte[] { 1 }));
byte[] a = new byte[31];
byte[] b = new byte[31];
random().nextBytes(a);
random().nextBytes(b);
int expected = scalarBitAnd(a, b);
assertEquals(expected, bitAnd.apply(a, b));
assertEquals(expected, bitAnd.applyAsLong(a, b));
}

void testBasicIpByteBinImpl(IpByteBin ipByteBinFunc) {
assertEquals(15L, ipByteBinFunc.apply(new byte[] { 1, 1, 1, 1 }, new byte[] { 1 }));
assertEquals(30L, ipByteBinFunc.apply(new byte[] { 1, 2, 1, 2, 1, 2, 1, 2 }, new byte[] { 1, 2 }));
void testBasicIpByteBinImpl(ToLongBiFunction<byte[], byte[]> ipByteBinFunc) {
assertEquals(15L, ipByteBinFunc.applyAsLong(new byte[] { 1, 1, 1, 1 }, new byte[] { 1 }));
assertEquals(30L, ipByteBinFunc.applyAsLong(new byte[] { 1, 2, 1, 2, 1, 2, 1, 2 }, new byte[] { 1, 2 }));

var d = new byte[] { 1, 2, 3 };
var q = new byte[] { 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3 };
assert scalarIpByteBin(q, d) == 60L; // 4 + 8 + 16 + 32
assertEquals(60L, ipByteBinFunc.apply(q, d));
assertEquals(60L, ipByteBinFunc.applyAsLong(q, d));

d = new byte[] { 1, 2, 3, 4 };
q = new byte[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 };
assert scalarIpByteBin(q, d) == 75L; // 5 + 10 + 20 + 40
assertEquals(75L, ipByteBinFunc.apply(q, d));
assertEquals(75L, ipByteBinFunc.applyAsLong(q, d));

d = new byte[] { 1, 2, 3, 4, 5 };
q = new byte[] { 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5 };
assert scalarIpByteBin(q, d) == 105L; // 7 + 14 + 28 + 56
assertEquals(105L, ipByteBinFunc.apply(q, d));
assertEquals(105L, ipByteBinFunc.applyAsLong(q, d));

d = new byte[] { 1, 2, 3, 4, 5, 6 };
q = new byte[] { 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6 };
assert scalarIpByteBin(q, d) == 135L; // 9 + 18 + 36 + 72
assertEquals(135L, ipByteBinFunc.apply(q, d));
assertEquals(135L, ipByteBinFunc.applyAsLong(q, d));

d = new byte[] { 1, 2, 3, 4, 5, 6, 7 };
q = new byte[] { 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7 };
assert scalarIpByteBin(q, d) == 180L; // 12 + 24 + 48 + 96
assertEquals(180L, ipByteBinFunc.apply(q, d));
assertEquals(180L, ipByteBinFunc.applyAsLong(q, d));

d = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 };
q = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8 };
assert scalarIpByteBin(q, d) == 195L; // 13 + 26 + 52 + 104
assertEquals(195L, ipByteBinFunc.apply(q, d));
assertEquals(195L, ipByteBinFunc.applyAsLong(q, d));

d = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9 };
q = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
assert scalarIpByteBin(q, d) == 225L; // 15 + 30 + 60 + 120
assertEquals(225L, ipByteBinFunc.apply(q, d));
assertEquals(225L, ipByteBinFunc.applyAsLong(q, d));
}

public void testIpByteBin() {
Expand All @@ -141,23 +143,23 @@ public void testIpByteBin() {
testIpByteBinImpl(defOrPanamaProvider.getVectorUtilSupport()::ipByteBinByte);
}

void testIpByteBinImpl(IpByteBin ipByteBinFunc) {
void testIpByteBinImpl(ToLongBiFunction<byte[], byte[]> ipByteBinFunc) {
int iterations = atLeast(50);
for (int i = 0; i < iterations; i++) {
int size = random().nextInt(5000);
var d = new byte[size];
var q = new byte[size * B_QUERY];
random().nextBytes(d);
random().nextBytes(q);
assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.apply(q, d));
assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.applyAsLong(q, d));

Arrays.fill(d, Byte.MAX_VALUE);
Arrays.fill(q, Byte.MAX_VALUE);
assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.apply(q, d));
assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.applyAsLong(q, d));

Arrays.fill(d, Byte.MIN_VALUE);
Arrays.fill(q, Byte.MIN_VALUE);
assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.apply(q, d));
assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.applyAsLong(q, d));
}
}

Expand Down