Skip to content

Commit 371d621

Browse files
committed
Use compact everywhere instead of convert when increasing number of lanes
1 parent 1eba7f3 commit 371d621

File tree

1 file changed

+59
-28
lines changed

1 file changed

+59
-28
lines changed

wasm/src/org.graalvm.wasm.jdk25/src/org/graalvm/wasm/api/Vector128OpsVectorAPI.java

Lines changed: 59 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ public ByteVector unary(ByteVector xVec, int vectorOpcode) {
363363
case Bytecode.VECTOR_I32X4_TRUNC_SAT_F64X2_U_ZERO, Bytecode.VECTOR_I32X4_RELAXED_TRUNC_F64X2_U_ZERO -> I8X16.species().fromArray(fallbackOps.unary(x.toArray(), vectorOpcode), 0); // GR-51421
364364
case Bytecode.VECTOR_F64X2_CONVERT_LOW_I32X4_S -> convert(x, I32X4, VectorOperators.I2D);
365365
case Bytecode.VECTOR_F64X2_CONVERT_LOW_I32X4_U -> f64x2_convert_low_i32x4_u(x);
366-
case Bytecode.VECTOR_F32X4_DEMOTE_F64X2_ZERO -> convert(x, F64X2, VectorOperators.D2F);
366+
case Bytecode.VECTOR_F32X4_DEMOTE_F64X2_ZERO -> f32X4_demote_f64X2_zero(x);
367367
case Bytecode.VECTOR_F64X2_PROMOTE_LOW_F32X4 -> convert(x, F32X4, VectorOperators.F2D);
368368
default -> throw CompilerDirectives.shouldNotReachHere();
369369
});
@@ -439,7 +439,7 @@ public ByteVector binary(ByteVector xVec, ByteVector yVec, int vectorOpcode) {
439439
case Bytecode.VECTOR_I8X16_MIN_U -> binop(x, y, I8X16, VectorOperators.UMIN);
440440
case Bytecode.VECTOR_I8X16_MAX_S -> binop(x, y, I8X16, VectorOperators.MAX);
441441
case Bytecode.VECTOR_I8X16_MAX_U -> binop(x, y, I8X16, VectorOperators.UMAX);
442-
case Bytecode.VECTOR_I8X16_AVGR_U -> avgr(x, y, I8X16, VectorOperators.ZERO_EXTEND_B2S, VectorOperators.S2B);
442+
case Bytecode.VECTOR_I8X16_AVGR_U -> avgr(x, y, I8X16, I16X8, VectorOperators.ZERO_EXTEND_B2S, VectorOperators.S2B);
443443
case Bytecode.VECTOR_I16X8_NARROW_I32X4_S -> narrow(x, y, I32X4, I16X8, VectorOperators.I2S, VectorOperators.ZERO_EXTEND_S2I, Short.MIN_VALUE, Short.MAX_VALUE);
444444
case Bytecode.VECTOR_I16X8_NARROW_I32X4_U -> narrow(x, y, I32X4, I16X8, VectorOperators.I2S, VectorOperators.ZERO_EXTEND_S2I, 0, 0xffff);
445445
case Bytecode.VECTOR_I16X8_Q15MULR_SAT_S, Bytecode.VECTOR_I16X8_RELAXED_Q15MULR_S -> i16x8_q15mulr_sat_s(x, y);
@@ -454,7 +454,7 @@ public ByteVector binary(ByteVector xVec, ByteVector yVec, int vectorOpcode) {
454454
case Bytecode.VECTOR_I16X8_MIN_U -> binop(x, y, I16X8, VectorOperators.UMIN);
455455
case Bytecode.VECTOR_I16X8_MAX_S -> binop(x, y, I16X8, VectorOperators.MAX);
456456
case Bytecode.VECTOR_I16X8_MAX_U -> binop(x, y, I16X8, VectorOperators.UMAX);
457-
case Bytecode.VECTOR_I16X8_AVGR_U -> avgr(x, y, I16X8, VectorOperators.ZERO_EXTEND_S2I, VectorOperators.I2S);
457+
case Bytecode.VECTOR_I16X8_AVGR_U -> avgr(x, y, I16X8, I32X4, VectorOperators.ZERO_EXTEND_S2I, VectorOperators.I2S);
458458
case Bytecode.VECTOR_I16X8_EXTMUL_LOW_I8X16_S -> extmul(x, y, I8X16, VectorOperators.B2S, 0);
459459
case Bytecode.VECTOR_I16X8_EXTMUL_LOW_I8X16_U -> extmul(x, y, I8X16, VectorOperators.ZERO_EXTEND_B2S, 0);
460460
case Bytecode.VECTOR_I16X8_EXTMUL_HIGH_I8X16_S -> extmul(x, y, I8X16, VectorOperators.B2S, 1);
@@ -870,27 +870,29 @@ private static ByteVector i32x4_trunc_sat_f32x4_u(ByteVector xBytes) {
870870
FloatVector x = F32X4.reinterpret(xBytes);
871871
DoubleVector xLow = castDouble128(x.convert(VectorOperators.F2D, 0));
872872
DoubleVector xHigh = castDouble128(x.convert(VectorOperators.F2D, 1));
873-
IntVector resultLow = castInt128(truncSatU32(xLow).convert(VectorOperators.L2I, 0));
874-
IntVector resultHigh = castInt128(truncSatU32(xHigh).convert(VectorOperators.L2I, -1));
875-
Vector<Integer> result = firstNonzero(resultLow, resultHigh);
873+
LongVector xLowTrunc = truncSatU32(xLow);
874+
LongVector xHighTrunc = truncSatU32(xHigh);
875+
IntVector resultLow = castInt128(compact(xLowTrunc, 0, I64X2, I32X4, VectorOperators.L2I, VectorOperators.ZERO_EXTEND_I2L));
876+
IntVector resultHigh = castInt128(compact(xHighTrunc, 0, I64X2, I32X4, VectorOperators.L2I, VectorOperators.ZERO_EXTEND_I2L));
877+
IntVector result = resultLow.or(resultHigh);
876878
return result.reinterpretAsBytes();
877879
}
878880

879881
private static ByteVector f32x4_convert_i32x4_u(ByteVector xBytes) {
880882
IntVector x = xBytes.reinterpretAsInts();
881883
LongVector xUnsignedLow = castLong128(x.convert(VectorOperators.ZERO_EXTEND_I2L, 0));
882884
LongVector xUnsignedHigh = castLong128(x.convert(VectorOperators.ZERO_EXTEND_I2L, 1));
883-
FloatVector resultLow = castFloat128(xUnsignedLow.convert(VectorOperators.L2F, 0));
884-
FloatVector resultHigh = castFloat128(xUnsignedHigh.convert(VectorOperators.L2F, -1));
885-
Vector<Float> result = firstNonzero(resultLow, resultHigh);
885+
FloatVector resultLow = castFloat128(compactGeneral(xUnsignedLow, 0, I64X2, F32X4, VectorOperators.L2F, VectorOperators.REINTERPRET_F2I, VectorOperators.ZERO_EXTEND_I2L));
886+
FloatVector resultHigh = castFloat128(compactGeneral(xUnsignedHigh, -1, I64X2, F32X4, VectorOperators.L2F, VectorOperators.REINTERPRET_F2I, VectorOperators.ZERO_EXTEND_I2L));
887+
IntVector result = resultLow.reinterpretAsInts().or(resultHigh.reinterpretAsInts());
886888
return result.reinterpretAsBytes();
887889
}
888890

889891
@SuppressWarnings("unused")
890892
private static ByteVector i32x4_trunc_sat_f64x2_u_zero(ByteVector xBytes) {
891893
DoubleVector x = F64X2.reinterpret(xBytes);
892894
LongVector longResult = truncSatU32(x);
893-
IntVector result = castInt128(longResult.convert(VectorOperators.L2I, 0));
895+
IntVector result = castInt128(compact(longResult, 0, I64X2, I32X4, VectorOperators.L2I, VectorOperators.ZERO_EXTEND_I2L));
894896
return result.reinterpretAsBytes();
895897
}
896898

@@ -902,6 +904,12 @@ private static ByteVector f64x2_convert_low_i32x4_u(ByteVector xBytes) {
902904
return result.reinterpretAsBytes();
903905
}
904906

907+
private static ByteVector f32X4_demote_f64X2_zero(ByteVector xBytes) {
908+
DoubleVector x = F64X2.reinterpret(xBytes);
909+
Vector<Float> result = compactGeneral(x, 0, I64X2, F32X4, VectorOperators.D2F, VectorOperators.REINTERPRET_F2I, VectorOperators.ZERO_EXTEND_I2L);
910+
return result.reinterpretAsBytes();
911+
}
912+
905913
private static ByteVector i8x16_swizzle(ByteVector valueBytes, ByteVector indexBytes) {
906914
ByteVector values = valueBytes;
907915
ByteVector indices = indexBytes;
@@ -961,6 +969,31 @@ private static <E, F> Vector<F> compact(Vector<E> vec, int part, Shape<E> inShap
961969
return vec.convertShape(downcast, halfSizeOutShape, 0).convertShape(upcast, inShape.species(), 0).reinterpretShape(outShape.species(), 0).rearrange(outShape.compressEvensShuffle, mask);
962970
}
963971

972+
/**
973+
* Like {@link #compact}, but generalized for non-integral input and output shapes.
974+
*/
975+
private static <W, WI, N, NI> Vector<N> compactGeneral(Vector<W> vec, int part,
976+
Shape<WI> wideIntegralShape, Shape<N> narrowShape,
977+
VectorOperators.Conversion<W, N> downcast,
978+
VectorOperators.Conversion<N, NI> asIntegral,
979+
VectorOperators.Conversion<NI, WI> upcast) {
980+
// NI and WI must be integral types, with NI being half the size of WI.
981+
assert upcast.domainType() == byte.class && upcast.rangeType() == short.class ||
982+
upcast.domainType() == short.class && upcast.rangeType() == int.class ||
983+
upcast.domainType() == int.class && upcast.rangeType() == long.class;
984+
VectorMask<N> mask = switch (part) {
985+
case 0 -> narrowShape.lowMask;
986+
case -1 -> narrowShape.highMask;
987+
default -> throw CompilerDirectives.shouldNotReachHere();
988+
};
989+
VectorSpecies<N> halfSizeOutShape = narrowShape.species().withShape(VectorShape.S_64_BIT);
990+
Vector<N> down = vec.convertShape(downcast, halfSizeOutShape, 0);
991+
Vector<NI> downIntegral = down.convert(asIntegral, 0);
992+
Vector<WI> upIntegral = downIntegral.convertShape(upcast, wideIntegralShape.species(), 0);
993+
Vector<N> narrowEvens = upIntegral.reinterpretShape(narrowShape.species(), 0);
994+
return narrowEvens.rearrange(narrowShape.compressEvensShuffle, mask);
995+
}
996+
964997
private static <E, F> ByteVector narrow(ByteVector xBytes, ByteVector yBytes,
965998
Shape<E> inShape, Shape<F> outShape, VectorOperators.Conversion<E, F> downcast, VectorOperators.Conversion<F, E> upcast,
966999
long min, long max) {
@@ -974,23 +1007,26 @@ private static <E, F> ByteVector narrow(ByteVector xBytes, ByteVector yBytes,
9741007
return result.reinterpretAsBytes();
9751008
}
9761009

977-
private static <E, F> ByteVector binop_sat_u(ByteVector xBytes, ByteVector yBytes, Shape<E> shape, Shape<F> extendedShape, VectorOperators.Conversion<E, F> upcast, VectorOperators.Conversion<F, E> downcast,
1010+
private static <E, F> ByteVector binop_sat_u(ByteVector xBytes, ByteVector yBytes,
1011+
Shape<E> shape, Shape<F> extendedShape,
1012+
VectorOperators.Conversion<E, F> upcast, VectorOperators.Conversion<F, E> downcast,
9781013
VectorOperators.Binary op, long min, long max) {
979-
return upcastBinopDowncast(xBytes, yBytes, shape, upcast, downcast, (x, y) -> {
1014+
return upcastBinopDowncast(xBytes, yBytes, shape, extendedShape, upcast, downcast, (x, y) -> {
9801015
Vector<F> rawResult = x.lanewise(op, y);
9811016
Vector<F> satResult = sat(rawResult, extendedShape, min, max);
9821017
return satResult;
9831018
});
9841019
}
9851020

986-
private static <E, F> ByteVector avgr(ByteVector xBytes, ByteVector yBytes, Shape<E> shape, VectorOperators.Conversion<E, F> upcast, VectorOperators.Conversion<F, E> downcast) {
987-
Vector<F> one = VectorShape.S_128_BIT.withLanes(upcast.rangeType()).broadcast(1);
988-
Vector<F> two = VectorShape.S_128_BIT.withLanes(upcast.rangeType()).broadcast(2);
989-
return upcastBinopDowncast(xBytes, yBytes, shape, upcast, downcast, (x, y) -> x.add(y).add(one).div(two));
1021+
private static <E, F> ByteVector avgr(ByteVector xBytes, ByteVector yBytes, Shape<E> shape, Shape<F> extendedShape, VectorOperators.Conversion<E, F> upcast,
1022+
VectorOperators.Conversion<F, E> downcast) {
1023+
Vector<F> one = extendedShape.broadcast(1);
1024+
Vector<F> two = extendedShape.broadcast(2);
1025+
return upcastBinopDowncast(xBytes, yBytes, shape, extendedShape, upcast, downcast, (x, y) -> x.add(y).add(one).div(two));
9901026
}
9911027

9921028
private static ByteVector i16x8_q15mulr_sat_s(ByteVector xBytes, ByteVector yBytes) {
993-
return upcastBinopDowncast(xBytes, yBytes, I16X8, VectorOperators.S2I, VectorOperators.I2S, (x, y) -> {
1029+
return upcastBinopDowncast(xBytes, yBytes, I16X8, I32X4, VectorOperators.S2I, VectorOperators.I2S, (x, y) -> {
9941030
Vector<Integer> rawResult = x.mul(y).add(I32X4.broadcast(1 << 14)).lanewise(VectorOperators.ASHR, I32X4.broadcast(15));
9951031
Vector<Integer> satResult = sat(rawResult, I32X4, Short.MIN_VALUE, Short.MAX_VALUE);
9961032
return satResult;
@@ -1179,17 +1215,19 @@ private static LongVector truncSatU32(DoubleVector x) {
11791215
return trunc.blend(u32max, overflow).blend(zero, underflow);
11801216
}
11811217

1182-
private static <E, F> ByteVector upcastBinopDowncast(ByteVector xBytes, ByteVector yBytes, Shape<E> shape, VectorOperators.Conversion<E, F> upcast, VectorOperators.Conversion<F, E> downcast,
1218+
private static <E, F> ByteVector upcastBinopDowncast(ByteVector xBytes, ByteVector yBytes,
1219+
Shape<E> shape, Shape<F> extendedShape,
1220+
VectorOperators.Conversion<E, F> upcast, VectorOperators.Conversion<F, E> downcast,
11831221
BinaryVectorOp<F> op) {
11841222
Vector<E> x = shape.reinterpret(xBytes);
11851223
Vector<E> y = shape.reinterpret(yBytes);
11861224
Vector<F> xLow = x.convert(upcast, 0);
11871225
Vector<F> xHigh = x.convert(upcast, 1);
11881226
Vector<F> yLow = y.convert(upcast, 0);
11891227
Vector<F> yHigh = y.convert(upcast, 1);
1190-
Vector<E> resultLow = op.apply(xLow, yLow).convert(downcast, 0);
1191-
Vector<E> resultHigh = op.apply(xHigh, yHigh).convert(downcast, -1);
1192-
Vector<E> result = firstNonzero(resultLow, resultHigh);
1228+
Vector<E> resultLow = compact(op.apply(xLow, yLow), 0, extendedShape, shape, downcast, upcast);
1229+
Vector<E> resultHigh = compact(op.apply(xHigh, yHigh), -1, extendedShape, shape, downcast, upcast);
1230+
Vector<E> result = resultLow.lanewise(VectorOperators.OR, resultHigh);
11931231
return result.reinterpretAsBytes();
11941232
}
11951233

@@ -1210,13 +1248,6 @@ private static <E> VectorMask<E> odds(Shape<E> shape) {
12101248
return VectorMask.fromArray(shape.species(), ALTERNATING_BITS, 1);
12111249
}
12121250

1213-
private static <E> Vector<E> firstNonzero(Vector<E> x, Vector<E> y) {
1214-
// Use this definition instead of the FIRST_NONZERO operators, because the FIRST_NONZERO
1215-
// operator is not compatible with native image
1216-
VectorMask<?> mask = x.viewAsIntegralLanes().compare(VectorOperators.EQ, 0);
1217-
return x.blend(y, mask.cast(x.species()));
1218-
}
1219-
12201251
@Override
12211252
public ByteVector fromArray(byte[] bytes, int offset) {
12221253
return ByteVector.fromArray(I8X16.species(), bytes, offset);

0 commit comments

Comments
 (0)