Skip to content

Commit 312875d

Browse files
committed
Simplify the integral implementation of compact
1 parent b2d601d commit 312875d

File tree

1 file changed

+36
-35
lines changed

1 file changed

+36
-35
lines changed

wasm/src/org.graalvm.wasm.jdk25/src/org/graalvm/wasm/api/Vector128OpsVectorAPI.java

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -438,34 +438,34 @@ public ByteVector binary(ByteVector xVec, ByteVector yVec, int vectorOpcode) {
438438
case Bytecode.VECTOR_F64X2_GT -> f64x2_relop(x, y, VectorOperators.GT);
439439
case Bytecode.VECTOR_F64X2_LE -> f64x2_relop(x, y, VectorOperators.LE);
440440
case Bytecode.VECTOR_F64X2_GE -> f64x2_relop(x, y, VectorOperators.GE);
441-
case Bytecode.VECTOR_I8X16_NARROW_I16X8_S -> narrow(x, y, I16X8, I8X16, VectorOperators.S2B, VectorOperators.ZERO_EXTEND_B2S, Byte.MIN_VALUE, Byte.MAX_VALUE);
442-
case Bytecode.VECTOR_I8X16_NARROW_I16X8_U -> narrow(x, y, I16X8, I8X16, VectorOperators.S2B, VectorOperators.ZERO_EXTEND_B2S, (short) 0, (short) 0xff);
441+
case Bytecode.VECTOR_I8X16_NARROW_I16X8_S -> narrow(x, y, I16X8, I8X16, Byte.MIN_VALUE, Byte.MAX_VALUE);
442+
case Bytecode.VECTOR_I8X16_NARROW_I16X8_U -> narrow(x, y, I16X8, I8X16, (short) 0, (short) 0xff);
443443
case Bytecode.VECTOR_I8X16_ADD -> binop(x, y, I8X16, VectorOperators.ADD);
444444
case Bytecode.VECTOR_I8X16_ADD_SAT_S -> binop(x, y, I8X16, VectorOperators.SADD);
445-
case Bytecode.VECTOR_I8X16_ADD_SAT_U -> binop_sat_u(x, y, I8X16, I16X8, VectorOperators.ZERO_EXTEND_B2S, VectorOperators.S2B, VectorOperators.ADD, 0, 0xff);
445+
case Bytecode.VECTOR_I8X16_ADD_SAT_U -> binop_sat_u(x, y, I8X16, I16X8, VectorOperators.ZERO_EXTEND_B2S, VectorOperators.ADD, 0, 0xff);
446446
case Bytecode.VECTOR_I8X16_SUB -> binop(x, y, I8X16, VectorOperators.SUB);
447447
case Bytecode.VECTOR_I8X16_SUB_SAT_S -> binop(x, y, I8X16, VectorOperators.SSUB);
448-
case Bytecode.VECTOR_I8X16_SUB_SAT_U -> binop_sat_u(x, y, I8X16, I16X8, VectorOperators.ZERO_EXTEND_B2S, VectorOperators.S2B, VectorOperators.SUB, 0, 0xff);
448+
case Bytecode.VECTOR_I8X16_SUB_SAT_U -> binop_sat_u(x, y, I8X16, I16X8, VectorOperators.ZERO_EXTEND_B2S, VectorOperators.SUB, 0, 0xff);
449449
case Bytecode.VECTOR_I8X16_MIN_S -> binop(x, y, I8X16, VectorOperators.MIN);
450450
case Bytecode.VECTOR_I8X16_MIN_U -> binop(x, y, I8X16, VectorOperators.UMIN);
451451
case Bytecode.VECTOR_I8X16_MAX_S -> binop(x, y, I8X16, VectorOperators.MAX);
452452
case Bytecode.VECTOR_I8X16_MAX_U -> binop(x, y, I8X16, VectorOperators.UMAX);
453-
case Bytecode.VECTOR_I8X16_AVGR_U -> avgr(x, y, I8X16, I16X8, VectorOperators.ZERO_EXTEND_B2S, VectorOperators.S2B);
454-
case Bytecode.VECTOR_I16X8_NARROW_I32X4_S -> narrow(x, y, I32X4, I16X8, VectorOperators.I2S, VectorOperators.ZERO_EXTEND_S2I, Short.MIN_VALUE, Short.MAX_VALUE);
455-
case Bytecode.VECTOR_I16X8_NARROW_I32X4_U -> narrow(x, y, I32X4, I16X8, VectorOperators.I2S, VectorOperators.ZERO_EXTEND_S2I, 0, 0xffff);
453+
case Bytecode.VECTOR_I8X16_AVGR_U -> avgr_u(x, y, I8X16, I16X8, VectorOperators.ZERO_EXTEND_B2S);
454+
case Bytecode.VECTOR_I16X8_NARROW_I32X4_S -> narrow(x, y, I32X4, I16X8, Short.MIN_VALUE, Short.MAX_VALUE);
455+
case Bytecode.VECTOR_I16X8_NARROW_I32X4_U -> narrow(x, y, I32X4, I16X8, 0, 0xffff);
456456
case Bytecode.VECTOR_I16X8_Q15MULR_SAT_S, Bytecode.VECTOR_I16X8_RELAXED_Q15MULR_S -> i16x8_q15mulr_sat_s(x, y);
457457
case Bytecode.VECTOR_I16X8_ADD -> binop(x, y, I16X8, VectorOperators.ADD);
458458
case Bytecode.VECTOR_I16X8_ADD_SAT_S -> binop(x, y, I16X8, VectorOperators.SADD);
459-
case Bytecode.VECTOR_I16X8_ADD_SAT_U -> binop_sat_u(x, y, I16X8, I32X4, VectorOperators.ZERO_EXTEND_S2I, VectorOperators.I2S, VectorOperators.ADD, 0, 0xffff);
459+
case Bytecode.VECTOR_I16X8_ADD_SAT_U -> binop_sat_u(x, y, I16X8, I32X4, VectorOperators.ZERO_EXTEND_S2I, VectorOperators.ADD, 0, 0xffff);
460460
case Bytecode.VECTOR_I16X8_SUB -> binop(x, y, I16X8, VectorOperators.SUB);
461461
case Bytecode.VECTOR_I16X8_SUB_SAT_S -> binop(x, y, I16X8, VectorOperators.SSUB);
462-
case Bytecode.VECTOR_I16X8_SUB_SAT_U -> binop_sat_u(x, y, I16X8, I32X4, VectorOperators.ZERO_EXTEND_S2I, VectorOperators.I2S, VectorOperators.SUB, 0, 0xffff);
462+
case Bytecode.VECTOR_I16X8_SUB_SAT_U -> binop_sat_u(x, y, I16X8, I32X4, VectorOperators.ZERO_EXTEND_S2I, VectorOperators.SUB, 0, 0xffff);
463463
case Bytecode.VECTOR_I16X8_MUL -> binop(x, y, I16X8, VectorOperators.MUL);
464464
case Bytecode.VECTOR_I16X8_MIN_S -> binop(x, y, I16X8, VectorOperators.MIN);
465465
case Bytecode.VECTOR_I16X8_MIN_U -> binop(x, y, I16X8, VectorOperators.UMIN);
466466
case Bytecode.VECTOR_I16X8_MAX_S -> binop(x, y, I16X8, VectorOperators.MAX);
467467
case Bytecode.VECTOR_I16X8_MAX_U -> binop(x, y, I16X8, VectorOperators.UMAX);
468-
case Bytecode.VECTOR_I16X8_AVGR_U -> avgr(x, y, I16X8, I32X4, VectorOperators.ZERO_EXTEND_S2I, VectorOperators.I2S);
468+
case Bytecode.VECTOR_I16X8_AVGR_U -> avgr_u(x, y, I16X8, I32X4, VectorOperators.ZERO_EXTEND_S2I);
469469
case Bytecode.VECTOR_I16X8_EXTMUL_LOW_I8X16_S -> extmul(x, y, I8X16, VectorOperators.B2S, 0);
470470
case Bytecode.VECTOR_I16X8_EXTMUL_LOW_I8X16_U -> extmul(x, y, I8X16, VectorOperators.ZERO_EXTEND_B2S, 0);
471471
case Bytecode.VECTOR_I16X8_EXTMUL_HIGH_I8X16_S -> extmul(x, y, I8X16, VectorOperators.B2S, 1);
@@ -883,8 +883,8 @@ private static ByteVector i32x4_trunc_sat_f32x4_u(ByteVector xBytes) {
883883
DoubleVector xHigh = castDouble128(x.convert(VectorOperators.F2D, 1));
884884
LongVector xLowTrunc = truncSatU32(xLow);
885885
LongVector xHighTrunc = truncSatU32(xHigh);
886-
IntVector resultLow = castInt128(compact(xLowTrunc, 0, I64X2, I32X4, VectorOperators.L2I, VectorOperators.ZERO_EXTEND_I2L));
887-
IntVector resultHigh = castInt128(compact(xHighTrunc, 0, I64X2, I32X4, VectorOperators.L2I, VectorOperators.ZERO_EXTEND_I2L));
886+
IntVector resultLow = castInt128(compact(xLowTrunc, 0, I64X2, I32X4));
887+
IntVector resultHigh = castInt128(compact(xHighTrunc, -1, I64X2, I32X4));
888888
IntVector result = resultLow.or(resultHigh);
889889
return result.reinterpretAsBytes();
890890
}
@@ -903,7 +903,7 @@ private static ByteVector f32x4_convert_i32x4_u(ByteVector xBytes) {
903903
private static ByteVector i32x4_trunc_sat_f64x2_u_zero(ByteVector xBytes) {
904904
DoubleVector x = F64X2.reinterpret(xBytes);
905905
LongVector longResult = truncSatU32(x);
906-
IntVector result = castInt128(compact(longResult, 0, I64X2, I32X4, VectorOperators.L2I, VectorOperators.ZERO_EXTEND_I2L));
906+
IntVector result = castInt128(compact(longResult, 0, I64X2, I32X4));
907907
return result.reinterpretAsBytes();
908908
}
909909

@@ -965,8 +965,11 @@ private static ByteVector f64x2_relop(ByteVector xBytes, ByteVector yBytes, Vect
965965
* {@code VectorSupport#convert} in a way that would map a vector of N elements to a vector of M
966966
* elements, where M > N. Such a situation is currently not supported by the Graal compiler
967967
* [GR-68216].
968+
* <p>
969+
* Works only for integral shapes. See {@link #compactGeneral} for the general implementation.
970+
* </p>
968971
*/
969-
private static <E, F> Vector<F> compact(Vector<E> vec, int part, Shape<E> inShape, Shape<F> outShape, VectorOperators.Conversion<E, F> downcast, VectorOperators.Conversion<F, E> upcast) {
972+
private static <E, F> Vector<F> compact(Vector<E> vec, int part, Shape<E> inShape, Shape<F> outShape) {
970973
// Works only for integral shapes.
971974
assert inShape.species().elementType() == short.class && outShape.species().elementType() == byte.class ||
972975
inShape.species().elementType() == int.class && outShape.species().elementType() == short.class ||
@@ -976,8 +979,7 @@ private static <E, F> Vector<F> compact(Vector<E> vec, int part, Shape<E> inShap
976979
case -1 -> outShape.highMask;
977980
default -> throw CompilerDirectives.shouldNotReachHere();
978981
};
979-
VectorSpecies<F> halfSizeOutShape = outShape.species().withShape(VectorShape.S_64_BIT);
980-
return vec.convertShape(downcast, halfSizeOutShape, 0).convertShape(upcast, inShape.species(), 0).reinterpretShape(outShape.species(), 0).rearrange(outShape.compressEvensShuffle, mask);
982+
return vec.reinterpretShape(outShape.species(), 0).rearrange(outShape.compressEvensShuffle, mask);
981983
}
982984

983985
/**
@@ -987,11 +989,11 @@ private static <W, WI, N, NI> Vector<N> compactGeneral(Vector<W> vec, int part,
987989
Shape<WI> wideIntegralShape, Shape<N> narrowShape,
988990
VectorOperators.Conversion<W, N> downcast,
989991
VectorOperators.Conversion<N, NI> asIntegral,
990-
VectorOperators.Conversion<NI, WI> upcast) {
992+
VectorOperators.Conversion<NI, WI> zeroExtend) {
991993
// NI and WI must be integral types, with NI being half the size of WI.
992-
assert upcast.domainType() == byte.class && upcast.rangeType() == short.class ||
993-
upcast.domainType() == short.class && upcast.rangeType() == int.class ||
994-
upcast.domainType() == int.class && upcast.rangeType() == long.class;
994+
assert zeroExtend.domainType() == byte.class && zeroExtend.rangeType() == short.class ||
995+
zeroExtend.domainType() == short.class && zeroExtend.rangeType() == int.class ||
996+
zeroExtend.domainType() == int.class && zeroExtend.rangeType() == long.class;
995997
VectorMask<N> mask = switch (part) {
996998
case 0 -> narrowShape.lowMask;
997999
case -1 -> narrowShape.highMask;
@@ -1000,44 +1002,43 @@ private static <W, WI, N, NI> Vector<N> compactGeneral(Vector<W> vec, int part,
10001002
VectorSpecies<N> halfSizeOutShape = narrowShape.species().withShape(VectorShape.S_64_BIT);
10011003
Vector<N> down = vec.convertShape(downcast, halfSizeOutShape, 0);
10021004
Vector<NI> downIntegral = down.convert(asIntegral, 0);
1003-
Vector<WI> upIntegral = downIntegral.convertShape(upcast, wideIntegralShape.species(), 0);
1005+
Vector<WI> upIntegral = downIntegral.convertShape(zeroExtend, wideIntegralShape.species(), 0);
10041006
Vector<N> narrowEvens = upIntegral.reinterpretShape(narrowShape.species(), 0);
10051007
return narrowEvens.rearrange(narrowShape.compressEvensShuffle, mask);
10061008
}
10071009

1008-
private static <E, F> ByteVector narrow(ByteVector xBytes, ByteVector yBytes,
1009-
Shape<E> inShape, Shape<F> outShape, VectorOperators.Conversion<E, F> downcast, VectorOperators.Conversion<F, E> upcast,
1010-
long min, long max) {
1010+
private static <E, F> ByteVector narrow(ByteVector xBytes, ByteVector yBytes, Shape<E> inShape, Shape<F> outShape, long min, long max) {
10111011
Vector<E> x = inShape.reinterpret(xBytes);
10121012
Vector<E> y = inShape.reinterpret(yBytes);
10131013
Vector<E> xSat = sat(x, inShape, min, max);
10141014
Vector<E> ySat = sat(y, inShape, min, max);
1015-
Vector<F> resultLow = compact(xSat, 0, inShape, outShape, downcast, upcast);
1016-
Vector<F> resultHigh = compact(ySat, -1, inShape, outShape, downcast, upcast);
1015+
Vector<F> resultLow = compact(xSat, 0, inShape, outShape);
1016+
Vector<F> resultHigh = compact(ySat, -1, inShape, outShape);
10171017
Vector<F> result = resultLow.lanewise(VectorOperators.OR, resultHigh);
10181018
return result.reinterpretAsBytes();
10191019
}
10201020

10211021
private static <E, F> ByteVector binop_sat_u(ByteVector xBytes, ByteVector yBytes,
10221022
Shape<E> shape, Shape<F> extendedShape,
1023-
VectorOperators.Conversion<E, F> upcast, VectorOperators.Conversion<F, E> downcast,
1023+
VectorOperators.Conversion<E, F> upcast,
10241024
VectorOperators.Binary op, long min, long max) {
1025-
return upcastBinopDowncast(xBytes, yBytes, shape, extendedShape, upcast, downcast, (x, y) -> {
1025+
return upcastBinopDowncast(xBytes, yBytes, shape, extendedShape, upcast, (x, y) -> {
10261026
Vector<F> rawResult = x.lanewise(op, y);
10271027
Vector<F> satResult = sat(rawResult, extendedShape, min, max);
10281028
return satResult;
10291029
});
10301030
}
10311031

1032-
private static <E, F> ByteVector avgr(ByteVector xBytes, ByteVector yBytes, Shape<E> shape, Shape<F> extendedShape, VectorOperators.Conversion<E, F> upcast,
1033-
VectorOperators.Conversion<F, E> downcast) {
1032+
private static <E, F> ByteVector avgr_u(ByteVector xBytes, ByteVector yBytes,
1033+
Shape<E> shape, Shape<F> extendedShape,
1034+
VectorOperators.Conversion<E, F> upcast) {
10341035
Vector<F> one = extendedShape.broadcast(1);
10351036
Vector<F> two = extendedShape.broadcast(2);
1036-
return upcastBinopDowncast(xBytes, yBytes, shape, extendedShape, upcast, downcast, (x, y) -> x.add(y).add(one).div(two));
1037+
return upcastBinopDowncast(xBytes, yBytes, shape, extendedShape, upcast, (x, y) -> x.add(y).add(one).div(two));
10371038
}
10381039

10391040
private static ByteVector i16x8_q15mulr_sat_s(ByteVector xBytes, ByteVector yBytes) {
1040-
return upcastBinopDowncast(xBytes, yBytes, I16X8, I32X4, VectorOperators.S2I, VectorOperators.I2S, (x, y) -> {
1041+
return upcastBinopDowncast(xBytes, yBytes, I16X8, I32X4, VectorOperators.S2I, (x, y) -> {
10411042
Vector<Integer> rawResult = x.mul(y).add(I32X4.broadcast(1 << 14)).lanewise(VectorOperators.ASHR, I32X4.broadcast(15));
10421043
Vector<Integer> satResult = sat(rawResult, I32X4, Short.MIN_VALUE, Short.MAX_VALUE);
10431044
return satResult;
@@ -1228,16 +1229,16 @@ private static LongVector truncSatU32(DoubleVector x) {
12281229

12291230
private static <E, F> ByteVector upcastBinopDowncast(ByteVector xBytes, ByteVector yBytes,
12301231
Shape<E> shape, Shape<F> extendedShape,
1231-
VectorOperators.Conversion<E, F> upcast, VectorOperators.Conversion<F, E> downcast,
1232+
VectorOperators.Conversion<E, F> upcast,
12321233
BinaryVectorOp<F> op) {
12331234
Vector<E> x = shape.reinterpret(xBytes);
12341235
Vector<E> y = shape.reinterpret(yBytes);
12351236
Vector<F> xLow = x.convert(upcast, 0);
12361237
Vector<F> xHigh = x.convert(upcast, 1);
12371238
Vector<F> yLow = y.convert(upcast, 0);
12381239
Vector<F> yHigh = y.convert(upcast, 1);
1239-
Vector<E> resultLow = compact(op.apply(xLow, yLow), 0, extendedShape, shape, downcast, upcast);
1240-
Vector<E> resultHigh = compact(op.apply(xHigh, yHigh), -1, extendedShape, shape, downcast, upcast);
1240+
Vector<E> resultLow = compact(op.apply(xLow, yLow), 0, extendedShape, shape);
1241+
Vector<E> resultHigh = compact(op.apply(xHigh, yHigh), -1, extendedShape, shape);
12411242
Vector<E> result = resultLow.lanewise(VectorOperators.OR, resultHigh);
12421243
return result.reinterpretAsBytes();
12431244
}

0 commit comments

Comments
 (0)