Skip to content

Commit b529e0c

Browse files
committed
[NVPTX] Lower bfloat16 add/mul/sub as fma on SM80
SM80 has fma for bfloat16 but not add/mul/sub. Currently these are just promoted to f32 but we can instead write them in terms of the fma: ``` FADD(a, b) -> FMA(a, 1.0, b) FMUL(a, b) -> FMA(a, b, 0.0) FSUB(a, b) -> FMA(b, -1.0, a) ``` Unfortunately there is no `fma.ftz` so when ftz is enabled, we still fall back to promotion. This is also the inverse of some generic DAGCombiner patterns, so I've had to add checks to avoid it reversing the legalization which would cause an infinite loop.
1 parent 1d890b0 commit b529e0c

File tree

9 files changed

+229
-381
lines changed

9 files changed

+229
-381
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17559,10 +17559,13 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
1755917559
return N2;
1756017560
}
1756117561

17562+
const bool PreferFMAAdd = (TLI.isOperationLegal(ISD::FMA, VT) &&
17563+
!TLI.isOperationLegal(ISD::FADD, VT));
17564+
1756217565
// FIXME: Support splat of constant.
17563-
if (N0CFP && N0CFP->isExactlyValue(1.0))
17566+
if (!PreferFMAAdd && N0CFP && N0CFP->isExactlyValue(1.0))
1756417567
return matcher.getNode(ISD::FADD, DL, VT, N1, N2);
17565-
if (N1CFP && N1CFP->isExactlyValue(1.0))
17568+
if (!PreferFMAAdd && N1CFP && N1CFP->isExactlyValue(1.0))
1756617569
return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
1756717570

1756817571
// Canonicalize (fma c, x, y) -> (fma x, c, y)
@@ -17594,7 +17597,7 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
1759417597

1759517598
// (fma x, -1, y) -> (fadd (fneg x), y)
1759617599
// FIXME: Support splat of constant.
17597-
if (N1CFP) {
17600+
if (N1CFP && !PreferFMAAdd) {
1759817601
if (N1CFP->isExactlyValue(1.0))
1759917602
return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
1760017603

@@ -17604,15 +17607,14 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
1760417607
AddToWorklist(RHSNeg.getNode());
1760517608
return matcher.getNode(ISD::FADD, DL, VT, N2, RHSNeg);
1760617609
}
17607-
17608-
// fma (fneg x), K, y -> fma x -K, y
17609-
if (matcher.match(N0, ISD::FNEG) &&
17610-
(TLI.isOperationLegal(ISD::ConstantFP, VT) ||
17611-
(N1.hasOneUse() &&
17612-
!TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
17613-
return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
17614-
matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
17615-
}
17610+
}
17611+
// fma (fneg x), K, y -> fma x -K, y
17612+
if (N1CFP && matcher.match(N0, ISD::FNEG) &&
17613+
(TLI.isOperationLegal(ISD::ConstantFP, VT) ||
17614+
(N1.hasOneUse() &&
17615+
!TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
17616+
return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
17617+
matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
1761617618
}
1761717619

1761817620
// FIXME: Support splat of constant.

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,16 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
862862
AddPromotedToType(Op, MVT::bf16, MVT::f32);
863863
}
864864

865+
// Lower bf16 add/mul/sub as fma when it avoids promotion
866+
for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB}) {
867+
for (const auto &VT : {MVT::bf16, MVT::v2bf16}) {
868+
if (getOperationAction(Op, VT) != Legal &&
869+
getOperationAction(ISD::FMA, VT) == Legal) {
870+
setOperationAction(Op, VT, Custom);
871+
}
872+
}
873+
}
874+
865875
// f16/f16x2 neg was introduced in PTX 60, SM_53.
866876
const bool IsFP16FP16x2NegAvailable = STI.getSmVersion() >= 53 &&
867877
STI.getPTXVersion() >= 60 &&
@@ -2498,6 +2508,62 @@ SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
24982508
return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
24992509
}
25002510

2511+
static SDValue PromoteBinOpToF32(SDNode *N, SelectionDAG &DAG) {
2512+
EVT VT = N->getValueType(0);
2513+
EVT NVT = MVT::f32;
2514+
if (VT.isVector()) {
2515+
NVT = EVT::getVectorVT(*DAG.getContext(), NVT, VT.getVectorElementCount());
2516+
}
2517+
SDLoc DL(N);
2518+
SDValue Tmp0 = DAG.getFPExtendOrRound(N->getOperand(0), DL, NVT);
2519+
SDValue Tmp1 = DAG.getFPExtendOrRound(N->getOperand(1), DL, NVT);
2520+
SDValue Res = DAG.getNode(N->getOpcode(), DL, NVT, Tmp0, Tmp1, N->getFlags());
2521+
return DAG.getFPExtendOrRound(Res, DL, VT);
2522+
}
2523+
2524+
SDValue NVPTXTargetLowering::LowerFADD(SDValue Op, SelectionDAG &DAG) const {
2525+
// No fma.ftz for bf16, so fall back to promotion
2526+
if (useF32FTZ(DAG.getMachineFunction())) {
2527+
return PromoteBinOpToF32(Op.getNode(), DAG);
2528+
}
2529+
2530+
// FADD(a, b) -> FMA(a, 1.0, b)
2531+
SDLoc DL(Op);
2532+
auto VT = Op.getValueType();
2533+
auto One = DAG.getConstantFP(1.0, DL, VT);
2534+
SmallVector<SDValue, 3> Operands{Op->getOperand(0), One, Op->getOperand(1)};
2535+
return DAG.getNode(ISD::FMA, DL, VT, Operands);
2536+
}
2537+
2538+
SDValue NVPTXTargetLowering::LowerFSUB(SDValue Op, SelectionDAG &DAG) const {
2539+
// No fma.ftz for bf16, so fall back to promotion
2540+
if (useF32FTZ(DAG.getMachineFunction())) {
2541+
return PromoteBinOpToF32(Op.getNode(), DAG);
2542+
}
2543+
2544+
// FSUB(a, b) -> FMA(b, -1.0, a)
2545+
SDLoc DL(Op);
2546+
auto VT = Op.getValueType();
2547+
auto NegOne = DAG.getConstantFP(-1.0, DL, VT);
2548+
SmallVector<SDValue, 3> Operands{Op->getOperand(1), NegOne,
2549+
Op->getOperand(0)};
2550+
return DAG.getNode(ISD::FMA, DL, VT, Operands);
2551+
}
2552+
2553+
SDValue NVPTXTargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
2554+
// No fma.ftz for bf16, so fall back to promotion
2555+
if (useF32FTZ(DAG.getMachineFunction())) {
2556+
return PromoteBinOpToF32(Op.getNode(), DAG);
2557+
}
2558+
2559+
// FMUL(a, b) -> FMA(a, b, 0.0)
2560+
SDLoc DL(Op);
2561+
auto VT = Op.getValueType();
2562+
auto Zero = DAG.getConstantFP(0.0, DL, VT);
2563+
SmallVector<SDValue, 3> Operands{Op->getOperand(0), Op->getOperand(1), Zero};
2564+
return DAG.getNode(ISD::FMA, DL, VT, Operands);
2565+
}
2566+
25012567
SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
25022568
SelectionDAG &DAG) const {
25032569
assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
@@ -2689,6 +2755,13 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
26892755
return LowerSTACKSAVE(Op, DAG);
26902756
case ISD::CopyToReg:
26912757
return LowerCopyToReg_128(Op, DAG);
2758+
case ISD::FADD:
2759+
return LowerFADD(Op, DAG);
2760+
case ISD::FSUB:
2761+
return LowerFSUB(Op, DAG);
2762+
case ISD::FMUL:
2763+
return LowerFMUL(Op, DAG);
2764+
26922765
default:
26932766
llvm_unreachable("Custom lowering not defined for operation");
26942767
}

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,10 @@ class NVPTXTargetLowering : public TargetLowering {
278278
SDValue LowerFROUND32(SDValue Op, SelectionDAG &DAG) const;
279279
SDValue LowerFROUND64(SDValue Op, SelectionDAG &DAG) const;
280280

281+
SDValue LowerFADD(SDValue Op, SelectionDAG &DAG) const;
282+
SDValue LowerFSUB(SDValue Op, SelectionDAG &DAG) const;
283+
SDValue LowerFMUL(SDValue Op, SelectionDAG &DAG) const;
284+
281285
SDValue LowerINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
282286
SDValue LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const;
283287

llvm/test/CodeGen/NVPTX/atomics-sm90.ll

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -46,58 +46,52 @@ define void @test(ptr %dp0, ptr addrspace(1) %dp1, ptr addrspace(3) %dp3, bfloat
4646
; CHECKPTX71-LABEL: test(
4747
; CHECKPTX71: {
4848
; CHECKPTX71-NEXT: .reg .pred %p<5>;
49-
; CHECKPTX71-NEXT: .reg .b16 %rs<22>;
49+
; CHECKPTX71-NEXT: .reg .b16 %rs<26>;
5050
; CHECKPTX71-NEXT: .reg .b32 %r<4>;
51-
; CHECKPTX71-NEXT: .reg .f32 %f<12>;
5251
; CHECKPTX71-EMPTY:
5352
; CHECKPTX71-NEXT: // %bb.0:
5453
; CHECKPTX71-NEXT: ld.param.b16 %rs13, [test_param_3];
5554
; CHECKPTX71-NEXT: ld.param.u32 %r3, [test_param_2];
5655
; CHECKPTX71-NEXT: ld.param.u32 %r2, [test_param_1];
5756
; CHECKPTX71-NEXT: ld.param.u32 %r1, [test_param_0];
58-
; CHECKPTX71-NEXT: ld.b16 %rs18, [%r1];
59-
; CHECKPTX71-NEXT: cvt.f32.bf16 %f1, %rs13;
57+
; CHECKPTX71-NEXT: ld.b16 %rs22, [%r1];
6058
; CHECKPTX71-NEXT: $L__BB0_1: // %atomicrmw.start14
6159
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
62-
; CHECKPTX71-NEXT: cvt.f32.bf16 %f2, %rs18;
63-
; CHECKPTX71-NEXT: add.rn.f32 %f3, %f2, %f1;
64-
; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs14, %f3;
65-
; CHECKPTX71-NEXT: atom.cas.b16 %rs3, [%r1], %rs18, %rs14;
66-
; CHECKPTX71-NEXT: setp.ne.s16 %p1, %rs3, %rs18;
67-
; CHECKPTX71-NEXT: mov.u16 %rs18, %rs3;
60+
; CHECKPTX71-NEXT: mov.b16 %rs14, 0x3F80;
61+
; CHECKPTX71-NEXT: fma.rn.bf16 %rs15, %rs22, %rs14, %rs13;
62+
; CHECKPTX71-NEXT: atom.cas.b16 %rs3, [%r1], %rs22, %rs15;
63+
; CHECKPTX71-NEXT: setp.ne.s16 %p1, %rs3, %rs22;
64+
; CHECKPTX71-NEXT: mov.u16 %rs22, %rs3;
6865
; CHECKPTX71-NEXT: @%p1 bra $L__BB0_1;
6966
; CHECKPTX71-NEXT: // %bb.2: // %atomicrmw.end13
70-
; CHECKPTX71-NEXT: ld.b16 %rs19, [%r1];
67+
; CHECKPTX71-NEXT: ld.b16 %rs23, [%r1];
7168
; CHECKPTX71-NEXT: $L__BB0_3: // %atomicrmw.start8
7269
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
73-
; CHECKPTX71-NEXT: cvt.f32.bf16 %f4, %rs19;
74-
; CHECKPTX71-NEXT: add.rn.f32 %f5, %f4, 0f3F800000;
75-
; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs15, %f5;
76-
; CHECKPTX71-NEXT: atom.cas.b16 %rs6, [%r1], %rs19, %rs15;
77-
; CHECKPTX71-NEXT: setp.ne.s16 %p2, %rs6, %rs19;
78-
; CHECKPTX71-NEXT: mov.u16 %rs19, %rs6;
70+
; CHECKPTX71-NEXT: mov.b16 %rs16, 0x3F80;
71+
; CHECKPTX71-NEXT: fma.rn.bf16 %rs17, %rs23, %rs16, %rs16;
72+
; CHECKPTX71-NEXT: atom.cas.b16 %rs6, [%r1], %rs23, %rs17;
73+
; CHECKPTX71-NEXT: setp.ne.s16 %p2, %rs6, %rs23;
74+
; CHECKPTX71-NEXT: mov.u16 %rs23, %rs6;
7975
; CHECKPTX71-NEXT: @%p2 bra $L__BB0_3;
8076
; CHECKPTX71-NEXT: // %bb.4: // %atomicrmw.end7
81-
; CHECKPTX71-NEXT: ld.global.b16 %rs20, [%r2];
77+
; CHECKPTX71-NEXT: ld.global.b16 %rs24, [%r2];
8278
; CHECKPTX71-NEXT: $L__BB0_5: // %atomicrmw.start2
8379
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
84-
; CHECKPTX71-NEXT: cvt.f32.bf16 %f7, %rs20;
85-
; CHECKPTX71-NEXT: add.rn.f32 %f8, %f7, %f1;
86-
; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs16, %f8;
87-
; CHECKPTX71-NEXT: atom.global.cas.b16 %rs9, [%r2], %rs20, %rs16;
88-
; CHECKPTX71-NEXT: setp.ne.s16 %p3, %rs9, %rs20;
89-
; CHECKPTX71-NEXT: mov.u16 %rs20, %rs9;
80+
; CHECKPTX71-NEXT: mov.b16 %rs18, 0x3F80;
81+
; CHECKPTX71-NEXT: fma.rn.bf16 %rs19, %rs24, %rs18, %rs13;
82+
; CHECKPTX71-NEXT: atom.global.cas.b16 %rs9, [%r2], %rs24, %rs19;
83+
; CHECKPTX71-NEXT: setp.ne.s16 %p3, %rs9, %rs24;
84+
; CHECKPTX71-NEXT: mov.u16 %rs24, %rs9;
9085
; CHECKPTX71-NEXT: @%p3 bra $L__BB0_5;
9186
; CHECKPTX71-NEXT: // %bb.6: // %atomicrmw.end1
92-
; CHECKPTX71-NEXT: ld.shared.b16 %rs21, [%r3];
87+
; CHECKPTX71-NEXT: ld.shared.b16 %rs25, [%r3];
9388
; CHECKPTX71-NEXT: $L__BB0_7: // %atomicrmw.start
9489
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
95-
; CHECKPTX71-NEXT: cvt.f32.bf16 %f10, %rs21;
96-
; CHECKPTX71-NEXT: add.rn.f32 %f11, %f10, %f1;
97-
; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs17, %f11;
98-
; CHECKPTX71-NEXT: atom.shared.cas.b16 %rs12, [%r3], %rs21, %rs17;
99-
; CHECKPTX71-NEXT: setp.ne.s16 %p4, %rs12, %rs21;
100-
; CHECKPTX71-NEXT: mov.u16 %rs21, %rs12;
90+
; CHECKPTX71-NEXT: mov.b16 %rs20, 0x3F80;
91+
; CHECKPTX71-NEXT: fma.rn.bf16 %rs21, %rs25, %rs20, %rs13;
92+
; CHECKPTX71-NEXT: atom.shared.cas.b16 %rs12, [%r3], %rs25, %rs21;
93+
; CHECKPTX71-NEXT: setp.ne.s16 %p4, %rs12, %rs25;
94+
; CHECKPTX71-NEXT: mov.u16 %rs25, %rs12;
10195
; CHECKPTX71-NEXT: @%p4 bra $L__BB0_7;
10296
; CHECKPTX71-NEXT: // %bb.8: // %atomicrmw.end
10397
; CHECKPTX71-NEXT: ret;

llvm/test/CodeGen/NVPTX/bf16-instructions.ll

Lines changed: 28 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -42,17 +42,14 @@ define bfloat @test_fadd(bfloat %0, bfloat %1) {
4242
;
4343
; SM80-LABEL: test_fadd(
4444
; SM80: {
45-
; SM80-NEXT: .reg .b16 %rs<4>;
46-
; SM80-NEXT: .reg .f32 %f<4>;
45+
; SM80-NEXT: .reg .b16 %rs<5>;
4746
; SM80-EMPTY:
4847
; SM80-NEXT: // %bb.0:
4948
; SM80-NEXT: ld.param.b16 %rs1, [test_fadd_param_0];
5049
; SM80-NEXT: ld.param.b16 %rs2, [test_fadd_param_1];
51-
; SM80-NEXT: cvt.f32.bf16 %f1, %rs2;
52-
; SM80-NEXT: cvt.f32.bf16 %f2, %rs1;
53-
; SM80-NEXT: add.rn.f32 %f3, %f2, %f1;
54-
; SM80-NEXT: cvt.rn.bf16.f32 %rs3, %f3;
55-
; SM80-NEXT: st.param.b16 [func_retval0], %rs3;
50+
; SM80-NEXT: mov.b16 %rs3, 0x3F80;
51+
; SM80-NEXT: fma.rn.bf16 %rs4, %rs1, %rs3, %rs2;
52+
; SM80-NEXT: st.param.b16 [func_retval0], %rs4;
5653
; SM80-NEXT: ret;
5754
;
5855
; SM80-FTZ-LABEL: test_fadd(
@@ -113,17 +110,14 @@ define bfloat @test_fsub(bfloat %0, bfloat %1) {
113110
;
114111
; SM80-LABEL: test_fsub(
115112
; SM80: {
116-
; SM80-NEXT: .reg .b16 %rs<4>;
117-
; SM80-NEXT: .reg .f32 %f<4>;
113+
; SM80-NEXT: .reg .b16 %rs<5>;
118114
; SM80-EMPTY:
119115
; SM80-NEXT: // %bb.0:
120116
; SM80-NEXT: ld.param.b16 %rs1, [test_fsub_param_0];
121117
; SM80-NEXT: ld.param.b16 %rs2, [test_fsub_param_1];
122-
; SM80-NEXT: cvt.f32.bf16 %f1, %rs2;
123-
; SM80-NEXT: cvt.f32.bf16 %f2, %rs1;
124-
; SM80-NEXT: sub.rn.f32 %f3, %f2, %f1;
125-
; SM80-NEXT: cvt.rn.bf16.f32 %rs3, %f3;
126-
; SM80-NEXT: st.param.b16 [func_retval0], %rs3;
118+
; SM80-NEXT: mov.b16 %rs3, 0xBF80;
119+
; SM80-NEXT: fma.rn.bf16 %rs4, %rs2, %rs3, %rs1;
120+
; SM80-NEXT: st.param.b16 [func_retval0], %rs4;
127121
; SM80-NEXT: ret;
128122
;
129123
; SM80-FTZ-LABEL: test_fsub(
@@ -202,23 +196,14 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
202196
;
203197
; SM80-LABEL: test_faddx2(
204198
; SM80: {
205-
; SM80-NEXT: .reg .b16 %rs<5>;
206-
; SM80-NEXT: .reg .b32 %r<4>;
207-
; SM80-NEXT: .reg .f32 %f<7>;
199+
; SM80-NEXT: .reg .b32 %r<5>;
208200
; SM80-EMPTY:
209201
; SM80-NEXT: // %bb.0:
210-
; SM80-NEXT: ld.param.b32 %r1, [test_faddx2_param_0];
211-
; SM80-NEXT: ld.param.b32 %r2, [test_faddx2_param_1];
212-
; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r2;
213-
; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
214-
; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
215-
; SM80-NEXT: cvt.f32.bf16 %f2, %rs3;
216-
; SM80-NEXT: add.rn.f32 %f3, %f2, %f1;
217-
; SM80-NEXT: cvt.f32.bf16 %f4, %rs2;
218-
; SM80-NEXT: cvt.f32.bf16 %f5, %rs4;
219-
; SM80-NEXT: add.rn.f32 %f6, %f5, %f4;
220-
; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
221-
; SM80-NEXT: st.param.b32 [func_retval0], %r3;
202+
; SM80-NEXT: ld.param.b32 %r1, [test_faddx2_param_1];
203+
; SM80-NEXT: ld.param.b32 %r2, [test_faddx2_param_0];
204+
; SM80-NEXT: mov.b32 %r3, 1065369472;
205+
; SM80-NEXT: fma.rn.bf16x2 %r4, %r2, %r3, %r1;
206+
; SM80-NEXT: st.param.b32 [func_retval0], %r4;
222207
; SM80-NEXT: ret;
223208
;
224209
; SM80-FTZ-LABEL: test_faddx2(
@@ -303,23 +288,14 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
303288
;
304289
; SM80-LABEL: test_fsubx2(
305290
; SM80: {
306-
; SM80-NEXT: .reg .b16 %rs<5>;
307-
; SM80-NEXT: .reg .b32 %r<4>;
308-
; SM80-NEXT: .reg .f32 %f<7>;
291+
; SM80-NEXT: .reg .b32 %r<5>;
309292
; SM80-EMPTY:
310293
; SM80-NEXT: // %bb.0:
311294
; SM80-NEXT: ld.param.b32 %r1, [test_fsubx2_param_0];
312295
; SM80-NEXT: ld.param.b32 %r2, [test_fsubx2_param_1];
313-
; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r2;
314-
; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
315-
; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
316-
; SM80-NEXT: cvt.f32.bf16 %f2, %rs3;
317-
; SM80-NEXT: sub.rn.f32 %f3, %f2, %f1;
318-
; SM80-NEXT: cvt.f32.bf16 %f4, %rs2;
319-
; SM80-NEXT: cvt.f32.bf16 %f5, %rs4;
320-
; SM80-NEXT: sub.rn.f32 %f6, %f5, %f4;
321-
; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
322-
; SM80-NEXT: st.param.b32 [func_retval0], %r3;
296+
; SM80-NEXT: mov.b32 %r3, -1082081408;
297+
; SM80-NEXT: fma.rn.bf16x2 %r4, %r2, %r3, %r1;
298+
; SM80-NEXT: st.param.b32 [func_retval0], %r4;
323299
; SM80-NEXT: ret;
324300
;
325301
; SM80-FTZ-LABEL: test_fsubx2(
@@ -404,23 +380,14 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
404380
;
405381
; SM80-LABEL: test_fmulx2(
406382
; SM80: {
407-
; SM80-NEXT: .reg .b16 %rs<5>;
408-
; SM80-NEXT: .reg .b32 %r<4>;
409-
; SM80-NEXT: .reg .f32 %f<7>;
383+
; SM80-NEXT: .reg .b32 %r<5>;
410384
; SM80-EMPTY:
411385
; SM80-NEXT: // %bb.0:
412-
; SM80-NEXT: ld.param.b32 %r1, [test_fmulx2_param_0];
413-
; SM80-NEXT: ld.param.b32 %r2, [test_fmulx2_param_1];
414-
; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r2;
415-
; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
416-
; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
417-
; SM80-NEXT: cvt.f32.bf16 %f2, %rs3;
418-
; SM80-NEXT: mul.rn.f32 %f3, %f2, %f1;
419-
; SM80-NEXT: cvt.f32.bf16 %f4, %rs2;
420-
; SM80-NEXT: cvt.f32.bf16 %f5, %rs4;
421-
; SM80-NEXT: mul.rn.f32 %f6, %f5, %f4;
422-
; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
423-
; SM80-NEXT: st.param.b32 [func_retval0], %r3;
386+
; SM80-NEXT: ld.param.b32 %r1, [test_fmulx2_param_1];
387+
; SM80-NEXT: ld.param.b32 %r2, [test_fmulx2_param_0];
388+
; SM80-NEXT: mov.b32 %r3, 0;
389+
; SM80-NEXT: fma.rn.bf16x2 %r4, %r2, %r1, %r3;
390+
; SM80-NEXT: st.param.b32 [func_retval0], %r4;
424391
; SM80-NEXT: ret;
425392
;
426393
; SM80-FTZ-LABEL: test_fmulx2(
@@ -727,15 +694,13 @@ define bfloat @test_fadd_imm_1(bfloat %a) #0 {
727694
;
728695
; SM80-LABEL: test_fadd_imm_1(
729696
; SM80: {
730-
; SM80-NEXT: .reg .b16 %rs<3>;
731-
; SM80-NEXT: .reg .f32 %f<3>;
697+
; SM80-NEXT: .reg .b16 %rs<4>;
732698
; SM80-EMPTY:
733699
; SM80-NEXT: // %bb.0:
734700
; SM80-NEXT: ld.param.b16 %rs1, [test_fadd_imm_1_param_0];
735-
; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
736-
; SM80-NEXT: add.rn.f32 %f2, %f1, 0f3F800000;
737-
; SM80-NEXT: cvt.rn.bf16.f32 %rs2, %f2;
738-
; SM80-NEXT: st.param.b16 [func_retval0], %rs2;
701+
; SM80-NEXT: mov.b16 %rs2, 0x3F80;
702+
; SM80-NEXT: fma.rn.bf16 %rs3, %rs1, %rs2, %rs2;
703+
; SM80-NEXT: st.param.b16 [func_retval0], %rs3;
739704
; SM80-NEXT: ret;
740705
;
741706
; SM80-FTZ-LABEL: test_fadd_imm_1(

0 commit comments

Comments
 (0)