diff --git a/clang/include/clang/Basic/arm_mve.td b/clang/include/clang/Basic/arm_mve.td index 52185ca07da41..1debb94a0a7b8 100644 --- a/clang/include/clang/Basic/arm_mve.td +++ b/clang/include/clang/Basic/arm_mve.td @@ -193,7 +193,7 @@ multiclass FMA { def sq_m_n: Intrinsic:$addend_s, Predicate:$pred), - (seq (splat $addend_s):$addend, pred_cg)>; + (select $pred, (seq (splat $addend_s):$addend, unpred_cg), $m1)>; } } diff --git a/clang/test/CodeGen/arm-mve-intrinsics/ternary.c b/clang/test/CodeGen/arm-mve-intrinsics/ternary.c index 36b2ce063cb18..768d397cb5611 100644 --- a/clang/test/CodeGen/arm-mve-intrinsics/ternary.c +++ b/clang/test/CodeGen/arm-mve-intrinsics/ternary.c @@ -542,12 +542,13 @@ float32x4_t test_vfmaq_m_n_f32(float32x4_t a, float32x4_t b, float32_t c, mve_pr // CHECK-LABEL: @test_vfmasq_m_n_f16( // CHECK-NEXT: entry: -// CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <8 x half> poison, half [[C:%.*]], i64 0 -// CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <8 x half> [[DOTSPLATINSERT]], <8 x half> poison, <8 x i32> zeroinitializer // CHECK-NEXT: [[TMP0:%.*]] = zext i16 [[P:%.*]] to i32 // CHECK-NEXT: [[TMP1:%.*]] = call <8 x i1> @llvm.arm.mve.pred.i2v.v8i1(i32 [[TMP0]]) -// CHECK-NEXT: [[TMP2:%.*]] = call <8 x half> @llvm.arm.mve.fma.predicated.v8f16.v8i1(<8 x half> [[A:%.*]], <8 x half> [[B:%.*]], <8 x half> [[DOTSPLAT]], <8 x i1> [[TMP1]]) -// CHECK-NEXT: ret <8 x half> [[TMP2]] +// CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <8 x half> poison, half [[C:%.*]], i64 0 +// CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <8 x half> [[DOTSPLATINSERT]], <8 x half> poison, <8 x i32> zeroinitializer +// CHECK-NEXT: [[TMP2:%.*]] = call <8 x half> @llvm.fma.v8f16(<8 x half> [[A:%.*]], <8 x half> [[B:%.*]], <8 x half> [[DOTSPLAT]]) +// CHECK-NEXT: [[TMP3:%.*]] = select <8 x i1> [[TMP1]], <8 x half> [[TMP2]], <8 x half> [[A]] +// CHECK-NEXT: ret <8 x half> [[TMP3]] // float16x8_t test_vfmasq_m_n_f16(float16x8_t a, float16x8_t b, float16_t c, mve_pred16_t p) { #ifdef POLYMORPHIC @@ -559,12 +560,13 @@ float16x8_t test_vfmasq_m_n_f16(float16x8_t a, float16x8_t b, float16_t c, mve_p // CHECK-LABEL: @test_vfmasq_m_n_f32( // CHECK-NEXT: entry: -// CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <4 x float> poison, float [[C:%.*]], i64 0 -// CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <4 x float> [[DOTSPLATINSERT]], <4 x float> poison, <4 x i32> zeroinitializer // CHECK-NEXT: [[TMP0:%.*]] = zext i16 [[P:%.*]] to i32 // CHECK-NEXT: [[TMP1:%.*]] = call <4 x i1> @llvm.arm.mve.pred.i2v.v4i1(i32 [[TMP0]]) -// CHECK-NEXT: [[TMP2:%.*]] = call <4 x float> @llvm.arm.mve.fma.predicated.v4f32.v4i1(<4 x float> [[A:%.*]], <4 x float> [[B:%.*]], <4 x float> [[DOTSPLAT]], <4 x i1> [[TMP1]]) -// CHECK-NEXT: ret <4 x float> [[TMP2]] +// CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <4 x float> poison, float [[C:%.*]], i64 0 +// CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <4 x float> [[DOTSPLATINSERT]], <4 x float> poison, <4 x i32> zeroinitializer +// CHECK-NEXT: [[TMP2:%.*]] = call <4 x float> @llvm.fma.v4f32(<4 x float> [[A:%.*]], <4 x float> [[B:%.*]], <4 x float> [[DOTSPLAT]]) +// CHECK-NEXT: [[TMP3:%.*]] = select <4 x i1> [[TMP1]], <4 x float> [[TMP2]], <4 x float> [[A]] +// CHECK-NEXT: ret <4 x float> [[TMP3]] // float32x4_t test_vfmasq_m_n_f32(float32x4_t a, float32x4_t b, float32_t c, mve_pred16_t p) { #ifdef POLYMORPHIC diff --git a/llvm/include/llvm/IR/IntrinsicsARM.td b/llvm/include/llvm/IR/IntrinsicsARM.td index 11b9877091a8e..b18d3fcc9e3f4 100644 --- a/llvm/include/llvm/IR/IntrinsicsARM.td +++ b/llvm/include/llvm/IR/IntrinsicsARM.td @@ -1362,6 +1362,7 @@ def int_arm_mve_vqmovn_predicated: DefaultAttrsIntrinsic<[llvm_anyvector_ty], llvm_i32_ty /* unsigned output */, llvm_i32_ty /* unsigned input */, llvm_i32_ty /* top half */, llvm_anyvector_ty /* pred */], [IntrNoMem]>; +// fma_predicated returns the add operand for disabled lanes. def int_arm_mve_fma_predicated: DefaultAttrsIntrinsic<[llvm_anyvector_ty], [LLVMMatchType<0> /* mult op #1 */, LLVMMatchType<0> /* mult op #2 */, LLVMMatchType<0> /* addend */, llvm_anyvector_ty /* pred */], [IntrNoMem]>; diff --git a/llvm/lib/Target/ARM/ARMInstrMVE.td b/llvm/lib/Target/ARM/ARMInstrMVE.td index 22af599f4f085..bdd0d739a0568 100644 --- a/llvm/lib/Target/ARM/ARMInstrMVE.td +++ b/llvm/lib/Target/ARM/ARMInstrMVE.td @@ -5614,8 +5614,6 @@ multiclass MVE_VFMA_qr_multi; - def : Pat<(VTI.Vec (pred_int v1, v2, vs, pred)), - (VTI.Vec (Inst v1, v2, is, ARMVCCThen, pred, zero_reg))>; } else { def : Pat<(VTI.Vec (fma v1, vs, v2)), (VTI.Vec (Inst v2, v1, is))>; diff --git a/llvm/test/CodeGen/Thumb2/mve-intrinsics/ternary.ll b/llvm/test/CodeGen/Thumb2/mve-intrinsics/ternary.ll index 4b1b8998ce73e..265452c18f81a 100644 --- a/llvm/test/CodeGen/Thumb2/mve-intrinsics/ternary.ll +++ b/llvm/test/CodeGen/Thumb2/mve-intrinsics/ternary.ll @@ -461,8 +461,10 @@ define arm_aapcs_vfpcc <8 x half> @test_vfmasq_m_n_f16(<8 x half> %a, <8 x half> ; CHECK: @ %bb.0: @ %entry ; CHECK-NEXT: vmov r1, s8 ; CHECK-NEXT: vmsr p0, r0 +; CHECK-NEXT: vdup.16 q2, r1 ; CHECK-NEXT: vpst -; CHECK-NEXT: vfmast.f16 q0, q1, r1 +; CHECK-NEXT: vfmat.f16 q2, q0, q1 +; CHECK-NEXT: vmov q0, q2 ; CHECK-NEXT: bx lr entry: %0 = bitcast float %c.coerce to i32 @@ -476,13 +478,36 @@ entry: ret <8 x half> %4 } +define arm_aapcs_vfpcc <8 x half> @test_vfmasq_m_n_f16_select(<8 x half> %a, <8 x half> %b, float %c.coerce, i16 zeroext %p) { +; CHECK-LABEL: test_vfmasq_m_n_f16_select: +; CHECK: @ %bb.0: @ %entry +; CHECK-NEXT: vmov r1, s8 +; CHECK-NEXT: vmsr p0, r0 +; CHECK-NEXT: vpst +; CHECK-NEXT: vfmast.f16 q0, q1, r1 +; CHECK-NEXT: bx lr +entry: + %0 = bitcast float %c.coerce to i32 + %tmp.0.extract.trunc = trunc i32 %0 to i16 + %1 = bitcast i16 %tmp.0.extract.trunc to half + %.splatinsert = insertelement <8 x half> undef, half %1, i32 0 + %.splat = shufflevector <8 x half> %.splatinsert, <8 x half> undef, <8 x i32> zeroinitializer + %2 = zext i16 %p to i32 + %3 = tail call <8 x i1> @llvm.arm.mve.pred.i2v.v8i1(i32 %2) + %4 = tail call <8 x half> @llvm.fma.v8f16(<8 x half> %a, <8 x half> %b, <8 x half> %.splat) + %5 = select <8 x i1> %3, <8 x half> %4, <8 x half> %a + ret <8 x half> %5 +} + define arm_aapcs_vfpcc <4 x float> @test_vfmasq_m_n_f32(<4 x float> %a, <4 x float> %b, float %c, i16 zeroext %p) { ; CHECK-LABEL: test_vfmasq_m_n_f32: ; CHECK: @ %bb.0: @ %entry ; CHECK-NEXT: vmov r1, s8 ; CHECK-NEXT: vmsr p0, r0 +; CHECK-NEXT: vdup.32 q2, r1 ; CHECK-NEXT: vpst -; CHECK-NEXT: vfmast.f32 q0, q1, r1 +; CHECK-NEXT: vfmat.f32 q2, q0, q1 +; CHECK-NEXT: vmov q0, q2 ; CHECK-NEXT: bx lr entry: %.splatinsert = insertelement <4 x float> undef, float %c, i32 0 @@ -493,6 +518,24 @@ entry: ret <4 x float> %2 } +define arm_aapcs_vfpcc <4 x float> @test_vfmasq_m_n_f32_select(<4 x float> %a, <4 x float> %b, float %c, i16 zeroext %p) { +; CHECK-LABEL: test_vfmasq_m_n_f32_select: +; CHECK: @ %bb.0: @ %entry +; CHECK-NEXT: vmov r1, s8 +; CHECK-NEXT: vmsr p0, r0 +; CHECK-NEXT: vpst +; CHECK-NEXT: vfmast.f32 q0, q1, r1 +; CHECK-NEXT: bx lr +entry: + %.splatinsert = insertelement <4 x float> undef, float %c, i32 0 + %.splat = shufflevector <4 x float> %.splatinsert, <4 x float> undef, <4 x i32> zeroinitializer + %0 = zext i16 %p to i32 + %1 = tail call <4 x i1> @llvm.arm.mve.pred.i2v.v4i1(i32 %0) + %2 = tail call <4 x float> @llvm.fma.v4f32(<4 x float> %a, <4 x float> %b, <4 x float> %.splat) + %3 = select <4 x i1> %1, <4 x float> %2, <4 x float> %a + ret <4 x float> %3 +} + define arm_aapcs_vfpcc <8 x half> @test_vfmsq_m_f16(<8 x half> %a, <8 x half> %b, <8 x half> %c, i16 zeroext %p) { ; CHECK-LABEL: test_vfmsq_m_f16: ; CHECK: @ %bb.0: @ %entry diff --git a/llvm/test/CodeGen/Thumb2/mve-qrintr.ll b/llvm/test/CodeGen/Thumb2/mve-qrintr.ll index 151e51fcf0c93..87cce804356de 100644 --- a/llvm/test/CodeGen/Thumb2/mve-qrintr.ll +++ b/llvm/test/CodeGen/Thumb2/mve-qrintr.ll @@ -1536,7 +1536,8 @@ while.body: ; preds = %while.body.lr.ph, % %0 = tail call <4 x i1> @llvm.arm.mve.vctp32(i32 %N.addr.013) %1 = tail call fast <4 x float> @llvm.masked.load.v4f32.p0(ptr %s1.addr.014, i32 4, <4 x i1> %0, <4 x float> zeroinitializer) %2 = tail call fast <4 x float> @llvm.masked.load.v4f32.p0(ptr %s2, i32 4, <4 x i1> %0, <4 x float> zeroinitializer) - %3 = tail call fast <4 x float> @llvm.arm.mve.fma.predicated.v4f32.v4i1(<4 x float> %1, <4 x float> %2, <4 x float> %.splat, <4 x i1> %0) + %3 = tail call fast <4 x float> @llvm.fma.v4f32(<4 x float> %1, <4 x float> %2, <4 x float> %.splat) + %4 = select <4 x i1> %0, <4 x float> %3, <4 x float> %1 tail call void @llvm.masked.store.v4f32.p0(<4 x float> %3, ptr %s1.addr.014, i32 4, <4 x i1> %0) %add.ptr = getelementptr inbounds float, ptr %s1.addr.014, i32 4 %sub = add nsw i32 %N.addr.013, -4