Skip to content

Commit d9d67c4

Browse files
committed
[NVPTX] Lower bfloat16 add/mul/sub as fma on SM80
SM80 has fma for bfloat16 but not add/mul/sub. Currently these are just promoted to f32 but we can instead write them in terms of the fma: ``` FADD(a, b) -> FMA(a, 1.0, b) FMUL(a, b) -> FMA(a, b, 0.0) FSUB(a, b) -> FMA(b, -1.0, a) ``` Unfortunately there is no `fma.ftz` so when ftz is enabled, we still fall back to promotion. This is also the inverse of some generic DAGCombiner patterns, so I've had to add checks to avoid it reversing the legalization which would cause an infinite loop.
1 parent 222ff18 commit d9d67c4

File tree

9 files changed

+229
-381
lines changed

9 files changed

+229
-381
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17551,10 +17551,13 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
1755117551
return N2;
1755217552
}
1755317553

17554+
const bool PreferFMAAdd = (TLI.isOperationLegal(ISD::FMA, VT) &&
17555+
!TLI.isOperationLegal(ISD::FADD, VT));
17556+
1755417557
// FIXME: Support splat of constant.
17555-
if (N0CFP && N0CFP->isExactlyValue(1.0))
17558+
if (!PreferFMAAdd && N0CFP && N0CFP->isExactlyValue(1.0))
1755617559
return matcher.getNode(ISD::FADD, DL, VT, N1, N2);
17557-
if (N1CFP && N1CFP->isExactlyValue(1.0))
17560+
if (!PreferFMAAdd && N1CFP && N1CFP->isExactlyValue(1.0))
1755817561
return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
1755917562

1756017563
// Canonicalize (fma c, x, y) -> (fma x, c, y)
@@ -17586,7 +17589,7 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
1758617589

1758717590
// (fma x, -1, y) -> (fadd (fneg x), y)
1758817591
// FIXME: Support splat of constant.
17589-
if (N1CFP) {
17592+
if (N1CFP && !PreferFMAAdd) {
1759017593
if (N1CFP->isExactlyValue(1.0))
1759117594
return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
1759217595

@@ -17596,15 +17599,14 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
1759617599
AddToWorklist(RHSNeg.getNode());
1759717600
return matcher.getNode(ISD::FADD, DL, VT, N2, RHSNeg);
1759817601
}
17599-
17600-
// fma (fneg x), K, y -> fma x -K, y
17601-
if (matcher.match(N0, ISD::FNEG) &&
17602-
(TLI.isOperationLegal(ISD::ConstantFP, VT) ||
17603-
(N1.hasOneUse() &&
17604-
!TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
17605-
return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
17606-
matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
17607-
}
17602+
}
17603+
// fma (fneg x), K, y -> fma x -K, y
17604+
if (N1CFP && matcher.match(N0, ISD::FNEG) &&
17605+
(TLI.isOperationLegal(ISD::ConstantFP, VT) ||
17606+
(N1.hasOneUse() &&
17607+
!TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
17608+
return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
17609+
matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
1760817610
}
1760917611

1761017612
// FIXME: Support splat of constant.

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,16 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
853853
AddPromotedToType(Op, MVT::bf16, MVT::f32);
854854
}
855855

856+
// Lower bf16 add/mul/sub as fma when it avoids promotion
857+
for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB}) {
858+
for (const auto &VT : {MVT::bf16, MVT::v2bf16}) {
859+
if (getOperationAction(Op, VT) != Legal &&
860+
getOperationAction(ISD::FMA, VT) == Legal) {
861+
setOperationAction(Op, VT, Custom);
862+
}
863+
}
864+
}
865+
856866
// f16/f16x2 neg was introduced in PTX 60, SM_53.
857867
const bool IsFP16FP16x2NegAvailable = STI.getSmVersion() >= 53 &&
858868
STI.getPTXVersion() >= 60 &&
@@ -2490,6 +2500,62 @@ SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
24902500
return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
24912501
}
24922502

2503+
static SDValue PromoteBinOpToF32(SDNode *N, SelectionDAG &DAG) {
2504+
EVT VT = N->getValueType(0);
2505+
EVT NVT = MVT::f32;
2506+
if (VT.isVector()) {
2507+
NVT = EVT::getVectorVT(*DAG.getContext(), NVT, VT.getVectorElementCount());
2508+
}
2509+
SDLoc DL(N);
2510+
SDValue Tmp0 = DAG.getFPExtendOrRound(N->getOperand(0), DL, NVT);
2511+
SDValue Tmp1 = DAG.getFPExtendOrRound(N->getOperand(1), DL, NVT);
2512+
SDValue Res = DAG.getNode(N->getOpcode(), DL, NVT, Tmp0, Tmp1, N->getFlags());
2513+
return DAG.getFPExtendOrRound(Res, DL, VT);
2514+
}
2515+
2516+
SDValue NVPTXTargetLowering::LowerFADD(SDValue Op, SelectionDAG &DAG) const {
2517+
// No fma.ftz for bf16, so fall back to promotion
2518+
if (useF32FTZ(DAG.getMachineFunction())) {
2519+
return PromoteBinOpToF32(Op.getNode(), DAG);
2520+
}
2521+
2522+
// FADD(a, b) -> FMA(a, 1.0, b)
2523+
SDLoc DL(Op);
2524+
auto VT = Op.getValueType();
2525+
auto One = DAG.getConstantFP(1.0, DL, VT);
2526+
SmallVector<SDValue, 3> Operands{Op->getOperand(0), One, Op->getOperand(1)};
2527+
return DAG.getNode(ISD::FMA, DL, VT, Operands);
2528+
}
2529+
2530+
SDValue NVPTXTargetLowering::LowerFSUB(SDValue Op, SelectionDAG &DAG) const {
2531+
// No fma.ftz for bf16, so fall back to promotion
2532+
if (useF32FTZ(DAG.getMachineFunction())) {
2533+
return PromoteBinOpToF32(Op.getNode(), DAG);
2534+
}
2535+
2536+
// FSUB(a, b) -> FMA(b, -1.0, a)
2537+
SDLoc DL(Op);
2538+
auto VT = Op.getValueType();
2539+
auto NegOne = DAG.getConstantFP(-1.0, DL, VT);
2540+
SmallVector<SDValue, 3> Operands{Op->getOperand(1), NegOne,
2541+
Op->getOperand(0)};
2542+
return DAG.getNode(ISD::FMA, DL, VT, Operands);
2543+
}
2544+
2545+
SDValue NVPTXTargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
2546+
// No fma.ftz for bf16, so fall back to promotion
2547+
if (useF32FTZ(DAG.getMachineFunction())) {
2548+
return PromoteBinOpToF32(Op.getNode(), DAG);
2549+
}
2550+
2551+
// FMUL(a, b) -> FMA(a, b, 0.0)
2552+
SDLoc DL(Op);
2553+
auto VT = Op.getValueType();
2554+
auto Zero = DAG.getConstantFP(0.0, DL, VT);
2555+
SmallVector<SDValue, 3> Operands{Op->getOperand(0), Op->getOperand(1), Zero};
2556+
return DAG.getNode(ISD::FMA, DL, VT, Operands);
2557+
}
2558+
24932559
SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
24942560
SelectionDAG &DAG) const {
24952561
assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
@@ -2681,6 +2747,13 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
26812747
return LowerSTACKSAVE(Op, DAG);
26822748
case ISD::CopyToReg:
26832749
return LowerCopyToReg_128(Op, DAG);
2750+
case ISD::FADD:
2751+
return LowerFADD(Op, DAG);
2752+
case ISD::FSUB:
2753+
return LowerFSUB(Op, DAG);
2754+
case ISD::FMUL:
2755+
return LowerFMUL(Op, DAG);
2756+
26842757
default:
26852758
llvm_unreachable("Custom lowering not defined for operation");
26862759
}

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,10 @@ class NVPTXTargetLowering : public TargetLowering {
279279
SDValue LowerFROUND32(SDValue Op, SelectionDAG &DAG) const;
280280
SDValue LowerFROUND64(SDValue Op, SelectionDAG &DAG) const;
281281

282+
SDValue LowerFADD(SDValue Op, SelectionDAG &DAG) const;
283+
SDValue LowerFSUB(SDValue Op, SelectionDAG &DAG) const;
284+
SDValue LowerFMUL(SDValue Op, SelectionDAG &DAG) const;
285+
282286
SDValue LowerINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
283287
SDValue LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const;
284288

llvm/test/CodeGen/NVPTX/atomics-sm90.ll

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -46,58 +46,52 @@ define void @test(ptr %dp0, ptr addrspace(1) %dp1, ptr addrspace(3) %dp3, bfloat
4646
; CHECKPTX71-LABEL: test(
4747
; CHECKPTX71: {
4848
; CHECKPTX71-NEXT: .reg .pred %p<5>;
49-
; CHECKPTX71-NEXT: .reg .b16 %rs<22>;
49+
; CHECKPTX71-NEXT: .reg .b16 %rs<26>;
5050
; CHECKPTX71-NEXT: .reg .b32 %r<4>;
51-
; CHECKPTX71-NEXT: .reg .f32 %f<12>;
5251
; CHECKPTX71-EMPTY:
5352
; CHECKPTX71-NEXT: // %bb.0:
5453
; CHECKPTX71-NEXT: ld.param.b16 %rs13, [test_param_3];
5554
; CHECKPTX71-NEXT: ld.param.u32 %r3, [test_param_2];
5655
; CHECKPTX71-NEXT: ld.param.u32 %r2, [test_param_1];
5756
; CHECKPTX71-NEXT: ld.param.u32 %r1, [test_param_0];
58-
; CHECKPTX71-NEXT: ld.b16 %rs18, [%r1];
59-
; CHECKPTX71-NEXT: cvt.f32.bf16 %f1, %rs13;
57+
; CHECKPTX71-NEXT: ld.b16 %rs22, [%r1];
6058
; CHECKPTX71-NEXT: $L__BB0_1: // %atomicrmw.start14
6159
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
62-
; CHECKPTX71-NEXT: cvt.f32.bf16 %f2, %rs18;
63-
; CHECKPTX71-NEXT: add.rn.f32 %f3, %f2, %f1;
64-
; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs14, %f3;
65-
; CHECKPTX71-NEXT: atom.cas.b16 %rs3, [%r1], %rs18, %rs14;
66-
; CHECKPTX71-NEXT: setp.ne.s16 %p1, %rs3, %rs18;
67-
; CHECKPTX71-NEXT: mov.u16 %rs18, %rs3;
60+
; CHECKPTX71-NEXT: mov.b16 %rs14, 0x3F80;
61+
; CHECKPTX71-NEXT: fma.rn.bf16 %rs15, %rs22, %rs14, %rs13;
62+
; CHECKPTX71-NEXT: atom.cas.b16 %rs3, [%r1], %rs22, %rs15;
63+
; CHECKPTX71-NEXT: setp.ne.s16 %p1, %rs3, %rs22;
64+
; CHECKPTX71-NEXT: mov.u16 %rs22, %rs3;
6865
; CHECKPTX71-NEXT: @%p1 bra $L__BB0_1;
6966
; CHECKPTX71-NEXT: // %bb.2: // %atomicrmw.end13
70-
; CHECKPTX71-NEXT: ld.b16 %rs19, [%r1];
67+
; CHECKPTX71-NEXT: ld.b16 %rs23, [%r1];
7168
; CHECKPTX71-NEXT: $L__BB0_3: // %atomicrmw.start8
7269
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
73-
; CHECKPTX71-NEXT: cvt.f32.bf16 %f4, %rs19;
74-
; CHECKPTX71-NEXT: add.rn.f32 %f5, %f4, 0f3F800000;
75-
; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs15, %f5;
76-
; CHECKPTX71-NEXT: atom.cas.b16 %rs6, [%r1], %rs19, %rs15;
77-
; CHECKPTX71-NEXT: setp.ne.s16 %p2, %rs6, %rs19;
78-
; CHECKPTX71-NEXT: mov.u16 %rs19, %rs6;
70+
; CHECKPTX71-NEXT: mov.b16 %rs16, 0x3F80;
71+
; CHECKPTX71-NEXT: fma.rn.bf16 %rs17, %rs23, %rs16, %rs16;
72+
; CHECKPTX71-NEXT: atom.cas.b16 %rs6, [%r1], %rs23, %rs17;
73+
; CHECKPTX71-NEXT: setp.ne.s16 %p2, %rs6, %rs23;
74+
; CHECKPTX71-NEXT: mov.u16 %rs23, %rs6;
7975
; CHECKPTX71-NEXT: @%p2 bra $L__BB0_3;
8076
; CHECKPTX71-NEXT: // %bb.4: // %atomicrmw.end7
81-
; CHECKPTX71-NEXT: ld.global.b16 %rs20, [%r2];
77+
; CHECKPTX71-NEXT: ld.global.b16 %rs24, [%r2];
8278
; CHECKPTX71-NEXT: $L__BB0_5: // %atomicrmw.start2
8379
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
84-
; CHECKPTX71-NEXT: cvt.f32.bf16 %f7, %rs20;
85-
; CHECKPTX71-NEXT: add.rn.f32 %f8, %f7, %f1;
86-
; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs16, %f8;
87-
; CHECKPTX71-NEXT: atom.global.cas.b16 %rs9, [%r2], %rs20, %rs16;
88-
; CHECKPTX71-NEXT: setp.ne.s16 %p3, %rs9, %rs20;
89-
; CHECKPTX71-NEXT: mov.u16 %rs20, %rs9;
80+
; CHECKPTX71-NEXT: mov.b16 %rs18, 0x3F80;
81+
; CHECKPTX71-NEXT: fma.rn.bf16 %rs19, %rs24, %rs18, %rs13;
82+
; CHECKPTX71-NEXT: atom.global.cas.b16 %rs9, [%r2], %rs24, %rs19;
83+
; CHECKPTX71-NEXT: setp.ne.s16 %p3, %rs9, %rs24;
84+
; CHECKPTX71-NEXT: mov.u16 %rs24, %rs9;
9085
; CHECKPTX71-NEXT: @%p3 bra $L__BB0_5;
9186
; CHECKPTX71-NEXT: // %bb.6: // %atomicrmw.end1
92-
; CHECKPTX71-NEXT: ld.shared.b16 %rs21, [%r3];
87+
; CHECKPTX71-NEXT: ld.shared.b16 %rs25, [%r3];
9388
; CHECKPTX71-NEXT: $L__BB0_7: // %atomicrmw.start
9489
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
95-
; CHECKPTX71-NEXT: cvt.f32.bf16 %f10, %rs21;
96-
; CHECKPTX71-NEXT: add.rn.f32 %f11, %f10, %f1;
97-
; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs17, %f11;
98-
; CHECKPTX71-NEXT: atom.shared.cas.b16 %rs12, [%r3], %rs21, %rs17;
99-
; CHECKPTX71-NEXT: setp.ne.s16 %p4, %rs12, %rs21;
100-
; CHECKPTX71-NEXT: mov.u16 %rs21, %rs12;
90+
; CHECKPTX71-NEXT: mov.b16 %rs20, 0x3F80;
91+
; CHECKPTX71-NEXT: fma.rn.bf16 %rs21, %rs25, %rs20, %rs13;
92+
; CHECKPTX71-NEXT: atom.shared.cas.b16 %rs12, [%r3], %rs25, %rs21;
93+
; CHECKPTX71-NEXT: setp.ne.s16 %p4, %rs12, %rs25;
94+
; CHECKPTX71-NEXT: mov.u16 %rs25, %rs12;
10195
; CHECKPTX71-NEXT: @%p4 bra $L__BB0_7;
10296
; CHECKPTX71-NEXT: // %bb.8: // %atomicrmw.end
10397
; CHECKPTX71-NEXT: ret;

llvm/test/CodeGen/NVPTX/bf16-instructions.ll

Lines changed: 28 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -42,17 +42,14 @@ define bfloat @test_fadd(bfloat %0, bfloat %1) {
4242
;
4343
; SM80-LABEL: test_fadd(
4444
; SM80: {
45-
; SM80-NEXT: .reg .b16 %rs<4>;
46-
; SM80-NEXT: .reg .f32 %f<4>;
45+
; SM80-NEXT: .reg .b16 %rs<5>;
4746
; SM80-EMPTY:
4847
; SM80-NEXT: // %bb.0:
4948
; SM80-NEXT: ld.param.b16 %rs1, [test_fadd_param_0];
5049
; SM80-NEXT: ld.param.b16 %rs2, [test_fadd_param_1];
51-
; SM80-NEXT: cvt.f32.bf16 %f1, %rs2;
52-
; SM80-NEXT: cvt.f32.bf16 %f2, %rs1;
53-
; SM80-NEXT: add.rn.f32 %f3, %f2, %f1;
54-
; SM80-NEXT: cvt.rn.bf16.f32 %rs3, %f3;
55-
; SM80-NEXT: st.param.b16 [func_retval0], %rs3;
50+
; SM80-NEXT: mov.b16 %rs3, 0x3F80;
51+
; SM80-NEXT: fma.rn.bf16 %rs4, %rs1, %rs3, %rs2;
52+
; SM80-NEXT: st.param.b16 [func_retval0], %rs4;
5653
; SM80-NEXT: ret;
5754
;
5855
; SM80-FTZ-LABEL: test_fadd(
@@ -113,17 +110,14 @@ define bfloat @test_fsub(bfloat %0, bfloat %1) {
113110
;
114111
; SM80-LABEL: test_fsub(
115112
; SM80: {
116-
; SM80-NEXT: .reg .b16 %rs<4>;
117-
; SM80-NEXT: .reg .f32 %f<4>;
113+
; SM80-NEXT: .reg .b16 %rs<5>;
118114
; SM80-EMPTY:
119115
; SM80-NEXT: // %bb.0:
120116
; SM80-NEXT: ld.param.b16 %rs1, [test_fsub_param_0];
121117
; SM80-NEXT: ld.param.b16 %rs2, [test_fsub_param_1];
122-
; SM80-NEXT: cvt.f32.bf16 %f1, %rs2;
123-
; SM80-NEXT: cvt.f32.bf16 %f2, %rs1;
124-
; SM80-NEXT: sub.rn.f32 %f3, %f2, %f1;
125-
; SM80-NEXT: cvt.rn.bf16.f32 %rs3, %f3;
126-
; SM80-NEXT: st.param.b16 [func_retval0], %rs3;
118+
; SM80-NEXT: mov.b16 %rs3, 0xBF80;
119+
; SM80-NEXT: fma.rn.bf16 %rs4, %rs2, %rs3, %rs1;
120+
; SM80-NEXT: st.param.b16 [func_retval0], %rs4;
127121
; SM80-NEXT: ret;
128122
;
129123
; SM80-FTZ-LABEL: test_fsub(
@@ -202,23 +196,14 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
202196
;
203197
; SM80-LABEL: test_faddx2(
204198
; SM80: {
205-
; SM80-NEXT: .reg .b16 %rs<5>;
206-
; SM80-NEXT: .reg .b32 %r<4>;
207-
; SM80-NEXT: .reg .f32 %f<7>;
199+
; SM80-NEXT: .reg .b32 %r<5>;
208200
; SM80-EMPTY:
209201
; SM80-NEXT: // %bb.0:
210-
; SM80-NEXT: ld.param.b32 %r1, [test_faddx2_param_0];
211-
; SM80-NEXT: ld.param.b32 %r2, [test_faddx2_param_1];
212-
; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r2;
213-
; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
214-
; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
215-
; SM80-NEXT: cvt.f32.bf16 %f2, %rs3;
216-
; SM80-NEXT: add.rn.f32 %f3, %f2, %f1;
217-
; SM80-NEXT: cvt.f32.bf16 %f4, %rs2;
218-
; SM80-NEXT: cvt.f32.bf16 %f5, %rs4;
219-
; SM80-NEXT: add.rn.f32 %f6, %f5, %f4;
220-
; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
221-
; SM80-NEXT: st.param.b32 [func_retval0], %r3;
202+
; SM80-NEXT: ld.param.b32 %r1, [test_faddx2_param_1];
203+
; SM80-NEXT: ld.param.b32 %r2, [test_faddx2_param_0];
204+
; SM80-NEXT: mov.b32 %r3, 1065369472;
205+
; SM80-NEXT: fma.rn.bf16x2 %r4, %r2, %r3, %r1;
206+
; SM80-NEXT: st.param.b32 [func_retval0], %r4;
222207
; SM80-NEXT: ret;
223208
;
224209
; SM80-FTZ-LABEL: test_faddx2(
@@ -303,23 +288,14 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
303288
;
304289
; SM80-LABEL: test_fsubx2(
305290
; SM80: {
306-
; SM80-NEXT: .reg .b16 %rs<5>;
307-
; SM80-NEXT: .reg .b32 %r<4>;
308-
; SM80-NEXT: .reg .f32 %f<7>;
291+
; SM80-NEXT: .reg .b32 %r<5>;
309292
; SM80-EMPTY:
310293
; SM80-NEXT: // %bb.0:
311294
; SM80-NEXT: ld.param.b32 %r1, [test_fsubx2_param_0];
312295
; SM80-NEXT: ld.param.b32 %r2, [test_fsubx2_param_1];
313-
; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r2;
314-
; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
315-
; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
316-
; SM80-NEXT: cvt.f32.bf16 %f2, %rs3;
317-
; SM80-NEXT: sub.rn.f32 %f3, %f2, %f1;
318-
; SM80-NEXT: cvt.f32.bf16 %f4, %rs2;
319-
; SM80-NEXT: cvt.f32.bf16 %f5, %rs4;
320-
; SM80-NEXT: sub.rn.f32 %f6, %f5, %f4;
321-
; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
322-
; SM80-NEXT: st.param.b32 [func_retval0], %r3;
296+
; SM80-NEXT: mov.b32 %r3, -1082081408;
297+
; SM80-NEXT: fma.rn.bf16x2 %r4, %r2, %r3, %r1;
298+
; SM80-NEXT: st.param.b32 [func_retval0], %r4;
323299
; SM80-NEXT: ret;
324300
;
325301
; SM80-FTZ-LABEL: test_fsubx2(
@@ -404,23 +380,14 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
404380
;
405381
; SM80-LABEL: test_fmulx2(
406382
; SM80: {
407-
; SM80-NEXT: .reg .b16 %rs<5>;
408-
; SM80-NEXT: .reg .b32 %r<4>;
409-
; SM80-NEXT: .reg .f32 %f<7>;
383+
; SM80-NEXT: .reg .b32 %r<5>;
410384
; SM80-EMPTY:
411385
; SM80-NEXT: // %bb.0:
412-
; SM80-NEXT: ld.param.b32 %r1, [test_fmulx2_param_0];
413-
; SM80-NEXT: ld.param.b32 %r2, [test_fmulx2_param_1];
414-
; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r2;
415-
; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
416-
; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
417-
; SM80-NEXT: cvt.f32.bf16 %f2, %rs3;
418-
; SM80-NEXT: mul.rn.f32 %f3, %f2, %f1;
419-
; SM80-NEXT: cvt.f32.bf16 %f4, %rs2;
420-
; SM80-NEXT: cvt.f32.bf16 %f5, %rs4;
421-
; SM80-NEXT: mul.rn.f32 %f6, %f5, %f4;
422-
; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
423-
; SM80-NEXT: st.param.b32 [func_retval0], %r3;
386+
; SM80-NEXT: ld.param.b32 %r1, [test_fmulx2_param_1];
387+
; SM80-NEXT: ld.param.b32 %r2, [test_fmulx2_param_0];
388+
; SM80-NEXT: mov.b32 %r3, 0;
389+
; SM80-NEXT: fma.rn.bf16x2 %r4, %r2, %r1, %r3;
390+
; SM80-NEXT: st.param.b32 [func_retval0], %r4;
424391
; SM80-NEXT: ret;
425392
;
426393
; SM80-FTZ-LABEL: test_fmulx2(
@@ -727,15 +694,13 @@ define bfloat @test_fadd_imm_1(bfloat %a) #0 {
727694
;
728695
; SM80-LABEL: test_fadd_imm_1(
729696
; SM80: {
730-
; SM80-NEXT: .reg .b16 %rs<3>;
731-
; SM80-NEXT: .reg .f32 %f<3>;
697+
; SM80-NEXT: .reg .b16 %rs<4>;
732698
; SM80-EMPTY:
733699
; SM80-NEXT: // %bb.0:
734700
; SM80-NEXT: ld.param.b16 %rs1, [test_fadd_imm_1_param_0];
735-
; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
736-
; SM80-NEXT: add.rn.f32 %f2, %f1, 0f3F800000;
737-
; SM80-NEXT: cvt.rn.bf16.f32 %rs2, %f2;
738-
; SM80-NEXT: st.param.b16 [func_retval0], %rs2;
701+
; SM80-NEXT: mov.b16 %rs2, 0x3F80;
702+
; SM80-NEXT: fma.rn.bf16 %rs3, %rs1, %rs2, %rs2;
703+
; SM80-NEXT: st.param.b16 [func_retval0], %rs3;
739704
; SM80-NEXT: ret;
740705
;
741706
; SM80-FTZ-LABEL: test_fadd_imm_1(

0 commit comments

Comments
 (0)