Skip to content

Commit b2d601d

Browse files
committed
Use constant masks
1 parent 371d621 commit b2d601d

File tree

1 file changed

+27
-33
lines changed

1 file changed

+27
-33
lines changed

wasm/src/org.graalvm.wasm.jdk25/src/org/graalvm/wasm/api/Vector128OpsVectorAPI.java

Lines changed: 27 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,17 @@ private abstract static class Shape<E> {
8080
public final VectorShuffle<E> compressEvensShuffle = VectorShuffle.fromOp(species(), i -> (i * 2) % species().length());
8181
public final VectorMask<E> lowMask = VectorMask.fromLong(species(), (1L << (species().length() / 2)) - 1);
8282
public final VectorMask<E> highMask = VectorMask.fromLong(species(), ((1L << (species().length() / 2)) - 1) << (species().length() / 2));
83+
public final VectorMask<E> evensMask;
84+
public final VectorMask<E> oddsMask;
85+
86+
protected Shape() {
87+
boolean[] values = new boolean[species().length() + 1];
88+
for (int i = 0; i < values.length; i++) {
89+
values[i] = i % 2 == 0;
90+
}
91+
evensMask = species().loadMask(values, 0);
92+
oddsMask = species().loadMask(values, 1);
93+
}
8394

8495
public abstract Vector<E> reinterpret(ByteVector bytes);
8596

@@ -738,8 +749,8 @@ private static <E> ByteVector unop(ByteVector xBytes, Shape<E> shape, VectorOper
738749

739750
private static <E, F> ByteVector extadd_pairwise(ByteVector xBytes, Shape<E> shape, VectorOperators.Conversion<E, F> conv) {
740751
Vector<E> x = shape.reinterpret(xBytes);
741-
Vector<F> evens = x.compress(evens(shape)).convert(conv, 0);
742-
Vector<F> odds = x.compress(odds(shape)).convert(conv, 0);
752+
Vector<F> evens = x.compress(shape.evensMask).convert(conv, 0);
753+
Vector<F> odds = x.compress(shape.oddsMask).convert(conv, 0);
743754
Vector<F> result = evens.add(odds);
744755
return result.reinterpretAsBytes();
745756
}
@@ -1045,10 +1056,10 @@ private static <E, F> ByteVector extmul(ByteVector xBytes, ByteVector yBytes, Sh
10451056
private static ByteVector i32x4_dot_i16x8_s(ByteVector xBytes, ByteVector yBytes) {
10461057
ShortVector x = xBytes.reinterpretAsShorts();
10471058
ShortVector y = yBytes.reinterpretAsShorts();
1048-
Vector<Integer> xEvens = castInt128(x.compress(castShort128Mask(evens(I16X8))).convert(VectorOperators.S2I, 0));
1049-
Vector<Integer> xOdds = castInt128(x.compress(castShort128Mask(odds(I16X8))).convert(VectorOperators.S2I, 0));
1050-
Vector<Integer> yEvens = castInt128(y.compress(castShort128Mask(evens(I16X8))).convert(VectorOperators.S2I, 0));
1051-
Vector<Integer> yOdds = castInt128(y.compress(castShort128Mask(odds(I16X8))).convert(VectorOperators.S2I, 0));
1059+
Vector<Integer> xEvens = castInt128(x.compress(castShort128Mask(I16X8.evensMask)).convert(VectorOperators.S2I, 0));
1060+
Vector<Integer> xOdds = castInt128(x.compress(castShort128Mask(I16X8.oddsMask)).convert(VectorOperators.S2I, 0));
1061+
Vector<Integer> yEvens = castInt128(y.compress(castShort128Mask(I16X8.evensMask)).convert(VectorOperators.S2I, 0));
1062+
Vector<Integer> yOdds = castInt128(y.compress(castShort128Mask(I16X8.oddsMask)).convert(VectorOperators.S2I, 0));
10521063
Vector<Integer> xMulYEvens = xEvens.mul(yEvens);
10531064
Vector<Integer> xMulYOdds = xOdds.mul(yOdds);
10541065
Vector<Integer> dot = xMulYEvens.lanewise(VectorOperators.ADD, xMulYOdds);
@@ -1070,10 +1081,10 @@ private static <E> ByteVector pmax(ByteVector xBytes, ByteVector yBytes, Shape<E
10701081
}
10711082

10721083
private static ByteVector i16x8_relaxed_dot_i8x16_i7x16_s(ByteVector x, ByteVector y) {
1073-
Vector<Short> xEvens = castShort128(x.compress(castByte128Mask(evens(I8X16))).convert(VectorOperators.B2S, 0));
1074-
Vector<Short> xOdds = castShort128(x.compress(castByte128Mask(odds(I8X16))).convert(VectorOperators.B2S, 0));
1075-
Vector<Short> yEvens = castShort128(y.compress(castByte128Mask(evens(I8X16))).convert(VectorOperators.B2S, 0));
1076-
Vector<Short> yOdds = castShort128(y.compress(castByte128Mask(odds(I8X16))).convert(VectorOperators.B2S, 0));
1084+
Vector<Short> xEvens = castShort128(x.compress(castByte128Mask(I8X16.evensMask)).convert(VectorOperators.B2S, 0));
1085+
Vector<Short> xOdds = castShort128(x.compress(castByte128Mask(I8X16.oddsMask)).convert(VectorOperators.B2S, 0));
1086+
Vector<Short> yEvens = castShort128(y.compress(castByte128Mask(I8X16.evensMask)).convert(VectorOperators.B2S, 0));
1087+
Vector<Short> yOdds = castShort128(y.compress(castByte128Mask(I8X16.oddsMask)).convert(VectorOperators.B2S, 0));
10771088
Vector<Short> xMulYEvens = xEvens.mul(yEvens);
10781089
Vector<Short> xMulYOdds = xOdds.mul(yOdds);
10791090
Vector<Short> dot = xMulYEvens.lanewise(VectorOperators.SADD, xMulYOdds);
@@ -1112,15 +1123,15 @@ private static ByteVector f64x2_ternop(ByteVector xBytes, ByteVector yBytes, Byt
11121123

11131124
private static ByteVector i32x4_relaxed_dot_i8x16_i7x16_add_s(ByteVector x, ByteVector y, ByteVector zBytes) {
11141125
IntVector z = zBytes.reinterpretAsInts();
1115-
ShortVector xEvens = castShort128(x.compress(castByte128Mask(evens(I8X16))).convert(VectorOperators.B2S, 0));
1116-
ShortVector xOdds = castShort128(x.compress(castByte128Mask(odds(I8X16))).convert(VectorOperators.B2S, 0));
1117-
ShortVector yEvens = castShort128(y.compress(castByte128Mask(evens(I8X16))).convert(VectorOperators.B2S, 0));
1118-
ShortVector yOdds = castShort128(y.compress(castByte128Mask(odds(I8X16))).convert(VectorOperators.B2S, 0));
1126+
ShortVector xEvens = castShort128(x.compress(castByte128Mask(I8X16.evensMask)).convert(VectorOperators.B2S, 0));
1127+
ShortVector xOdds = castShort128(x.compress(castByte128Mask(I8X16.oddsMask)).convert(VectorOperators.B2S, 0));
1128+
ShortVector yEvens = castShort128(y.compress(castByte128Mask(I8X16.evensMask)).convert(VectorOperators.B2S, 0));
1129+
ShortVector yOdds = castShort128(y.compress(castByte128Mask(I8X16.oddsMask)).convert(VectorOperators.B2S, 0));
11191130
ShortVector xMulYEvens = xEvens.mul(yEvens);
11201131
ShortVector xMulYOdds = xOdds.mul(yOdds);
11211132
ShortVector dot = xMulYEvens.lanewise(VectorOperators.SADD, xMulYOdds);
1122-
IntVector dotEvens = castInt128(dot.compress(castShort128Mask(evens(I16X8))).convert(VectorOperators.S2I, 0));
1123-
IntVector dotOdds = castInt128(dot.compress(castShort128Mask(odds(I16X8))).convert(VectorOperators.S2I, 0));
1133+
IntVector dotEvens = castInt128(dot.compress(castShort128Mask(I16X8.evensMask)).convert(VectorOperators.S2I, 0));
1134+
IntVector dotOdds = castInt128(dot.compress(castShort128Mask(I16X8.oddsMask)).convert(VectorOperators.S2I, 0));
11241135
IntVector dots = dotEvens.add(dotOdds);
11251136
IntVector result = dots.add(z);
11261137
return result.reinterpretAsBytes();
@@ -1231,23 +1242,6 @@ private static <E, F> ByteVector upcastBinopDowncast(ByteVector xBytes, ByteVect
12311242
return result.reinterpretAsBytes();
12321243
}
12331244

1234-
private static final boolean[] ALTERNATING_BITS;
1235-
1236-
static {
1237-
ALTERNATING_BITS = new boolean[I8X16.species().length() + 1];
1238-
for (int i = 0; i < ALTERNATING_BITS.length; i++) {
1239-
ALTERNATING_BITS[i] = i % 2 == 0;
1240-
}
1241-
}
1242-
1243-
private static <E> VectorMask<E> evens(Shape<E> shape) {
1244-
return VectorMask.fromArray(shape.species(), ALTERNATING_BITS, 0);
1245-
}
1246-
1247-
private static <E> VectorMask<E> odds(Shape<E> shape) {
1248-
return VectorMask.fromArray(shape.species(), ALTERNATING_BITS, 1);
1249-
}
1250-
12511245
@Override
12521246
public ByteVector fromArray(byte[] bytes, int offset) {
12531247
return ByteVector.fromArray(I8X16.species(), bytes, offset);

0 commit comments

Comments
 (0)