2929#include "llvm/CodeGen/MachineFunction.h"
3030#include "llvm/CodeGen/MachineInstrBuilder.h"
3131#include "llvm/CodeGen/MachineJumpTableInfo.h"
32+ #include "llvm/CodeGen/MachineModuleInfo.h"
3233#include "llvm/CodeGen/MachineRegisterInfo.h"
3334#include "llvm/CodeGen/SDPatternMatch.h"
3435#include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
@@ -8362,9 +8363,23 @@ SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op,
83628363 // 16: <StaticChainOffset>
83638364 // 24: <FunctionAddressOffset>
83648365 // 32:
8365-
8366- constexpr unsigned StaticChainOffset = 16;
8367- constexpr unsigned FunctionAddressOffset = 24;
8366+ // Offset with branch control flow protection enabled:
8367+ // 0: lpad <imm20>
8368+ // 4: auipc t3, 0
8369+ // 8: ld t0, 28(t3)
8370+ // 12: ld t3, 20(t3)
8371+ // 16: lui t2, <imm20>
8372+ // 20: jalr t0
8373+ // 24: <StaticChainOffset>
8374+ // 32: <FunctionAddressOffset>
8375+ // 40:
8376+
8377+ const bool HasCFBranch =
8378+ Subtarget.hasStdExtZicfilp() &&
8379+ DAG.getMMI()->getModule()->getModuleFlag("cf-protection-branch");
8380+ const unsigned StaticChainIdx = HasCFBranch ? 6 : 4;
8381+ const unsigned StaticChainOffset = StaticChainIdx * 4;
8382+ const unsigned FunctionAddressOffset = StaticChainOffset + 8;
83688383
83698384 const MCSubtargetInfo *STI = getTargetMachine().getMCSubtargetInfo();
83708385 assert(STI);
@@ -8377,35 +8392,77 @@ SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op,
83778392 };
83788393
83798394 SDValue OutChains[6];
8380-
8381- uint32_t Encodings[] = {
8382- // auipc t2, 0
8383- // Loads the current PC into t2.
8384- GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X7).addImm(0)),
8385- // ld t0, 24(t2)
8386- // Loads the function address into t0. Note that we are using offsets
8387- // pc-relative to the first instruction of the trampoline.
8388- GetEncoding(
8389- MCInstBuilder(RISCV::LD).addReg(RISCV::X5).addReg(RISCV::X7).addImm(
8390- FunctionAddressOffset)),
8391- // ld t2, 16(t2)
8392- // Load the value of the static chain.
8393- GetEncoding(
8394- MCInstBuilder(RISCV::LD).addReg(RISCV::X7).addReg(RISCV::X7).addImm(
8395- StaticChainOffset)),
8396- // jalr t0
8397- // Jump to the function.
8398- GetEncoding(MCInstBuilder(RISCV::JALR)
8399- .addReg(RISCV::X0)
8400- .addReg(RISCV::X5)
8401- .addImm(0))};
8395+ SDValue OutChainsLPAD[8];
8396+ if (HasCFBranch)
8397+ assert(std::size(OutChainsLPAD) == StaticChainIdx + 2);
8398+ else
8399+ assert(std::size(OutChains) == StaticChainIdx + 2);
8400+
8401+ SmallVector<uint32_t> Encodings;
8402+ if (!HasCFBranch) {
8403+ Encodings.append(
8404+ {// auipc t2, 0
8405+ // Loads the current PC into t2.
8406+ GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X7).addImm(0)),
8407+ // ld t0, 24(t2)
8408+ // Loads the function address into t0. Note that we are using offsets
8409+ // pc-relative to the first instruction of the trampoline.
8410+ GetEncoding(MCInstBuilder(RISCV::LD)
8411+ .addReg(RISCV::X5)
8412+ .addReg(RISCV::X7)
8413+ .addImm(FunctionAddressOffset)),
8414+ // ld t2, 16(t2)
8415+ // Load the value of the static chain.
8416+ GetEncoding(MCInstBuilder(RISCV::LD)
8417+ .addReg(RISCV::X7)
8418+ .addReg(RISCV::X7)
8419+ .addImm(StaticChainOffset)),
8420+ // jalr t0
8421+ // Jump to the function.
8422+ GetEncoding(MCInstBuilder(RISCV::JALR)
8423+ .addReg(RISCV::X0)
8424+ .addReg(RISCV::X5)
8425+ .addImm(0))});
8426+ } else {
8427+ Encodings.append(
8428+ {// auipc x0, <imm20> (lpad <imm20>)
8429+ // Landing pad.
8430+ GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X0).addImm(0)),
8431+ // auipc t3, 0
8432+ // Loads the current PC into t3.
8433+ GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X28).addImm(0)),
8434+ // ld t0, (FunctionAddressOffset - 4)(t3)
8435+ // Loads the function address into t0. Note that we are using offsets
8436+ // pc-relative to the SECOND instruction of the trampoline.
8437+ GetEncoding(MCInstBuilder(RISCV::LD)
8438+ .addReg(RISCV::X5)
8439+ .addReg(RISCV::X28)
8440+ .addImm(FunctionAddressOffset - 4)),
8441+ // ld t3, (StaticChainOffset - 4)(t3)
8442+ // Load the value of the static chain.
8443+ GetEncoding(MCInstBuilder(RISCV::LD)
8444+ .addReg(RISCV::X28)
8445+ .addReg(RISCV::X28)
8446+ .addImm(StaticChainOffset - 4)),
8447+ // lui t2, <imm20>
8448+ // Setup the landing pad value.
8449+ GetEncoding(MCInstBuilder(RISCV::LUI).addReg(RISCV::X7).addImm(0)),
8450+ // jalr t0
8451+ // Jump to the function.
8452+ GetEncoding(MCInstBuilder(RISCV::JALR)
8453+ .addReg(RISCV::X0)
8454+ .addReg(RISCV::X5)
8455+ .addImm(0))});
8456+ }
8457+
8458+ SDValue *OutChainsUsed = HasCFBranch ? OutChainsLPAD : OutChains;
84028459
84038460 // Store encoded instructions.
84048461 for (auto [Idx, Encoding] : llvm::enumerate(Encodings)) {
84058462 SDValue Addr = Idx > 0 ? DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp,
84068463 DAG.getConstant(Idx * 4, dl, MVT::i64))
84078464 : Trmp;
8408- OutChains [Idx] = DAG.getTruncStore(
8465+ OutChainsUsed [Idx] = DAG.getTruncStore(
84098466 Root, dl, DAG.getConstant(Encoding, dl, MVT::i64), Addr,
84108467 MachinePointerInfo(TrmpAddr, Idx * 4), MVT::i32);
84118468 }
@@ -8428,12 +8485,16 @@ SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op,
84288485 DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp,
84298486 DAG.getConstant(OffsetValue.Offset, dl, MVT::i64));
84308487 OffsetValue.Addr = Addr;
8431- OutChains [Idx + 4 ] =
8488+ OutChainsUsed [Idx + StaticChainIdx ] =
84328489 DAG.getStore(Root, dl, OffsetValue.Value, Addr,
84338490 MachinePointerInfo(TrmpAddr, OffsetValue.Offset));
84348491 }
84358492
8436- SDValue StoreToken = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains);
8493+ SDValue StoreToken;
8494+ if (HasCFBranch)
8495+ StoreToken = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChainsLPAD);
8496+ else
8497+ StoreToken = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains);
84378498
84388499 // The end of instructions of trampoline is the same as the static chain
84398500 // address that we computed earlier.
0 commit comments