Skip to content

Commit 3e342ea

Browse files
committed
[CHERIOT] Use capability registers to store f64 values.
This enables each f64 to be passed by value in a single cap register, rather than in pairs of integer registers. This required adding explicit type annotations to various places in the XCheri tblgen files, as the GPCR class can now hold values type c64 or f64, breaking type inference.
1 parent 44339d9 commit 3e342ea

File tree

6 files changed

+553
-93
lines changed

6 files changed

+553
-93
lines changed

llvm/lib/Target/RISCV/RISCVCallingConv.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,14 @@ bool llvm::CC_RISCV(unsigned ValNo, MVT ValVT, MVT LocVT,
486486
}
487487
}
488488

489+
// Cheriot uses GPCR without a bitcast when possible.
490+
if (LocVT == MVT::f64 && Subtarget.hasVendorXCheriot() && !IsPureCapVarArgs) {
491+
if (MCRegister Reg = State.AllocateReg(ArgGPCRs)) {
492+
State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
493+
return false;
494+
}
495+
}
496+
489497
// FP smaller than XLen, uses custom GPR.
490498
if (LocVT == MVT::f16 || LocVT == MVT::bf16 ||
491499
(LocVT == MVT::f32 && XLen == 64)) {

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 149 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,11 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
170170
addRegisterClass(CapType, &RISCV::GPCRRegClass);
171171
}
172172

173+
if (Subtarget.hasVendorXCheriot()) {
174+
// Cheriot holds f64's in capability registers.
175+
addRegisterClass(MVT::f64, &RISCV::GPCRRegClass);
176+
}
177+
173178
static const MVT::SimpleValueType BoolVecVTs[] = {
174179
MVT::nxv1i1, MVT::nxv2i1, MVT::nxv4i1, MVT::nxv8i1,
175180
MVT::nxv16i1, MVT::nxv32i1, MVT::nxv64i1};
@@ -680,6 +685,29 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
680685
setLibcallName(RTLIB::MEMSET, "memset");
681686
}
682687

688+
if (Subtarget.hasVendorXCheriot()) {
689+
// FP64 is "legal" on Cheriot in that we store it in c64 registers, but
690+
// essentially no operations on it are legal other than load/store/copy.
691+
setOperationAction({ISD::ConstantFP, ISD::SELECT_CC, ISD::SETCC}, MVT::f64,
692+
Custom);
693+
694+
// These require custom lowering because their inputs might be f64.
695+
setOperationAction({ISD::SELECT_CC, ISD::SETCC}, MVT::i32, Custom);
696+
setOperationAction({ISD::FP_TO_UINT, ISD::FP_TO_SINT}, MVT::i32, LibCall);
697+
698+
static const unsigned CheriotF64ExpandOps[] = {
699+
ISD::FMINNUM, ISD::FMAXNUM, ISD::FADD, ISD::FSUB,
700+
ISD::FMUL, ISD::FMA, ISD::FDIV, ISD::FSQRT,
701+
ISD::FCEIL, ISD::FTRUNC, ISD::FFLOOR, ISD::FROUND,
702+
ISD::FROUNDEVEN, ISD::FRINT, ISD::FNEARBYINT, ISD::IS_FPCLASS,
703+
ISD::SETCC, ISD::FMAXIMUM, ISD::FMINIMUM, ISD::STRICT_FADD,
704+
ISD::STRICT_FSUB, ISD::STRICT_FMUL, ISD::STRICT_FDIV, ISD::STRICT_FSQRT,
705+
ISD::STRICT_FMA, ISD::FNEG, ISD::FABS, ISD::FCOPYSIGN,
706+
ISD::UINT_TO_FP, ISD::SINT_TO_FP, ISD::BR_CC};
707+
setOperationAction(CheriotF64ExpandOps, MVT::f64, Expand);
708+
setCondCodeAction(FPCCToExpand, MVT::f64, Expand);
709+
}
710+
683711
// TODO: On M-mode only targets, the cycle[h]/time[h] CSR may not be present.
684712
// Unfortunately this can't be determined just from the ISA naming string.
685713
setOperationAction(ISD::READCYCLECOUNTER, MVT::i64,
@@ -6145,11 +6173,44 @@ static SDValue lowerConstant(SDValue Op, SelectionDAG &DAG,
61456173
return SDValue();
61466174
}
61476175

6148-
SDValue RISCVTargetLowering::lowerConstantFP(SDValue Op,
6149-
SelectionDAG &DAG) const {
6176+
SDValue
6177+
RISCVTargetLowering::lowerConstantFP(SDValue Op, SelectionDAG &DAG,
6178+
const RISCVSubtarget &Subtarget) const {
61506179
MVT VT = Op.getSimpleValueType();
61516180
const APFloat &Imm = cast<ConstantFPSDNode>(Op)->getValueAPF();
61526181

6182+
if (Subtarget.hasVendorXCheriot()) {
6183+
// Cheriot needs to custom lower f64 immediates using csethigh
6184+
if (VT != MVT::f64)
6185+
return Op;
6186+
6187+
SDLoc DL(Op);
6188+
uint64_t Val = Imm.bitcastToAPInt().getLimitedValue();
6189+
6190+
// Materialize 0.0 as cnull
6191+
if (Val == 0)
6192+
return DAG.getRegister(getNullCapabilityRegister(), MVT::f64);
6193+
6194+
// Otherwise, materialize the low part into a 32-bit register.
6195+
auto Lo = DAG.getConstant(Val & 0xFFFFFFFF, DL, MVT::i32);
6196+
auto LoAsCap = DAG.getTargetInsertSubreg(RISCV::sub_cap_addr, DL, MVT::c64,
6197+
DAG.getUNDEF(MVT::f64), Lo);
6198+
6199+
// The high half of a capability register is zeroed by integer ops,
6200+
// so if we wanted a zero high half then we are done.
6201+
if (Val >> 32 == 0)
6202+
return DAG.getBitcast(MVT::f64, LoAsCap);
6203+
6204+
// Otherwise, materialize the high half and use csethigh to combine the two
6205+
// halve.
6206+
auto Hi = DAG.getConstant(Val >> 32, DL, MVT::i32);
6207+
auto Cap = DAG.getNode(
6208+
ISD::INTRINSIC_WO_CHAIN, DL, MVT::c64,
6209+
DAG.getTargetConstant(Intrinsic::cheri_cap_high_set, DL, MVT::i32),
6210+
LoAsCap, Hi);
6211+
return DAG.getBitcast(MVT::f64, Cap);
6212+
}
6213+
61536214
// Can this constant be selected by a Zfa FLI instruction?
61546215
bool Negate = false;
61556216
int Index = getLegalZfaFPImm(Imm, VT);
@@ -6799,7 +6860,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
67996860
case ISD::Constant:
68006861
return lowerConstant(Op, DAG, Subtarget);
68016862
case ISD::ConstantFP:
6802-
return lowerConstantFP(Op, DAG);
6863+
return lowerConstantFP(Op, DAG, Subtarget);
68036864
case ISD::SELECT:
68046865
return lowerSELECT(Op, DAG);
68056866
case ISD::BRCOND:
@@ -7589,6 +7650,50 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
75897650
case ISD::VECTOR_COMPRESS:
75907651
return lowerVectorCompress(Op, DAG);
75917652
case ISD::SELECT_CC: {
7653+
if (Subtarget.hasVendorXCheriot() &&
7654+
(Op.getValueType() == MVT::f64 ||
7655+
Op.getOperand(0).getValueType() == MVT::f64)) {
7656+
SDLoc DL(Op);
7657+
EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(),
7658+
Op.getOperand(0).getValueType());
7659+
SDValue SetCC;
7660+
if (Op.getOperand(0).getValueType() == MVT::f64) {
7661+
SDValue NewLHS = Op.getOperand(0), NewRHS = Op.getOperand(1);
7662+
ISD::CondCode CCCode = cast<CondCodeSDNode>(Op.getOperand(4))->get();
7663+
7664+
NewLHS = DAG.getBitcast(MVT::c64, NewLHS);
7665+
NewRHS = DAG.getBitcast(MVT::c64, NewRHS);
7666+
softenSetCCOperands(DAG, MVT::f64, NewLHS, NewRHS, CCCode, DL,
7667+
Op.getOperand(0), Op.getOperand(1));
7668+
7669+
// If softenSetCCOperands returned a scalar, we need to compare the
7670+
// result against zero to select between true and false values.
7671+
if (!NewRHS.getNode()) {
7672+
NewRHS = DAG.getConstant(0, DL, NewLHS.getValueType());
7673+
CCCode = ISD::SETNE;
7674+
}
7675+
7676+
SetCC = DAG.getNode(ISD::SETCC, DL, CCVT, NewLHS, NewRHS,
7677+
DAG.getCondCode(CCCode));
7678+
} else {
7679+
SetCC = DAG.getNode(ISD::SETCC, DL, CCVT, Op.getOperand(0),
7680+
Op.getOperand(1), Op.getOperand(4), Op->getFlags());
7681+
}
7682+
7683+
SDValue LSel = Op.getOperand(2);
7684+
if (LSel.getValueType() == MVT::f64)
7685+
LSel = DAG.getBitcast(MVT::c64, LSel);
7686+
SDValue RSel = Op.getOperand(3);
7687+
if (RSel.getValueType() == MVT::f64)
7688+
RSel = DAG.getBitcast(MVT::c64, RSel);
7689+
7690+
SDValue Select =
7691+
DAG.getSelect(DL, LSel.getValueType(), SetCC, LSel, RSel);
7692+
if (Op.getValueType() == MVT::f64)
7693+
Select = DAG.getBitcast(MVT::f64, Select);
7694+
return Select;
7695+
}
7696+
75927697
// This occurs because we custom legalize SETGT and SETUGT for setcc. That
75937698
// causes LegalizeDAG to think we need to custom legalize select_cc. Expand
75947699
// into separate SETCC+SELECT just like LegalizeDAG.
@@ -7608,13 +7713,43 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
76087713
}
76097714
case ISD::SETCC: {
76107715
MVT OpVT = Op.getOperand(0).getSimpleValueType();
7716+
7717+
if (Subtarget.hasVendorXCheriot() &&
7718+
(OpVT == MVT::f64 || Op.getValueType() == MVT::f64)) {
7719+
7720+
SDNode *N = Op.getNode();
7721+
bool IsStrict = N->isStrictFPOpcode();
7722+
SDValue Op0 = N->getOperand(IsStrict ? 1 : 0);
7723+
SDValue Op1 = N->getOperand(IsStrict ? 2 : 1);
7724+
SDValue Chain = IsStrict ? N->getOperand(0) : SDValue();
7725+
ISD::CondCode CCCode =
7726+
cast<CondCodeSDNode>(N->getOperand(IsStrict ? 3 : 2))->get();
7727+
7728+
EVT VT = Op0.getValueType();
7729+
SDValue NewLHS = DAG.getBitcast(MVT::c64, Op0);
7730+
SDValue NewRHS = DAG.getBitcast(MVT::c64, Op1);
7731+
softenSetCCOperands(DAG, VT, NewLHS, NewRHS, CCCode, SDLoc(N), Op0, Op1,
7732+
Chain, N->getOpcode() == ISD::STRICT_FSETCCS);
7733+
7734+
// Update N to have the operands specified.
7735+
if (NewRHS.getNode()) {
7736+
if (IsStrict)
7737+
NewLHS = DAG.getNode(ISD::SETCC, SDLoc(N), N->getValueType(0), NewLHS,
7738+
NewRHS, DAG.getCondCode(CCCode));
7739+
else
7740+
return SDValue(DAG.UpdateNodeOperands(N, NewLHS, NewRHS,
7741+
DAG.getCondCode(CCCode)),
7742+
0);
7743+
}
7744+
}
7745+
76117746
if (OpVT.isScalarInteger()) {
76127747
MVT VT = Op.getSimpleValueType();
76137748
SDValue LHS = Op.getOperand(0);
76147749
SDValue RHS = Op.getOperand(1);
76157750
ISD::CondCode CCVal = cast<CondCodeSDNode>(Op.getOperand(2))->get();
7616-
assert((CCVal == ISD::SETGT || CCVal == ISD::SETUGT) &&
7617-
"Unexpected CondCode");
7751+
if (CCVal != ISD::SETGT && CCVal != ISD::SETUGT)
7752+
return Op;
76187753

76197754
SDLoc DL(Op);
76207755

@@ -8633,6 +8768,15 @@ SDValue RISCVTargetLowering::lowerSELECT(SDValue Op, SelectionDAG &DAG) const {
86338768
MVT VT = Op.getSimpleValueType();
86348769
MVT XLenVT = Subtarget.getXLenVT();
86358770

8771+
if (Subtarget.hasVendorXCheriot() && VT == MVT::f64) {
8772+
// Perform SELECT_CC on f64 by bitcasting through c64.
8773+
SDValue LHSCap = DAG.getBitcast(MVT::c64, TrueV);
8774+
SDValue RHSCap = DAG.getBitcast(MVT::c64, FalseV);
8775+
SDValue Select =
8776+
DAG.getNode(ISD::SELECT, DL, MVT::c64, CondV, LHSCap, RHSCap);
8777+
return DAG.getBitcast(MVT::f64, Select);
8778+
}
8779+
86368780
// Lower vector SELECTs to VSELECTs by splatting the condition.
86378781
if (VT.isVector()) {
86388782
MVT SplatCondVT = VT.changeVectorElementType(MVT::i1);

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -981,7 +981,8 @@ class RISCVTargetLowering : public TargetLowering {
981981
SelectionDAG &DAG) const;
982982
SDValue getTLSDescAddr(GlobalAddressSDNode *N, SelectionDAG &DAG) const;
983983

984-
SDValue lowerConstantFP(SDValue Op, SelectionDAG &DAG) const;
984+
SDValue lowerConstantFP(SDValue Op, SelectionDAG &DAG,
985+
const RISCVSubtarget &Subtarget) const;
985986
SDValue lowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const;
986987
SDValue lowerBlockAddress(SDValue Op, SelectionDAG &DAG) const;
987988
SDValue lowerConstantPool(SDValue Op, SelectionDAG &DAG) const;

0 commit comments

Comments
 (0)