Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 79 additions & 31 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8362,9 +8362,23 @@ SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op,
// 16: <StaticChainOffset>
// 24: <FunctionAddressOffset>
// 32:

constexpr unsigned StaticChainOffset = 16;
constexpr unsigned FunctionAddressOffset = 24;
// Offset with branch control flow protection enabled:
// 0: lpad <imm20>
// 4: auipc t3, 0
// 8: ld t2, 28(t3)
// 12: ld t3, 20(t3)
// 16: jalr t2
// 20: <StaticChainOffset>
// 28: <FunctionAddressOffset>
// 36:

const bool HasCFBranch =
Subtarget.hasStdExtZicfilp() &&
DAG.getMachineFunction().getFunction().getParent()->getModuleFlag(
"cf-protection-branch");
const unsigned StaticChainIdx = HasCFBranch ? 5 : 4;
const unsigned StaticChainOffset = StaticChainIdx * 4;
const unsigned FunctionAddressOffset = StaticChainOffset + 8;

const MCSubtargetInfo *STI = getTargetMachine().getMCSubtargetInfo();
assert(STI);
Expand All @@ -8376,38 +8390,70 @@ SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op,
return Encoding;
};

SDValue OutChains[6];

uint32_t Encodings[] = {
// auipc t2, 0
// Loads the current PC into t2.
GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X7).addImm(0)),
// ld t0, 24(t2)
// Loads the function address into t0. Note that we are using offsets
// pc-relative to the first instruction of the trampoline.
GetEncoding(
MCInstBuilder(RISCV::LD).addReg(RISCV::X5).addReg(RISCV::X7).addImm(
FunctionAddressOffset)),
// ld t2, 16(t2)
// Load the value of the static chain.
GetEncoding(
MCInstBuilder(RISCV::LD).addReg(RISCV::X7).addReg(RISCV::X7).addImm(
StaticChainOffset)),
// jalr t0
// Jump to the function.
GetEncoding(MCInstBuilder(RISCV::JALR)
.addReg(RISCV::X0)
.addReg(RISCV::X5)
.addImm(0))};
SmallVector<SDValue> OutChains;

SmallVector<uint32_t> Encodings;
if (!HasCFBranch) {
Encodings.append(
{// auipc t2, 0
// Loads the current PC into t2.
GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X7).addImm(0)),
// ld t0, 24(t2)
// Loads the function address into t0. Note that we are using offsets
// pc-relative to the first instruction of the trampoline.
GetEncoding(MCInstBuilder(RISCV::LD)
.addReg(RISCV::X5)
.addReg(RISCV::X7)
.addImm(FunctionAddressOffset)),
// ld t2, 16(t2)
// Load the value of the static chain.
GetEncoding(MCInstBuilder(RISCV::LD)
.addReg(RISCV::X7)
.addReg(RISCV::X7)
.addImm(StaticChainOffset)),
// jalr t0
// Jump to the function.
GetEncoding(MCInstBuilder(RISCV::JALR)
.addReg(RISCV::X0)
.addReg(RISCV::X5)
.addImm(0))});
} else {
Encodings.append(
{// auipc x0, <imm20> (lpad <imm20>)
// Landing pad.
GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X0).addImm(0)),
// auipc t3, 0
// Loads the current PC into t3.
GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X28).addImm(0)),
// ld t2, (FunctionAddressOffset - 4)(t3)
// Loads the function address into t2. Note that we are using offsets
// pc-relative to the SECOND instruction of the trampoline.
GetEncoding(MCInstBuilder(RISCV::LD)
.addReg(RISCV::X7)
.addReg(RISCV::X28)
.addImm(FunctionAddressOffset - 4)),
// ld t3, (StaticChainOffset - 4)(t3)
// Load the value of the static chain.
GetEncoding(MCInstBuilder(RISCV::LD)
.addReg(RISCV::X28)
.addReg(RISCV::X28)
.addImm(StaticChainOffset - 4)),
// jalr t2
// Software-guarded jump to the function.
GetEncoding(MCInstBuilder(RISCV::JALR)
.addReg(RISCV::X0)
.addReg(RISCV::X7)
.addImm(0))});
}

// Store encoded instructions.
for (auto [Idx, Encoding] : llvm::enumerate(Encodings)) {
SDValue Addr = Idx > 0 ? DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp,
DAG.getConstant(Idx * 4, dl, MVT::i64))
: Trmp;
OutChains[Idx] = DAG.getTruncStore(
OutChains.push_back(DAG.getTruncStore(
Root, dl, DAG.getConstant(Encoding, dl, MVT::i64), Addr,
MachinePointerInfo(TrmpAddr, Idx * 4), MVT::i32);
MachinePointerInfo(TrmpAddr, Idx * 4), MVT::i32));
}

// Now store the variable part of the trampoline.
Expand All @@ -8423,16 +8469,18 @@ SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op,
{StaticChainOffset, StaticChain},
{FunctionAddressOffset, FunctionAddress},
};
for (auto [Idx, OffsetValue] : llvm::enumerate(OffsetValues)) {
for (auto &OffsetValue : OffsetValues) {
SDValue Addr =
DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp,
DAG.getConstant(OffsetValue.Offset, dl, MVT::i64));
OffsetValue.Addr = Addr;
OutChains[Idx + 4] =
OutChains.push_back(
DAG.getStore(Root, dl, OffsetValue.Value, Addr,
MachinePointerInfo(TrmpAddr, OffsetValue.Offset));
MachinePointerInfo(TrmpAddr, OffsetValue.Offset)));
}

assert(OutChains.size() == StaticChainIdx + 2 &&
"Size of OutChains mismatch");
SDValue StoreToken = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains);

// The end of instructions of trampoline is the same as the static chain
Expand Down
99 changes: 99 additions & 0 deletions llvm/test/CodeGen/RISCV/rv64-trampoline-cfi.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -O0 -mtriple=riscv64 -mattr=+experimental-zicfilp -verify-machineinstrs < %s \
; RUN: | FileCheck -check-prefix=RV64 %s
; RUN: llc -O0 -mtriple=riscv64-unknown-linux-gnu -mattr=+experimental-zicfilp -verify-machineinstrs < %s \
; RUN: | FileCheck -check-prefix=RV64-LINUX %s

declare void @llvm.init.trampoline(ptr, ptr, ptr)
declare ptr @llvm.adjust.trampoline(ptr)
declare i64 @f(ptr nest, i64)

define i64 @test0(i64 %n, ptr %p) nounwind {
; RV64-LABEL: test0:
; RV64: # %bb.0:
; RV64-NEXT: lpad 0
; RV64-NEXT: addi sp, sp, -64
; RV64-NEXT: sd ra, 56(sp) # 8-byte Folded Spill
; RV64-NEXT: sd a0, 0(sp) # 8-byte Folded Spill
; RV64-NEXT: lui a0, %hi(f)
; RV64-NEXT: addi a0, a0, %lo(f)
; RV64-NEXT: sw a0, 44(sp)
; RV64-NEXT: srli a0, a0, 32
; RV64-NEXT: sw a0, 48(sp)
; RV64-NEXT: sw a1, 36(sp)
; RV64-NEXT: srli a0, a1, 32
; RV64-NEXT: sw a0, 40(sp)
; RV64-NEXT: li a0, 23
; RV64-NEXT: sw a0, 16(sp)
; RV64-NEXT: lui a0, 56
; RV64-NEXT: addi a0, a0, 103
; RV64-NEXT: sw a0, 32(sp)
; RV64-NEXT: lui a0, 4324
; RV64-NEXT: addi a0, a0, -509
; RV64-NEXT: sw a0, 28(sp)
; RV64-NEXT: lui a0, 6371
; RV64-NEXT: addi a0, a0, 899
; RV64-NEXT: sw a0, 24(sp)
; RV64-NEXT: lui a0, 1
; RV64-NEXT: addi a0, a0, -489
; RV64-NEXT: sw a0, 20(sp)
; RV64-NEXT: addi a1, sp, 36
; RV64-NEXT: addi a0, sp, 16
; RV64-NEXT: sd a0, 8(sp) # 8-byte Folded Spill
; RV64-NEXT: call __clear_cache
; RV64-NEXT: ld a0, 0(sp) # 8-byte Folded Reload
; RV64-NEXT: ld a1, 8(sp) # 8-byte Folded Reload
; RV64-NEXT: jalr a1
; RV64-NEXT: ld ra, 56(sp) # 8-byte Folded Reload
; RV64-NEXT: addi sp, sp, 64
; RV64-NEXT: ret
;
; RV64-LINUX-LABEL: test0:
; RV64-LINUX: # %bb.0:
; RV64-LINUX-NEXT: lpad 0
; RV64-LINUX-NEXT: addi sp, sp, -64
; RV64-LINUX-NEXT: sd ra, 56(sp) # 8-byte Folded Spill
; RV64-LINUX-NEXT: sd a0, 0(sp) # 8-byte Folded Spill
; RV64-LINUX-NEXT: lui a0, %hi(f)
; RV64-LINUX-NEXT: addi a0, a0, %lo(f)
; RV64-LINUX-NEXT: sw a0, 44(sp)
; RV64-LINUX-NEXT: srli a0, a0, 32
; RV64-LINUX-NEXT: sw a0, 48(sp)
; RV64-LINUX-NEXT: sw a1, 36(sp)
; RV64-LINUX-NEXT: srli a0, a1, 32
; RV64-LINUX-NEXT: sw a0, 40(sp)
; RV64-LINUX-NEXT: li a0, 23
; RV64-LINUX-NEXT: sw a0, 16(sp)
; RV64-LINUX-NEXT: lui a0, 56
; RV64-LINUX-NEXT: addi a0, a0, 103
; RV64-LINUX-NEXT: sw a0, 32(sp)
; RV64-LINUX-NEXT: lui a0, 4324
; RV64-LINUX-NEXT: addi a0, a0, -509
; RV64-LINUX-NEXT: sw a0, 28(sp)
; RV64-LINUX-NEXT: lui a0, 6371
; RV64-LINUX-NEXT: addi a0, a0, 899
; RV64-LINUX-NEXT: sw a0, 24(sp)
; RV64-LINUX-NEXT: lui a0, 1
; RV64-LINUX-NEXT: addi a0, a0, -489
; RV64-LINUX-NEXT: sw a0, 20(sp)
; RV64-LINUX-NEXT: addi a1, sp, 36
; RV64-LINUX-NEXT: addi a0, sp, 16
; RV64-LINUX-NEXT: sd a0, 8(sp) # 8-byte Folded Spill
; RV64-LINUX-NEXT: li a2, 0
; RV64-LINUX-NEXT: call __riscv_flush_icache
; RV64-LINUX-NEXT: ld a0, 0(sp) # 8-byte Folded Reload
; RV64-LINUX-NEXT: ld a1, 8(sp) # 8-byte Folded Reload
; RV64-LINUX-NEXT: jalr a1
; RV64-LINUX-NEXT: ld ra, 56(sp) # 8-byte Folded Reload
; RV64-LINUX-NEXT: addi sp, sp, 64
; RV64-LINUX-NEXT: ret
%alloca = alloca [36 x i8], align 8
call void @llvm.init.trampoline(ptr %alloca, ptr @f, ptr %p)
%tramp = call ptr @llvm.adjust.trampoline(ptr %alloca)
%ret = call i64 %tramp(i64 %n)
ret i64 %ret
}

!llvm.module.flags = !{!0}

!0 = !{i32 8, !"cf-protection-branch", i32 1}