Skip to content

Commit 6c7ed1a

Browse files
committed
Don't unroll the filler operations
1 parent 9279e05 commit 6c7ed1a

File tree

1 file changed

+16
-42
lines changed

1 file changed

+16
-42
lines changed

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java

Lines changed: 16 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -236,34 +236,20 @@ static int ipByteBit512(byte[] q, byte[] d) {
236236
+ acc3.reduceLanes(VectorOperators.ADD);
237237
}
238238

239-
sectionLength = INT_SPECIES_256.length() * 4;
239+
sectionLength = INT_SPECIES_256.length();
240240
if (q.length - i >= sectionLength) {
241-
IntVector acc0 = IntVector.zero(INT_SPECIES_256);
242-
IntVector acc1 = IntVector.zero(INT_SPECIES_256);
243-
IntVector acc2 = IntVector.zero(INT_SPECIES_256);
244-
IntVector acc3 = IntVector.zero(INT_SPECIES_256);
241+
// don't unroll this, we want to catch as many as we can before going scalar
242+
IntVector acc = IntVector.zero(INT_SPECIES_256);
245243
int limit = limit(q.length, sectionLength);
246244
for (; i < limit; i += sectionLength) {
247-
var vals0 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i).castShape(INT_SPECIES_256, 0);
248-
var vals1 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i + INT_SPECIES_256.length()).castShape(INT_SPECIES_256, 0);
249-
var vals2 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i + INT_SPECIES_256.length() * 2)
250-
.castShape(INT_SPECIES_256, 0);
251-
var vals3 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i + INT_SPECIES_256.length() * 3)
252-
.castShape(INT_SPECIES_256, 0);
245+
var vals = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i).castShape(INT_SPECIES_256, 0);
253246

254-
long maskBits = Integer.reverse((int) BitUtil.VH_BE_INT.get(d, i / 8));
255-
var mask0 = VectorMask.fromLong(INT_SPECIES_256, maskBits);
256-
var mask1 = VectorMask.fromLong(INT_SPECIES_256, maskBits >> 8);
257-
var mask2 = VectorMask.fromLong(INT_SPECIES_256, maskBits >> 16);
258-
var mask3 = VectorMask.fromLong(INT_SPECIES_256, maskBits >> 24);
247+
long maskBits = Integer.reverse(d[i / 8]) >> 24;
248+
var mask = VectorMask.fromLong(INT_SPECIES_256, maskBits);
259249

260-
acc0 = acc0.add(vals0, mask0);
261-
acc1 = acc1.add(vals1, mask1);
262-
acc2 = acc2.add(vals2, mask2);
263-
acc3 = acc3.add(vals3, mask3);
250+
acc = acc.add(vals, mask);
264251
}
265-
sum += acc0.reduceLanes(VectorOperators.ADD) + acc1.reduceLanes(VectorOperators.ADD) + acc2.reduceLanes(VectorOperators.ADD)
266-
+ acc3.reduceLanes(VectorOperators.ADD);
252+
sum += acc.reduceLanes(VectorOperators.ADD);
267253
}
268254

269255
if (i < q.length) {
@@ -353,32 +339,20 @@ static float ipFloatBit512(float[] q, byte[] d) {
353339
+ acc3.reduceLanes(VectorOperators.ADD);
354340
}
355341

356-
sectionLength = FLOAT_SPECIES_256.length() * 4;
342+
sectionLength = FLOAT_SPECIES_256.length();
357343
if (q.length - i >= sectionLength) {
358-
FloatVector acc0 = FloatVector.zero(FLOAT_SPECIES_256);
359-
FloatVector acc1 = FloatVector.zero(FLOAT_SPECIES_256);
360-
FloatVector acc2 = FloatVector.zero(FLOAT_SPECIES_256);
361-
FloatVector acc3 = FloatVector.zero(FLOAT_SPECIES_256);
344+
// don't unroll this, we want to catch as many as we can before going scalar
345+
FloatVector acc = FloatVector.zero(FLOAT_SPECIES_256);
362346
int limit = limit(q.length, sectionLength);
363347
for (; i < limit; i += sectionLength) {
364-
var floats0 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i);
365-
var floats1 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i + FLOAT_SPECIES_256.length());
366-
var floats2 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i + FLOAT_SPECIES_256.length() * 2);
367-
var floats3 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i + FLOAT_SPECIES_256.length() * 3);
348+
var floats = FloatVector.fromArray(FLOAT_SPECIES_256, q, i);
368349

369-
long maskBits = Integer.reverse((int) BitUtil.VH_BE_INT.get(d, i / 8));
370-
var mask0 = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits);
371-
var mask1 = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits >> 8);
372-
var mask2 = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits >> 16);
373-
var mask3 = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits >> 24);
350+
long maskBits = Integer.reverse(d[i / 8]) >> 24;
351+
var mask = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits);
374352

375-
acc0 = acc0.add(floats0, mask0);
376-
acc1 = acc1.add(floats1, mask1);
377-
acc2 = acc2.add(floats2, mask2);
378-
acc3 = acc3.add(floats3, mask3);
353+
acc = acc.add(floats, mask);
379354
}
380-
sum += acc0.reduceLanes(VectorOperators.ADD) + acc1.reduceLanes(VectorOperators.ADD) + acc2.reduceLanes(VectorOperators.ADD)
381-
+ acc3.reduceLanes(VectorOperators.ADD);
355+
sum += acc.reduceLanes(VectorOperators.ADD);
382356
}
383357

384358
if (i < q.length) {

0 commit comments

Comments
 (0)