Skip to content

Commit 299a919

Browse files
committed
[NVPTX] Lower bfloat16 add/mul/sub as fma on SM80
SM80 has fma for bfloat16 but not add/mul/sub. Currently these are just promoted to f32 but we can instead write them in terms of the fma: ``` FADD(a, b) -> FMA(a, 1.0, b) FMUL(a, b) -> FMA(a, b, 0.0) FSUB(a, b) -> FMA(b, -1.0, a) ``` Unfortunately there is no `fma.ftz` so when ftz is enabled, we still fall back to promotion. This is also the inverse of some generic DAGCombiner patterns, so I've had to add checks to avoid it reversing the legalization which would cause an infinite loop.
1 parent 852feea commit 299a919

File tree

9 files changed

+229
-381
lines changed

9 files changed

+229
-381
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17534,10 +17534,13 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
1753417534
return N2;
1753517535
}
1753617536

17537+
const bool PreferFMAAdd = (TLI.isOperationLegal(ISD::FMA, VT) &&
17538+
!TLI.isOperationLegal(ISD::FADD, VT));
17539+
1753717540
// FIXME: Support splat of constant.
17538-
if (N0CFP && N0CFP->isExactlyValue(1.0))
17541+
if (!PreferFMAAdd && N0CFP && N0CFP->isExactlyValue(1.0))
1753917542
return matcher.getNode(ISD::FADD, DL, VT, N1, N2);
17540-
if (N1CFP && N1CFP->isExactlyValue(1.0))
17543+
if (!PreferFMAAdd && N1CFP && N1CFP->isExactlyValue(1.0))
1754117544
return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
1754217545

1754317546
// Canonicalize (fma c, x, y) -> (fma x, c, y)
@@ -17569,7 +17572,7 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
1756917572

1757017573
// (fma x, -1, y) -> (fadd (fneg x), y)
1757117574
// FIXME: Support splat of constant.
17572-
if (N1CFP) {
17575+
if (N1CFP && !PreferFMAAdd) {
1757317576
if (N1CFP->isExactlyValue(1.0))
1757417577
return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
1757517578

@@ -17579,15 +17582,14 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
1757917582
AddToWorklist(RHSNeg.getNode());
1758017583
return matcher.getNode(ISD::FADD, DL, VT, N2, RHSNeg);
1758117584
}
17582-
17583-
// fma (fneg x), K, y -> fma x -K, y
17584-
if (matcher.match(N0, ISD::FNEG) &&
17585-
(TLI.isOperationLegal(ISD::ConstantFP, VT) ||
17586-
(N1.hasOneUse() &&
17587-
!TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
17588-
return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
17589-
matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
17590-
}
17585+
}
17586+
// fma (fneg x), K, y -> fma x -K, y
17587+
if (N1CFP && matcher.match(N0, ISD::FNEG) &&
17588+
(TLI.isOperationLegal(ISD::ConstantFP, VT) ||
17589+
(N1.hasOneUse() &&
17590+
!TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
17591+
return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
17592+
matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
1759117593
}
1759217594

1759317595
// FIXME: Support splat of constant.

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,16 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
853853
AddPromotedToType(Op, MVT::bf16, MVT::f32);
854854
}
855855

856+
// Lower bf16 add/mul/sub as fma when it avoids promotion
857+
for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB}) {
858+
for (const auto &VT : {MVT::bf16, MVT::v2bf16}) {
859+
if (getOperationAction(Op, VT) != Legal &&
860+
getOperationAction(ISD::FMA, VT) == Legal) {
861+
setOperationAction(Op, VT, Custom);
862+
}
863+
}
864+
}
865+
856866
// f16/f16x2 neg was introduced in PTX 60, SM_53.
857867
const bool IsFP16FP16x2NegAvailable = STI.getSmVersion() >= 53 &&
858868
STI.getPTXVersion() >= 60 &&
@@ -2490,6 +2500,62 @@ SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
24902500
return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
24912501
}
24922502

2503+
static SDValue PromoteBinOpToF32(SDNode *N, SelectionDAG &DAG) {
2504+
EVT VT = N->getValueType(0);
2505+
EVT NVT = MVT::f32;
2506+
if (VT.isVector()) {
2507+
NVT = EVT::getVectorVT(*DAG.getContext(), NVT, VT.getVectorElementCount());
2508+
}
2509+
SDLoc DL(N);
2510+
SDValue Tmp0 = DAG.getFPExtendOrRound(N->getOperand(0), DL, NVT);
2511+
SDValue Tmp1 = DAG.getFPExtendOrRound(N->getOperand(1), DL, NVT);
2512+
SDValue Res = DAG.getNode(N->getOpcode(), DL, NVT, Tmp0, Tmp1, N->getFlags());
2513+
return DAG.getFPExtendOrRound(Res, DL, VT);
2514+
}
2515+
2516+
SDValue NVPTXTargetLowering::LowerFADD(SDValue Op, SelectionDAG &DAG) const {
2517+
// No fma.ftz for bf16, so fall back to promotion
2518+
if (useF32FTZ(DAG.getMachineFunction())) {
2519+
return PromoteBinOpToF32(Op.getNode(), DAG);
2520+
}
2521+
2522+
// FADD(a, b) -> FMA(a, 1.0, b)
2523+
SDLoc DL(Op);
2524+
auto VT = Op.getValueType();
2525+
auto One = DAG.getConstantFP(1.0, DL, VT);
2526+
SmallVector<SDValue, 3> Operands{Op->getOperand(0), One, Op->getOperand(1)};
2527+
return DAG.getNode(ISD::FMA, DL, VT, Operands);
2528+
}
2529+
2530+
SDValue NVPTXTargetLowering::LowerFSUB(SDValue Op, SelectionDAG &DAG) const {
2531+
// No fma.ftz for bf16, so fall back to promotion
2532+
if (useF32FTZ(DAG.getMachineFunction())) {
2533+
return PromoteBinOpToF32(Op.getNode(), DAG);
2534+
}
2535+
2536+
// FSUB(a, b) -> FMA(b, -1.0, a)
2537+
SDLoc DL(Op);
2538+
auto VT = Op.getValueType();
2539+
auto NegOne = DAG.getConstantFP(-1.0, DL, VT);
2540+
SmallVector<SDValue, 3> Operands{Op->getOperand(1), NegOne,
2541+
Op->getOperand(0)};
2542+
return DAG.getNode(ISD::FMA, DL, VT, Operands);
2543+
}
2544+
2545+
SDValue NVPTXTargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
2546+
// No fma.ftz for bf16, so fall back to promotion
2547+
if (useF32FTZ(DAG.getMachineFunction())) {
2548+
return PromoteBinOpToF32(Op.getNode(), DAG);
2549+
}
2550+
2551+
// FMUL(a, b) -> FMA(a, b, 0.0)
2552+
SDLoc DL(Op);
2553+
auto VT = Op.getValueType();
2554+
auto Zero = DAG.getConstantFP(0.0, DL, VT);
2555+
SmallVector<SDValue, 3> Operands{Op->getOperand(0), Op->getOperand(1), Zero};
2556+
return DAG.getNode(ISD::FMA, DL, VT, Operands);
2557+
}
2558+
24932559
SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
24942560
SelectionDAG &DAG) const {
24952561
assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
@@ -2681,6 +2747,13 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
26812747
return LowerSTACKSAVE(Op, DAG);
26822748
case ISD::CopyToReg:
26832749
return LowerCopyToReg_128(Op, DAG);
2750+
case ISD::FADD:
2751+
return LowerFADD(Op, DAG);
2752+
case ISD::FSUB:
2753+
return LowerFSUB(Op, DAG);
2754+
case ISD::FMUL:
2755+
return LowerFMUL(Op, DAG);
2756+
26842757
default:
26852758
llvm_unreachable("Custom lowering not defined for operation");
26862759
}

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,10 @@ class NVPTXTargetLowering : public TargetLowering {
279279
SDValue LowerFROUND32(SDValue Op, SelectionDAG &DAG) const;
280280
SDValue LowerFROUND64(SDValue Op, SelectionDAG &DAG) const;
281281

282+
SDValue LowerFADD(SDValue Op, SelectionDAG &DAG) const;
283+
SDValue LowerFSUB(SDValue Op, SelectionDAG &DAG) const;
284+
SDValue LowerFMUL(SDValue Op, SelectionDAG &DAG) const;
285+
282286
SDValue LowerINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
283287
SDValue LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const;
284288

llvm/test/CodeGen/NVPTX/atomics-sm90.ll

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -46,58 +46,52 @@ define void @test(ptr %dp0, ptr addrspace(1) %dp1, ptr addrspace(3) %dp3, bfloat
4646
; CHECKPTX71-LABEL: test(
4747
; CHECKPTX71: {
4848
; CHECKPTX71-NEXT: .reg .pred %p<5>;
49-
; CHECKPTX71-NEXT: .reg .b16 %rs<22>;
49+
; CHECKPTX71-NEXT: .reg .b16 %rs<26>;
5050
; CHECKPTX71-NEXT: .reg .b32 %r<4>;
51-
; CHECKPTX71-NEXT: .reg .f32 %f<12>;
5251
; CHECKPTX71-EMPTY:
5352
; CHECKPTX71-NEXT: // %bb.0:
5453
; CHECKPTX71-NEXT: ld.param.b16 %rs13, [test_param_3];
5554
; CHECKPTX71-NEXT: ld.param.u32 %r3, [test_param_2];
5655
; CHECKPTX71-NEXT: ld.param.u32 %r2, [test_param_1];
5756
; CHECKPTX71-NEXT: ld.param.u32 %r1, [test_param_0];
58-
; CHECKPTX71-NEXT: ld.b16 %rs18, [%r1];
59-
; CHECKPTX71-NEXT: cvt.f32.bf16 %f1, %rs13;
57+
; CHECKPTX71-NEXT: ld.b16 %rs22, [%r1];
6058
; CHECKPTX71-NEXT: $L__BB0_1: // %atomicrmw.start14
6159
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
62-
; CHECKPTX71-NEXT: cvt.f32.bf16 %f2, %rs18;
63-
; CHECKPTX71-NEXT: add.rn.f32 %f3, %f2, %f1;
64-
; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs14, %f3;
65-
; CHECKPTX71-NEXT: atom.cas.b16 %rs3, [%r1], %rs18, %rs14;
66-
; CHECKPTX71-NEXT: setp.ne.s16 %p1, %rs3, %rs18;
67-
; CHECKPTX71-NEXT: mov.u16 %rs18, %rs3;
60+
; CHECKPTX71-NEXT: mov.b16 %rs14, 0x3F80;
61+
; CHECKPTX71-NEXT: fma.rn.bf16 %rs15, %rs22, %rs14, %rs13;
62+
; CHECKPTX71-NEXT: atom.cas.b16 %rs3, [%r1], %rs22, %rs15;
63+
; CHECKPTX71-NEXT: setp.ne.s16 %p1, %rs3, %rs22;
64+
; CHECKPTX71-NEXT: mov.u16 %rs22, %rs3;
6865
; CHECKPTX71-NEXT: @%p1 bra $L__BB0_1;
6966
; CHECKPTX71-NEXT: // %bb.2: // %atomicrmw.end13
70-
; CHECKPTX71-NEXT: ld.b16 %rs19, [%r1];
67+
; CHECKPTX71-NEXT: ld.b16 %rs23, [%r1];
7168
; CHECKPTX71-NEXT: $L__BB0_3: // %atomicrmw.start8
7269
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
73-
; CHECKPTX71-NEXT: cvt.f32.bf16 %f4, %rs19;
74-
; CHECKPTX71-NEXT: add.rn.f32 %f5, %f4, 0f3F800000;
75-
; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs15, %f5;
76-
; CHECKPTX71-NEXT: atom.cas.b16 %rs6, [%r1], %rs19, %rs15;
77-
; CHECKPTX71-NEXT: setp.ne.s16 %p2, %rs6, %rs19;
78-
; CHECKPTX71-NEXT: mov.u16 %rs19, %rs6;
70+
; CHECKPTX71-NEXT: mov.b16 %rs16, 0x3F80;
71+
; CHECKPTX71-NEXT: fma.rn.bf16 %rs17, %rs23, %rs16, %rs16;
72+
; CHECKPTX71-NEXT: atom.cas.b16 %rs6, [%r1], %rs23, %rs17;
73+
; CHECKPTX71-NEXT: setp.ne.s16 %p2, %rs6, %rs23;
74+
; CHECKPTX71-NEXT: mov.u16 %rs23, %rs6;
7975
; CHECKPTX71-NEXT: @%p2 bra $L__BB0_3;
8076
; CHECKPTX71-NEXT: // %bb.4: // %atomicrmw.end7
81-
; CHECKPTX71-NEXT: ld.global.b16 %rs20, [%r2];
77+
; CHECKPTX71-NEXT: ld.global.b16 %rs24, [%r2];
8278
; CHECKPTX71-NEXT: $L__BB0_5: // %atomicrmw.start2
8379
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
84-
; CHECKPTX71-NEXT: cvt.f32.bf16 %f7, %rs20;
85-
; CHECKPTX71-NEXT: add.rn.f32 %f8, %f7, %f1;
86-
; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs16, %f8;
87-
; CHECKPTX71-NEXT: atom.global.cas.b16 %rs9, [%r2], %rs20, %rs16;
88-
; CHECKPTX71-NEXT: setp.ne.s16 %p3, %rs9, %rs20;
89-
; CHECKPTX71-NEXT: mov.u16 %rs20, %rs9;
80+
; CHECKPTX71-NEXT: mov.b16 %rs18, 0x3F80;
81+
; CHECKPTX71-NEXT: fma.rn.bf16 %rs19, %rs24, %rs18, %rs13;
82+
; CHECKPTX71-NEXT: atom.global.cas.b16 %rs9, [%r2], %rs24, %rs19;
83+
; CHECKPTX71-NEXT: setp.ne.s16 %p3, %rs9, %rs24;
84+
; CHECKPTX71-NEXT: mov.u16 %rs24, %rs9;
9085
; CHECKPTX71-NEXT: @%p3 bra $L__BB0_5;
9186
; CHECKPTX71-NEXT: // %bb.6: // %atomicrmw.end1
92-
; CHECKPTX71-NEXT: ld.shared.b16 %rs21, [%r3];
87+
; CHECKPTX71-NEXT: ld.shared.b16 %rs25, [%r3];
9388
; CHECKPTX71-NEXT: $L__BB0_7: // %atomicrmw.start
9489
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
95-
; CHECKPTX71-NEXT: cvt.f32.bf16 %f10, %rs21;
96-
; CHECKPTX71-NEXT: add.rn.f32 %f11, %f10, %f1;
97-
; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs17, %f11;
98-
; CHECKPTX71-NEXT: atom.shared.cas.b16 %rs12, [%r3], %rs21, %rs17;
99-
; CHECKPTX71-NEXT: setp.ne.s16 %p4, %rs12, %rs21;
100-
; CHECKPTX71-NEXT: mov.u16 %rs21, %rs12;
90+
; CHECKPTX71-NEXT: mov.b16 %rs20, 0x3F80;
91+
; CHECKPTX71-NEXT: fma.rn.bf16 %rs21, %rs25, %rs20, %rs13;
92+
; CHECKPTX71-NEXT: atom.shared.cas.b16 %rs12, [%r3], %rs25, %rs21;
93+
; CHECKPTX71-NEXT: setp.ne.s16 %p4, %rs12, %rs25;
94+
; CHECKPTX71-NEXT: mov.u16 %rs25, %rs12;
10195
; CHECKPTX71-NEXT: @%p4 bra $L__BB0_7;
10296
; CHECKPTX71-NEXT: // %bb.8: // %atomicrmw.end
10397
; CHECKPTX71-NEXT: ret;

llvm/test/CodeGen/NVPTX/bf16-instructions.ll

Lines changed: 28 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -42,17 +42,14 @@ define bfloat @test_fadd(bfloat %0, bfloat %1) {
4242
;
4343
; SM80-LABEL: test_fadd(
4444
; SM80: {
45-
; SM80-NEXT: .reg .b16 %rs<4>;
46-
; SM80-NEXT: .reg .f32 %f<4>;
45+
; SM80-NEXT: .reg .b16 %rs<5>;
4746
; SM80-EMPTY:
4847
; SM80-NEXT: // %bb.0:
4948
; SM80-NEXT: ld.param.b16 %rs1, [test_fadd_param_0];
5049
; SM80-NEXT: ld.param.b16 %rs2, [test_fadd_param_1];
51-
; SM80-NEXT: cvt.f32.bf16 %f1, %rs2;
52-
; SM80-NEXT: cvt.f32.bf16 %f2, %rs1;
53-
; SM80-NEXT: add.rn.f32 %f3, %f2, %f1;
54-
; SM80-NEXT: cvt.rn.bf16.f32 %rs3, %f3;
55-
; SM80-NEXT: st.param.b16 [func_retval0], %rs3;
50+
; SM80-NEXT: mov.b16 %rs3, 0x3F80;
51+
; SM80-NEXT: fma.rn.bf16 %rs4, %rs1, %rs3, %rs2;
52+
; SM80-NEXT: st.param.b16 [func_retval0], %rs4;
5653
; SM80-NEXT: ret;
5754
;
5855
; SM80-FTZ-LABEL: test_fadd(
@@ -113,17 +110,14 @@ define bfloat @test_fsub(bfloat %0, bfloat %1) {
113110
;
114111
; SM80-LABEL: test_fsub(
115112
; SM80: {
116-
; SM80-NEXT: .reg .b16 %rs<4>;
117-
; SM80-NEXT: .reg .f32 %f<4>;
113+
; SM80-NEXT: .reg .b16 %rs<5>;
118114
; SM80-EMPTY:
119115
; SM80-NEXT: // %bb.0:
120116
; SM80-NEXT: ld.param.b16 %rs1, [test_fsub_param_0];
121117
; SM80-NEXT: ld.param.b16 %rs2, [test_fsub_param_1];
122-
; SM80-NEXT: cvt.f32.bf16 %f1, %rs2;
123-
; SM80-NEXT: cvt.f32.bf16 %f2, %rs1;
124-
; SM80-NEXT: sub.rn.f32 %f3, %f2, %f1;
125-
; SM80-NEXT: cvt.rn.bf16.f32 %rs3, %f3;
126-
; SM80-NEXT: st.param.b16 [func_retval0], %rs3;
118+
; SM80-NEXT: mov.b16 %rs3, 0xBF80;
119+
; SM80-NEXT: fma.rn.bf16 %rs4, %rs2, %rs3, %rs1;
120+
; SM80-NEXT: st.param.b16 [func_retval0], %rs4;
127121
; SM80-NEXT: ret;
128122
;
129123
; SM80-FTZ-LABEL: test_fsub(
@@ -202,23 +196,14 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
202196
;
203197
; SM80-LABEL: test_faddx2(
204198
; SM80: {
205-
; SM80-NEXT: .reg .b16 %rs<5>;
206-
; SM80-NEXT: .reg .b32 %r<4>;
207-
; SM80-NEXT: .reg .f32 %f<7>;
199+
; SM80-NEXT: .reg .b32 %r<5>;
208200
; SM80-EMPTY:
209201
; SM80-NEXT: // %bb.0:
210-
; SM80-NEXT: ld.param.b32 %r1, [test_faddx2_param_0];
211-
; SM80-NEXT: ld.param.b32 %r2, [test_faddx2_param_1];
212-
; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r2;
213-
; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
214-
; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
215-
; SM80-NEXT: cvt.f32.bf16 %f2, %rs3;
216-
; SM80-NEXT: add.rn.f32 %f3, %f2, %f1;
217-
; SM80-NEXT: cvt.f32.bf16 %f4, %rs2;
218-
; SM80-NEXT: cvt.f32.bf16 %f5, %rs4;
219-
; SM80-NEXT: add.rn.f32 %f6, %f5, %f4;
220-
; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
221-
; SM80-NEXT: st.param.b32 [func_retval0], %r3;
202+
; SM80-NEXT: ld.param.b32 %r1, [test_faddx2_param_1];
203+
; SM80-NEXT: ld.param.b32 %r2, [test_faddx2_param_0];
204+
; SM80-NEXT: mov.b32 %r3, 1065369472;
205+
; SM80-NEXT: fma.rn.bf16x2 %r4, %r2, %r3, %r1;
206+
; SM80-NEXT: st.param.b32 [func_retval0], %r4;
222207
; SM80-NEXT: ret;
223208
;
224209
; SM80-FTZ-LABEL: test_faddx2(
@@ -303,23 +288,14 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
303288
;
304289
; SM80-LABEL: test_fsubx2(
305290
; SM80: {
306-
; SM80-NEXT: .reg .b16 %rs<5>;
307-
; SM80-NEXT: .reg .b32 %r<4>;
308-
; SM80-NEXT: .reg .f32 %f<7>;
291+
; SM80-NEXT: .reg .b32 %r<5>;
309292
; SM80-EMPTY:
310293
; SM80-NEXT: // %bb.0:
311294
; SM80-NEXT: ld.param.b32 %r1, [test_fsubx2_param_0];
312295
; SM80-NEXT: ld.param.b32 %r2, [test_fsubx2_param_1];
313-
; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r2;
314-
; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
315-
; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
316-
; SM80-NEXT: cvt.f32.bf16 %f2, %rs3;
317-
; SM80-NEXT: sub.rn.f32 %f3, %f2, %f1;
318-
; SM80-NEXT: cvt.f32.bf16 %f4, %rs2;
319-
; SM80-NEXT: cvt.f32.bf16 %f5, %rs4;
320-
; SM80-NEXT: sub.rn.f32 %f6, %f5, %f4;
321-
; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
322-
; SM80-NEXT: st.param.b32 [func_retval0], %r3;
296+
; SM80-NEXT: mov.b32 %r3, -1082081408;
297+
; SM80-NEXT: fma.rn.bf16x2 %r4, %r2, %r3, %r1;
298+
; SM80-NEXT: st.param.b32 [func_retval0], %r4;
323299
; SM80-NEXT: ret;
324300
;
325301
; SM80-FTZ-LABEL: test_fsubx2(
@@ -404,23 +380,14 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
404380
;
405381
; SM80-LABEL: test_fmulx2(
406382
; SM80: {
407-
; SM80-NEXT: .reg .b16 %rs<5>;
408-
; SM80-NEXT: .reg .b32 %r<4>;
409-
; SM80-NEXT: .reg .f32 %f<7>;
383+
; SM80-NEXT: .reg .b32 %r<5>;
410384
; SM80-EMPTY:
411385
; SM80-NEXT: // %bb.0:
412-
; SM80-NEXT: ld.param.b32 %r1, [test_fmulx2_param_0];
413-
; SM80-NEXT: ld.param.b32 %r2, [test_fmulx2_param_1];
414-
; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r2;
415-
; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
416-
; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
417-
; SM80-NEXT: cvt.f32.bf16 %f2, %rs3;
418-
; SM80-NEXT: mul.rn.f32 %f3, %f2, %f1;
419-
; SM80-NEXT: cvt.f32.bf16 %f4, %rs2;
420-
; SM80-NEXT: cvt.f32.bf16 %f5, %rs4;
421-
; SM80-NEXT: mul.rn.f32 %f6, %f5, %f4;
422-
; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
423-
; SM80-NEXT: st.param.b32 [func_retval0], %r3;
386+
; SM80-NEXT: ld.param.b32 %r1, [test_fmulx2_param_1];
387+
; SM80-NEXT: ld.param.b32 %r2, [test_fmulx2_param_0];
388+
; SM80-NEXT: mov.b32 %r3, 0;
389+
; SM80-NEXT: fma.rn.bf16x2 %r4, %r2, %r1, %r3;
390+
; SM80-NEXT: st.param.b32 [func_retval0], %r4;
424391
; SM80-NEXT: ret;
425392
;
426393
; SM80-FTZ-LABEL: test_fmulx2(
@@ -727,15 +694,13 @@ define bfloat @test_fadd_imm_1(bfloat %a) #0 {
727694
;
728695
; SM80-LABEL: test_fadd_imm_1(
729696
; SM80: {
730-
; SM80-NEXT: .reg .b16 %rs<3>;
731-
; SM80-NEXT: .reg .f32 %f<3>;
697+
; SM80-NEXT: .reg .b16 %rs<4>;
732698
; SM80-EMPTY:
733699
; SM80-NEXT: // %bb.0:
734700
; SM80-NEXT: ld.param.b16 %rs1, [test_fadd_imm_1_param_0];
735-
; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
736-
; SM80-NEXT: add.rn.f32 %f2, %f1, 0f3F800000;
737-
; SM80-NEXT: cvt.rn.bf16.f32 %rs2, %f2;
738-
; SM80-NEXT: st.param.b16 [func_retval0], %rs2;
701+
; SM80-NEXT: mov.b16 %rs2, 0x3F80;
702+
; SM80-NEXT: fma.rn.bf16 %rs3, %rs1, %rs2, %rs2;
703+
; SM80-NEXT: st.param.b16 [func_retval0], %rs3;
739704
; SM80-NEXT: ret;
740705
;
741706
; SM80-FTZ-LABEL: test_fadd_imm_1(

0 commit comments

Comments
 (0)