Skip to content

Conversation

@ostannard
Copy link
Collaborator

For most MVE predicated FMA instructions, disabled lanes will contain the value in the addend operand. However, The VFMAS instruction takes the addend in a GPR, and the output register is shared with the first multiply operand, so disabled lanes will get that value instead. This means that we can't use the same intrinsic as for the other VFMA instructions. Instead, we can codegen the vfmas intrinsic to a regular FMA and select in clang, which the backend already has the patterns to select VFMAS from.

For most MVE predicated FMA instructions, disabled lanes will contain
the value in the addend operand. However, The VFMAS instruction takes
the addend in a GPR, and the output register is shared with the first
multiply operand, so disabled lanes will get that value instead. This
means that we can't use the same intrinsic as for the other VFMA
instructions. Instead, we can codegen the vfmas intrinsic to a regular
FMA and select in clang, which the backend already has the patterns to
select VFMAS from.
@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" llvm:ir labels Nov 12, 2024
@llvmbot
Copy link
Member

llvmbot commented Nov 12, 2024

@llvm/pr-subscribers-llvm-ir
@llvm/pr-subscribers-clang

@llvm/pr-subscribers-backend-arm

Author: Oliver Stannard (ostannard)

Changes

For most MVE predicated FMA instructions, disabled lanes will contain the value in the addend operand. However, The VFMAS instruction takes the addend in a GPR, and the output register is shared with the first multiply operand, so disabled lanes will get that value instead. This means that we can't use the same intrinsic as for the other VFMA instructions. Instead, we can codegen the vfmas intrinsic to a regular FMA and select in clang, which the backend already has the patterns to select VFMAS from.


Full diff: https://github.com/llvm/llvm-project/pull/115908.diff

6 Files Affected:

  • (modified) clang/include/clang/Basic/arm_mve.td (+1-1)
  • (modified) clang/test/CodeGen/arm-mve-intrinsics/ternary.c (+10-8)
  • (modified) llvm/include/llvm/IR/IntrinsicsARM.td (+1)
  • (modified) llvm/lib/Target/ARM/ARMInstrMVE.td (-2)
  • (modified) llvm/test/CodeGen/Thumb2/mve-intrinsics/ternary.ll (+45-2)
  • (modified) llvm/test/CodeGen/Thumb2/mve-qrintr.ll (+2-1)
diff --git a/clang/include/clang/Basic/arm_mve.td b/clang/include/clang/Basic/arm_mve.td
index 52185ca07da41f..1debb94a0a7b81 100644
--- a/clang/include/clang/Basic/arm_mve.td
+++ b/clang/include/clang/Basic/arm_mve.td
@@ -193,7 +193,7 @@ multiclass FMA<bit add> {
     def sq_m_n: Intrinsic<Vector, (args Vector:$m1, Vector:$m2,
                                         unpromoted<Scalar>:$addend_s,
                                         Predicate:$pred),
-                          (seq (splat $addend_s):$addend, pred_cg)>;
+                          (select $pred, (seq (splat $addend_s):$addend, unpred_cg), $m1)>;
   }
 }
 
diff --git a/clang/test/CodeGen/arm-mve-intrinsics/ternary.c b/clang/test/CodeGen/arm-mve-intrinsics/ternary.c
index 36b2ce063cb188..768d397cb5611f 100644
--- a/clang/test/CodeGen/arm-mve-intrinsics/ternary.c
+++ b/clang/test/CodeGen/arm-mve-intrinsics/ternary.c
@@ -542,12 +542,13 @@ float32x4_t test_vfmaq_m_n_f32(float32x4_t a, float32x4_t b, float32_t c, mve_pr
 
 // CHECK-LABEL: @test_vfmasq_m_n_f16(
 // CHECK-NEXT:  entry:
-// CHECK-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <8 x half> poison, half [[C:%.*]], i64 0
-// CHECK-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <8 x half> [[DOTSPLATINSERT]], <8 x half> poison, <8 x i32> zeroinitializer
 // CHECK-NEXT:    [[TMP0:%.*]] = zext i16 [[P:%.*]] to i32
 // CHECK-NEXT:    [[TMP1:%.*]] = call <8 x i1> @llvm.arm.mve.pred.i2v.v8i1(i32 [[TMP0]])
-// CHECK-NEXT:    [[TMP2:%.*]] = call <8 x half> @llvm.arm.mve.fma.predicated.v8f16.v8i1(<8 x half> [[A:%.*]], <8 x half> [[B:%.*]], <8 x half> [[DOTSPLAT]], <8 x i1> [[TMP1]])
-// CHECK-NEXT:    ret <8 x half> [[TMP2]]
+// CHECK-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <8 x half> poison, half [[C:%.*]], i64 0
+// CHECK-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <8 x half> [[DOTSPLATINSERT]], <8 x half> poison, <8 x i32> zeroinitializer
+// CHECK-NEXT:    [[TMP2:%.*]] = call <8 x half> @llvm.fma.v8f16(<8 x half> [[A:%.*]], <8 x half> [[B:%.*]], <8 x half> [[DOTSPLAT]])
+// CHECK-NEXT:    [[TMP3:%.*]] = select <8 x i1> [[TMP1]], <8 x half> [[TMP2]], <8 x half> [[A]]
+// CHECK-NEXT:    ret <8 x half> [[TMP3]]
 //
 float16x8_t test_vfmasq_m_n_f16(float16x8_t a, float16x8_t b, float16_t c, mve_pred16_t p) {
 #ifdef POLYMORPHIC
@@ -559,12 +560,13 @@ float16x8_t test_vfmasq_m_n_f16(float16x8_t a, float16x8_t b, float16_t c, mve_p
 
 // CHECK-LABEL: @test_vfmasq_m_n_f32(
 // CHECK-NEXT:  entry:
-// CHECK-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <4 x float> poison, float [[C:%.*]], i64 0
-// CHECK-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <4 x float> [[DOTSPLATINSERT]], <4 x float> poison, <4 x i32> zeroinitializer
 // CHECK-NEXT:    [[TMP0:%.*]] = zext i16 [[P:%.*]] to i32
 // CHECK-NEXT:    [[TMP1:%.*]] = call <4 x i1> @llvm.arm.mve.pred.i2v.v4i1(i32 [[TMP0]])
-// CHECK-NEXT:    [[TMP2:%.*]] = call <4 x float> @llvm.arm.mve.fma.predicated.v4f32.v4i1(<4 x float> [[A:%.*]], <4 x float> [[B:%.*]], <4 x float> [[DOTSPLAT]], <4 x i1> [[TMP1]])
-// CHECK-NEXT:    ret <4 x float> [[TMP2]]
+// CHECK-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <4 x float> poison, float [[C:%.*]], i64 0
+// CHECK-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <4 x float> [[DOTSPLATINSERT]], <4 x float> poison, <4 x i32> zeroinitializer
+// CHECK-NEXT:    [[TMP2:%.*]] = call <4 x float> @llvm.fma.v4f32(<4 x float> [[A:%.*]], <4 x float> [[B:%.*]], <4 x float> [[DOTSPLAT]])
+// CHECK-NEXT:    [[TMP3:%.*]] = select <4 x i1> [[TMP1]], <4 x float> [[TMP2]], <4 x float> [[A]]
+// CHECK-NEXT:    ret <4 x float> [[TMP3]]
 //
 float32x4_t test_vfmasq_m_n_f32(float32x4_t a, float32x4_t b, float32_t c, mve_pred16_t p) {
 #ifdef POLYMORPHIC
diff --git a/llvm/include/llvm/IR/IntrinsicsARM.td b/llvm/include/llvm/IR/IntrinsicsARM.td
index 11b9877091a8ed..b18d3fcc9e3f44 100644
--- a/llvm/include/llvm/IR/IntrinsicsARM.td
+++ b/llvm/include/llvm/IR/IntrinsicsARM.td
@@ -1362,6 +1362,7 @@ def int_arm_mve_vqmovn_predicated: DefaultAttrsIntrinsic<[llvm_anyvector_ty],
     llvm_i32_ty /* unsigned output */, llvm_i32_ty /* unsigned input */,
     llvm_i32_ty /* top half */, llvm_anyvector_ty /* pred */], [IntrNoMem]>;
 
+// fma_predicated returns the add operand for disabled lanes.
 def int_arm_mve_fma_predicated: DefaultAttrsIntrinsic<[llvm_anyvector_ty],
    [LLVMMatchType<0> /* mult op #1 */, LLVMMatchType<0> /* mult op #2 */,
     LLVMMatchType<0> /* addend */, llvm_anyvector_ty /* pred */], [IntrNoMem]>;
diff --git a/llvm/lib/Target/ARM/ARMInstrMVE.td b/llvm/lib/Target/ARM/ARMInstrMVE.td
index 22af599f4f0859..bdd0d739a05684 100644
--- a/llvm/lib/Target/ARM/ARMInstrMVE.td
+++ b/llvm/lib/Target/ARM/ARMInstrMVE.td
@@ -5614,8 +5614,6 @@ multiclass MVE_VFMA_qr_multi<string iname, MVEVectorVTInfo VTI,
                                   (VTI.Vec (fma v1, v2, vs)),
                                   v1)),
                 (VTI.Vec (Inst v1, v2, is, ARMVCCThen, $pred, zero_reg))>;
-      def : Pat<(VTI.Vec (pred_int v1, v2, vs, pred)),
-                (VTI.Vec (Inst v1, v2, is, ARMVCCThen, pred, zero_reg))>;
     } else {
       def : Pat<(VTI.Vec (fma v1, vs, v2)),
                 (VTI.Vec (Inst v2, v1, is))>;
diff --git a/llvm/test/CodeGen/Thumb2/mve-intrinsics/ternary.ll b/llvm/test/CodeGen/Thumb2/mve-intrinsics/ternary.ll
index 4b1b8998ce73eb..265452c18f81a4 100644
--- a/llvm/test/CodeGen/Thumb2/mve-intrinsics/ternary.ll
+++ b/llvm/test/CodeGen/Thumb2/mve-intrinsics/ternary.ll
@@ -461,8 +461,10 @@ define arm_aapcs_vfpcc <8 x half> @test_vfmasq_m_n_f16(<8 x half> %a, <8 x half>
 ; CHECK:       @ %bb.0: @ %entry
 ; CHECK-NEXT:    vmov r1, s8
 ; CHECK-NEXT:    vmsr p0, r0
+; CHECK-NEXT:    vdup.16 q2, r1
 ; CHECK-NEXT:    vpst
-; CHECK-NEXT:    vfmast.f16 q0, q1, r1
+; CHECK-NEXT:    vfmat.f16 q2, q0, q1
+; CHECK-NEXT:    vmov q0, q2
 ; CHECK-NEXT:    bx lr
 entry:
   %0 = bitcast float %c.coerce to i32
@@ -476,13 +478,36 @@ entry:
   ret <8 x half> %4
 }
 
+define arm_aapcs_vfpcc <8 x half> @test_vfmasq_m_n_f16_select(<8 x half> %a, <8 x half> %b, float %c.coerce, i16 zeroext %p) {
+; CHECK-LABEL: test_vfmasq_m_n_f16_select:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vmov r1, s8
+; CHECK-NEXT:    vmsr p0, r0
+; CHECK-NEXT:    vpst
+; CHECK-NEXT:    vfmast.f16 q0, q1, r1
+; CHECK-NEXT:    bx lr
+entry:
+  %0 = bitcast float %c.coerce to i32
+  %tmp.0.extract.trunc = trunc i32 %0 to i16
+  %1 = bitcast i16 %tmp.0.extract.trunc to half
+  %.splatinsert = insertelement <8 x half> undef, half %1, i32 0
+  %.splat = shufflevector <8 x half> %.splatinsert, <8 x half> undef, <8 x i32> zeroinitializer
+  %2 = zext i16 %p to i32
+  %3 = tail call <8 x i1> @llvm.arm.mve.pred.i2v.v8i1(i32 %2)
+  %4 = tail call <8 x half> @llvm.fma.v8f16(<8 x half> %a, <8 x half> %b, <8 x half> %.splat)
+  %5 = select <8 x i1> %3, <8 x half> %4, <8 x half> %a
+  ret <8 x half> %5
+}
+
 define arm_aapcs_vfpcc <4 x float> @test_vfmasq_m_n_f32(<4 x float> %a, <4 x float> %b, float %c, i16 zeroext %p) {
 ; CHECK-LABEL: test_vfmasq_m_n_f32:
 ; CHECK:       @ %bb.0: @ %entry
 ; CHECK-NEXT:    vmov r1, s8
 ; CHECK-NEXT:    vmsr p0, r0
+; CHECK-NEXT:    vdup.32 q2, r1
 ; CHECK-NEXT:    vpst
-; CHECK-NEXT:    vfmast.f32 q0, q1, r1
+; CHECK-NEXT:    vfmat.f32 q2, q0, q1
+; CHECK-NEXT:    vmov q0, q2
 ; CHECK-NEXT:    bx lr
 entry:
   %.splatinsert = insertelement <4 x float> undef, float %c, i32 0
@@ -493,6 +518,24 @@ entry:
   ret <4 x float> %2
 }
 
+define arm_aapcs_vfpcc <4 x float> @test_vfmasq_m_n_f32_select(<4 x float> %a, <4 x float> %b, float %c, i16 zeroext %p) {
+; CHECK-LABEL: test_vfmasq_m_n_f32_select:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vmov r1, s8
+; CHECK-NEXT:    vmsr p0, r0
+; CHECK-NEXT:    vpst
+; CHECK-NEXT:    vfmast.f32 q0, q1, r1
+; CHECK-NEXT:    bx lr
+entry:
+  %.splatinsert = insertelement <4 x float> undef, float %c, i32 0
+  %.splat = shufflevector <4 x float> %.splatinsert, <4 x float> undef, <4 x i32> zeroinitializer
+  %0 = zext i16 %p to i32
+  %1 = tail call <4 x i1> @llvm.arm.mve.pred.i2v.v4i1(i32 %0)
+  %2 = tail call <4 x float> @llvm.fma.v4f32(<4 x float> %a, <4 x float> %b, <4 x float> %.splat)
+  %3 = select <4 x i1> %1, <4 x float> %2, <4 x float> %a
+  ret <4 x float> %3
+}
+
 define arm_aapcs_vfpcc <8 x half> @test_vfmsq_m_f16(<8 x half> %a, <8 x half> %b, <8 x half> %c, i16 zeroext %p) {
 ; CHECK-LABEL: test_vfmsq_m_f16:
 ; CHECK:       @ %bb.0: @ %entry
diff --git a/llvm/test/CodeGen/Thumb2/mve-qrintr.ll b/llvm/test/CodeGen/Thumb2/mve-qrintr.ll
index 151e51fcf0c93d..87cce804356dee 100644
--- a/llvm/test/CodeGen/Thumb2/mve-qrintr.ll
+++ b/llvm/test/CodeGen/Thumb2/mve-qrintr.ll
@@ -1536,7 +1536,8 @@ while.body:                                       ; preds = %while.body.lr.ph, %
   %0 = tail call <4 x i1> @llvm.arm.mve.vctp32(i32 %N.addr.013)
   %1 = tail call fast <4 x float> @llvm.masked.load.v4f32.p0(ptr %s1.addr.014, i32 4, <4 x i1> %0, <4 x float> zeroinitializer)
   %2 = tail call fast <4 x float> @llvm.masked.load.v4f32.p0(ptr %s2, i32 4, <4 x i1> %0, <4 x float> zeroinitializer)
-  %3 = tail call fast <4 x float> @llvm.arm.mve.fma.predicated.v4f32.v4i1(<4 x float> %1, <4 x float> %2, <4 x float> %.splat, <4 x i1> %0)
+  %3 = tail call fast <4 x float> @llvm.fma.v4f32(<4 x float> %1, <4 x float> %2, <4 x float> %.splat)
+  %4 = select <4 x i1> %0, <4 x float> %3, <4 x float> %1
   tail call void @llvm.masked.store.v4f32.p0(<4 x float> %3, ptr %s1.addr.014, i32 4, <4 x i1> %0)
   %add.ptr = getelementptr inbounds float, ptr %s1.addr.014, i32 4
   %sub = add nsw i32 %N.addr.013, -4

Copy link
Collaborator

@statham-arm statham-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I suppose the error wouldn't have shown up if you did the simplest possible end-to-end test along the lines of

float32x4_t foo(float32x4_t a, float32x4_t b, float32_t c, mve_pred16_t p) {
  return vfmasq_m(a, b, c, p);
}

but only shows up when more complicated register allocation starts happening around it!

@ostannard
Copy link
Collaborator Author

The pre-commit CI failure is an unrelated LLDB test failure, so I'm ignoring it.

@ostannard ostannard merged commit aba5580 into llvm:main Nov 13, 2024
12 of 14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

backend:ARM clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category llvm:ir miscompilation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants