Skip to content

Commit 1eba7f3

Browse files
committed
Generic impl of narrow that uses compact instead of convert
1 parent baec29a commit 1eba7f3

File tree

1 file changed

+48
-45
lines changed

1 file changed

+48
-45
lines changed

wasm/src/org.graalvm.wasm.jdk25/src/org/graalvm/wasm/api/Vector128OpsVectorAPI.java

Lines changed: 48 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -75,30 +75,34 @@ static Vector128Ops<ByteVector> create() {
7575
return new Vector128OpsVectorAPI();
7676
}
7777

78-
private interface Shape<E> {
78+
private abstract static class Shape<E> {
7979

80-
Vector<E> reinterpret(ByteVector bytes);
80+
public final VectorShuffle<E> compressEvensShuffle = VectorShuffle.fromOp(species(), i -> (i * 2) % species().length());
81+
public final VectorMask<E> lowMask = VectorMask.fromLong(species(), (1L << (species().length() / 2)) - 1);
82+
public final VectorMask<E> highMask = VectorMask.fromLong(species(), ((1L << (species().length() / 2)) - 1) << (species().length() / 2));
8183

82-
VectorSpecies<E> species();
84+
public abstract Vector<E> reinterpret(ByteVector bytes);
8385

84-
default Vector<E> zero() {
86+
public abstract VectorSpecies<E> species();
87+
88+
public Vector<E> zero() {
8589
return species().zero();
8690
}
8791

88-
default Vector<E> broadcast(long e) {
92+
public Vector<E> broadcast(long e) {
8993
return species().broadcast(e);
9094
}
9195

9296
/**
9397
* This is used by floating-point Shapes to be able to broadcast -0.0, which cannot be
9498
* faithfully represented as a long.
9599
*/
96-
default Vector<E> broadcast(@SuppressWarnings("unused") double e) {
100+
public Vector<E> broadcast(@SuppressWarnings("unused") double e) {
97101
throw CompilerDirectives.shouldNotReachHere();
98102
}
99103
}
100104

101-
private static final class I8X16Shape implements Shape<Byte> {
105+
private static final class I8X16Shape extends Shape<Byte> {
102106

103107
private I8X16Shape() {
104108
}
@@ -130,7 +134,7 @@ public ByteVector broadcast(byte e) {
130134

131135
private static final I8X16Shape I8X16 = new I8X16Shape();
132136

133-
private static final class I16X8Shape implements Shape<Short> {
137+
private static final class I16X8Shape extends Shape<Short> {
134138

135139
private I16X8Shape() {
136140
}
@@ -162,7 +166,7 @@ public ShortVector broadcast(short e) {
162166

163167
private static final I16X8Shape I16X8 = new I16X8Shape();
164168

165-
private static final class I32X4Shape implements Shape<Integer> {
169+
private static final class I32X4Shape extends Shape<Integer> {
166170

167171
private I32X4Shape() {
168172
}
@@ -194,7 +198,7 @@ public IntVector broadcast(int e) {
194198

195199
private static final I32X4Shape I32X4 = new I32X4Shape();
196200

197-
private static final class I64X2Shape implements Shape<Long> {
201+
private static final class I64X2Shape extends Shape<Long> {
198202

199203
private I64X2Shape() {
200204
}
@@ -222,7 +226,7 @@ public LongVector broadcast(long e) {
222226

223227
private static final I64X2Shape I64X2 = new I64X2Shape();
224228

225-
private static final class F32X4Shape implements Shape<Float> {
229+
private static final class F32X4Shape extends Shape<Float> {
226230

227231
private F32X4Shape() {
228232
}
@@ -263,7 +267,7 @@ public FloatVector broadcast(float e) {
263267

264268
private static final F32X4Shape F32X4 = new F32X4Shape();
265269

266-
private static final class F64X2Shape implements Shape<Double> {
270+
private static final class F64X2Shape extends Shape<Double> {
267271

268272
private F64X2Shape() {
269273
}
@@ -423,8 +427,8 @@ public ByteVector binary(ByteVector xVec, ByteVector yVec, int vectorOpcode) {
423427
case Bytecode.VECTOR_F64X2_GT -> f64x2_relop(x, y, VectorOperators.GT);
424428
case Bytecode.VECTOR_F64X2_LE -> f64x2_relop(x, y, VectorOperators.LE);
425429
case Bytecode.VECTOR_F64X2_GE -> f64x2_relop(x, y, VectorOperators.GE);
426-
case Bytecode.VECTOR_I8X16_NARROW_I16X8_S -> i8x16_narrow_i16x8_s(x, y, Byte.MIN_VALUE, Byte.MAX_VALUE);
427-
case Bytecode.VECTOR_I8X16_NARROW_I16X8_U -> i8x16_narrow_i16x8_s(x, y, (short) 0, (short) 0xff);
430+
case Bytecode.VECTOR_I8X16_NARROW_I16X8_S -> narrow(x, y, I16X8, I8X16, VectorOperators.S2B, VectorOperators.ZERO_EXTEND_B2S, Byte.MIN_VALUE, Byte.MAX_VALUE);
431+
case Bytecode.VECTOR_I8X16_NARROW_I16X8_U -> narrow(x, y, I16X8, I8X16, VectorOperators.S2B, VectorOperators.ZERO_EXTEND_B2S, (short) 0, (short) 0xff);
428432
case Bytecode.VECTOR_I8X16_ADD -> binop(x, y, I8X16, VectorOperators.ADD);
429433
case Bytecode.VECTOR_I8X16_ADD_SAT_S -> binop(x, y, I8X16, VectorOperators.SADD);
430434
case Bytecode.VECTOR_I8X16_ADD_SAT_U -> binop_sat_u(x, y, I8X16, I16X8, VectorOperators.ZERO_EXTEND_B2S, VectorOperators.S2B, VectorOperators.ADD, 0, 0xff);
@@ -436,8 +440,8 @@ public ByteVector binary(ByteVector xVec, ByteVector yVec, int vectorOpcode) {
436440
case Bytecode.VECTOR_I8X16_MAX_S -> binop(x, y, I8X16, VectorOperators.MAX);
437441
case Bytecode.VECTOR_I8X16_MAX_U -> binop(x, y, I8X16, VectorOperators.UMAX);
438442
case Bytecode.VECTOR_I8X16_AVGR_U -> avgr(x, y, I8X16, VectorOperators.ZERO_EXTEND_B2S, VectorOperators.S2B);
439-
case Bytecode.VECTOR_I16X8_NARROW_I32X4_S -> narrow(x, y, I32X4, VectorOperators.I2S, Short.MIN_VALUE, Short.MAX_VALUE);
440-
case Bytecode.VECTOR_I16X8_NARROW_I32X4_U -> narrow(x, y, I32X4, VectorOperators.I2S, 0, 0xffff);
443+
case Bytecode.VECTOR_I16X8_NARROW_I32X4_S -> narrow(x, y, I32X4, I16X8, VectorOperators.I2S, VectorOperators.ZERO_EXTEND_S2I, Short.MIN_VALUE, Short.MAX_VALUE);
444+
case Bytecode.VECTOR_I16X8_NARROW_I32X4_U -> narrow(x, y, I32X4, I16X8, VectorOperators.I2S, VectorOperators.ZERO_EXTEND_S2I, 0, 0xffff);
441445
case Bytecode.VECTOR_I16X8_Q15MULR_SAT_S, Bytecode.VECTOR_I16X8_RELAXED_Q15MULR_S -> i16x8_q15mulr_sat_s(x, y);
442446
case Bytecode.VECTOR_I16X8_ADD -> binop(x, y, I16X8, VectorOperators.ADD);
443447
case Bytecode.VECTOR_I16X8_ADD_SAT_S -> binop(x, y, I16X8, VectorOperators.SADD);
@@ -937,37 +941,36 @@ private static ByteVector f64x2_relop(ByteVector xBytes, ByteVector yBytes, Vect
937941
return result.reinterpretAsBytes();
938942
}
939943

940-
private static final VectorShuffle<Byte> EVENS_I8X16 = VectorShuffle.fromOp(ByteVector.SPECIES_128, i -> (i * 2) % 16);
941-
private static final VectorMask<Byte> LOW_I8X16 = VectorMask.fromLong(ByteVector.SPECIES_128, (1L << 8) - 1);
942-
private static final VectorMask<Byte> HIGH_I8X16 = VectorMask.fromLong(ByteVector.SPECIES_128, ((1L << 8) - 1) << 8);
943-
944-
private static ByteVector i8x16_compact_16x8(Vector<Short> vec, int part) {
945-
VectorMask<Byte> mask = switch (part) {
946-
case 0 -> LOW_I8X16;
947-
case -1 -> HIGH_I8X16;
944+
/**
945+
* Implements {@code vec.convert(downcast, part)} while avoiding the use of
946+
* {@code VectorSupport#convert} in a way that would map a vector of N elements to a vector of M
947+
* elements, where M > N. Such a situation is currently not supported by the Graal compiler
948+
* [GR-68216].
949+
*/
950+
private static <E, F> Vector<F> compact(Vector<E> vec, int part, Shape<E> inShape, Shape<F> outShape, VectorOperators.Conversion<E, F> downcast, VectorOperators.Conversion<F, E> upcast) {
951+
// Works only for integral shapes.
952+
assert inShape.species().elementType() == short.class && outShape.species().elementType() == byte.class ||
953+
inShape.species().elementType() == int.class && outShape.species().elementType() == short.class ||
954+
inShape.species().elementType() == long.class && outShape.species().elementType() == int.class;
955+
VectorMask<F> mask = switch (part) {
956+
case 0 -> outShape.lowMask;
957+
case -1 -> outShape.highMask;
948958
default -> throw CompilerDirectives.shouldNotReachHere();
949959
};
950-
return vec.convertShape(VectorOperators.S2B, ByteVector.SPECIES_64, 0).convertShape(VectorOperators.ZERO_EXTEND_B2S, ShortVector.SPECIES_128, 0).reinterpretAsBytes().rearrange(EVENS_I8X16, mask);
951-
}
952-
953-
private static ByteVector i8x16_narrow_i16x8_s(ByteVector xBytes, ByteVector yBytes, short min, short max) {
954-
ShortVector x = I16X8.reinterpret(xBytes);
955-
ShortVector y = I16X8.reinterpret(yBytes);
956-
Vector<Short> xSat = sat(x, I16X8, min, max);
957-
Vector<Short> ySat = sat(y, I16X8, min, max);
958-
ByteVector resultLow = i8x16_compact_16x8(xSat, 0);
959-
ByteVector resultHigh = i8x16_compact_16x8(ySat, -1);
960-
return resultLow.or(resultHigh);
961-
}
962-
963-
private static <E, F> ByteVector narrow(ByteVector xBytes, ByteVector yBytes, Shape<E> shape, VectorOperators.Conversion<E, F> conv, long min, long max) {
964-
Vector<E> x = shape.reinterpret(xBytes);
965-
Vector<E> y = shape.reinterpret(yBytes);
966-
Vector<E> xSat = sat(x, shape, min, max);
967-
Vector<E> ySat = sat(y, shape, min, max);
968-
Vector<F> resultLow = xSat.convert(conv, 0);
969-
Vector<F> resultHigh = ySat.convert(conv, -1);
970-
Vector<F> result = firstNonzero(resultLow, resultHigh);
960+
VectorSpecies<F> halfSizeOutShape = outShape.species().withShape(VectorShape.S_64_BIT);
961+
return vec.convertShape(downcast, halfSizeOutShape, 0).convertShape(upcast, inShape.species(), 0).reinterpretShape(outShape.species(), 0).rearrange(outShape.compressEvensShuffle, mask);
962+
}
963+
964+
private static <E, F> ByteVector narrow(ByteVector xBytes, ByteVector yBytes,
965+
Shape<E> inShape, Shape<F> outShape, VectorOperators.Conversion<E, F> downcast, VectorOperators.Conversion<F, E> upcast,
966+
long min, long max) {
967+
Vector<E> x = inShape.reinterpret(xBytes);
968+
Vector<E> y = inShape.reinterpret(yBytes);
969+
Vector<E> xSat = sat(x, inShape, min, max);
970+
Vector<E> ySat = sat(y, inShape, min, max);
971+
Vector<F> resultLow = compact(xSat, 0, inShape, outShape, downcast, upcast);
972+
Vector<F> resultHigh = compact(ySat, -1, inShape, outShape, downcast, upcast);
973+
Vector<F> result = resultLow.lanewise(VectorOperators.OR, resultHigh);
971974
return result.reinterpretAsBytes();
972975
}
973976

0 commit comments

Comments
 (0)