From ff2e0e2a9395a607867e428ec5b6d43b9f6aaf83 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Mon, 24 Feb 2025 14:29:49 +0000 Subject: [PATCH 1/6] Panama implementation of float-byte vector operation --- .../PanamaESVectorUtilSupport.java | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java index 42462f62f6115..4a55300e3823a 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java @@ -10,6 +10,7 @@ package org.elasticsearch.simdvec.internal.vectorization; import jdk.incubator.vector.ByteVector; +import jdk.incubator.vector.FloatVector; import jdk.incubator.vector.IntVector; import jdk.incubator.vector.LongVector; import jdk.incubator.vector.VectorOperators; @@ -60,6 +61,9 @@ public float ipFloatBit(float[] q, byte[] d) { @Override public float ipFloatByte(float[] q, byte[] d) { + if (BYTE_FOR_FLOAT_SPECIES != null && q.length >= FLOAT_SPECIES.length()) { + return ipFloatByteImpl(q, d); + } return DefaultESVectorUtilSupport.ipFloatByteImpl(q, d); } @@ -165,4 +169,40 @@ public static long ipByteBin128(byte[] q, byte[] d) { } return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); } + + private static final VectorSpecies FLOAT_SPECIES = FloatVector.SPECIES_PREFERRED; + private static final VectorSpecies BYTE_FOR_FLOAT_SPECIES; + + static { + VectorSpecies byteForFloat; + try { + // calculate vector size to convert from single bytes to 4-byte floats + byteForFloat = VectorSpecies.of(byte.class, VectorShape.forBitSize(FLOAT_SPECIES.vectorBitSize() / 4)); + } catch (IllegalArgumentException e) { + // can't get a byte vector size small enough, just use default impl + byteForFloat = null; + } + BYTE_FOR_FLOAT_SPECIES = byteForFloat; + } + + public static float ipFloatByteImpl(float[] q, byte[] d) { + assert BYTE_FOR_FLOAT_SPECIES != null; + float sum = 0; + int i = 0; + + int limit = FLOAT_SPECIES.loopBound(q.length); + for (; i < limit; i += FLOAT_SPECIES.length()) { + FloatVector qv = FloatVector.fromArray(FLOAT_SPECIES, q, i); + ByteVector bv = ByteVector.fromArray(BYTE_FOR_FLOAT_SPECIES, d, i); + // no separate parts needed for the cast, as we've used a byte vector size 1/4th the float vector size + sum += qv.mul(bv.castShape(qv.species(), 0)).reduceLanes(VectorOperators.ADD); + } + + // handle the tail + for (; i < q.length; i++) { + sum += q[i] * d[i]; + } + + return sum; + } } From ddac590d064c519506bef9460172fdbb053e6b87 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Mon, 3 Mar 2025 17:13:44 +0000 Subject: [PATCH 2/6] Add some tests for the panama implementation --- .../simdvec/ESVectorUtilTests.java | 64 ++++++++++--------- 1 file changed, 33 insertions(+), 31 deletions(-) diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java index 7259d8204f071..e8c9644c73ed7 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java @@ -13,6 +13,8 @@ import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider; import java.util.Arrays; +import java.util.function.ToDoubleBiFunction; +import java.util.function.ToLongBiFunction; import static org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport.B_QUERY; @@ -40,8 +42,16 @@ public void testIpFloatBit() { } public void testIpFloatByte() { - float[] q = new float[16]; - byte[] d = new byte[16]; + testIpFloatByteImpl(ESVectorUtil::ipFloatByte); + testIpFloatByteImpl(defaultedProvider.getVectorUtilSupport()::ipFloatByte); + testIpFloatByteImpl(defOrPanamaProvider.getVectorUtilSupport()::ipFloatByte); + } + + private void testIpFloatByteImpl(ToDoubleBiFunction impl) { + int vectorSize = randomIntBetween(1, 1024); + + float[] q = new float[vectorSize]; + byte[] d = new byte[vectorSize]; for (int i = 0; i < q.length; i++) { q[i] = random().nextFloat(); } @@ -51,7 +61,7 @@ public void testIpFloatByte() { for (int i = 0; i < q.length; i++) { expected += q[i] * d[i]; } - assertEquals(expected, ESVectorUtil.ipFloatByte(q, d), 1e-6); + assertEquals(expected, impl.applyAsDouble(q, d), 1e-2); } public void testBitAndCount() { @@ -74,65 +84,57 @@ public void testBasicIpByteBin() { testBasicIpByteBinImpl(defOrPanamaProvider.getVectorUtilSupport()::ipByteBinByte); } - interface IpByteBin { - long apply(byte[] q, byte[] d); - } - - interface BitOps { - long apply(byte[] q, byte[] d); - } - - void testBasicBitAndImpl(BitOps bitAnd) { - assertEquals(0, bitAnd.apply(new byte[] { 0 }, new byte[] { 0 })); - assertEquals(0, bitAnd.apply(new byte[] { 1 }, new byte[] { 0 })); - assertEquals(0, bitAnd.apply(new byte[] { 0 }, new byte[] { 1 })); - assertEquals(1, bitAnd.apply(new byte[] { 1 }, new byte[] { 1 })); + void testBasicBitAndImpl(ToLongBiFunction bitAnd) { + assertEquals(0, bitAnd.applyAsLong(new byte[] { 0 }, new byte[] { 0 })); + assertEquals(0, bitAnd.applyAsLong(new byte[] { 1 }, new byte[] { 0 })); + assertEquals(0, bitAnd.applyAsLong(new byte[] { 0 }, new byte[] { 1 })); + assertEquals(1, bitAnd.applyAsLong(new byte[] { 1 }, new byte[] { 1 })); byte[] a = new byte[31]; byte[] b = new byte[31]; random().nextBytes(a); random().nextBytes(b); int expected = scalarBitAnd(a, b); - assertEquals(expected, bitAnd.apply(a, b)); + assertEquals(expected, bitAnd.applyAsLong(a, b)); } - void testBasicIpByteBinImpl(IpByteBin ipByteBinFunc) { - assertEquals(15L, ipByteBinFunc.apply(new byte[] { 1, 1, 1, 1 }, new byte[] { 1 })); - assertEquals(30L, ipByteBinFunc.apply(new byte[] { 1, 2, 1, 2, 1, 2, 1, 2 }, new byte[] { 1, 2 })); + void testBasicIpByteBinImpl(ToLongBiFunction ipByteBinFunc) { + assertEquals(15L, ipByteBinFunc.applyAsLong(new byte[] { 1, 1, 1, 1 }, new byte[] { 1 })); + assertEquals(30L, ipByteBinFunc.applyAsLong(new byte[] { 1, 2, 1, 2, 1, 2, 1, 2 }, new byte[] { 1, 2 })); var d = new byte[] { 1, 2, 3 }; var q = new byte[] { 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3 }; assert scalarIpByteBin(q, d) == 60L; // 4 + 8 + 16 + 32 - assertEquals(60L, ipByteBinFunc.apply(q, d)); + assertEquals(60L, ipByteBinFunc.applyAsLong(q, d)); d = new byte[] { 1, 2, 3, 4 }; q = new byte[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 }; assert scalarIpByteBin(q, d) == 75L; // 5 + 10 + 20 + 40 - assertEquals(75L, ipByteBinFunc.apply(q, d)); + assertEquals(75L, ipByteBinFunc.applyAsLong(q, d)); d = new byte[] { 1, 2, 3, 4, 5 }; q = new byte[] { 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5 }; assert scalarIpByteBin(q, d) == 105L; // 7 + 14 + 28 + 56 - assertEquals(105L, ipByteBinFunc.apply(q, d)); + assertEquals(105L, ipByteBinFunc.applyAsLong(q, d)); d = new byte[] { 1, 2, 3, 4, 5, 6 }; q = new byte[] { 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6 }; assert scalarIpByteBin(q, d) == 135L; // 9 + 18 + 36 + 72 - assertEquals(135L, ipByteBinFunc.apply(q, d)); + assertEquals(135L, ipByteBinFunc.applyAsLong(q, d)); d = new byte[] { 1, 2, 3, 4, 5, 6, 7 }; q = new byte[] { 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7 }; assert scalarIpByteBin(q, d) == 180L; // 12 + 24 + 48 + 96 - assertEquals(180L, ipByteBinFunc.apply(q, d)); + assertEquals(180L, ipByteBinFunc.applyAsLong(q, d)); d = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 }; q = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8 }; assert scalarIpByteBin(q, d) == 195L; // 13 + 26 + 52 + 104 - assertEquals(195L, ipByteBinFunc.apply(q, d)); + assertEquals(195L, ipByteBinFunc.applyAsLong(q, d)); d = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9 }; q = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9 }; assert scalarIpByteBin(q, d) == 225L; // 15 + 30 + 60 + 120 - assertEquals(225L, ipByteBinFunc.apply(q, d)); + assertEquals(225L, ipByteBinFunc.applyAsLong(q, d)); } public void testIpByteBin() { @@ -141,7 +143,7 @@ public void testIpByteBin() { testIpByteBinImpl(defOrPanamaProvider.getVectorUtilSupport()::ipByteBinByte); } - void testIpByteBinImpl(IpByteBin ipByteBinFunc) { + void testIpByteBinImpl(ToLongBiFunction ipByteBinFunc) { int iterations = atLeast(50); for (int i = 0; i < iterations; i++) { int size = random().nextInt(5000); @@ -149,15 +151,15 @@ void testIpByteBinImpl(IpByteBin ipByteBinFunc) { var q = new byte[size * B_QUERY]; random().nextBytes(d); random().nextBytes(q); - assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.apply(q, d)); + assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.applyAsLong(q, d)); Arrays.fill(d, Byte.MAX_VALUE); Arrays.fill(q, Byte.MAX_VALUE); - assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.apply(q, d)); + assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.applyAsLong(q, d)); Arrays.fill(d, Byte.MIN_VALUE); Arrays.fill(q, Byte.MIN_VALUE); - assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.apply(q, d)); + assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.applyAsLong(q, d)); } } From d6c2a6a726ddec38b1d3deda0eec014be6b5e39d Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Wed, 5 Mar 2025 13:06:25 +0000 Subject: [PATCH 3/6] Scale delta to the vector size --- .../java/org/elasticsearch/simdvec/ESVectorUtilTests.java | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java index e8c9644c73ed7..173cb0455a291 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java @@ -17,6 +17,7 @@ import java.util.function.ToLongBiFunction; import static org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport.B_QUERY; +import static org.hamcrest.Matchers.closeTo; public class ESVectorUtilTests extends BaseVectorizationTests { @@ -49,6 +50,8 @@ public void testIpFloatByte() { private void testIpFloatByteImpl(ToDoubleBiFunction impl) { int vectorSize = randomIntBetween(1, 1024); + // scale the delta according to the vector size + double delta = 1e-5 * vectorSize; float[] q = new float[vectorSize]; byte[] d = new byte[vectorSize]; @@ -61,7 +64,7 @@ private void testIpFloatByteImpl(ToDoubleBiFunction impl) { for (int i = 0; i < q.length; i++) { expected += q[i] * d[i]; } - assertEquals(expected, impl.applyAsDouble(q, d), 1e-2); + assertThat(impl.applyAsDouble(q, d), closeTo(expected, delta)); } public void testBitAndCount() { From d0c90ee7c45109e1b2de39caebd93bccb874f940 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Mon, 10 Mar 2025 10:15:23 +0000 Subject: [PATCH 4/6] Only reduce lanes at the end --- .../internal/vectorization/PanamaESVectorUtilSupport.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java index 4a55300e3823a..9f16176897df3 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java @@ -187,7 +187,7 @@ public static long ipByteBin128(byte[] q, byte[] d) { public static float ipFloatByteImpl(float[] q, byte[] d) { assert BYTE_FOR_FLOAT_SPECIES != null; - float sum = 0; + FloatVector vSum = FloatVector.zero(FLOAT_SPECIES); int i = 0; int limit = FLOAT_SPECIES.loopBound(q.length); @@ -195,9 +195,11 @@ public static float ipFloatByteImpl(float[] q, byte[] d) { FloatVector qv = FloatVector.fromArray(FLOAT_SPECIES, q, i); ByteVector bv = ByteVector.fromArray(BYTE_FOR_FLOAT_SPECIES, d, i); // no separate parts needed for the cast, as we've used a byte vector size 1/4th the float vector size - sum += qv.mul(bv.castShape(qv.species(), 0)).reduceLanes(VectorOperators.ADD); + vSum = qv.mul(bv.castShape(qv.species(), 0)).add(vSum); } + float sum = vSum.reduceLanes(VectorOperators.ADD); + // handle the tail for (; i < q.length; i++) { sum += q[i] * d[i]; From 1428aa874de5baaafa6c39590ca385d030c44f5e Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Mon, 10 Mar 2025 12:13:54 +0000 Subject: [PATCH 5/6] FMA FTW --- .../internal/vectorization/PanamaESVectorUtilSupport.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java index 9f16176897df3..b805f5d6f0735 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java @@ -195,7 +195,7 @@ public static float ipFloatByteImpl(float[] q, byte[] d) { FloatVector qv = FloatVector.fromArray(FLOAT_SPECIES, q, i); ByteVector bv = ByteVector.fromArray(BYTE_FOR_FLOAT_SPECIES, d, i); // no separate parts needed for the cast, as we've used a byte vector size 1/4th the float vector size - vSum = qv.mul(bv.castShape(qv.species(), 0)).add(vSum); + vSum = qv.fma(bv.castShape(qv.species(), 0), vSum); } float sum = vSum.reduceLanes(VectorOperators.ADD); From 1f7927d008531e9a46842c9a5702cf81b4f35e78 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Fri, 14 Mar 2025 12:59:47 +0000 Subject: [PATCH 6/6] Some tweaks --- .../PanamaESVectorUtilSupport.java | 27 +++++++++---------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java index b805f5d6f0735..8ef3f2a7f9881 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java @@ -61,7 +61,7 @@ public float ipFloatBit(float[] q, byte[] d) { @Override public float ipFloatByte(float[] q, byte[] d) { - if (BYTE_FOR_FLOAT_SPECIES != null && q.length >= FLOAT_SPECIES.length()) { + if (BYTE_SPECIES_FOR_PREFFERED_FLOATS != null && q.length >= PREFERRED_FLOAT_SPECIES.length()) { return ipFloatByteImpl(q, d); } return DefaultESVectorUtilSupport.ipFloatByteImpl(q, d); @@ -170,35 +170,34 @@ public static long ipByteBin128(byte[] q, byte[] d) { return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); } - private static final VectorSpecies FLOAT_SPECIES = FloatVector.SPECIES_PREFERRED; - private static final VectorSpecies BYTE_FOR_FLOAT_SPECIES; + private static final VectorSpecies PREFERRED_FLOAT_SPECIES = FloatVector.SPECIES_PREFERRED; + private static final VectorSpecies BYTE_SPECIES_FOR_PREFFERED_FLOATS; static { VectorSpecies byteForFloat; try { // calculate vector size to convert from single bytes to 4-byte floats - byteForFloat = VectorSpecies.of(byte.class, VectorShape.forBitSize(FLOAT_SPECIES.vectorBitSize() / 4)); + byteForFloat = VectorSpecies.of(byte.class, VectorShape.forBitSize(PREFERRED_FLOAT_SPECIES.vectorBitSize() / Integer.BYTES)); } catch (IllegalArgumentException e) { // can't get a byte vector size small enough, just use default impl byteForFloat = null; } - BYTE_FOR_FLOAT_SPECIES = byteForFloat; + BYTE_SPECIES_FOR_PREFFERED_FLOATS = byteForFloat; } public static float ipFloatByteImpl(float[] q, byte[] d) { - assert BYTE_FOR_FLOAT_SPECIES != null; - FloatVector vSum = FloatVector.zero(FLOAT_SPECIES); + assert BYTE_SPECIES_FOR_PREFFERED_FLOATS != null; + FloatVector acc = FloatVector.zero(PREFERRED_FLOAT_SPECIES); int i = 0; - int limit = FLOAT_SPECIES.loopBound(q.length); - for (; i < limit; i += FLOAT_SPECIES.length()) { - FloatVector qv = FloatVector.fromArray(FLOAT_SPECIES, q, i); - ByteVector bv = ByteVector.fromArray(BYTE_FOR_FLOAT_SPECIES, d, i); - // no separate parts needed for the cast, as we've used a byte vector size 1/4th the float vector size - vSum = qv.fma(bv.castShape(qv.species(), 0), vSum); + int limit = PREFERRED_FLOAT_SPECIES.loopBound(q.length); + for (; i < limit; i += PREFERRED_FLOAT_SPECIES.length()) { + FloatVector qv = FloatVector.fromArray(PREFERRED_FLOAT_SPECIES, q, i); + ByteVector bv = ByteVector.fromArray(BYTE_SPECIES_FOR_PREFFERED_FLOATS, d, i); + acc = qv.fma(bv.castShape(PREFERRED_FLOAT_SPECIES, 0), acc); } - float sum = vSum.reduceLanes(VectorOperators.ADD); + float sum = acc.reduceLanes(VectorOperators.ADD); // handle the tail for (; i < q.length; i++) {