Skip to content

Commit de6d7a6

Browse files
authored
[RISCV] Expand Zfa fli+fneg cases during lowering instead of during isel. (#108316)
Most of the constants fli can generate are positive numbers. We can use fli+fneg to generate their negative versions. Previously, we considered such negative constants as "legal" and let isel generate the fli+fneg. However, it is useful to expose the fneg to DAG combines to fold with fadd to produce fsub or with fma to produce fnmadd, fnmsub, or fmsub. This patch moves the fneg creation to lowering so that the fneg will be visible to the last DAG combine. I might move the rest of Zfa handling from isel to lowering as a follow up. Fixes #107772.
1 parent b2e8b8f commit de6d7a6

File tree

7 files changed

+219
-32
lines changed

7 files changed

+219
-32
lines changed

llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -889,33 +889,25 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
889889
}
890890
case ISD::ConstantFP: {
891891
const APFloat &APF = cast<ConstantFPSDNode>(Node)->getValueAPF();
892-
auto [FPImm, NeedsFNeg] =
893-
static_cast<const RISCVTargetLowering *>(TLI)->getLegalZfaFPImm(APF,
894-
VT);
892+
int FPImm = static_cast<const RISCVTargetLowering *>(TLI)->getLegalZfaFPImm(
893+
APF, VT);
895894
if (FPImm >= 0) {
896895
unsigned Opc;
897-
unsigned FNegOpc;
898896
switch (VT.SimpleTy) {
899897
default:
900898
llvm_unreachable("Unexpected size");
901899
case MVT::f16:
902900
Opc = RISCV::FLI_H;
903-
FNegOpc = RISCV::FSGNJN_H;
904901
break;
905902
case MVT::f32:
906903
Opc = RISCV::FLI_S;
907-
FNegOpc = RISCV::FSGNJN_S;
908904
break;
909905
case MVT::f64:
910906
Opc = RISCV::FLI_D;
911-
FNegOpc = RISCV::FSGNJN_D;
912907
break;
913908
}
914909
SDNode *Res = CurDAG->getMachineNode(
915910
Opc, DL, VT, CurDAG->getTargetConstant(FPImm, DL, XLenVT));
916-
if (NeedsFNeg)
917-
Res = CurDAG->getMachineNode(FNegOpc, DL, VT, SDValue(Res, 0),
918-
SDValue(Res, 0));
919911

920912
ReplaceNode(Node, Res);
921913
return;
@@ -3563,9 +3555,8 @@ bool RISCVDAGToDAGISel::selectScalarFPAsInt(SDValue N, SDValue &Imm) {
35633555
// Even if this FPImm requires an additional FNEG (i.e. the second element of
35643556
// the returned pair is true) we still prefer FLI + FNEG over immediate
35653557
// materialization as the latter might generate a longer instruction sequence.
3566-
if (static_cast<const RISCVTargetLowering *>(TLI)
3567-
->getLegalZfaFPImm(APF, VT)
3568-
.first >= 0)
3558+
if (static_cast<const RISCVTargetLowering *>(TLI)->getLegalZfaFPImm(APF,
3559+
VT) >= 0)
35693560
return false;
35703561

35713562
MVT XLenVT = Subtarget->getXLenVT();

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
469469
setOperationAction(ISD::IS_FPCLASS, MVT::f16, Custom);
470470
setOperationAction({ISD::FMAXIMUM, ISD::FMINIMUM}, MVT::f16,
471471
Subtarget.hasStdExtZfa() ? Legal : Custom);
472+
if (Subtarget.hasStdExtZfa())
473+
setOperationAction(ISD::ConstantFP, MVT::f16, Custom);
472474
} else {
473475
setOperationAction(ZfhminZfbfminPromoteOps, MVT::f16, Promote);
474476
setOperationAction({ISD::FMAXIMUM, ISD::FMINIMUM}, MVT::f16, Promote);
@@ -533,6 +535,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
533535
setOperationAction(ISD::FP16_TO_FP, MVT::f32, Custom);
534536

535537
if (Subtarget.hasStdExtZfa()) {
538+
setOperationAction(ISD::ConstantFP, MVT::f32, Custom);
536539
setOperationAction(ISD::FNEARBYINT, MVT::f32, Legal);
537540
setOperationAction({ISD::FMAXIMUM, ISD::FMINIMUM}, MVT::f32, Legal);
538541
} else {
@@ -550,6 +553,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
550553
setOperationAction(ISD::BITCAST, MVT::i64, Custom);
551554

552555
if (Subtarget.hasStdExtZfa()) {
556+
setOperationAction(ISD::ConstantFP, MVT::f64, Custom);
553557
setOperationAction(FPRndMode, MVT::f64, Legal);
554558
setOperationAction(ISD::FNEARBYINT, MVT::f64, Legal);
555559
setOperationAction({ISD::FMAXIMUM, ISD::FMINIMUM}, MVT::f64, Legal);
@@ -2238,17 +2242,11 @@ bool RISCVTargetLowering::isOffsetFoldingLegal(
22382242
return false;
22392243
}
22402244

2241-
// Return one of the followings:
2242-
// (1) `{0-31 value, false}` if FLI is available for Imm's type and FP value.
2243-
// (2) `{0-31 value, true}` if Imm is negative and FLI is available for its
2244-
// positive counterpart, which will be materialized from the first returned
2245-
// element. The second returned element indicated that there should be a FNEG
2246-
// followed.
2247-
// (3) `{-1, _}` if there is no way FLI can be used to materialize Imm.
2248-
std::pair<int, bool> RISCVTargetLowering::getLegalZfaFPImm(const APFloat &Imm,
2249-
EVT VT) const {
2245+
// Returns 0-31 if the fli instruction is available for the type and this is
2246+
// legal FP immediate for the type. Returns -1 otherwise.
2247+
int RISCVTargetLowering::getLegalZfaFPImm(const APFloat &Imm, EVT VT) const {
22502248
if (!Subtarget.hasStdExtZfa())
2251-
return std::make_pair(-1, false);
2249+
return -1;
22522250

22532251
bool IsSupportedVT = false;
22542252
if (VT == MVT::f16) {
@@ -2261,14 +2259,9 @@ std::pair<int, bool> RISCVTargetLowering::getLegalZfaFPImm(const APFloat &Imm,
22612259
}
22622260

22632261
if (!IsSupportedVT)
2264-
return std::make_pair(-1, false);
2262+
return -1;
22652263

2266-
int Index = RISCVLoadFPImm::getLoadFPImm(Imm);
2267-
if (Index < 0 && Imm.isNegative())
2268-
// Try the combination of its positive counterpart + FNEG.
2269-
return std::make_pair(RISCVLoadFPImm::getLoadFPImm(-Imm), true);
2270-
else
2271-
return std::make_pair(Index, false);
2264+
return RISCVLoadFPImm::getLoadFPImm(Imm);
22722265
}
22732266

22742267
bool RISCVTargetLowering::isFPImmLegal(const APFloat &Imm, EVT VT,
@@ -2286,7 +2279,7 @@ bool RISCVTargetLowering::isFPImmLegal(const APFloat &Imm, EVT VT,
22862279
if (!IsLegalVT)
22872280
return false;
22882281

2289-
if (getLegalZfaFPImm(Imm, VT).first >= 0)
2282+
if (getLegalZfaFPImm(Imm, VT) >= 0)
22902283
return true;
22912284

22922285
// Cannot create a 64 bit floating-point immediate value for rv32.
@@ -5816,6 +5809,29 @@ static SDValue lowerConstant(SDValue Op, SelectionDAG &DAG,
58165809
return SDValue();
58175810
}
58185811

5812+
SDValue RISCVTargetLowering::lowerConstantFP(SDValue Op,
5813+
SelectionDAG &DAG) const {
5814+
MVT VT = Op.getSimpleValueType();
5815+
const APFloat &Imm = cast<ConstantFPSDNode>(Op)->getValueAPF();
5816+
5817+
if (getLegalZfaFPImm(Imm, VT) >= 0)
5818+
return Op;
5819+
5820+
if (!Imm.isNegative())
5821+
return SDValue();
5822+
5823+
int Index = getLegalZfaFPImm(-Imm, VT);
5824+
if (Index < 0)
5825+
return SDValue();
5826+
5827+
// Emit an FLI+FNEG. We use a custom node to hide from constant folding.
5828+
SDLoc DL(Op);
5829+
SDValue Const =
5830+
DAG.getNode(RISCVISD::FLI, Op, VT,
5831+
DAG.getTargetConstant(Index, DL, Subtarget.getXLenVT()));
5832+
return DAG.getNode(ISD::FNEG, Op, VT, Const);
5833+
}
5834+
58195835
static SDValue LowerATOMIC_FENCE(SDValue Op, SelectionDAG &DAG,
58205836
const RISCVSubtarget &Subtarget) {
58215837
SDLoc dl(Op);
@@ -6435,6 +6451,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
64356451
return lowerGlobalTLSAddress(Op, DAG);
64366452
case ISD::Constant:
64376453
return lowerConstant(Op, DAG, Subtarget);
6454+
case ISD::ConstantFP:
6455+
return lowerConstantFP(Op, DAG);
64386456
case ISD::SELECT:
64396457
return lowerSELECT(Op, DAG);
64406458
case ISD::BRCOND:
@@ -19978,6 +19996,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
1997819996
NODE_NAME_CASE(FSGNJX)
1997919997
NODE_NAME_CASE(FMAX)
1998019998
NODE_NAME_CASE(FMIN)
19999+
NODE_NAME_CASE(FLI)
1998120000
NODE_NAME_CASE(READ_COUNTER_WIDE)
1998220001
NODE_NAME_CASE(BREV8)
1998320002
NODE_NAME_CASE(ORC_B)

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,9 @@ enum NodeType : unsigned {
130130
// Floating point fmax and fmin matching the RISC-V instruction semantics.
131131
FMAX, FMIN,
132132

133+
// Zfa fli instruction for constant materialization.
134+
FLI,
135+
133136
// A read of the 64-bit counter CSR on a 32-bit target (returns (Lo, Hi)).
134137
// It takes a chain operand and another two target constant operands (the
135138
// CSR numbers of the low and high parts of the counter).
@@ -524,7 +527,7 @@ class RISCVTargetLowering : public TargetLowering {
524527
SmallVectorImpl<Use *> &Ops) const override;
525528
bool shouldScalarizeBinop(SDValue VecOp) const override;
526529
bool isOffsetFoldingLegal(const GlobalAddressSDNode *GA) const override;
527-
std::pair<int, bool> getLegalZfaFPImm(const APFloat &Imm, EVT VT) const;
530+
int getLegalZfaFPImm(const APFloat &Imm, EVT VT) const;
528531
bool isFPImmLegal(const APFloat &Imm, EVT VT,
529532
bool ForCodeSize) const override;
530533
bool isExtractSubvectorCheap(EVT ResVT, EVT SrcVT,
@@ -914,6 +917,7 @@ class RISCVTargetLowering : public TargetLowering {
914917
SDValue getDynamicTLSAddr(GlobalAddressSDNode *N, SelectionDAG &DAG) const;
915918
SDValue getTLSDescAddr(GlobalAddressSDNode *N, SelectionDAG &DAG) const;
916919

920+
SDValue lowerConstantFP(SDValue Op, SelectionDAG &DAG) const;
917921
SDValue lowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const;
918922
SDValue lowerBlockAddress(SDValue Op, SelectionDAG &DAG) const;
919923
SDValue lowerConstantPool(SDValue Op, SelectionDAG &DAG) const;

llvm/lib/Target/RISCV/RISCVInstrInfoZfa.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,15 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313

14+
//===----------------------------------------------------------------------===//
15+
// RISC-V specific DAG Nodes.
16+
//===----------------------------------------------------------------------===//
17+
18+
def SDT_RISCVFLI
19+
: SDTypeProfile<1, 1, [SDTCisFP<0>, SDTCisVT<1, XLenVT>]>;
20+
21+
def riscv_fli : SDNode<"RISCVISD::FLI", SDT_RISCVFLI>;
22+
1423
//===----------------------------------------------------------------------===//
1524
// Operand and SDNode transformation definitions.
1625
//===----------------------------------------------------------------------===//
@@ -189,6 +198,8 @@ def : InstAlias<"fgeq.h $rd, $rs, $rt",
189198
//===----------------------------------------------------------------------===//
190199

191200
let Predicates = [HasStdExtZfa] in {
201+
def: Pat<(f32 (riscv_fli timm:$imm)), (FLI_S timm:$imm)>;
202+
192203
def: PatFprFpr<fminimum, FMINM_S, FPR32, f32>;
193204
def: PatFprFpr<fmaximum, FMAXM_S, FPR32, f32>;
194205

@@ -211,6 +222,8 @@ def: PatSetCC<FPR32, strict_fsetcc, SETOLE, FLEQ_S, f32>;
211222
} // Predicates = [HasStdExtZfa]
212223

213224
let Predicates = [HasStdExtZfa, HasStdExtD] in {
225+
def: Pat<(f64 (riscv_fli timm:$imm)), (FLI_D timm:$imm)>;
226+
214227
def: PatFprFpr<fminimum, FMINM_D, FPR64, f64>;
215228
def: PatFprFpr<fmaximum, FMAXM_D, FPR64, f64>;
216229

@@ -239,6 +252,8 @@ def : Pat<(RISCVBuildPairF64 GPR:$rs1, GPR:$rs2),
239252
}
240253

241254
let Predicates = [HasStdExtZfa, HasStdExtZfh] in {
255+
def: Pat<(f16 (riscv_fli timm:$imm)), (FLI_H timm:$imm)>;
256+
242257
def: PatFprFpr<fminimum, FMINM_H, FPR16, f16>;
243258
def: PatFprFpr<fmaximum, FMAXM_H, FPR16, f16>;
244259

llvm/test/CodeGen/RISCV/double-zfa.ll

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,3 +330,44 @@ define double @fmvp_d_x(i64 %a) {
330330
%or = bitcast i64 %a to double
331331
ret double %or
332332
}
333+
334+
define double @fadd_neg_0p5(double %x) {
335+
; CHECK-LABEL: fadd_neg_0p5:
336+
; CHECK: # %bb.0:
337+
; CHECK-NEXT: fli.d fa5, 0.5
338+
; CHECK-NEXT: fsub.d fa0, fa0, fa5
339+
; CHECK-NEXT: ret
340+
%a = fadd double %x, -0.5
341+
ret double %a
342+
}
343+
344+
define double @fma_neg_addend(double %x, double %y) nounwind {
345+
; CHECK-LABEL: fma_neg_addend:
346+
; CHECK: # %bb.0:
347+
; CHECK-NEXT: fli.d fa5, 0.5
348+
; CHECK-NEXT: fmsub.d fa0, fa0, fa1, fa5
349+
; CHECK-NEXT: ret
350+
%a = call double @llvm.fma.f32(double %x, double %y, double -0.5)
351+
ret double %a
352+
}
353+
354+
define double @fma_neg_multiplicand(double %x, double %y) nounwind {
355+
; CHECK-LABEL: fma_neg_multiplicand:
356+
; CHECK: # %bb.0:
357+
; CHECK-NEXT: fli.d fa5, 0.125
358+
; CHECK-NEXT: fnmsub.d fa0, fa5, fa0, fa1
359+
; CHECK-NEXT: ret
360+
%a = call double @llvm.fma.f32(double %x, double -0.125, double %y)
361+
ret double %a
362+
}
363+
364+
define double @fma_neg_addend_multiplicand(double %x) nounwind {
365+
; CHECK-LABEL: fma_neg_addend_multiplicand:
366+
; CHECK: # %bb.0:
367+
; CHECK-NEXT: fli.d fa5, 0.25
368+
; CHECK-NEXT: fli.d fa4, 0.5
369+
; CHECK-NEXT: fnmadd.d fa0, fa4, fa0, fa5
370+
; CHECK-NEXT: ret
371+
%a = call double @llvm.fma.f32(double %x, double -0.5, double -0.25)
372+
ret double %a
373+
}

llvm/test/CodeGen/RISCV/float-zfa.ll

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,3 +269,44 @@ define void @fli_remat() {
269269
tail call void @foo(float 1.000000e+00, float 1.000000e+00)
270270
ret void
271271
}
272+
273+
define float @fadd_neg_0p5(float %x) {
274+
; CHECK-LABEL: fadd_neg_0p5:
275+
; CHECK: # %bb.0:
276+
; CHECK-NEXT: fli.s fa5, 0.5
277+
; CHECK-NEXT: fsub.s fa0, fa0, fa5
278+
; CHECK-NEXT: ret
279+
%a = fadd float %x, -0.5
280+
ret float %a
281+
}
282+
283+
define float @fma_neg_addend(float %x, float %y) nounwind {
284+
; CHECK-LABEL: fma_neg_addend:
285+
; CHECK: # %bb.0:
286+
; CHECK-NEXT: fli.s fa5, 0.5
287+
; CHECK-NEXT: fmsub.s fa0, fa0, fa1, fa5
288+
; CHECK-NEXT: ret
289+
%a = call float @llvm.fma.f32(float %x, float %y, float -0.5)
290+
ret float %a
291+
}
292+
293+
define float @fma_neg_multiplicand(float %x, float %y) nounwind {
294+
; CHECK-LABEL: fma_neg_multiplicand:
295+
; CHECK: # %bb.0:
296+
; CHECK-NEXT: fli.s fa5, 0.125
297+
; CHECK-NEXT: fnmsub.s fa0, fa5, fa0, fa1
298+
; CHECK-NEXT: ret
299+
%a = call float @llvm.fma.f32(float %x, float -0.125, float %y)
300+
ret float %a
301+
}
302+
303+
define float @fma_neg_addend_multiplicand(float %x) nounwind {
304+
; CHECK-LABEL: fma_neg_addend_multiplicand:
305+
; CHECK: # %bb.0:
306+
; CHECK-NEXT: fli.s fa5, 0.25
307+
; CHECK-NEXT: fli.s fa4, 0.5
308+
; CHECK-NEXT: fnmadd.s fa0, fa4, fa0, fa5
309+
; CHECK-NEXT: ret
310+
%a = call float @llvm.fma.f32(float %x, float -0.5, float -0.25)
311+
ret float %a
312+
}

0 commit comments

Comments
 (0)