Skip to content

Commit d02379c

Browse files
committed
[RISC-V] Adjust trampoline code for branch control flow protection
1 parent 54da543 commit d02379c

File tree

2 files changed

+184
-28
lines changed

2 files changed

+184
-28
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 89 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "llvm/CodeGen/MachineFunction.h"
3030
#include "llvm/CodeGen/MachineInstrBuilder.h"
3131
#include "llvm/CodeGen/MachineJumpTableInfo.h"
32+
#include "llvm/CodeGen/MachineModuleInfo.h"
3233
#include "llvm/CodeGen/MachineRegisterInfo.h"
3334
#include "llvm/CodeGen/SDPatternMatch.h"
3435
#include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
@@ -8362,9 +8363,23 @@ SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op,
83628363
// 16: <StaticChainOffset>
83638364
// 24: <FunctionAddressOffset>
83648365
// 32:
8365-
8366-
constexpr unsigned StaticChainOffset = 16;
8367-
constexpr unsigned FunctionAddressOffset = 24;
8366+
// Offset with branch control flow protection enabled:
8367+
// 0: lpad <imm20>
8368+
// 4: auipc t3, 0
8369+
// 8: ld t0, 28(t3)
8370+
// 12: ld t3, 20(t3)
8371+
// 16: lui t2, <imm20>
8372+
// 20: jalr t0
8373+
// 24: <StaticChainOffset>
8374+
// 32: <FunctionAddressOffset>
8375+
// 40:
8376+
8377+
const bool HasCFBranch =
8378+
Subtarget.hasStdExtZicfilp() &&
8379+
DAG.getMMI()->getModule()->getModuleFlag("cf-protection-branch");
8380+
const unsigned StaticChainIdx = HasCFBranch ? 6 : 4;
8381+
const unsigned StaticChainOffset = StaticChainIdx * 4;
8382+
const unsigned FunctionAddressOffset = StaticChainOffset + 8;
83688383

83698384
const MCSubtargetInfo *STI = getTargetMachine().getMCSubtargetInfo();
83708385
assert(STI);
@@ -8377,35 +8392,77 @@ SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op,
83778392
};
83788393

83798394
SDValue OutChains[6];
8380-
8381-
uint32_t Encodings[] = {
8382-
// auipc t2, 0
8383-
// Loads the current PC into t2.
8384-
GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X7).addImm(0)),
8385-
// ld t0, 24(t2)
8386-
// Loads the function address into t0. Note that we are using offsets
8387-
// pc-relative to the first instruction of the trampoline.
8388-
GetEncoding(
8389-
MCInstBuilder(RISCV::LD).addReg(RISCV::X5).addReg(RISCV::X7).addImm(
8390-
FunctionAddressOffset)),
8391-
// ld t2, 16(t2)
8392-
// Load the value of the static chain.
8393-
GetEncoding(
8394-
MCInstBuilder(RISCV::LD).addReg(RISCV::X7).addReg(RISCV::X7).addImm(
8395-
StaticChainOffset)),
8396-
// jalr t0
8397-
// Jump to the function.
8398-
GetEncoding(MCInstBuilder(RISCV::JALR)
8399-
.addReg(RISCV::X0)
8400-
.addReg(RISCV::X5)
8401-
.addImm(0))};
8395+
SDValue OutChainsLPAD[8];
8396+
if (HasCFBranch)
8397+
assert(std::size(OutChainsLPAD) == StaticChainIdx + 2);
8398+
else
8399+
assert(std::size(OutChains) == StaticChainIdx + 2);
8400+
8401+
SmallVector<uint32_t> Encodings;
8402+
if (!HasCFBranch) {
8403+
Encodings.append(
8404+
{// auipc t2, 0
8405+
// Loads the current PC into t2.
8406+
GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X7).addImm(0)),
8407+
// ld t0, 24(t2)
8408+
// Loads the function address into t0. Note that we are using offsets
8409+
// pc-relative to the first instruction of the trampoline.
8410+
GetEncoding(MCInstBuilder(RISCV::LD)
8411+
.addReg(RISCV::X5)
8412+
.addReg(RISCV::X7)
8413+
.addImm(FunctionAddressOffset)),
8414+
// ld t2, 16(t2)
8415+
// Load the value of the static chain.
8416+
GetEncoding(MCInstBuilder(RISCV::LD)
8417+
.addReg(RISCV::X7)
8418+
.addReg(RISCV::X7)
8419+
.addImm(StaticChainOffset)),
8420+
// jalr t0
8421+
// Jump to the function.
8422+
GetEncoding(MCInstBuilder(RISCV::JALR)
8423+
.addReg(RISCV::X0)
8424+
.addReg(RISCV::X5)
8425+
.addImm(0))});
8426+
} else {
8427+
Encodings.append(
8428+
{// auipc x0, <imm20> (lpad <imm20>)
8429+
// Landing pad.
8430+
GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X0).addImm(0)),
8431+
// auipc t3, 0
8432+
// Loads the current PC into t3.
8433+
GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X28).addImm(0)),
8434+
// ld t0, (FunctionAddressOffset - 4)(t3)
8435+
// Loads the function address into t0. Note that we are using offsets
8436+
// pc-relative to the SECOND instruction of the trampoline.
8437+
GetEncoding(MCInstBuilder(RISCV::LD)
8438+
.addReg(RISCV::X5)
8439+
.addReg(RISCV::X28)
8440+
.addImm(FunctionAddressOffset - 4)),
8441+
// ld t3, (StaticChainOffset - 4)(t3)
8442+
// Load the value of the static chain.
8443+
GetEncoding(MCInstBuilder(RISCV::LD)
8444+
.addReg(RISCV::X28)
8445+
.addReg(RISCV::X28)
8446+
.addImm(StaticChainOffset - 4)),
8447+
// lui t2, <imm20>
8448+
// Setup the landing pad value.
8449+
GetEncoding(MCInstBuilder(RISCV::LUI).addReg(RISCV::X7).addImm(0)),
8450+
// jalr t0
8451+
// Jump to the function.
8452+
GetEncoding(MCInstBuilder(RISCV::JALR)
8453+
.addReg(RISCV::X0)
8454+
.addReg(RISCV::X5)
8455+
.addImm(0))});
8456+
}
8457+
8458+
SDValue *OutChainsUsed = HasCFBranch ? OutChainsLPAD : OutChains;
84028459

84038460
// Store encoded instructions.
84048461
for (auto [Idx, Encoding] : llvm::enumerate(Encodings)) {
84058462
SDValue Addr = Idx > 0 ? DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp,
84068463
DAG.getConstant(Idx * 4, dl, MVT::i64))
84078464
: Trmp;
8408-
OutChains[Idx] = DAG.getTruncStore(
8465+
OutChainsUsed[Idx] = DAG.getTruncStore(
84098466
Root, dl, DAG.getConstant(Encoding, dl, MVT::i64), Addr,
84108467
MachinePointerInfo(TrmpAddr, Idx * 4), MVT::i32);
84118468
}
@@ -8428,12 +8485,16 @@ SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op,
84288485
DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp,
84298486
DAG.getConstant(OffsetValue.Offset, dl, MVT::i64));
84308487
OffsetValue.Addr = Addr;
8431-
OutChains[Idx + 4] =
8488+
OutChainsUsed[Idx + StaticChainIdx] =
84328489
DAG.getStore(Root, dl, OffsetValue.Value, Addr,
84338490
MachinePointerInfo(TrmpAddr, OffsetValue.Offset));
84348491
}
84358492

8436-
SDValue StoreToken = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains);
8493+
SDValue StoreToken;
8494+
if (HasCFBranch)
8495+
StoreToken = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChainsLPAD);
8496+
else
8497+
StoreToken = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains);
84378498

84388499
// The end of instructions of trampoline is the same as the static chain
84398500
// address that we computed earlier.
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc -O0 -mtriple=riscv64 -mattr=+experimental-zicfilp -verify-machineinstrs < %s \
3+
; RUN: | FileCheck -check-prefix=RV64 %s
4+
; RUN: llc -O0 -mtriple=riscv64-unknown-linux-gnu -mattr=+experimental-zicfilp -verify-machineinstrs < %s \
5+
; RUN: | FileCheck -check-prefix=RV64-LINUX %s
6+
7+
declare void @llvm.init.trampoline(ptr, ptr, ptr)
8+
declare ptr @llvm.adjust.trampoline(ptr)
9+
declare i64 @f(ptr nest, i64)
10+
11+
define i64 @test0(i64 %n, ptr %p) nounwind {
12+
; RV64-LABEL: test0:
13+
; RV64: # %bb.0:
14+
; RV64-NEXT: lpad 0
15+
; RV64-NEXT: addi sp, sp, -64
16+
; RV64-NEXT: sd ra, 56(sp) # 8-byte Folded Spill
17+
; RV64-NEXT: sd a0, 0(sp) # 8-byte Folded Spill
18+
; RV64-NEXT: lui a0, %hi(f)
19+
; RV64-NEXT: addi a0, a0, %lo(f)
20+
; RV64-NEXT: sd a0, 48(sp)
21+
; RV64-NEXT: sd a1, 40(sp)
22+
; RV64-NEXT: li a0, 951
23+
; RV64-NEXT: sw a0, 32(sp)
24+
; RV64-NEXT: li a0, 23
25+
; RV64-NEXT: sw a0, 16(sp)
26+
; RV64-NEXT: lui a0, 40
27+
; RV64-NEXT: addiw a0, a0, 103
28+
; RV64-NEXT: sw a0, 36(sp)
29+
; RV64-NEXT: lui a0, 5348
30+
; RV64-NEXT: addiw a0, a0, -509
31+
; RV64-NEXT: sw a0, 28(sp)
32+
; RV64-NEXT: lui a0, 7395
33+
; RV64-NEXT: addiw a0, a0, 643
34+
; RV64-NEXT: sw a0, 24(sp)
35+
; RV64-NEXT: lui a0, 1
36+
; RV64-NEXT: addiw a0, a0, -489
37+
; RV64-NEXT: sw a0, 20(sp)
38+
; RV64-NEXT: addi a1, sp, 40
39+
; RV64-NEXT: addi a0, sp, 16
40+
; RV64-NEXT: sd a0, 8(sp) # 8-byte Folded Spill
41+
; RV64-NEXT: call __clear_cache
42+
; RV64-NEXT: ld a0, 0(sp) # 8-byte Folded Reload
43+
; RV64-NEXT: ld a1, 8(sp) # 8-byte Folded Reload
44+
; RV64-NEXT: jalr a1
45+
; RV64-NEXT: ld ra, 56(sp) # 8-byte Folded Reload
46+
; RV64-NEXT: addi sp, sp, 64
47+
; RV64-NEXT: ret
48+
;
49+
; RV64-LINUX-LABEL: test0:
50+
; RV64-LINUX: # %bb.0:
51+
; RV64-LINUX-NEXT: lpad 0
52+
; RV64-LINUX-NEXT: addi sp, sp, -64
53+
; RV64-LINUX-NEXT: sd ra, 56(sp) # 8-byte Folded Spill
54+
; RV64-LINUX-NEXT: sd a0, 0(sp) # 8-byte Folded Spill
55+
; RV64-LINUX-NEXT: lui a0, %hi(f)
56+
; RV64-LINUX-NEXT: addi a0, a0, %lo(f)
57+
; RV64-LINUX-NEXT: sd a0, 48(sp)
58+
; RV64-LINUX-NEXT: sd a1, 40(sp)
59+
; RV64-LINUX-NEXT: li a0, 951
60+
; RV64-LINUX-NEXT: sw a0, 32(sp)
61+
; RV64-LINUX-NEXT: li a0, 23
62+
; RV64-LINUX-NEXT: sw a0, 16(sp)
63+
; RV64-LINUX-NEXT: lui a0, 40
64+
; RV64-LINUX-NEXT: addiw a0, a0, 103
65+
; RV64-LINUX-NEXT: sw a0, 36(sp)
66+
; RV64-LINUX-NEXT: lui a0, 5348
67+
; RV64-LINUX-NEXT: addiw a0, a0, -509
68+
; RV64-LINUX-NEXT: sw a0, 28(sp)
69+
; RV64-LINUX-NEXT: lui a0, 7395
70+
; RV64-LINUX-NEXT: addiw a0, a0, 643
71+
; RV64-LINUX-NEXT: sw a0, 24(sp)
72+
; RV64-LINUX-NEXT: lui a0, 1
73+
; RV64-LINUX-NEXT: addiw a0, a0, -489
74+
; RV64-LINUX-NEXT: sw a0, 20(sp)
75+
; RV64-LINUX-NEXT: addi a1, sp, 40
76+
; RV64-LINUX-NEXT: addi a0, sp, 16
77+
; RV64-LINUX-NEXT: sd a0, 8(sp) # 8-byte Folded Spill
78+
; RV64-LINUX-NEXT: li a2, 0
79+
; RV64-LINUX-NEXT: call __riscv_flush_icache
80+
; RV64-LINUX-NEXT: ld a0, 0(sp) # 8-byte Folded Reload
81+
; RV64-LINUX-NEXT: ld a1, 8(sp) # 8-byte Folded Reload
82+
; RV64-LINUX-NEXT: jalr a1
83+
; RV64-LINUX-NEXT: ld ra, 56(sp) # 8-byte Folded Reload
84+
; RV64-LINUX-NEXT: addi sp, sp, 64
85+
; RV64-LINUX-NEXT: ret
86+
%alloca = alloca [40 x i8], align 8
87+
call void @llvm.init.trampoline(ptr %alloca, ptr @f, ptr %p)
88+
%tramp = call ptr @llvm.adjust.trampoline(ptr %alloca)
89+
%ret = call i64 %tramp(i64 %n)
90+
ret i64 %ret
91+
}
92+
93+
!llvm.module.flags = !{!0}
94+
95+
!0 = !{i32 8, !"cf-protection-branch", i32 1}

0 commit comments

Comments
 (0)