Skip to content

Commit 346e294

Browse files
committed
Unroll loops for more oomph
1 parent 8c1ac09 commit 346e294

File tree

1 file changed

+90
-31
lines changed

1 file changed

+90
-31
lines changed

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java

Lines changed: 90 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import jdk.incubator.vector.FloatVector;
1414
import jdk.incubator.vector.IntVector;
1515
import jdk.incubator.vector.LongVector;
16-
import jdk.incubator.vector.Vector;
1716
import jdk.incubator.vector.VectorMask;
1817
import jdk.incubator.vector.VectorOperators;
1918
import jdk.incubator.vector.VectorShape;
@@ -197,17 +196,32 @@ public static long ipByteBin128(byte[] q, byte[] d) {
197196

198197
static int ipByteBit512(byte[] q, byte[] d) {
199198
assert q.length == d.length * Byte.SIZE;
200-
IntVector acc = IntVector.zero(INT_SPECIES_512);
199+
IntVector acc0 = IntVector.zero(INT_SPECIES_512);
200+
IntVector acc1 = IntVector.zero(INT_SPECIES_512);
201+
IntVector acc2 = IntVector.zero(INT_SPECIES_512);
202+
IntVector acc3 = IntVector.zero(INT_SPECIES_512);
201203

202204
int i = 0;
203-
for (; i < BYTE_SPECIES_FOR_INT_512.loopBound(q.length); i += BYTE_SPECIES_FOR_INT_512.length()) {
204-
Vector<Integer> bytes = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_512, q, i).castShape(INT_SPECIES_512, 0);
205-
long maskBits = Integer.reverse((short) BitUtil.VH_BE_SHORT.get(d, i / 8)) >> 16;
206-
207-
acc = acc.add(bytes, VectorMask.fromLong(INT_SPECIES_512, maskBits));
205+
for (; i < INT_SPECIES_512.loopBound(q.length); i += INT_SPECIES_512.length() * 4) {
206+
var vals0 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_512, q, i).castShape(INT_SPECIES_512, 0);
207+
var vals1 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_512, q, i + INT_SPECIES_512.length()).castShape(INT_SPECIES_512, 0);
208+
var vals2 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_512, q, i + INT_SPECIES_512.length() * 2).castShape(INT_SPECIES_512, 0);
209+
var vals3 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_512, q, i + INT_SPECIES_512.length() * 3).castShape(INT_SPECIES_512, 0);
210+
211+
long maskBits = Long.reverse((long) BitUtil.VH_BE_LONG.get(d, i / 8));
212+
var mask0 = VectorMask.fromLong(INT_SPECIES_512, maskBits);
213+
var mask1 = VectorMask.fromLong(INT_SPECIES_512, maskBits >> 16);
214+
var mask2 = VectorMask.fromLong(INT_SPECIES_512, maskBits >> 32);
215+
var mask3 = VectorMask.fromLong(INT_SPECIES_512, maskBits >> 48);
216+
217+
acc0 = acc0.add(vals0, mask0);
218+
acc1 = acc1.add(vals1, mask1);
219+
acc2 = acc2.add(vals2, mask2);
220+
acc3 = acc3.add(vals3, mask3);
208221
}
209222

210-
int sum = acc.reduceLanes(VectorOperators.ADD);
223+
int sum = acc0.reduceLanes(VectorOperators.ADD) + acc1.reduceLanes(VectorOperators.ADD) + acc2.reduceLanes(VectorOperators.ADD)
224+
+ acc3.reduceLanes(VectorOperators.ADD);
211225
if (i < q.length) {
212226
// do the tail
213227
sum += DefaultESVectorUtilSupport.ipByteBitImpl(q, d, i);
@@ -217,17 +231,32 @@ static int ipByteBit512(byte[] q, byte[] d) {
217231

218232
static int ipByteBit256(byte[] q, byte[] d) {
219233
assert q.length == d.length * Byte.SIZE;
220-
IntVector acc = IntVector.zero(INT_SPECIES_256);
234+
IntVector acc0 = IntVector.zero(INT_SPECIES_256);
235+
IntVector acc1 = IntVector.zero(INT_SPECIES_256);
236+
IntVector acc2 = IntVector.zero(INT_SPECIES_256);
237+
IntVector acc3 = IntVector.zero(INT_SPECIES_256);
221238

222239
int i = 0;
223-
for (; i < BYTE_SPECIES_FOR_INT_256.loopBound(q.length); i += BYTE_SPECIES_FOR_INT_256.length()) {
224-
Vector<Integer> bytes = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i).castShape(INT_SPECIES_256, 0);
225-
long maskBits = Integer.reverse(d[i / 8]) >> 24;
226-
227-
acc = acc.add(bytes, VectorMask.fromLong(INT_SPECIES_256, maskBits));
240+
for (; i < INT_SPECIES_256.loopBound(q.length); i += INT_SPECIES_256.length() * 4) {
241+
var vals0 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i).castShape(INT_SPECIES_256, 0);
242+
var vals1 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i + INT_SPECIES_256.length()).castShape(INT_SPECIES_256, 0);
243+
var vals2 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i + INT_SPECIES_256.length() * 2).castShape(INT_SPECIES_256, 0);
244+
var vals3 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i + INT_SPECIES_256.length() * 3).castShape(INT_SPECIES_256, 0);
245+
246+
long maskBits = Integer.reverse((int) BitUtil.VH_BE_INT.get(d, i / 8));
247+
var mask0 = VectorMask.fromLong(INT_SPECIES_256, maskBits);
248+
var mask1 = VectorMask.fromLong(INT_SPECIES_256, maskBits >> 8);
249+
var mask2 = VectorMask.fromLong(INT_SPECIES_256, maskBits >> 16);
250+
var mask3 = VectorMask.fromLong(INT_SPECIES_256, maskBits >> 24);
251+
252+
acc0 = acc0.add(vals0, mask0);
253+
acc1 = acc1.add(vals1, mask1);
254+
acc2 = acc2.add(vals2, mask2);
255+
acc3 = acc3.add(vals3, mask3);
228256
}
229257

230-
int sum = acc.reduceLanes(VectorOperators.ADD);
258+
int sum = acc0.reduceLanes(VectorOperators.ADD) + acc1.reduceLanes(VectorOperators.ADD) + acc2.reduceLanes(VectorOperators.ADD)
259+
+ acc3.reduceLanes(VectorOperators.ADD);
231260
if (i < q.length) {
232261
// do the tail
233262
sum += DefaultESVectorUtilSupport.ipByteBitImpl(q, d, i);
@@ -240,18 +269,33 @@ static int ipByteBit256(byte[] q, byte[] d) {
240269

241270
static float ipFloatBit512(float[] q, byte[] d) {
242271
assert q.length == d.length * Byte.SIZE;
243-
FloatVector acc = FloatVector.zero(FLOAT_SPECIES_512);
272+
FloatVector acc0 = FloatVector.zero(FLOAT_SPECIES_512);
273+
FloatVector acc1 = FloatVector.zero(FLOAT_SPECIES_512);
274+
FloatVector acc2 = FloatVector.zero(FLOAT_SPECIES_512);
275+
FloatVector acc3 = FloatVector.zero(FLOAT_SPECIES_512);
244276

245277
int i = 0;
246-
for (; i < FLOAT_SPECIES_512.loopBound(q.length); i += FLOAT_SPECIES_512.length()) {
247-
FloatVector floats = FloatVector.fromArray(FLOAT_SPECIES_512, q, i);
248-
// use the two bytes corresponding to the same sections
249-
// of the bit vector as a mask for addition
250-
long maskBits = Integer.reverse((short) BitUtil.VH_BE_SHORT.get(d, i / 8)) >> 16;
251-
acc = acc.add(floats, VectorMask.fromLong(FLOAT_SPECIES_512, maskBits));
278+
for (; i < FLOAT_SPECIES_512.loopBound(q.length); i += FLOAT_SPECIES_512.length() * 4) {
279+
var floats0 = FloatVector.fromArray(FLOAT_SPECIES_512, q, i);
280+
var floats1 = FloatVector.fromArray(FLOAT_SPECIES_512, q, i + FLOAT_SPECIES_512.length());
281+
var floats2 = FloatVector.fromArray(FLOAT_SPECIES_512, q, i + FLOAT_SPECIES_512.length() * 2);
282+
var floats3 = FloatVector.fromArray(FLOAT_SPECIES_512, q, i + FLOAT_SPECIES_512.length() * 3);
283+
284+
long maskBits = Long.reverse((long) BitUtil.VH_BE_LONG.get(d, i / 8));
285+
var mask0 = VectorMask.fromLong(FLOAT_SPECIES_512, maskBits);
286+
var mask1 = VectorMask.fromLong(FLOAT_SPECIES_512, maskBits >> 16);
287+
var mask2 = VectorMask.fromLong(FLOAT_SPECIES_512, maskBits >> 32);
288+
var mask3 = VectorMask.fromLong(FLOAT_SPECIES_512, maskBits >> 48);
289+
290+
acc0 = acc0.add(floats0, mask0);
291+
acc1 = acc1.add(floats1, mask1);
292+
acc2 = acc2.add(floats2, mask2);
293+
acc3 = acc3.add(floats3, mask3);
252294
}
253295

254-
float sum = acc.reduceLanes(VectorOperators.ADD);
296+
float sum = acc0.reduceLanes(VectorOperators.ADD) + acc1.reduceLanes(VectorOperators.ADD) + acc2.reduceLanes(VectorOperators.ADD)
297+
+ acc3.reduceLanes(VectorOperators.ADD);
298+
255299
if (i < q.length) {
256300
// do the tail
257301
sum += DefaultESVectorUtilSupport.ipFloatBitImpl(q, d, i);
@@ -262,18 +306,33 @@ static float ipFloatBit512(float[] q, byte[] d) {
262306

263307
static float ipFloatBit256(float[] q, byte[] d) {
264308
assert q.length == d.length * Byte.SIZE;
265-
FloatVector acc = FloatVector.zero(FLOAT_SPECIES_256);
309+
FloatVector acc0 = FloatVector.zero(FLOAT_SPECIES_256);
310+
FloatVector acc1 = FloatVector.zero(FLOAT_SPECIES_256);
311+
FloatVector acc2 = FloatVector.zero(FLOAT_SPECIES_256);
312+
FloatVector acc3 = FloatVector.zero(FLOAT_SPECIES_256);
266313

267314
int i = 0;
268-
for (; i < FLOAT_SPECIES_256.loopBound(q.length); i += FLOAT_SPECIES_256.length()) {
269-
FloatVector floats = FloatVector.fromArray(FLOAT_SPECIES_256, q, i);
270-
// use the byte corresponding to the same section
271-
// of the bit vector as a mask for addition
272-
long maskBits = Integer.reverse(d[i / 8]) >> 24;
273-
acc = acc.add(floats, VectorMask.fromLong(FLOAT_SPECIES_256, maskBits));
315+
for (; i < FLOAT_SPECIES_256.loopBound(q.length); i += FLOAT_SPECIES_256.length() * 4) {
316+
var floats0 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i);
317+
var floats1 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i + FLOAT_SPECIES_256.length());
318+
var floats2 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i + FLOAT_SPECIES_256.length() * 2);
319+
var floats3 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i + FLOAT_SPECIES_256.length() * 3);
320+
321+
long maskBits = Integer.reverse((int) BitUtil.VH_BE_INT.get(d, i / 8));
322+
var mask0 = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits);
323+
var mask1 = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits >> 8);
324+
var mask2 = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits >> 16);
325+
var mask3 = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits >> 24);
326+
327+
acc0 = acc0.add(floats0, mask0);
328+
acc1 = acc1.add(floats1, mask1);
329+
acc2 = acc2.add(floats2, mask2);
330+
acc3 = acc3.add(floats3, mask3);
274331
}
275332

276-
float sum = acc.reduceLanes(VectorOperators.ADD);
333+
float sum = acc0.reduceLanes(VectorOperators.ADD) + acc1.reduceLanes(VectorOperators.ADD) + acc2.reduceLanes(VectorOperators.ADD)
334+
+ acc3.reduceLanes(VectorOperators.ADD);
335+
277336
if (i < q.length) {
278337
// do the tail
279338
sum += DefaultESVectorUtilSupport.ipFloatBitImpl(q, d, i);

0 commit comments

Comments
 (0)