From 50ad4431ab0e430781baa9739b469432719269ae Mon Sep 17 00:00:00 2001 From: Brendan Dahl Date: Tue, 4 Mar 2025 19:25:35 +0000 Subject: [PATCH] [WebAssembly] Support promoting lower lanes of f16x8 to f32x4. --- clang/lib/Headers/wasm_simd128.h | 9 +++ .../intrinsic-header-tests/wasm_simd128.c | 6 ++ .../WebAssembly/WebAssemblyISelLowering.cpp | 55 ++++++++++++++----- .../WebAssembly/WebAssemblyInstrSIMD.td | 2 + .../CodeGen/WebAssembly/half-precision.ll | 20 +++++++ llvm/test/MC/WebAssembly/simd-encodings.s | 3 + 6 files changed, 80 insertions(+), 15 deletions(-) diff --git a/clang/lib/Headers/wasm_simd128.h b/clang/lib/Headers/wasm_simd128.h index 08e39bf1a79b4..1e6da28c052a0 100644 --- a/clang/lib/Headers/wasm_simd128.h +++ b/clang/lib/Headers/wasm_simd128.h @@ -45,6 +45,7 @@ typedef int __i32x2 __attribute__((__vector_size__(8), __aligned__(8))); typedef unsigned int __u32x2 __attribute__((__vector_size__(8), __aligned__(8))); typedef float __f32x2 __attribute__((__vector_size__(8), __aligned__(8))); +typedef __fp16 __f16x4 __attribute__((__vector_size__(8), __aligned__(8))); #define __DEFAULT_FN_ATTRS \ __attribute__((__always_inline__, __nodebug__, __target__("simd128"), \ @@ -2010,6 +2011,14 @@ static __inline__ v128_t __FP16_FN_ATTRS wasm_f16x8_convert_u16x8(v128_t __a) { return (v128_t) __builtin_convertvector((__u16x8)__a, __f16x8); } +static __inline__ v128_t __FP16_FN_ATTRS +wasm_f32x4_promote_low_f16x8(v128_t __a) { + return (v128_t) __builtin_convertvector( + (__f16x4){((__f16x8)__a)[0], ((__f16x8)__a)[1], ((__f16x8)__a)[2], + ((__f16x8)__a)[3]}, + __f32x4); +} + static __inline__ v128_t __FP16_FN_ATTRS wasm_f16x8_relaxed_madd(v128_t __a, v128_t __b, v128_t __c) { diff --git a/cross-project-tests/intrinsic-header-tests/wasm_simd128.c b/cross-project-tests/intrinsic-header-tests/wasm_simd128.c index b601d90cfcc92..1f4809483589e 100644 --- a/cross-project-tests/intrinsic-header-tests/wasm_simd128.c +++ b/cross-project-tests/intrinsic-header-tests/wasm_simd128.c @@ -1033,6 +1033,12 @@ v128_t test_f64x2_promote_low_f32x4(v128_t a) { return wasm_f64x2_promote_low_f32x4(a); } +// CHECK-LABEL: test_f32x4_promote_low_f16x8: +// CHECK: f32x4.promote_low_f16x8{{$}} +v128_t test_f32x4_promote_low_f16x8(v128_t a) { + return wasm_f32x4_promote_low_f16x8(a); +} + // CHECK-LABEL: test_i8x16_shuffle: // CHECK: i8x16.shuffle 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, // 0{{$}} diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp index b24a45c2d8898..9e4621e654347 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -2341,7 +2341,7 @@ WebAssemblyTargetLowering::LowerEXTEND_VECTOR_INREG(SDValue Op, static SDValue LowerConvertLow(SDValue Op, SelectionDAG &DAG) { SDLoc DL(Op); - if (Op.getValueType() != MVT::v2f64) + if (Op.getValueType() != MVT::v2f64 && Op.getValueType() != MVT::v4f32) return SDValue(); auto GetConvertedLane = [](SDValue Op, unsigned &Opcode, SDValue &SrcVec, @@ -2354,6 +2354,7 @@ static SDValue LowerConvertLow(SDValue Op, SelectionDAG &DAG) { Opcode = WebAssemblyISD::CONVERT_LOW_U; break; case ISD::FP_EXTEND: + case ISD::FP16_TO_FP: Opcode = WebAssemblyISD::PROMOTE_LOW; break; default: @@ -2372,36 +2373,60 @@ static SDValue LowerConvertLow(SDValue Op, SelectionDAG &DAG) { return true; }; - unsigned LHSOpcode, RHSOpcode, LHSIndex, RHSIndex; - SDValue LHSSrcVec, RHSSrcVec; - if (!GetConvertedLane(Op.getOperand(0), LHSOpcode, LHSSrcVec, LHSIndex) || - !GetConvertedLane(Op.getOperand(1), RHSOpcode, RHSSrcVec, RHSIndex)) + unsigned NumLanes = Op.getValueType() == MVT::v2f64 ? 2 : 4; + unsigned FirstOpcode = 0, SecondOpcode = 0, ThirdOpcode = 0, FourthOpcode = 0; + unsigned FirstIndex = 0, SecondIndex = 0, ThirdIndex = 0, FourthIndex = 0; + SDValue FirstSrcVec, SecondSrcVec, ThirdSrcVec, FourthSrcVec; + + if (!GetConvertedLane(Op.getOperand(0), FirstOpcode, FirstSrcVec, + FirstIndex) || + !GetConvertedLane(Op.getOperand(1), SecondOpcode, SecondSrcVec, + SecondIndex)) + return SDValue(); + + // If we're converting to v4f32, check the third and fourth lanes, too. + if (NumLanes == 4 && (!GetConvertedLane(Op.getOperand(2), ThirdOpcode, + ThirdSrcVec, ThirdIndex) || + !GetConvertedLane(Op.getOperand(3), FourthOpcode, + FourthSrcVec, FourthIndex))) + return SDValue(); + + if (FirstOpcode != SecondOpcode) return SDValue(); - if (LHSOpcode != RHSOpcode) + // TODO Add an optimization similar to the v2f64 below for shuffling the + // vectors when the lanes are in the wrong order or come from different src + // vectors. + if (NumLanes == 4 && + (FirstOpcode != ThirdOpcode || FirstOpcode != FourthOpcode || + FirstSrcVec != SecondSrcVec || FirstSrcVec != ThirdSrcVec || + FirstSrcVec != FourthSrcVec || FirstIndex != 0 || SecondIndex != 1 || + ThirdIndex != 2 || FourthIndex != 3)) return SDValue(); MVT ExpectedSrcVT; - switch (LHSOpcode) { + switch (FirstOpcode) { case WebAssemblyISD::CONVERT_LOW_S: case WebAssemblyISD::CONVERT_LOW_U: ExpectedSrcVT = MVT::v4i32; break; case WebAssemblyISD::PROMOTE_LOW: - ExpectedSrcVT = MVT::v4f32; + ExpectedSrcVT = NumLanes == 2 ? MVT::v4f32 : MVT::v8i16; break; } - if (LHSSrcVec.getValueType() != ExpectedSrcVT) + if (FirstSrcVec.getValueType() != ExpectedSrcVT) return SDValue(); - auto Src = LHSSrcVec; - if (LHSIndex != 0 || RHSIndex != 1 || LHSSrcVec != RHSSrcVec) { + auto Src = FirstSrcVec; + if (NumLanes == 2 && + (FirstIndex != 0 || SecondIndex != 1 || FirstSrcVec != SecondSrcVec)) { // Shuffle the source vector so that the converted lanes are the low lanes. - Src = DAG.getVectorShuffle( - ExpectedSrcVT, DL, LHSSrcVec, RHSSrcVec, - {static_cast(LHSIndex), static_cast(RHSIndex) + 4, -1, -1}); + Src = DAG.getVectorShuffle(ExpectedSrcVT, DL, FirstSrcVec, SecondSrcVec, + {static_cast(FirstIndex), + static_cast(SecondIndex) + 4, -1, -1}); } - return DAG.getNode(LHSOpcode, DL, MVT::v2f64, Src); + return DAG.getNode(FirstOpcode, DL, NumLanes == 2 ? MVT::v2f64 : MVT::v4f32, + Src); } SDValue WebAssemblyTargetLowering::LowerBUILD_VECTOR(SDValue Op, diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td index c591e5ef181a4..cb3008899f473 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td +++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td @@ -1468,6 +1468,8 @@ defm "" : SIMDConvert, SDTCisVec<1>]>; def promote_low : SDNode<"WebAssemblyISD::PROMOTE_LOW", promote_t>; defm "" : SIMDConvert; +defm "" : HalfPrecisionConvert; // Lower extending loads to load64_zero + promote_low def extloadv2f32 : PatFrag<(ops node:$ptr), (extload node:$ptr)> { diff --git a/llvm/test/CodeGen/WebAssembly/half-precision.ll b/llvm/test/CodeGen/WebAssembly/half-precision.ll index 4e8ff5955c63b..837e4db83110d 100644 --- a/llvm/test/CodeGen/WebAssembly/half-precision.ll +++ b/llvm/test/CodeGen/WebAssembly/half-precision.ll @@ -369,3 +369,23 @@ define <8 x half> @shuffle_poison_v8f16(<8 x half> %x, <8 x half> %y) { i32 poison, i32 poison, i32 poison, i32 poison> ret <8 x half> %res } + +define <4 x float> @promote_low_v4f32(<8 x half> %x) { +; CHECK-LABEL: promote_low_v4f32: +; CHECK: .functype promote_low_v4f32 (v128) -> (v128){{$}} +; CHECK-NEXT: f32x4.promote_low_f16x8 $push[[R:[0-9]+]]=, $0 +; CHECK-NEXT: return $pop[[R]] + %v = shufflevector <8 x half> %x, <8 x half> poison, <4 x i32> + %a = fpext <4 x half> %v to <4 x float> + ret <4 x float> %a +} + +define <4 x float> @promote_low_v4f32_2(<8 x half> %x) { +; CHECK-LABEL: promote_low_v4f32_2: +; CHECK: .functype promote_low_v4f32_2 (v128) -> (v128) +; CHECK-NEXT: f32x4.promote_low_f16x8 $push[[R:[0-9]+]]=, $0 +; CHECK-NEXT: return $pop[[R]] + %v = fpext <8 x half> %x to <8 x float> + %a = shufflevector <8 x float> %v, <8 x float> poison, <4 x i32> + ret <4 x float> %a +} diff --git a/llvm/test/MC/WebAssembly/simd-encodings.s b/llvm/test/MC/WebAssembly/simd-encodings.s index 48aec4bc52a0c..57af1daad0226 100644 --- a/llvm/test/MC/WebAssembly/simd-encodings.s +++ b/llvm/test/MC/WebAssembly/simd-encodings.s @@ -935,4 +935,7 @@ main: # CHECK: f16x8.convert_i16x8_u # encoding: [0xfd,0xc8,0x02] f16x8.convert_i16x8_u + # CHECK: f32x4.promote_low_f16x8 # encoding: [0xfd,0xcb,0x02] + f32x4.promote_low_f16x8 + end_function