diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java index 42462f62f6115..8ef3f2a7f9881 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java @@ -10,6 +10,7 @@ package org.elasticsearch.simdvec.internal.vectorization; import jdk.incubator.vector.ByteVector; +import jdk.incubator.vector.FloatVector; import jdk.incubator.vector.IntVector; import jdk.incubator.vector.LongVector; import jdk.incubator.vector.VectorOperators; @@ -60,6 +61,9 @@ public float ipFloatBit(float[] q, byte[] d) { @Override public float ipFloatByte(float[] q, byte[] d) { + if (BYTE_SPECIES_FOR_PREFFERED_FLOATS != null && q.length >= PREFERRED_FLOAT_SPECIES.length()) { + return ipFloatByteImpl(q, d); + } return DefaultESVectorUtilSupport.ipFloatByteImpl(q, d); } @@ -165,4 +169,41 @@ public static long ipByteBin128(byte[] q, byte[] d) { } return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); } + + private static final VectorSpecies PREFERRED_FLOAT_SPECIES = FloatVector.SPECIES_PREFERRED; + private static final VectorSpecies BYTE_SPECIES_FOR_PREFFERED_FLOATS; + + static { + VectorSpecies byteForFloat; + try { + // calculate vector size to convert from single bytes to 4-byte floats + byteForFloat = VectorSpecies.of(byte.class, VectorShape.forBitSize(PREFERRED_FLOAT_SPECIES.vectorBitSize() / Integer.BYTES)); + } catch (IllegalArgumentException e) { + // can't get a byte vector size small enough, just use default impl + byteForFloat = null; + } + BYTE_SPECIES_FOR_PREFFERED_FLOATS = byteForFloat; + } + + public static float ipFloatByteImpl(float[] q, byte[] d) { + assert BYTE_SPECIES_FOR_PREFFERED_FLOATS != null; + FloatVector acc = FloatVector.zero(PREFERRED_FLOAT_SPECIES); + int i = 0; + + int limit = PREFERRED_FLOAT_SPECIES.loopBound(q.length); + for (; i < limit; i += PREFERRED_FLOAT_SPECIES.length()) { + FloatVector qv = FloatVector.fromArray(PREFERRED_FLOAT_SPECIES, q, i); + ByteVector bv = ByteVector.fromArray(BYTE_SPECIES_FOR_PREFFERED_FLOATS, d, i); + acc = qv.fma(bv.castShape(PREFERRED_FLOAT_SPECIES, 0), acc); + } + + float sum = acc.reduceLanes(VectorOperators.ADD); + + // handle the tail + for (; i < q.length; i++) { + sum += q[i] * d[i]; + } + + return sum; + } } diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java index 7259d8204f071..173cb0455a291 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java @@ -13,8 +13,11 @@ import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider; import java.util.Arrays; +import java.util.function.ToDoubleBiFunction; +import java.util.function.ToLongBiFunction; import static org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport.B_QUERY; +import static org.hamcrest.Matchers.closeTo; public class ESVectorUtilTests extends BaseVectorizationTests { @@ -40,8 +43,18 @@ public void testIpFloatBit() { } public void testIpFloatByte() { - float[] q = new float[16]; - byte[] d = new byte[16]; + testIpFloatByteImpl(ESVectorUtil::ipFloatByte); + testIpFloatByteImpl(defaultedProvider.getVectorUtilSupport()::ipFloatByte); + testIpFloatByteImpl(defOrPanamaProvider.getVectorUtilSupport()::ipFloatByte); + } + + private void testIpFloatByteImpl(ToDoubleBiFunction impl) { + int vectorSize = randomIntBetween(1, 1024); + // scale the delta according to the vector size + double delta = 1e-5 * vectorSize; + + float[] q = new float[vectorSize]; + byte[] d = new byte[vectorSize]; for (int i = 0; i < q.length; i++) { q[i] = random().nextFloat(); } @@ -51,7 +64,7 @@ public void testIpFloatByte() { for (int i = 0; i < q.length; i++) { expected += q[i] * d[i]; } - assertEquals(expected, ESVectorUtil.ipFloatByte(q, d), 1e-6); + assertThat(impl.applyAsDouble(q, d), closeTo(expected, delta)); } public void testBitAndCount() { @@ -74,65 +87,57 @@ public void testBasicIpByteBin() { testBasicIpByteBinImpl(defOrPanamaProvider.getVectorUtilSupport()::ipByteBinByte); } - interface IpByteBin { - long apply(byte[] q, byte[] d); - } - - interface BitOps { - long apply(byte[] q, byte[] d); - } - - void testBasicBitAndImpl(BitOps bitAnd) { - assertEquals(0, bitAnd.apply(new byte[] { 0 }, new byte[] { 0 })); - assertEquals(0, bitAnd.apply(new byte[] { 1 }, new byte[] { 0 })); - assertEquals(0, bitAnd.apply(new byte[] { 0 }, new byte[] { 1 })); - assertEquals(1, bitAnd.apply(new byte[] { 1 }, new byte[] { 1 })); + void testBasicBitAndImpl(ToLongBiFunction bitAnd) { + assertEquals(0, bitAnd.applyAsLong(new byte[] { 0 }, new byte[] { 0 })); + assertEquals(0, bitAnd.applyAsLong(new byte[] { 1 }, new byte[] { 0 })); + assertEquals(0, bitAnd.applyAsLong(new byte[] { 0 }, new byte[] { 1 })); + assertEquals(1, bitAnd.applyAsLong(new byte[] { 1 }, new byte[] { 1 })); byte[] a = new byte[31]; byte[] b = new byte[31]; random().nextBytes(a); random().nextBytes(b); int expected = scalarBitAnd(a, b); - assertEquals(expected, bitAnd.apply(a, b)); + assertEquals(expected, bitAnd.applyAsLong(a, b)); } - void testBasicIpByteBinImpl(IpByteBin ipByteBinFunc) { - assertEquals(15L, ipByteBinFunc.apply(new byte[] { 1, 1, 1, 1 }, new byte[] { 1 })); - assertEquals(30L, ipByteBinFunc.apply(new byte[] { 1, 2, 1, 2, 1, 2, 1, 2 }, new byte[] { 1, 2 })); + void testBasicIpByteBinImpl(ToLongBiFunction ipByteBinFunc) { + assertEquals(15L, ipByteBinFunc.applyAsLong(new byte[] { 1, 1, 1, 1 }, new byte[] { 1 })); + assertEquals(30L, ipByteBinFunc.applyAsLong(new byte[] { 1, 2, 1, 2, 1, 2, 1, 2 }, new byte[] { 1, 2 })); var d = new byte[] { 1, 2, 3 }; var q = new byte[] { 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3 }; assert scalarIpByteBin(q, d) == 60L; // 4 + 8 + 16 + 32 - assertEquals(60L, ipByteBinFunc.apply(q, d)); + assertEquals(60L, ipByteBinFunc.applyAsLong(q, d)); d = new byte[] { 1, 2, 3, 4 }; q = new byte[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 }; assert scalarIpByteBin(q, d) == 75L; // 5 + 10 + 20 + 40 - assertEquals(75L, ipByteBinFunc.apply(q, d)); + assertEquals(75L, ipByteBinFunc.applyAsLong(q, d)); d = new byte[] { 1, 2, 3, 4, 5 }; q = new byte[] { 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5 }; assert scalarIpByteBin(q, d) == 105L; // 7 + 14 + 28 + 56 - assertEquals(105L, ipByteBinFunc.apply(q, d)); + assertEquals(105L, ipByteBinFunc.applyAsLong(q, d)); d = new byte[] { 1, 2, 3, 4, 5, 6 }; q = new byte[] { 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6 }; assert scalarIpByteBin(q, d) == 135L; // 9 + 18 + 36 + 72 - assertEquals(135L, ipByteBinFunc.apply(q, d)); + assertEquals(135L, ipByteBinFunc.applyAsLong(q, d)); d = new byte[] { 1, 2, 3, 4, 5, 6, 7 }; q = new byte[] { 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7 }; assert scalarIpByteBin(q, d) == 180L; // 12 + 24 + 48 + 96 - assertEquals(180L, ipByteBinFunc.apply(q, d)); + assertEquals(180L, ipByteBinFunc.applyAsLong(q, d)); d = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 }; q = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8 }; assert scalarIpByteBin(q, d) == 195L; // 13 + 26 + 52 + 104 - assertEquals(195L, ipByteBinFunc.apply(q, d)); + assertEquals(195L, ipByteBinFunc.applyAsLong(q, d)); d = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9 }; q = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9 }; assert scalarIpByteBin(q, d) == 225L; // 15 + 30 + 60 + 120 - assertEquals(225L, ipByteBinFunc.apply(q, d)); + assertEquals(225L, ipByteBinFunc.applyAsLong(q, d)); } public void testIpByteBin() { @@ -141,7 +146,7 @@ public void testIpByteBin() { testIpByteBinImpl(defOrPanamaProvider.getVectorUtilSupport()::ipByteBinByte); } - void testIpByteBinImpl(IpByteBin ipByteBinFunc) { + void testIpByteBinImpl(ToLongBiFunction ipByteBinFunc) { int iterations = atLeast(50); for (int i = 0; i < iterations; i++) { int size = random().nextInt(5000); @@ -149,15 +154,15 @@ void testIpByteBinImpl(IpByteBin ipByteBinFunc) { var q = new byte[size * B_QUERY]; random().nextBytes(d); random().nextBytes(q); - assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.apply(q, d)); + assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.applyAsLong(q, d)); Arrays.fill(d, Byte.MAX_VALUE); Arrays.fill(q, Byte.MAX_VALUE); - assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.apply(q, d)); + assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.applyAsLong(q, d)); Arrays.fill(d, Byte.MIN_VALUE); Arrays.fill(q, Byte.MIN_VALUE); - assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.apply(q, d)); + assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.applyAsLong(q, d)); } }