Skip to content

Commit 683b56d

Browse files
thecoopomricohenn
authored andcommitted
Panama implementation of painless float-byte vector ops (elastic#123270)
1 parent 88ebe4a commit 683b56d

File tree

2 files changed

+77
-31
lines changed

2 files changed

+77
-31
lines changed

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
package org.elasticsearch.simdvec.internal.vectorization;
1111

1212
import jdk.incubator.vector.ByteVector;
13+
import jdk.incubator.vector.FloatVector;
1314
import jdk.incubator.vector.IntVector;
1415
import jdk.incubator.vector.LongVector;
1516
import jdk.incubator.vector.VectorOperators;
@@ -60,6 +61,9 @@ public float ipFloatBit(float[] q, byte[] d) {
6061

6162
@Override
6263
public float ipFloatByte(float[] q, byte[] d) {
64+
if (BYTE_SPECIES_FOR_PREFFERED_FLOATS != null && q.length >= PREFERRED_FLOAT_SPECIES.length()) {
65+
return ipFloatByteImpl(q, d);
66+
}
6367
return DefaultESVectorUtilSupport.ipFloatByteImpl(q, d);
6468
}
6569

@@ -165,4 +169,41 @@ public static long ipByteBin128(byte[] q, byte[] d) {
165169
}
166170
return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
167171
}
172+
173+
private static final VectorSpecies<Float> PREFERRED_FLOAT_SPECIES = FloatVector.SPECIES_PREFERRED;
174+
private static final VectorSpecies<Byte> BYTE_SPECIES_FOR_PREFFERED_FLOATS;
175+
176+
static {
177+
VectorSpecies<Byte> byteForFloat;
178+
try {
179+
// calculate vector size to convert from single bytes to 4-byte floats
180+
byteForFloat = VectorSpecies.of(byte.class, VectorShape.forBitSize(PREFERRED_FLOAT_SPECIES.vectorBitSize() / Integer.BYTES));
181+
} catch (IllegalArgumentException e) {
182+
// can't get a byte vector size small enough, just use default impl
183+
byteForFloat = null;
184+
}
185+
BYTE_SPECIES_FOR_PREFFERED_FLOATS = byteForFloat;
186+
}
187+
188+
public static float ipFloatByteImpl(float[] q, byte[] d) {
189+
assert BYTE_SPECIES_FOR_PREFFERED_FLOATS != null;
190+
FloatVector acc = FloatVector.zero(PREFERRED_FLOAT_SPECIES);
191+
int i = 0;
192+
193+
int limit = PREFERRED_FLOAT_SPECIES.loopBound(q.length);
194+
for (; i < limit; i += PREFERRED_FLOAT_SPECIES.length()) {
195+
FloatVector qv = FloatVector.fromArray(PREFERRED_FLOAT_SPECIES, q, i);
196+
ByteVector bv = ByteVector.fromArray(BYTE_SPECIES_FOR_PREFFERED_FLOATS, d, i);
197+
acc = qv.fma(bv.castShape(PREFERRED_FLOAT_SPECIES, 0), acc);
198+
}
199+
200+
float sum = acc.reduceLanes(VectorOperators.ADD);
201+
202+
// handle the tail
203+
for (; i < q.length; i++) {
204+
sum += q[i] * d[i];
205+
}
206+
207+
return sum;
208+
}
168209
}

libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@
1313
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
1414

1515
import java.util.Arrays;
16+
import java.util.function.ToDoubleBiFunction;
17+
import java.util.function.ToLongBiFunction;
1618

1719
import static org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport.B_QUERY;
20+
import static org.hamcrest.Matchers.closeTo;
1821

1922
public class ESVectorUtilTests extends BaseVectorizationTests {
2023

@@ -40,8 +43,18 @@ public void testIpFloatBit() {
4043
}
4144

4245
public void testIpFloatByte() {
43-
float[] q = new float[16];
44-
byte[] d = new byte[16];
46+
testIpFloatByteImpl(ESVectorUtil::ipFloatByte);
47+
testIpFloatByteImpl(defaultedProvider.getVectorUtilSupport()::ipFloatByte);
48+
testIpFloatByteImpl(defOrPanamaProvider.getVectorUtilSupport()::ipFloatByte);
49+
}
50+
51+
private void testIpFloatByteImpl(ToDoubleBiFunction<float[], byte[]> impl) {
52+
int vectorSize = randomIntBetween(1, 1024);
53+
// scale the delta according to the vector size
54+
double delta = 1e-5 * vectorSize;
55+
56+
float[] q = new float[vectorSize];
57+
byte[] d = new byte[vectorSize];
4558
for (int i = 0; i < q.length; i++) {
4659
q[i] = random().nextFloat();
4760
}
@@ -51,7 +64,7 @@ public void testIpFloatByte() {
5164
for (int i = 0; i < q.length; i++) {
5265
expected += q[i] * d[i];
5366
}
54-
assertEquals(expected, ESVectorUtil.ipFloatByte(q, d), 1e-6);
67+
assertThat(impl.applyAsDouble(q, d), closeTo(expected, delta));
5568
}
5669

5770
public void testBitAndCount() {
@@ -74,65 +87,57 @@ public void testBasicIpByteBin() {
7487
testBasicIpByteBinImpl(defOrPanamaProvider.getVectorUtilSupport()::ipByteBinByte);
7588
}
7689

77-
interface IpByteBin {
78-
long apply(byte[] q, byte[] d);
79-
}
80-
81-
interface BitOps {
82-
long apply(byte[] q, byte[] d);
83-
}
84-
85-
void testBasicBitAndImpl(BitOps bitAnd) {
86-
assertEquals(0, bitAnd.apply(new byte[] { 0 }, new byte[] { 0 }));
87-
assertEquals(0, bitAnd.apply(new byte[] { 1 }, new byte[] { 0 }));
88-
assertEquals(0, bitAnd.apply(new byte[] { 0 }, new byte[] { 1 }));
89-
assertEquals(1, bitAnd.apply(new byte[] { 1 }, new byte[] { 1 }));
90+
void testBasicBitAndImpl(ToLongBiFunction<byte[], byte[]> bitAnd) {
91+
assertEquals(0, bitAnd.applyAsLong(new byte[] { 0 }, new byte[] { 0 }));
92+
assertEquals(0, bitAnd.applyAsLong(new byte[] { 1 }, new byte[] { 0 }));
93+
assertEquals(0, bitAnd.applyAsLong(new byte[] { 0 }, new byte[] { 1 }));
94+
assertEquals(1, bitAnd.applyAsLong(new byte[] { 1 }, new byte[] { 1 }));
9095
byte[] a = new byte[31];
9196
byte[] b = new byte[31];
9297
random().nextBytes(a);
9398
random().nextBytes(b);
9499
int expected = scalarBitAnd(a, b);
95-
assertEquals(expected, bitAnd.apply(a, b));
100+
assertEquals(expected, bitAnd.applyAsLong(a, b));
96101
}
97102

98-
void testBasicIpByteBinImpl(IpByteBin ipByteBinFunc) {
99-
assertEquals(15L, ipByteBinFunc.apply(new byte[] { 1, 1, 1, 1 }, new byte[] { 1 }));
100-
assertEquals(30L, ipByteBinFunc.apply(new byte[] { 1, 2, 1, 2, 1, 2, 1, 2 }, new byte[] { 1, 2 }));
103+
void testBasicIpByteBinImpl(ToLongBiFunction<byte[], byte[]> ipByteBinFunc) {
104+
assertEquals(15L, ipByteBinFunc.applyAsLong(new byte[] { 1, 1, 1, 1 }, new byte[] { 1 }));
105+
assertEquals(30L, ipByteBinFunc.applyAsLong(new byte[] { 1, 2, 1, 2, 1, 2, 1, 2 }, new byte[] { 1, 2 }));
101106

102107
var d = new byte[] { 1, 2, 3 };
103108
var q = new byte[] { 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3 };
104109
assert scalarIpByteBin(q, d) == 60L; // 4 + 8 + 16 + 32
105-
assertEquals(60L, ipByteBinFunc.apply(q, d));
110+
assertEquals(60L, ipByteBinFunc.applyAsLong(q, d));
106111

107112
d = new byte[] { 1, 2, 3, 4 };
108113
q = new byte[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 };
109114
assert scalarIpByteBin(q, d) == 75L; // 5 + 10 + 20 + 40
110-
assertEquals(75L, ipByteBinFunc.apply(q, d));
115+
assertEquals(75L, ipByteBinFunc.applyAsLong(q, d));
111116

112117
d = new byte[] { 1, 2, 3, 4, 5 };
113118
q = new byte[] { 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5 };
114119
assert scalarIpByteBin(q, d) == 105L; // 7 + 14 + 28 + 56
115-
assertEquals(105L, ipByteBinFunc.apply(q, d));
120+
assertEquals(105L, ipByteBinFunc.applyAsLong(q, d));
116121

117122
d = new byte[] { 1, 2, 3, 4, 5, 6 };
118123
q = new byte[] { 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6 };
119124
assert scalarIpByteBin(q, d) == 135L; // 9 + 18 + 36 + 72
120-
assertEquals(135L, ipByteBinFunc.apply(q, d));
125+
assertEquals(135L, ipByteBinFunc.applyAsLong(q, d));
121126

122127
d = new byte[] { 1, 2, 3, 4, 5, 6, 7 };
123128
q = new byte[] { 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7 };
124129
assert scalarIpByteBin(q, d) == 180L; // 12 + 24 + 48 + 96
125-
assertEquals(180L, ipByteBinFunc.apply(q, d));
130+
assertEquals(180L, ipByteBinFunc.applyAsLong(q, d));
126131

127132
d = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 };
128133
q = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8 };
129134
assert scalarIpByteBin(q, d) == 195L; // 13 + 26 + 52 + 104
130-
assertEquals(195L, ipByteBinFunc.apply(q, d));
135+
assertEquals(195L, ipByteBinFunc.applyAsLong(q, d));
131136

132137
d = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9 };
133138
q = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
134139
assert scalarIpByteBin(q, d) == 225L; // 15 + 30 + 60 + 120
135-
assertEquals(225L, ipByteBinFunc.apply(q, d));
140+
assertEquals(225L, ipByteBinFunc.applyAsLong(q, d));
136141
}
137142

138143
public void testIpByteBin() {
@@ -141,23 +146,23 @@ public void testIpByteBin() {
141146
testIpByteBinImpl(defOrPanamaProvider.getVectorUtilSupport()::ipByteBinByte);
142147
}
143148

144-
void testIpByteBinImpl(IpByteBin ipByteBinFunc) {
149+
void testIpByteBinImpl(ToLongBiFunction<byte[], byte[]> ipByteBinFunc) {
145150
int iterations = atLeast(50);
146151
for (int i = 0; i < iterations; i++) {
147152
int size = random().nextInt(5000);
148153
var d = new byte[size];
149154
var q = new byte[size * B_QUERY];
150155
random().nextBytes(d);
151156
random().nextBytes(q);
152-
assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.apply(q, d));
157+
assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.applyAsLong(q, d));
153158

154159
Arrays.fill(d, Byte.MAX_VALUE);
155160
Arrays.fill(q, Byte.MAX_VALUE);
156-
assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.apply(q, d));
161+
assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.applyAsLong(q, d));
157162

158163
Arrays.fill(d, Byte.MIN_VALUE);
159164
Arrays.fill(q, Byte.MIN_VALUE);
160-
assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.apply(q, d));
165+
assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.applyAsLong(q, d));
161166
}
162167
}
163168

0 commit comments

Comments
 (0)