|
10 | 10 | package org.elasticsearch.simdvec.internal.vectorization; |
11 | 11 |
|
12 | 12 | import jdk.incubator.vector.ByteVector; |
| 13 | +import jdk.incubator.vector.FloatVector; |
13 | 14 | import jdk.incubator.vector.IntVector; |
14 | 15 | import jdk.incubator.vector.LongVector; |
| 16 | +import jdk.incubator.vector.VectorMask; |
15 | 17 | import jdk.incubator.vector.VectorOperators; |
16 | 18 | import jdk.incubator.vector.VectorShape; |
17 | 19 | import jdk.incubator.vector.VectorSpecies; |
@@ -55,6 +57,13 @@ public int ipByteBit(byte[] q, byte[] d) { |
55 | 57 |
|
56 | 58 | @Override |
57 | 59 | public float ipFloatBit(float[] q, byte[] d) { |
| 60 | + if (q.length >= 16) { |
| 61 | + if (VECTOR_BITSIZE >= 512) { |
| 62 | + return ipFloatBit512(q, d); |
| 63 | + } else if (VECTOR_BITSIZE == 256) { |
| 64 | + return ipFloatBit256(q, d); |
| 65 | + } |
| 66 | + } |
58 | 67 | return DefaultESVectorUtilSupport.ipFloatBitImpl(q, d); |
59 | 68 | } |
60 | 69 |
|
@@ -165,4 +174,56 @@ public static long ipByteBin128(byte[] q, byte[] d) { |
165 | 174 | } |
166 | 175 | return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); |
167 | 176 | } |
| 177 | + |
| 178 | + private static final VectorSpecies<Float> FLOAT_SPECIES_8 = FloatVector.SPECIES_256; |
| 179 | + private static final VectorSpecies<Float> FLOAT_SPECIES_16 = FloatVector.SPECIES_512; |
| 180 | + |
| 181 | + private static long reverse(byte b) { |
| 182 | + // see https://graphics.stanford.edu/~seander/bithacks.html#ReverseByteWith64Bits |
| 183 | + return ((((b & 0xff) * 0x80200802L) & 0x0884422110L) * 0x0101010101L >> 32) & 0xff; |
| 184 | + } |
| 185 | + |
| 186 | + static float ipFloatBit512(float[] q, byte[] d) { |
| 187 | + assert q.length == d.length * Byte.SIZE; |
| 188 | + FloatVector acc = FloatVector.zero(FLOAT_SPECIES_16); |
| 189 | + |
| 190 | + int i = 0; |
| 191 | + for (; i < FLOAT_SPECIES_16.loopBound(q.length); i += FLOAT_SPECIES_16.length()) { |
| 192 | + FloatVector floats = FloatVector.fromArray(FLOAT_SPECIES_16, q, i); |
| 193 | + // use the two bytes corresponding to the same sections |
| 194 | + // of the bit vector as a mask for addition |
| 195 | + long maskBits = reverse(d[i / 8]) | reverse(d[i / 8 + 1]) << 8; |
| 196 | + acc = acc.add(floats, VectorMask.fromLong(FLOAT_SPECIES_16, maskBits)); |
| 197 | + } |
| 198 | + |
| 199 | + float sum = acc.reduceLanes(VectorOperators.ADD); |
| 200 | + if (i < q.length) { |
| 201 | + // do the tail |
| 202 | + sum += DefaultESVectorUtilSupport.ipFloatBitImpl(q, d, i); |
| 203 | + } |
| 204 | + |
| 205 | + return sum; |
| 206 | + } |
| 207 | + |
| 208 | + static float ipFloatBit256(float[] q, byte[] d) { |
| 209 | + assert q.length == d.length * Byte.SIZE; |
| 210 | + FloatVector acc = FloatVector.zero(FLOAT_SPECIES_8); |
| 211 | + |
| 212 | + int i = 0; |
| 213 | + for (; i < FLOAT_SPECIES_8.loopBound(q.length); i += FLOAT_SPECIES_8.length()) { |
| 214 | + FloatVector floats = FloatVector.fromArray(FLOAT_SPECIES_8, q, i); |
| 215 | + // use the byte corresponding to the same section |
| 216 | + // of the bit vector as a mask for addition |
| 217 | + long maskBits = reverse(d[i / 8]); |
| 218 | + acc = acc.add(floats, VectorMask.fromLong(FLOAT_SPECIES_8, maskBits)); |
| 219 | + } |
| 220 | + |
| 221 | + float sum = acc.reduceLanes(VectorOperators.ADD); |
| 222 | + if (i < q.length) { |
| 223 | + // do the tail |
| 224 | + sum += DefaultESVectorUtilSupport.ipFloatBitImpl(q, d, i); |
| 225 | + } |
| 226 | + |
| 227 | + return sum; |
| 228 | + } |
168 | 229 | } |
0 commit comments