Skip to content

Commit c983834

Browse files
committed
Check when the unrolled loops should be run
1 parent 346e294 commit c983834

File tree

1 file changed

+161
-93
lines changed

1 file changed

+161
-93
lines changed

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java

Lines changed: 161 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -196,32 +196,65 @@ public static long ipByteBin128(byte[] q, byte[] d) {
196196

197197
static int ipByteBit512(byte[] q, byte[] d) {
198198
assert q.length == d.length * Byte.SIZE;
199-
IntVector acc0 = IntVector.zero(INT_SPECIES_512);
200-
IntVector acc1 = IntVector.zero(INT_SPECIES_512);
201-
IntVector acc2 = IntVector.zero(INT_SPECIES_512);
202-
IntVector acc3 = IntVector.zero(INT_SPECIES_512);
203-
204199
int i = 0;
205-
for (; i < INT_SPECIES_512.loopBound(q.length); i += INT_SPECIES_512.length() * 4) {
206-
var vals0 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_512, q, i).castShape(INT_SPECIES_512, 0);
207-
var vals1 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_512, q, i + INT_SPECIES_512.length()).castShape(INT_SPECIES_512, 0);
208-
var vals2 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_512, q, i + INT_SPECIES_512.length() * 2).castShape(INT_SPECIES_512, 0);
209-
var vals3 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_512, q, i + INT_SPECIES_512.length() * 3).castShape(INT_SPECIES_512, 0);
210-
211-
long maskBits = Long.reverse((long) BitUtil.VH_BE_LONG.get(d, i / 8));
212-
var mask0 = VectorMask.fromLong(INT_SPECIES_512, maskBits);
213-
var mask1 = VectorMask.fromLong(INT_SPECIES_512, maskBits >> 16);
214-
var mask2 = VectorMask.fromLong(INT_SPECIES_512, maskBits >> 32);
215-
var mask3 = VectorMask.fromLong(INT_SPECIES_512, maskBits >> 48);
216-
217-
acc0 = acc0.add(vals0, mask0);
218-
acc1 = acc1.add(vals1, mask1);
219-
acc2 = acc2.add(vals2, mask2);
220-
acc3 = acc3.add(vals3, mask3);
200+
int sum = 0;
201+
202+
if (q.length >= INT_SPECIES_512.length() * 4) {
203+
IntVector acc0 = IntVector.zero(INT_SPECIES_512);
204+
IntVector acc1 = IntVector.zero(INT_SPECIES_512);
205+
IntVector acc2 = IntVector.zero(INT_SPECIES_512);
206+
IntVector acc3 = IntVector.zero(INT_SPECIES_512);
207+
for (; i < INT_SPECIES_512.loopBound(q.length); i += INT_SPECIES_512.length() * 4) {
208+
var vals0 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_512, q, i).castShape(INT_SPECIES_512, 0);
209+
var vals1 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_512, q, i + INT_SPECIES_512.length()).castShape(INT_SPECIES_512, 0);
210+
var vals2 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_512, q, i + INT_SPECIES_512.length() * 2)
211+
.castShape(INT_SPECIES_512, 0);
212+
var vals3 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_512, q, i + INT_SPECIES_512.length() * 3)
213+
.castShape(INT_SPECIES_512, 0);
214+
215+
long maskBits = Long.reverse((long) BitUtil.VH_BE_LONG.get(d, i / 8));
216+
var mask0 = VectorMask.fromLong(INT_SPECIES_512, maskBits);
217+
var mask1 = VectorMask.fromLong(INT_SPECIES_512, maskBits >> 16);
218+
var mask2 = VectorMask.fromLong(INT_SPECIES_512, maskBits >> 32);
219+
var mask3 = VectorMask.fromLong(INT_SPECIES_512, maskBits >> 48);
220+
221+
acc0 = acc0.add(vals0, mask0);
222+
acc1 = acc1.add(vals1, mask1);
223+
acc2 = acc2.add(vals2, mask2);
224+
acc3 = acc3.add(vals3, mask3);
225+
}
226+
sum += acc0.reduceLanes(VectorOperators.ADD) + acc1.reduceLanes(VectorOperators.ADD) + acc2.reduceLanes(VectorOperators.ADD)
227+
+ acc3.reduceLanes(VectorOperators.ADD);
228+
}
229+
230+
if (q.length - i >= INT_SPECIES_256.length() * 4) {
231+
IntVector acc0 = IntVector.zero(INT_SPECIES_256);
232+
IntVector acc1 = IntVector.zero(INT_SPECIES_256);
233+
IntVector acc2 = IntVector.zero(INT_SPECIES_256);
234+
IntVector acc3 = IntVector.zero(INT_SPECIES_256);
235+
for (; i < INT_SPECIES_256.loopBound(q.length); i += INT_SPECIES_256.length() * 4) {
236+
var vals0 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i).castShape(INT_SPECIES_256, 0);
237+
var vals1 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i + INT_SPECIES_256.length()).castShape(INT_SPECIES_256, 0);
238+
var vals2 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i + INT_SPECIES_256.length() * 2)
239+
.castShape(INT_SPECIES_256, 0);
240+
var vals3 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i + INT_SPECIES_256.length() * 3)
241+
.castShape(INT_SPECIES_256, 0);
242+
243+
long maskBits = Integer.reverse((int) BitUtil.VH_BE_INT.get(d, i / 8));
244+
var mask0 = VectorMask.fromLong(INT_SPECIES_256, maskBits);
245+
var mask1 = VectorMask.fromLong(INT_SPECIES_256, maskBits >> 8);
246+
var mask2 = VectorMask.fromLong(INT_SPECIES_256, maskBits >> 16);
247+
var mask3 = VectorMask.fromLong(INT_SPECIES_256, maskBits >> 24);
248+
249+
acc0 = acc0.add(vals0, mask0);
250+
acc1 = acc1.add(vals1, mask1);
251+
acc2 = acc2.add(vals2, mask2);
252+
acc3 = acc3.add(vals3, mask3);
253+
}
254+
sum += acc0.reduceLanes(VectorOperators.ADD) + acc1.reduceLanes(VectorOperators.ADD) + acc2.reduceLanes(VectorOperators.ADD)
255+
+ acc3.reduceLanes(VectorOperators.ADD);
221256
}
222257

223-
int sum = acc0.reduceLanes(VectorOperators.ADD) + acc1.reduceLanes(VectorOperators.ADD) + acc2.reduceLanes(VectorOperators.ADD)
224-
+ acc3.reduceLanes(VectorOperators.ADD);
225258
if (i < q.length) {
226259
// do the tail
227260
sum += DefaultESVectorUtilSupport.ipByteBitImpl(q, d, i);
@@ -231,32 +264,37 @@ static int ipByteBit512(byte[] q, byte[] d) {
231264

232265
static int ipByteBit256(byte[] q, byte[] d) {
233266
assert q.length == d.length * Byte.SIZE;
234-
IntVector acc0 = IntVector.zero(INT_SPECIES_256);
235-
IntVector acc1 = IntVector.zero(INT_SPECIES_256);
236-
IntVector acc2 = IntVector.zero(INT_SPECIES_256);
237-
IntVector acc3 = IntVector.zero(INT_SPECIES_256);
238-
239267
int i = 0;
240-
for (; i < INT_SPECIES_256.loopBound(q.length); i += INT_SPECIES_256.length() * 4) {
241-
var vals0 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i).castShape(INT_SPECIES_256, 0);
242-
var vals1 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i + INT_SPECIES_256.length()).castShape(INT_SPECIES_256, 0);
243-
var vals2 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i + INT_SPECIES_256.length() * 2).castShape(INT_SPECIES_256, 0);
244-
var vals3 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i + INT_SPECIES_256.length() * 3).castShape(INT_SPECIES_256, 0);
245-
246-
long maskBits = Integer.reverse((int) BitUtil.VH_BE_INT.get(d, i / 8));
247-
var mask0 = VectorMask.fromLong(INT_SPECIES_256, maskBits);
248-
var mask1 = VectorMask.fromLong(INT_SPECIES_256, maskBits >> 8);
249-
var mask2 = VectorMask.fromLong(INT_SPECIES_256, maskBits >> 16);
250-
var mask3 = VectorMask.fromLong(INT_SPECIES_256, maskBits >> 24);
251-
252-
acc0 = acc0.add(vals0, mask0);
253-
acc1 = acc1.add(vals1, mask1);
254-
acc2 = acc2.add(vals2, mask2);
255-
acc3 = acc3.add(vals3, mask3);
268+
int sum = 0;
269+
270+
if (q.length >= INT_SPECIES_256.length() * 4) {
271+
IntVector acc0 = IntVector.zero(INT_SPECIES_256);
272+
IntVector acc1 = IntVector.zero(INT_SPECIES_256);
273+
IntVector acc2 = IntVector.zero(INT_SPECIES_256);
274+
IntVector acc3 = IntVector.zero(INT_SPECIES_256);
275+
for (; i < INT_SPECIES_256.loopBound(q.length); i += INT_SPECIES_256.length() * 4) {
276+
var vals0 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i).castShape(INT_SPECIES_256, 0);
277+
var vals1 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i + INT_SPECIES_256.length()).castShape(INT_SPECIES_256, 0);
278+
var vals2 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i + INT_SPECIES_256.length() * 2)
279+
.castShape(INT_SPECIES_256, 0);
280+
var vals3 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i + INT_SPECIES_256.length() * 3)
281+
.castShape(INT_SPECIES_256, 0);
282+
283+
long maskBits = Integer.reverse((int) BitUtil.VH_BE_INT.get(d, i / 8));
284+
var mask0 = VectorMask.fromLong(INT_SPECIES_256, maskBits);
285+
var mask1 = VectorMask.fromLong(INT_SPECIES_256, maskBits >> 8);
286+
var mask2 = VectorMask.fromLong(INT_SPECIES_256, maskBits >> 16);
287+
var mask3 = VectorMask.fromLong(INT_SPECIES_256, maskBits >> 24);
288+
289+
acc0 = acc0.add(vals0, mask0);
290+
acc1 = acc1.add(vals1, mask1);
291+
acc2 = acc2.add(vals2, mask2);
292+
acc3 = acc3.add(vals3, mask3);
293+
}
294+
sum += acc0.reduceLanes(VectorOperators.ADD) + acc1.reduceLanes(VectorOperators.ADD) + acc2.reduceLanes(VectorOperators.ADD)
295+
+ acc3.reduceLanes(VectorOperators.ADD);
256296
}
257297

258-
int sum = acc0.reduceLanes(VectorOperators.ADD) + acc1.reduceLanes(VectorOperators.ADD) + acc2.reduceLanes(VectorOperators.ADD)
259-
+ acc3.reduceLanes(VectorOperators.ADD);
260298
if (i < q.length) {
261299
// do the tail
262300
sum += DefaultESVectorUtilSupport.ipByteBitImpl(q, d, i);
@@ -269,32 +307,60 @@ static int ipByteBit256(byte[] q, byte[] d) {
269307

270308
static float ipFloatBit512(float[] q, byte[] d) {
271309
assert q.length == d.length * Byte.SIZE;
272-
FloatVector acc0 = FloatVector.zero(FLOAT_SPECIES_512);
273-
FloatVector acc1 = FloatVector.zero(FLOAT_SPECIES_512);
274-
FloatVector acc2 = FloatVector.zero(FLOAT_SPECIES_512);
275-
FloatVector acc3 = FloatVector.zero(FLOAT_SPECIES_512);
276-
277310
int i = 0;
278-
for (; i < FLOAT_SPECIES_512.loopBound(q.length); i += FLOAT_SPECIES_512.length() * 4) {
279-
var floats0 = FloatVector.fromArray(FLOAT_SPECIES_512, q, i);
280-
var floats1 = FloatVector.fromArray(FLOAT_SPECIES_512, q, i + FLOAT_SPECIES_512.length());
281-
var floats2 = FloatVector.fromArray(FLOAT_SPECIES_512, q, i + FLOAT_SPECIES_512.length() * 2);
282-
var floats3 = FloatVector.fromArray(FLOAT_SPECIES_512, q, i + FLOAT_SPECIES_512.length() * 3);
283-
284-
long maskBits = Long.reverse((long) BitUtil.VH_BE_LONG.get(d, i / 8));
285-
var mask0 = VectorMask.fromLong(FLOAT_SPECIES_512, maskBits);
286-
var mask1 = VectorMask.fromLong(FLOAT_SPECIES_512, maskBits >> 16);
287-
var mask2 = VectorMask.fromLong(FLOAT_SPECIES_512, maskBits >> 32);
288-
var mask3 = VectorMask.fromLong(FLOAT_SPECIES_512, maskBits >> 48);
289-
290-
acc0 = acc0.add(floats0, mask0);
291-
acc1 = acc1.add(floats1, mask1);
292-
acc2 = acc2.add(floats2, mask2);
293-
acc3 = acc3.add(floats3, mask3);
311+
float sum = 0;
312+
313+
if (q.length >= FLOAT_SPECIES_512.length() * 4) {
314+
FloatVector acc0 = FloatVector.zero(FLOAT_SPECIES_512);
315+
FloatVector acc1 = FloatVector.zero(FLOAT_SPECIES_512);
316+
FloatVector acc2 = FloatVector.zero(FLOAT_SPECIES_512);
317+
FloatVector acc3 = FloatVector.zero(FLOAT_SPECIES_512);
318+
for (; i < FLOAT_SPECIES_512.loopBound(q.length); i += FLOAT_SPECIES_512.length() * 4) {
319+
var floats0 = FloatVector.fromArray(FLOAT_SPECIES_512, q, i);
320+
var floats1 = FloatVector.fromArray(FLOAT_SPECIES_512, q, i + FLOAT_SPECIES_512.length());
321+
var floats2 = FloatVector.fromArray(FLOAT_SPECIES_512, q, i + FLOAT_SPECIES_512.length() * 2);
322+
var floats3 = FloatVector.fromArray(FLOAT_SPECIES_512, q, i + FLOAT_SPECIES_512.length() * 3);
323+
324+
long maskBits = Long.reverse((long) BitUtil.VH_BE_LONG.get(d, i / 8));
325+
var mask0 = VectorMask.fromLong(FLOAT_SPECIES_512, maskBits);
326+
var mask1 = VectorMask.fromLong(FLOAT_SPECIES_512, maskBits >> 16);
327+
var mask2 = VectorMask.fromLong(FLOAT_SPECIES_512, maskBits >> 32);
328+
var mask3 = VectorMask.fromLong(FLOAT_SPECIES_512, maskBits >> 48);
329+
330+
acc0 = acc0.add(floats0, mask0);
331+
acc1 = acc1.add(floats1, mask1);
332+
acc2 = acc2.add(floats2, mask2);
333+
acc3 = acc3.add(floats3, mask3);
334+
}
335+
sum += acc0.reduceLanes(VectorOperators.ADD) + acc1.reduceLanes(VectorOperators.ADD) + acc2.reduceLanes(VectorOperators.ADD)
336+
+ acc3.reduceLanes(VectorOperators.ADD);
294337
}
295338

296-
float sum = acc0.reduceLanes(VectorOperators.ADD) + acc1.reduceLanes(VectorOperators.ADD) + acc2.reduceLanes(VectorOperators.ADD)
297-
+ acc3.reduceLanes(VectorOperators.ADD);
339+
if (q.length - i >= FLOAT_SPECIES_256.length() * 4) {
340+
FloatVector acc0 = FloatVector.zero(FLOAT_SPECIES_256);
341+
FloatVector acc1 = FloatVector.zero(FLOAT_SPECIES_256);
342+
FloatVector acc2 = FloatVector.zero(FLOAT_SPECIES_256);
343+
FloatVector acc3 = FloatVector.zero(FLOAT_SPECIES_256);
344+
for (; i < FLOAT_SPECIES_256.loopBound(q.length); i += FLOAT_SPECIES_256.length() * 4) {
345+
var floats0 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i);
346+
var floats1 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i + FLOAT_SPECIES_256.length());
347+
var floats2 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i + FLOAT_SPECIES_256.length() * 2);
348+
var floats3 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i + FLOAT_SPECIES_256.length() * 3);
349+
350+
long maskBits = Integer.reverse((int) BitUtil.VH_BE_INT.get(d, i / 8));
351+
var mask0 = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits);
352+
var mask1 = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits >> 8);
353+
var mask2 = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits >> 16);
354+
var mask3 = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits >> 24);
355+
356+
acc0 = acc0.add(floats0, mask0);
357+
acc1 = acc1.add(floats1, mask1);
358+
acc2 = acc2.add(floats2, mask2);
359+
acc3 = acc3.add(floats3, mask3);
360+
}
361+
sum += acc0.reduceLanes(VectorOperators.ADD) + acc1.reduceLanes(VectorOperators.ADD) + acc2.reduceLanes(VectorOperators.ADD)
362+
+ acc3.reduceLanes(VectorOperators.ADD);
363+
}
298364

299365
if (i < q.length) {
300366
// do the tail
@@ -306,33 +372,35 @@ static float ipFloatBit512(float[] q, byte[] d) {
306372

307373
static float ipFloatBit256(float[] q, byte[] d) {
308374
assert q.length == d.length * Byte.SIZE;
309-
FloatVector acc0 = FloatVector.zero(FLOAT_SPECIES_256);
310-
FloatVector acc1 = FloatVector.zero(FLOAT_SPECIES_256);
311-
FloatVector acc2 = FloatVector.zero(FLOAT_SPECIES_256);
312-
FloatVector acc3 = FloatVector.zero(FLOAT_SPECIES_256);
313-
314375
int i = 0;
315-
for (; i < FLOAT_SPECIES_256.loopBound(q.length); i += FLOAT_SPECIES_256.length() * 4) {
316-
var floats0 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i);
317-
var floats1 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i + FLOAT_SPECIES_256.length());
318-
var floats2 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i + FLOAT_SPECIES_256.length() * 2);
319-
var floats3 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i + FLOAT_SPECIES_256.length() * 3);
320-
321-
long maskBits = Integer.reverse((int) BitUtil.VH_BE_INT.get(d, i / 8));
322-
var mask0 = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits);
323-
var mask1 = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits >> 8);
324-
var mask2 = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits >> 16);
325-
var mask3 = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits >> 24);
326-
327-
acc0 = acc0.add(floats0, mask0);
328-
acc1 = acc1.add(floats1, mask1);
329-
acc2 = acc2.add(floats2, mask2);
330-
acc3 = acc3.add(floats3, mask3);
376+
float sum = 0;
377+
378+
if (q.length >= FLOAT_SPECIES_256.length() * 4) {
379+
FloatVector acc0 = FloatVector.zero(FLOAT_SPECIES_256);
380+
FloatVector acc1 = FloatVector.zero(FLOAT_SPECIES_256);
381+
FloatVector acc2 = FloatVector.zero(FLOAT_SPECIES_256);
382+
FloatVector acc3 = FloatVector.zero(FLOAT_SPECIES_256);
383+
for (; i < FLOAT_SPECIES_256.loopBound(q.length); i += FLOAT_SPECIES_256.length() * 4) {
384+
var floats0 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i);
385+
var floats1 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i + FLOAT_SPECIES_256.length());
386+
var floats2 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i + FLOAT_SPECIES_256.length() * 2);
387+
var floats3 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i + FLOAT_SPECIES_256.length() * 3);
388+
389+
long maskBits = Integer.reverse((int) BitUtil.VH_BE_INT.get(d, i / 8));
390+
var mask0 = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits);
391+
var mask1 = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits >> 8);
392+
var mask2 = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits >> 16);
393+
var mask3 = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits >> 24);
394+
395+
acc0 = acc0.add(floats0, mask0);
396+
acc1 = acc1.add(floats1, mask1);
397+
acc2 = acc2.add(floats2, mask2);
398+
acc3 = acc3.add(floats3, mask3);
399+
}
400+
sum += acc0.reduceLanes(VectorOperators.ADD) + acc1.reduceLanes(VectorOperators.ADD) + acc2.reduceLanes(VectorOperators.ADD)
401+
+ acc3.reduceLanes(VectorOperators.ADD);
331402
}
332403

333-
float sum = acc0.reduceLanes(VectorOperators.ADD) + acc1.reduceLanes(VectorOperators.ADD) + acc2.reduceLanes(VectorOperators.ADD)
334-
+ acc3.reduceLanes(VectorOperators.ADD);
335-
336404
if (i < q.length) {
337405
// do the tail
338406
sum += DefaultESVectorUtilSupport.ipFloatBitImpl(q, d, i);

0 commit comments

Comments
 (0)