Skip to content

Commit 2486434

Browse files
committed
[CHERIOT] Use capability registers to store f64 values.
This enables each f64 to be passed by value in a single cap register, rather than in pairs of integer registers. This required adding explicit type annotations to various places in the XCheri tblgen files, as the GPCR class can now hold values type c64 or f64, breaking type inference.
1 parent c650fb7 commit 2486434

File tree

6 files changed

+553
-93
lines changed

6 files changed

+553
-93
lines changed

llvm/lib/Target/RISCV/RISCVCallingConv.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,14 @@ bool llvm::CC_RISCV(unsigned ValNo, MVT ValVT, MVT LocVT,
501501
}
502502
}
503503

504+
// Cheriot uses GPCR without a bitcast when possible.
505+
if (LocVT == MVT::f64 && Subtarget.hasVendorXCheriot() && !IsPureCapVarArgs) {
506+
if (MCRegister Reg = State.AllocateReg(ArgGPCRs)) {
507+
State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
508+
return false;
509+
}
510+
}
511+
504512
// FP smaller than XLen, uses custom GPR.
505513
if (LocVT == MVT::f16 || LocVT == MVT::bf16 ||
506514
(LocVT == MVT::f32 && XLen == 64)) {

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 149 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,11 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
178178
addRegisterClass(CapType, &RISCV::GPCRRegClass);
179179
}
180180

181+
if (Subtarget.hasVendorXCheriot()) {
182+
// Cheriot holds f64's in capability registers.
183+
addRegisterClass(MVT::f64, &RISCV::GPCRRegClass);
184+
}
185+
181186
static const MVT::SimpleValueType BoolVecVTs[] = {
182187
MVT::nxv1i1, MVT::nxv2i1, MVT::nxv4i1, MVT::nxv8i1,
183188
MVT::nxv16i1, MVT::nxv32i1, MVT::nxv64i1};
@@ -724,6 +729,29 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
724729
setLibcallImpl(RTLIB::MEMSET, RTLIB::memset);
725730
}
726731

732+
if (Subtarget.hasVendorXCheriot()) {
733+
// FP64 is "legal" on Cheriot in that we store it in c64 registers, but
734+
// essentially no operations on it are legal other than load/store/copy.
735+
setOperationAction({ISD::ConstantFP, ISD::SELECT_CC, ISD::SETCC}, MVT::f64,
736+
Custom);
737+
738+
// These require custom lowering because their inputs might be f64.
739+
setOperationAction({ISD::SELECT_CC, ISD::SETCC}, MVT::i32, Custom);
740+
setOperationAction({ISD::FP_TO_UINT, ISD::FP_TO_SINT}, MVT::i32, LibCall);
741+
742+
static const unsigned CheriotF64ExpandOps[] = {
743+
ISD::FMINNUM, ISD::FMAXNUM, ISD::FADD, ISD::FSUB,
744+
ISD::FMUL, ISD::FMA, ISD::FDIV, ISD::FSQRT,
745+
ISD::FCEIL, ISD::FTRUNC, ISD::FFLOOR, ISD::FROUND,
746+
ISD::FROUNDEVEN, ISD::FRINT, ISD::FNEARBYINT, ISD::IS_FPCLASS,
747+
ISD::SETCC, ISD::FMAXIMUM, ISD::FMINIMUM, ISD::STRICT_FADD,
748+
ISD::STRICT_FSUB, ISD::STRICT_FMUL, ISD::STRICT_FDIV, ISD::STRICT_FSQRT,
749+
ISD::STRICT_FMA, ISD::FNEG, ISD::FABS, ISD::FCOPYSIGN,
750+
ISD::UINT_TO_FP, ISD::SINT_TO_FP, ISD::BR_CC};
751+
setOperationAction(CheriotF64ExpandOps, MVT::f64, Expand);
752+
setCondCodeAction(FPCCToExpand, MVT::f64, Expand);
753+
}
754+
727755
// TODO: On M-mode only targets, the cycle[h]/time[h] CSR may not be present.
728756
// Unfortunately this can't be determined just from the ISA naming string.
729757
setOperationAction(ISD::READCYCLECOUNTER, MVT::i64,
@@ -6723,11 +6751,44 @@ static SDValue lowerConstant(SDValue Op, SelectionDAG &DAG,
67236751
return SDValue();
67246752
}
67256753

6726-
SDValue RISCVTargetLowering::lowerConstantFP(SDValue Op,
6727-
SelectionDAG &DAG) const {
6754+
SDValue
6755+
RISCVTargetLowering::lowerConstantFP(SDValue Op, SelectionDAG &DAG,
6756+
const RISCVSubtarget &Subtarget) const {
67286757
MVT VT = Op.getSimpleValueType();
67296758
const APFloat &Imm = cast<ConstantFPSDNode>(Op)->getValueAPF();
67306759

6760+
if (Subtarget.hasVendorXCheriot()) {
6761+
// Cheriot needs to custom lower f64 immediates using csethigh
6762+
if (VT != MVT::f64)
6763+
return Op;
6764+
6765+
SDLoc DL(Op);
6766+
uint64_t Val = Imm.bitcastToAPInt().getLimitedValue();
6767+
6768+
// Materialize 0.0 as cnull
6769+
if (Val == 0)
6770+
return DAG.getRegister(getNullCapabilityRegister(), MVT::f64);
6771+
6772+
// Otherwise, materialize the low part into a 32-bit register.
6773+
auto Lo = DAG.getConstant(Val & 0xFFFFFFFF, DL, MVT::i32);
6774+
auto LoAsCap = DAG.getTargetInsertSubreg(RISCV::sub_cap_addr, DL, MVT::c64,
6775+
DAG.getUNDEF(MVT::f64), Lo);
6776+
6777+
// The high half of a capability register is zeroed by integer ops,
6778+
// so if we wanted a zero high half then we are done.
6779+
if (Val >> 32 == 0)
6780+
return DAG.getBitcast(MVT::f64, LoAsCap);
6781+
6782+
// Otherwise, materialize the high half and use csethigh to combine the two
6783+
// halve.
6784+
auto Hi = DAG.getConstant(Val >> 32, DL, MVT::i32);
6785+
auto Cap = DAG.getNode(
6786+
ISD::INTRINSIC_WO_CHAIN, DL, MVT::c64,
6787+
DAG.getTargetConstant(Intrinsic::cheri_cap_high_set, DL, MVT::i32),
6788+
LoAsCap, Hi);
6789+
return DAG.getBitcast(MVT::f64, Cap);
6790+
}
6791+
67316792
// Can this constant be selected by a Zfa FLI instruction?
67326793
bool Negate = false;
67336794
int Index = getLegalZfaFPImm(Imm, VT);
@@ -7346,7 +7407,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
73467407
case ISD::Constant:
73477408
return lowerConstant(Op, DAG, Subtarget);
73487409
case ISD::ConstantFP:
7349-
return lowerConstantFP(Op, DAG);
7410+
return lowerConstantFP(Op, DAG, Subtarget);
73507411
case ISD::SELECT:
73517412
return lowerSELECT(Op, DAG);
73527413
case ISD::BRCOND:
@@ -8205,6 +8266,50 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
82058266
case ISD::VECTOR_COMPRESS:
82068267
return lowerVectorCompress(Op, DAG);
82078268
case ISD::SELECT_CC: {
8269+
if (Subtarget.hasVendorXCheriot() &&
8270+
(Op.getValueType() == MVT::f64 ||
8271+
Op.getOperand(0).getValueType() == MVT::f64)) {
8272+
SDLoc DL(Op);
8273+
EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(),
8274+
Op.getOperand(0).getValueType());
8275+
SDValue SetCC;
8276+
if (Op.getOperand(0).getValueType() == MVT::f64) {
8277+
SDValue NewLHS = Op.getOperand(0), NewRHS = Op.getOperand(1);
8278+
ISD::CondCode CCCode = cast<CondCodeSDNode>(Op.getOperand(4))->get();
8279+
8280+
NewLHS = DAG.getBitcast(MVT::c64, NewLHS);
8281+
NewRHS = DAG.getBitcast(MVT::c64, NewRHS);
8282+
softenSetCCOperands(DAG, MVT::f64, NewLHS, NewRHS, CCCode, DL,
8283+
Op.getOperand(0), Op.getOperand(1));
8284+
8285+
// If softenSetCCOperands returned a scalar, we need to compare the
8286+
// result against zero to select between true and false values.
8287+
if (!NewRHS.getNode()) {
8288+
NewRHS = DAG.getConstant(0, DL, NewLHS.getValueType());
8289+
CCCode = ISD::SETNE;
8290+
}
8291+
8292+
SetCC = DAG.getNode(ISD::SETCC, DL, CCVT, NewLHS, NewRHS,
8293+
DAG.getCondCode(CCCode));
8294+
} else {
8295+
SetCC = DAG.getNode(ISD::SETCC, DL, CCVT, Op.getOperand(0),
8296+
Op.getOperand(1), Op.getOperand(4), Op->getFlags());
8297+
}
8298+
8299+
SDValue LSel = Op.getOperand(2);
8300+
if (LSel.getValueType() == MVT::f64)
8301+
LSel = DAG.getBitcast(MVT::c64, LSel);
8302+
SDValue RSel = Op.getOperand(3);
8303+
if (RSel.getValueType() == MVT::f64)
8304+
RSel = DAG.getBitcast(MVT::c64, RSel);
8305+
8306+
SDValue Select =
8307+
DAG.getSelect(DL, LSel.getValueType(), SetCC, LSel, RSel);
8308+
if (Op.getValueType() == MVT::f64)
8309+
Select = DAG.getBitcast(MVT::f64, Select);
8310+
return Select;
8311+
}
8312+
82088313
// This occurs because we custom legalize SETGT and SETUGT for setcc. That
82098314
// causes LegalizeDAG to think we need to custom legalize select_cc. Expand
82108315
// into separate SETCC+SELECT just like LegalizeDAG.
@@ -8224,13 +8329,43 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
82248329
}
82258330
case ISD::SETCC: {
82268331
MVT OpVT = Op.getOperand(0).getSimpleValueType();
8332+
8333+
if (Subtarget.hasVendorXCheriot() &&
8334+
(OpVT == MVT::f64 || Op.getValueType() == MVT::f64)) {
8335+
8336+
SDNode *N = Op.getNode();
8337+
bool IsStrict = N->isStrictFPOpcode();
8338+
SDValue Op0 = N->getOperand(IsStrict ? 1 : 0);
8339+
SDValue Op1 = N->getOperand(IsStrict ? 2 : 1);
8340+
SDValue Chain = IsStrict ? N->getOperand(0) : SDValue();
8341+
ISD::CondCode CCCode =
8342+
cast<CondCodeSDNode>(N->getOperand(IsStrict ? 3 : 2))->get();
8343+
8344+
EVT VT = Op0.getValueType();
8345+
SDValue NewLHS = DAG.getBitcast(MVT::c64, Op0);
8346+
SDValue NewRHS = DAG.getBitcast(MVT::c64, Op1);
8347+
softenSetCCOperands(DAG, VT, NewLHS, NewRHS, CCCode, SDLoc(N), Op0, Op1,
8348+
Chain, N->getOpcode() == ISD::STRICT_FSETCCS);
8349+
8350+
// Update N to have the operands specified.
8351+
if (NewRHS.getNode()) {
8352+
if (IsStrict)
8353+
NewLHS = DAG.getNode(ISD::SETCC, SDLoc(N), N->getValueType(0), NewLHS,
8354+
NewRHS, DAG.getCondCode(CCCode));
8355+
else
8356+
return SDValue(DAG.UpdateNodeOperands(N, NewLHS, NewRHS,
8357+
DAG.getCondCode(CCCode)),
8358+
0);
8359+
}
8360+
}
8361+
82278362
if (OpVT.isScalarInteger()) {
82288363
MVT VT = Op.getSimpleValueType();
82298364
SDValue LHS = Op.getOperand(0);
82308365
SDValue RHS = Op.getOperand(1);
82318366
ISD::CondCode CCVal = cast<CondCodeSDNode>(Op.getOperand(2))->get();
8232-
assert((CCVal == ISD::SETGT || CCVal == ISD::SETUGT) &&
8233-
"Unexpected CondCode");
8367+
if (CCVal != ISD::SETGT && CCVal != ISD::SETUGT)
8368+
return Op;
82348369

82358370
SDLoc DL(Op);
82368371

@@ -9361,6 +9496,15 @@ SDValue RISCVTargetLowering::lowerSELECT(SDValue Op, SelectionDAG &DAG) const {
93619496
MVT VT = Op.getSimpleValueType();
93629497
MVT XLenVT = Subtarget.getXLenVT();
93639498

9499+
if (Subtarget.hasVendorXCheriot() && VT == MVT::f64) {
9500+
// Perform SELECT_CC on f64 by bitcasting through c64.
9501+
SDValue LHSCap = DAG.getBitcast(MVT::c64, TrueV);
9502+
SDValue RHSCap = DAG.getBitcast(MVT::c64, FalseV);
9503+
SDValue Select =
9504+
DAG.getNode(ISD::SELECT, DL, MVT::c64, CondV, LHSCap, RHSCap);
9505+
return DAG.getBitcast(MVT::f64, Select);
9506+
}
9507+
93649508
// Lower vector SELECTs to VSELECTs by splatting the condition.
93659509
if (VT.isVector()) {
93669510
MVT SplatCondVT = VT.changeVectorElementType(MVT::i1);

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,8 @@ class RISCVTargetLowering : public TargetLowering {
507507
SelectionDAG &DAG) const;
508508
SDValue getTLSDescAddr(GlobalAddressSDNode *N, SelectionDAG &DAG) const;
509509

510-
SDValue lowerConstantFP(SDValue Op, SelectionDAG &DAG) const;
510+
SDValue lowerConstantFP(SDValue Op, SelectionDAG &DAG,
511+
const RISCVSubtarget &Subtarget) const;
511512
SDValue lowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const;
512513
SDValue lowerBlockAddress(SDValue Op, SelectionDAG &DAG) const;
513514
SDValue lowerConstantPool(SDValue Op, SelectionDAG &DAG) const;

0 commit comments

Comments
 (0)