diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 1f7cf7e857d0f..e7363f7aded72 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -8362,9 +8362,23 @@ SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op, // 16: // 24: // 32: - - constexpr unsigned StaticChainOffset = 16; - constexpr unsigned FunctionAddressOffset = 24; + // Offset with branch control flow protection enabled: + // 0: lpad + // 4: auipc t3, 0 + // 8: ld t2, 28(t3) + // 12: ld t3, 20(t3) + // 16: jalr t2 + // 20: + // 28: + // 36: + + const bool HasCFBranch = + Subtarget.hasStdExtZicfilp() && + DAG.getMachineFunction().getFunction().getParent()->getModuleFlag( + "cf-protection-branch"); + const unsigned StaticChainIdx = HasCFBranch ? 5 : 4; + const unsigned StaticChainOffset = StaticChainIdx * 4; + const unsigned FunctionAddressOffset = StaticChainOffset + 8; const MCSubtargetInfo *STI = getTargetMachine().getMCSubtargetInfo(); assert(STI); @@ -8376,38 +8390,70 @@ SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op, return Encoding; }; - SDValue OutChains[6]; - - uint32_t Encodings[] = { - // auipc t2, 0 - // Loads the current PC into t2. - GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X7).addImm(0)), - // ld t0, 24(t2) - // Loads the function address into t0. Note that we are using offsets - // pc-relative to the first instruction of the trampoline. - GetEncoding( - MCInstBuilder(RISCV::LD).addReg(RISCV::X5).addReg(RISCV::X7).addImm( - FunctionAddressOffset)), - // ld t2, 16(t2) - // Load the value of the static chain. - GetEncoding( - MCInstBuilder(RISCV::LD).addReg(RISCV::X7).addReg(RISCV::X7).addImm( - StaticChainOffset)), - // jalr t0 - // Jump to the function. - GetEncoding(MCInstBuilder(RISCV::JALR) - .addReg(RISCV::X0) - .addReg(RISCV::X5) - .addImm(0))}; + SmallVector OutChains; + + SmallVector Encodings; + if (!HasCFBranch) { + Encodings.append( + {// auipc t2, 0 + // Loads the current PC into t2. + GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X7).addImm(0)), + // ld t0, 24(t2) + // Loads the function address into t0. Note that we are using offsets + // pc-relative to the first instruction of the trampoline. + GetEncoding(MCInstBuilder(RISCV::LD) + .addReg(RISCV::X5) + .addReg(RISCV::X7) + .addImm(FunctionAddressOffset)), + // ld t2, 16(t2) + // Load the value of the static chain. + GetEncoding(MCInstBuilder(RISCV::LD) + .addReg(RISCV::X7) + .addReg(RISCV::X7) + .addImm(StaticChainOffset)), + // jalr t0 + // Jump to the function. + GetEncoding(MCInstBuilder(RISCV::JALR) + .addReg(RISCV::X0) + .addReg(RISCV::X5) + .addImm(0))}); + } else { + Encodings.append( + {// auipc x0, (lpad ) + // Landing pad. + GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X0).addImm(0)), + // auipc t3, 0 + // Loads the current PC into t3. + GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X28).addImm(0)), + // ld t2, (FunctionAddressOffset - 4)(t3) + // Loads the function address into t2. Note that we are using offsets + // pc-relative to the SECOND instruction of the trampoline. + GetEncoding(MCInstBuilder(RISCV::LD) + .addReg(RISCV::X7) + .addReg(RISCV::X28) + .addImm(FunctionAddressOffset - 4)), + // ld t3, (StaticChainOffset - 4)(t3) + // Load the value of the static chain. + GetEncoding(MCInstBuilder(RISCV::LD) + .addReg(RISCV::X28) + .addReg(RISCV::X28) + .addImm(StaticChainOffset - 4)), + // jalr t2 + // Software-guarded jump to the function. + GetEncoding(MCInstBuilder(RISCV::JALR) + .addReg(RISCV::X0) + .addReg(RISCV::X7) + .addImm(0))}); + } // Store encoded instructions. for (auto [Idx, Encoding] : llvm::enumerate(Encodings)) { SDValue Addr = Idx > 0 ? DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp, DAG.getConstant(Idx * 4, dl, MVT::i64)) : Trmp; - OutChains[Idx] = DAG.getTruncStore( + OutChains.push_back(DAG.getTruncStore( Root, dl, DAG.getConstant(Encoding, dl, MVT::i64), Addr, - MachinePointerInfo(TrmpAddr, Idx * 4), MVT::i32); + MachinePointerInfo(TrmpAddr, Idx * 4), MVT::i32)); } // Now store the variable part of the trampoline. @@ -8423,16 +8469,18 @@ SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op, {StaticChainOffset, StaticChain}, {FunctionAddressOffset, FunctionAddress}, }; - for (auto [Idx, OffsetValue] : llvm::enumerate(OffsetValues)) { + for (auto &OffsetValue : OffsetValues) { SDValue Addr = DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp, DAG.getConstant(OffsetValue.Offset, dl, MVT::i64)); OffsetValue.Addr = Addr; - OutChains[Idx + 4] = + OutChains.push_back( DAG.getStore(Root, dl, OffsetValue.Value, Addr, - MachinePointerInfo(TrmpAddr, OffsetValue.Offset)); + MachinePointerInfo(TrmpAddr, OffsetValue.Offset))); } + assert(OutChains.size() == StaticChainIdx + 2 && + "Size of OutChains mismatch"); SDValue StoreToken = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains); // The end of instructions of trampoline is the same as the static chain diff --git a/llvm/test/CodeGen/RISCV/rv64-trampoline-cfi.ll b/llvm/test/CodeGen/RISCV/rv64-trampoline-cfi.ll new file mode 100644 index 0000000000000..8a338a855c863 --- /dev/null +++ b/llvm/test/CodeGen/RISCV/rv64-trampoline-cfi.ll @@ -0,0 +1,99 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc -O0 -mtriple=riscv64 -mattr=+experimental-zicfilp -verify-machineinstrs < %s \ +; RUN: | FileCheck -check-prefix=RV64 %s +; RUN: llc -O0 -mtriple=riscv64-unknown-linux-gnu -mattr=+experimental-zicfilp -verify-machineinstrs < %s \ +; RUN: | FileCheck -check-prefix=RV64-LINUX %s + +declare void @llvm.init.trampoline(ptr, ptr, ptr) +declare ptr @llvm.adjust.trampoline(ptr) +declare i64 @f(ptr nest, i64) + +define i64 @test0(i64 %n, ptr %p) nounwind { +; RV64-LABEL: test0: +; RV64: # %bb.0: +; RV64-NEXT: lpad 0 +; RV64-NEXT: addi sp, sp, -64 +; RV64-NEXT: sd ra, 56(sp) # 8-byte Folded Spill +; RV64-NEXT: sd a0, 0(sp) # 8-byte Folded Spill +; RV64-NEXT: lui a0, %hi(f) +; RV64-NEXT: addi a0, a0, %lo(f) +; RV64-NEXT: sw a0, 44(sp) +; RV64-NEXT: srli a0, a0, 32 +; RV64-NEXT: sw a0, 48(sp) +; RV64-NEXT: sw a1, 36(sp) +; RV64-NEXT: srli a0, a1, 32 +; RV64-NEXT: sw a0, 40(sp) +; RV64-NEXT: li a0, 23 +; RV64-NEXT: sw a0, 16(sp) +; RV64-NEXT: lui a0, 56 +; RV64-NEXT: addi a0, a0, 103 +; RV64-NEXT: sw a0, 32(sp) +; RV64-NEXT: lui a0, 4324 +; RV64-NEXT: addi a0, a0, -509 +; RV64-NEXT: sw a0, 28(sp) +; RV64-NEXT: lui a0, 6371 +; RV64-NEXT: addi a0, a0, 899 +; RV64-NEXT: sw a0, 24(sp) +; RV64-NEXT: lui a0, 1 +; RV64-NEXT: addi a0, a0, -489 +; RV64-NEXT: sw a0, 20(sp) +; RV64-NEXT: addi a1, sp, 36 +; RV64-NEXT: addi a0, sp, 16 +; RV64-NEXT: sd a0, 8(sp) # 8-byte Folded Spill +; RV64-NEXT: call __clear_cache +; RV64-NEXT: ld a0, 0(sp) # 8-byte Folded Reload +; RV64-NEXT: ld a1, 8(sp) # 8-byte Folded Reload +; RV64-NEXT: jalr a1 +; RV64-NEXT: ld ra, 56(sp) # 8-byte Folded Reload +; RV64-NEXT: addi sp, sp, 64 +; RV64-NEXT: ret +; +; RV64-LINUX-LABEL: test0: +; RV64-LINUX: # %bb.0: +; RV64-LINUX-NEXT: lpad 0 +; RV64-LINUX-NEXT: addi sp, sp, -64 +; RV64-LINUX-NEXT: sd ra, 56(sp) # 8-byte Folded Spill +; RV64-LINUX-NEXT: sd a0, 0(sp) # 8-byte Folded Spill +; RV64-LINUX-NEXT: lui a0, %hi(f) +; RV64-LINUX-NEXT: addi a0, a0, %lo(f) +; RV64-LINUX-NEXT: sw a0, 44(sp) +; RV64-LINUX-NEXT: srli a0, a0, 32 +; RV64-LINUX-NEXT: sw a0, 48(sp) +; RV64-LINUX-NEXT: sw a1, 36(sp) +; RV64-LINUX-NEXT: srli a0, a1, 32 +; RV64-LINUX-NEXT: sw a0, 40(sp) +; RV64-LINUX-NEXT: li a0, 23 +; RV64-LINUX-NEXT: sw a0, 16(sp) +; RV64-LINUX-NEXT: lui a0, 56 +; RV64-LINUX-NEXT: addi a0, a0, 103 +; RV64-LINUX-NEXT: sw a0, 32(sp) +; RV64-LINUX-NEXT: lui a0, 4324 +; RV64-LINUX-NEXT: addi a0, a0, -509 +; RV64-LINUX-NEXT: sw a0, 28(sp) +; RV64-LINUX-NEXT: lui a0, 6371 +; RV64-LINUX-NEXT: addi a0, a0, 899 +; RV64-LINUX-NEXT: sw a0, 24(sp) +; RV64-LINUX-NEXT: lui a0, 1 +; RV64-LINUX-NEXT: addi a0, a0, -489 +; RV64-LINUX-NEXT: sw a0, 20(sp) +; RV64-LINUX-NEXT: addi a1, sp, 36 +; RV64-LINUX-NEXT: addi a0, sp, 16 +; RV64-LINUX-NEXT: sd a0, 8(sp) # 8-byte Folded Spill +; RV64-LINUX-NEXT: li a2, 0 +; RV64-LINUX-NEXT: call __riscv_flush_icache +; RV64-LINUX-NEXT: ld a0, 0(sp) # 8-byte Folded Reload +; RV64-LINUX-NEXT: ld a1, 8(sp) # 8-byte Folded Reload +; RV64-LINUX-NEXT: jalr a1 +; RV64-LINUX-NEXT: ld ra, 56(sp) # 8-byte Folded Reload +; RV64-LINUX-NEXT: addi sp, sp, 64 +; RV64-LINUX-NEXT: ret + %alloca = alloca [36 x i8], align 8 + call void @llvm.init.trampoline(ptr %alloca, ptr @f, ptr %p) + %tramp = call ptr @llvm.adjust.trampoline(ptr %alloca) + %ret = call i64 %tramp(i64 %n) + ret i64 %ret +} + +!llvm.module.flags = !{!0} + +!0 = !{i32 8, !"cf-protection-branch", i32 1}