Skip to content

Commit ff2e0e2

Browse files
committed
Panama implementation of float-byte vector operation
1 parent 8baba58 commit ff2e0e2

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
package org.elasticsearch.simdvec.internal.vectorization;
1111

1212
import jdk.incubator.vector.ByteVector;
13+
import jdk.incubator.vector.FloatVector;
1314
import jdk.incubator.vector.IntVector;
1415
import jdk.incubator.vector.LongVector;
1516
import jdk.incubator.vector.VectorOperators;
@@ -60,6 +61,9 @@ public float ipFloatBit(float[] q, byte[] d) {
6061

6162
@Override
6263
public float ipFloatByte(float[] q, byte[] d) {
64+
if (BYTE_FOR_FLOAT_SPECIES != null && q.length >= FLOAT_SPECIES.length()) {
65+
return ipFloatByteImpl(q, d);
66+
}
6367
return DefaultESVectorUtilSupport.ipFloatByteImpl(q, d);
6468
}
6569

@@ -165,4 +169,40 @@ public static long ipByteBin128(byte[] q, byte[] d) {
165169
}
166170
return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
167171
}
172+
173+
private static final VectorSpecies<Float> FLOAT_SPECIES = FloatVector.SPECIES_PREFERRED;
174+
private static final VectorSpecies<Byte> BYTE_FOR_FLOAT_SPECIES;
175+
176+
static {
177+
VectorSpecies<Byte> byteForFloat;
178+
try {
179+
// calculate vector size to convert from single bytes to 4-byte floats
180+
byteForFloat = VectorSpecies.of(byte.class, VectorShape.forBitSize(FLOAT_SPECIES.vectorBitSize() / 4));
181+
} catch (IllegalArgumentException e) {
182+
// can't get a byte vector size small enough, just use default impl
183+
byteForFloat = null;
184+
}
185+
BYTE_FOR_FLOAT_SPECIES = byteForFloat;
186+
}
187+
188+
public static float ipFloatByteImpl(float[] q, byte[] d) {
189+
assert BYTE_FOR_FLOAT_SPECIES != null;
190+
float sum = 0;
191+
int i = 0;
192+
193+
int limit = FLOAT_SPECIES.loopBound(q.length);
194+
for (; i < limit; i += FLOAT_SPECIES.length()) {
195+
FloatVector qv = FloatVector.fromArray(FLOAT_SPECIES, q, i);
196+
ByteVector bv = ByteVector.fromArray(BYTE_FOR_FLOAT_SPECIES, d, i);
197+
// no separate parts needed for the cast, as we've used a byte vector size 1/4th the float vector size
198+
sum += qv.mul(bv.castShape(qv.species(), 0)).reduceLanes(VectorOperators.ADD);
199+
}
200+
201+
// handle the tail
202+
for (; i < q.length; i++) {
203+
sum += q[i] * d[i];
204+
}
205+
206+
return sum;
207+
}
168208
}

0 commit comments

Comments
 (0)