|
10 | 10 | package org.elasticsearch.simdvec.internal.vectorization; |
11 | 11 |
|
12 | 12 | import jdk.incubator.vector.ByteVector; |
| 13 | +import jdk.incubator.vector.FloatVector; |
13 | 14 | import jdk.incubator.vector.IntVector; |
14 | 15 | import jdk.incubator.vector.LongVector; |
15 | 16 | import jdk.incubator.vector.VectorOperators; |
@@ -60,6 +61,9 @@ public float ipFloatBit(float[] q, byte[] d) { |
60 | 61 |
|
61 | 62 | @Override |
62 | 63 | public float ipFloatByte(float[] q, byte[] d) { |
| 64 | + if (BYTE_FOR_FLOAT_SPECIES != null && q.length >= FLOAT_SPECIES.length()) { |
| 65 | + return ipFloatByteImpl(q, d); |
| 66 | + } |
63 | 67 | return DefaultESVectorUtilSupport.ipFloatByteImpl(q, d); |
64 | 68 | } |
65 | 69 |
|
@@ -165,4 +169,40 @@ public static long ipByteBin128(byte[] q, byte[] d) { |
165 | 169 | } |
166 | 170 | return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); |
167 | 171 | } |
| 172 | + |
| 173 | + private static final VectorSpecies<Float> FLOAT_SPECIES = FloatVector.SPECIES_PREFERRED; |
| 174 | + private static final VectorSpecies<Byte> BYTE_FOR_FLOAT_SPECIES; |
| 175 | + |
| 176 | + static { |
| 177 | + VectorSpecies<Byte> byteForFloat; |
| 178 | + try { |
| 179 | + // calculate vector size to convert from single bytes to 4-byte floats |
| 180 | + byteForFloat = VectorSpecies.of(byte.class, VectorShape.forBitSize(FLOAT_SPECIES.vectorBitSize() / 4)); |
| 181 | + } catch (IllegalArgumentException e) { |
| 182 | + // can't get a byte vector size small enough, just use default impl |
| 183 | + byteForFloat = null; |
| 184 | + } |
| 185 | + BYTE_FOR_FLOAT_SPECIES = byteForFloat; |
| 186 | + } |
| 187 | + |
| 188 | + public static float ipFloatByteImpl(float[] q, byte[] d) { |
| 189 | + assert BYTE_FOR_FLOAT_SPECIES != null; |
| 190 | + float sum = 0; |
| 191 | + int i = 0; |
| 192 | + |
| 193 | + int limit = FLOAT_SPECIES.loopBound(q.length); |
| 194 | + for (; i < limit; i += FLOAT_SPECIES.length()) { |
| 195 | + FloatVector qv = FloatVector.fromArray(FLOAT_SPECIES, q, i); |
| 196 | + ByteVector bv = ByteVector.fromArray(BYTE_FOR_FLOAT_SPECIES, d, i); |
| 197 | + // no separate parts needed for the cast, as we've used a byte vector size 1/4th the float vector size |
| 198 | + sum += qv.mul(bv.castShape(qv.species(), 0)).reduceLanes(VectorOperators.ADD); |
| 199 | + } |
| 200 | + |
| 201 | + // handle the tail |
| 202 | + for (; i < q.length; i++) { |
| 203 | + sum += q[i] * d[i]; |
| 204 | + } |
| 205 | + |
| 206 | + return sum; |
| 207 | + } |
168 | 208 | } |
0 commit comments