Skip to content

Commit aba5580

Browse files
authored
[ARM] Fix operand order for MVE predicated VFMAS (#115908)
For most MVE predicated FMA instructions, disabled lanes will contain the value in the addend operand. However, The VFMAS instruction takes the addend in a GPR, and the output register is shared with the first multiply operand, so disabled lanes will get that value instead. This means that we can't use the same intrinsic as for the other VFMA instructions. Instead, we can codegen the vfmas intrinsic to a regular FMA and select in clang, which the backend already has the patterns to select VFMAS from.
1 parent 856c47b commit aba5580

File tree

6 files changed

+59
-14
lines changed

6 files changed

+59
-14
lines changed

clang/include/clang/Basic/arm_mve.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ multiclass FMA<bit add> {
193193
def sq_m_n: Intrinsic<Vector, (args Vector:$m1, Vector:$m2,
194194
unpromoted<Scalar>:$addend_s,
195195
Predicate:$pred),
196-
(seq (splat $addend_s):$addend, pred_cg)>;
196+
(select $pred, (seq (splat $addend_s):$addend, unpred_cg), $m1)>;
197197
}
198198
}
199199

clang/test/CodeGen/arm-mve-intrinsics/ternary.c

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -542,12 +542,13 @@ float32x4_t test_vfmaq_m_n_f32(float32x4_t a, float32x4_t b, float32_t c, mve_pr
542542

543543
// CHECK-LABEL: @test_vfmasq_m_n_f16(
544544
// CHECK-NEXT: entry:
545-
// CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <8 x half> poison, half [[C:%.*]], i64 0
546-
// CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <8 x half> [[DOTSPLATINSERT]], <8 x half> poison, <8 x i32> zeroinitializer
547545
// CHECK-NEXT: [[TMP0:%.*]] = zext i16 [[P:%.*]] to i32
548546
// CHECK-NEXT: [[TMP1:%.*]] = call <8 x i1> @llvm.arm.mve.pred.i2v.v8i1(i32 [[TMP0]])
549-
// CHECK-NEXT: [[TMP2:%.*]] = call <8 x half> @llvm.arm.mve.fma.predicated.v8f16.v8i1(<8 x half> [[A:%.*]], <8 x half> [[B:%.*]], <8 x half> [[DOTSPLAT]], <8 x i1> [[TMP1]])
550-
// CHECK-NEXT: ret <8 x half> [[TMP2]]
547+
// CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <8 x half> poison, half [[C:%.*]], i64 0
548+
// CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <8 x half> [[DOTSPLATINSERT]], <8 x half> poison, <8 x i32> zeroinitializer
549+
// CHECK-NEXT: [[TMP2:%.*]] = call <8 x half> @llvm.fma.v8f16(<8 x half> [[A:%.*]], <8 x half> [[B:%.*]], <8 x half> [[DOTSPLAT]])
550+
// CHECK-NEXT: [[TMP3:%.*]] = select <8 x i1> [[TMP1]], <8 x half> [[TMP2]], <8 x half> [[A]]
551+
// CHECK-NEXT: ret <8 x half> [[TMP3]]
551552
//
552553
float16x8_t test_vfmasq_m_n_f16(float16x8_t a, float16x8_t b, float16_t c, mve_pred16_t p) {
553554
#ifdef POLYMORPHIC
@@ -559,12 +560,13 @@ float16x8_t test_vfmasq_m_n_f16(float16x8_t a, float16x8_t b, float16_t c, mve_p
559560

560561
// CHECK-LABEL: @test_vfmasq_m_n_f32(
561562
// CHECK-NEXT: entry:
562-
// CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <4 x float> poison, float [[C:%.*]], i64 0
563-
// CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <4 x float> [[DOTSPLATINSERT]], <4 x float> poison, <4 x i32> zeroinitializer
564563
// CHECK-NEXT: [[TMP0:%.*]] = zext i16 [[P:%.*]] to i32
565564
// CHECK-NEXT: [[TMP1:%.*]] = call <4 x i1> @llvm.arm.mve.pred.i2v.v4i1(i32 [[TMP0]])
566-
// CHECK-NEXT: [[TMP2:%.*]] = call <4 x float> @llvm.arm.mve.fma.predicated.v4f32.v4i1(<4 x float> [[A:%.*]], <4 x float> [[B:%.*]], <4 x float> [[DOTSPLAT]], <4 x i1> [[TMP1]])
567-
// CHECK-NEXT: ret <4 x float> [[TMP2]]
565+
// CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <4 x float> poison, float [[C:%.*]], i64 0
566+
// CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <4 x float> [[DOTSPLATINSERT]], <4 x float> poison, <4 x i32> zeroinitializer
567+
// CHECK-NEXT: [[TMP2:%.*]] = call <4 x float> @llvm.fma.v4f32(<4 x float> [[A:%.*]], <4 x float> [[B:%.*]], <4 x float> [[DOTSPLAT]])
568+
// CHECK-NEXT: [[TMP3:%.*]] = select <4 x i1> [[TMP1]], <4 x float> [[TMP2]], <4 x float> [[A]]
569+
// CHECK-NEXT: ret <4 x float> [[TMP3]]
568570
//
569571
float32x4_t test_vfmasq_m_n_f32(float32x4_t a, float32x4_t b, float32_t c, mve_pred16_t p) {
570572
#ifdef POLYMORPHIC

llvm/include/llvm/IR/IntrinsicsARM.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1362,6 +1362,7 @@ def int_arm_mve_vqmovn_predicated: DefaultAttrsIntrinsic<[llvm_anyvector_ty],
13621362
llvm_i32_ty /* unsigned output */, llvm_i32_ty /* unsigned input */,
13631363
llvm_i32_ty /* top half */, llvm_anyvector_ty /* pred */], [IntrNoMem]>;
13641364

1365+
// fma_predicated returns the add operand for disabled lanes.
13651366
def int_arm_mve_fma_predicated: DefaultAttrsIntrinsic<[llvm_anyvector_ty],
13661367
[LLVMMatchType<0> /* mult op #1 */, LLVMMatchType<0> /* mult op #2 */,
13671368
LLVMMatchType<0> /* addend */, llvm_anyvector_ty /* pred */], [IntrNoMem]>;

llvm/lib/Target/ARM/ARMInstrMVE.td

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5614,8 +5614,6 @@ multiclass MVE_VFMA_qr_multi<string iname, MVEVectorVTInfo VTI,
56145614
(VTI.Vec (fma v1, v2, vs)),
56155615
v1)),
56165616
(VTI.Vec (Inst v1, v2, is, ARMVCCThen, $pred, zero_reg))>;
5617-
def : Pat<(VTI.Vec (pred_int v1, v2, vs, pred)),
5618-
(VTI.Vec (Inst v1, v2, is, ARMVCCThen, pred, zero_reg))>;
56195617
} else {
56205618
def : Pat<(VTI.Vec (fma v1, vs, v2)),
56215619
(VTI.Vec (Inst v2, v1, is))>;

llvm/test/CodeGen/Thumb2/mve-intrinsics/ternary.ll

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -461,8 +461,10 @@ define arm_aapcs_vfpcc <8 x half> @test_vfmasq_m_n_f16(<8 x half> %a, <8 x half>
461461
; CHECK: @ %bb.0: @ %entry
462462
; CHECK-NEXT: vmov r1, s8
463463
; CHECK-NEXT: vmsr p0, r0
464+
; CHECK-NEXT: vdup.16 q2, r1
464465
; CHECK-NEXT: vpst
465-
; CHECK-NEXT: vfmast.f16 q0, q1, r1
466+
; CHECK-NEXT: vfmat.f16 q2, q0, q1
467+
; CHECK-NEXT: vmov q0, q2
466468
; CHECK-NEXT: bx lr
467469
entry:
468470
%0 = bitcast float %c.coerce to i32
@@ -476,13 +478,36 @@ entry:
476478
ret <8 x half> %4
477479
}
478480

481+
define arm_aapcs_vfpcc <8 x half> @test_vfmasq_m_n_f16_select(<8 x half> %a, <8 x half> %b, float %c.coerce, i16 zeroext %p) {
482+
; CHECK-LABEL: test_vfmasq_m_n_f16_select:
483+
; CHECK: @ %bb.0: @ %entry
484+
; CHECK-NEXT: vmov r1, s8
485+
; CHECK-NEXT: vmsr p0, r0
486+
; CHECK-NEXT: vpst
487+
; CHECK-NEXT: vfmast.f16 q0, q1, r1
488+
; CHECK-NEXT: bx lr
489+
entry:
490+
%0 = bitcast float %c.coerce to i32
491+
%tmp.0.extract.trunc = trunc i32 %0 to i16
492+
%1 = bitcast i16 %tmp.0.extract.trunc to half
493+
%.splatinsert = insertelement <8 x half> undef, half %1, i32 0
494+
%.splat = shufflevector <8 x half> %.splatinsert, <8 x half> undef, <8 x i32> zeroinitializer
495+
%2 = zext i16 %p to i32
496+
%3 = tail call <8 x i1> @llvm.arm.mve.pred.i2v.v8i1(i32 %2)
497+
%4 = tail call <8 x half> @llvm.fma.v8f16(<8 x half> %a, <8 x half> %b, <8 x half> %.splat)
498+
%5 = select <8 x i1> %3, <8 x half> %4, <8 x half> %a
499+
ret <8 x half> %5
500+
}
501+
479502
define arm_aapcs_vfpcc <4 x float> @test_vfmasq_m_n_f32(<4 x float> %a, <4 x float> %b, float %c, i16 zeroext %p) {
480503
; CHECK-LABEL: test_vfmasq_m_n_f32:
481504
; CHECK: @ %bb.0: @ %entry
482505
; CHECK-NEXT: vmov r1, s8
483506
; CHECK-NEXT: vmsr p0, r0
507+
; CHECK-NEXT: vdup.32 q2, r1
484508
; CHECK-NEXT: vpst
485-
; CHECK-NEXT: vfmast.f32 q0, q1, r1
509+
; CHECK-NEXT: vfmat.f32 q2, q0, q1
510+
; CHECK-NEXT: vmov q0, q2
486511
; CHECK-NEXT: bx lr
487512
entry:
488513
%.splatinsert = insertelement <4 x float> undef, float %c, i32 0
@@ -493,6 +518,24 @@ entry:
493518
ret <4 x float> %2
494519
}
495520

521+
define arm_aapcs_vfpcc <4 x float> @test_vfmasq_m_n_f32_select(<4 x float> %a, <4 x float> %b, float %c, i16 zeroext %p) {
522+
; CHECK-LABEL: test_vfmasq_m_n_f32_select:
523+
; CHECK: @ %bb.0: @ %entry
524+
; CHECK-NEXT: vmov r1, s8
525+
; CHECK-NEXT: vmsr p0, r0
526+
; CHECK-NEXT: vpst
527+
; CHECK-NEXT: vfmast.f32 q0, q1, r1
528+
; CHECK-NEXT: bx lr
529+
entry:
530+
%.splatinsert = insertelement <4 x float> undef, float %c, i32 0
531+
%.splat = shufflevector <4 x float> %.splatinsert, <4 x float> undef, <4 x i32> zeroinitializer
532+
%0 = zext i16 %p to i32
533+
%1 = tail call <4 x i1> @llvm.arm.mve.pred.i2v.v4i1(i32 %0)
534+
%2 = tail call <4 x float> @llvm.fma.v4f32(<4 x float> %a, <4 x float> %b, <4 x float> %.splat)
535+
%3 = select <4 x i1> %1, <4 x float> %2, <4 x float> %a
536+
ret <4 x float> %3
537+
}
538+
496539
define arm_aapcs_vfpcc <8 x half> @test_vfmsq_m_f16(<8 x half> %a, <8 x half> %b, <8 x half> %c, i16 zeroext %p) {
497540
; CHECK-LABEL: test_vfmsq_m_f16:
498541
; CHECK: @ %bb.0: @ %entry

llvm/test/CodeGen/Thumb2/mve-qrintr.ll

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1536,7 +1536,8 @@ while.body: ; preds = %while.body.lr.ph, %
15361536
%0 = tail call <4 x i1> @llvm.arm.mve.vctp32(i32 %N.addr.013)
15371537
%1 = tail call fast <4 x float> @llvm.masked.load.v4f32.p0(ptr %s1.addr.014, i32 4, <4 x i1> %0, <4 x float> zeroinitializer)
15381538
%2 = tail call fast <4 x float> @llvm.masked.load.v4f32.p0(ptr %s2, i32 4, <4 x i1> %0, <4 x float> zeroinitializer)
1539-
%3 = tail call fast <4 x float> @llvm.arm.mve.fma.predicated.v4f32.v4i1(<4 x float> %1, <4 x float> %2, <4 x float> %.splat, <4 x i1> %0)
1539+
%3 = tail call fast <4 x float> @llvm.fma.v4f32(<4 x float> %1, <4 x float> %2, <4 x float> %.splat)
1540+
%4 = select <4 x i1> %0, <4 x float> %3, <4 x float> %1
15401541
tail call void @llvm.masked.store.v4f32.p0(<4 x float> %3, ptr %s1.addr.014, i32 4, <4 x i1> %0)
15411542
%add.ptr = getelementptr inbounds float, ptr %s1.addr.014, i32 4
15421543
%sub = add nsw i32 %N.addr.013, -4

0 commit comments

Comments
 (0)