Skip to content

Commit c3fb2e1

Browse files
[BPF] Support Jump Table (#149715)
Add jump table (switch statement and computed goto) support for BPF backend. A `gotox <reg>` insn is implemented and the `<reg>` holds the target insn where the gotox will go. For a switch statement like ``` ... switch (ctx->x) { case 1: ret_user = 18; break; case 20: ret_user = 6; break; case 16: ret_user = 9; break; case 6: ret_user = 16; break; case 8: ret_user = 14; break; case 30: ret_user = 2; break; default: ret_user = 1; break; } ... ``` and the final binary ``` The final binary: 4: 67 01 00 00 03 00 00 00 r1 <<= 0x3 5: 18 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 r2 = 0x0 ll 0000000000000028: R_BPF_64_64 BPF.JT.0.0 7: 0f 12 00 00 00 00 00 00 r2 += r1 ... Symbol table: 4: 0000000000000000 240 OBJECT GLOBAL DEFAULT 4 BPF.JT.0.0 5: 0000000000000000 4 OBJECT GLOBAL DEFAULT 6 ret_user 6: 0000000000000000 0 NOTYPE GLOBAL DEFAULT UND bar 7: 00000000000000f0 256 OBJECT GLOBAL DEFAULT 4 BPF.JT.0.1 and [ 4] .jumptables PROGBITS 0000000000000000 0001c8 0001f0 00 0 0 1 ``` Note that for the above example, `-mllvm -bpf-min-jump-table-entries=5` should be in compilation flags as the current default bpf-min-jump-table-entries is 13. For example. ``` clang --target=bpf -mcpu=v4 -O2 -mllvm -bpf-min-jump-table-entries=5 -S -g test.c ``` For computed goto like ``` int foo(int a, int b) { __label__ l1, l2, l3, l4; void *jt1[] = {[0]=&&l1, [1]=&&l2}; void *jt2[] = {[0]=&&l3, [1]=&&l4}; int ret = 0; goto *jt1[a % 2]; l1: ret += 1; l2: ret += 3; goto *jt2[b % 2]; l3: ret += 5; l4: ret += 7; return ret; } ``` The final binary: ``` 12: bf 23 20 00 00 00 00 00 r3 = (s32)r2 13: 67 03 00 00 03 00 00 00 r3 <<= 0x3 14: 18 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 r2 = 0x0 ll 0000000000000070: R_BPF_64_64 BPF.JT.0.0 16: 0f 32 00 00 00 00 00 00 r2 += r3 17: bf 11 20 00 00 00 00 00 r1 = (s32)r1 18: 67 01 00 00 03 00 00 00 r1 <<= 0x3 19: 18 03 00 00 00 00 00 00 00 00 00 00 00 00 00 00 r3 = 0x0 ll 0000000000000098: R_BPF_64_64 BPF.JT.0.1 21: 0f 13 00 00 00 00 00 00 r3 += r1 [ 4] .jumptables PROGBITS 0000000000000000 000160 000020 00 0 0 1 4: 0000000000000000 16 OBJECT GLOBAL DEFAULT 4 BPF.JT.0.0 5: 0000000000000010 16 OBJECT GLOBAL DEFAULT 4 BPF.JT.0.1 ``` A more complicated test with both switch-statement triggered jump table and compute gotos: ``` $ cat test3.c struct simple_ctx { int x; int y; int z; }; int ret_user, ret_user2; void bar(void); int foo(struct simple_ctx *ctx, struct simple_ctx *ctx2, int a, int b) { __label__ l1, l2, l3, l4; void *jt1[] = {[0]=&&l1, [1]=&&l2}; void *jt2[] = {[0]=&&l3, [1]=&&l4}; int ret = 0; goto *jt1[a % 2]; l1: ret += 1; l2: ret += 3; goto *jt2[b % 2]; l3: ret += 5; l4: ret += 7; bar(); switch (ctx->x) { case 1: ret_user = 18; break; case 20: ret_user = 6; break; case 16: ret_user = 9; break; case 6: ret_user = 16; break; case 8: ret_user = 14; break; case 30: ret_user = 2; break; default: ret_user = 1; break; } return ret; } ``` Compile with ``` clang --target=bpf -mcpu=v4 -O2 -S test3.c clang --target=bpf -mcpu=v4 -O2 -c test3.c ``` The binary: ``` /* For computed goto */ 13: bf 42 20 00 00 00 00 00 r2 = (s32)r4 14: 67 02 00 00 03 00 00 00 r2 <<= 0x3 15: 18 01 00 00 00 00 00 00 00 00 00 00 00 00 00 00 r1 = 0x0 ll 0000000000000078: R_BPF_64_64 BPF.JT.0.1 17: 0f 21 00 00 00 00 00 00 r1 += r2 18: bf 32 20 00 00 00 00 00 r2 = (s32)r3 19: 67 02 00 00 03 00 00 00 r2 <<= 0x3 20: 18 03 00 00 00 00 00 00 00 00 00 00 00 00 00 00 r3 = 0x0 ll 00000000000000a0: R_BPF_64_64 BPF.JT.0.2 22: 0f 23 00 00 00 00 00 00 r3 += r2 /* For switch statement */ 39: 67 01 00 00 03 00 00 00 r1 <<= 0x3 40: 18 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 r2 = 0x0 ll 0000000000000140: R_BPF_64_64 BPF.JT.0.0 42: 0f 12 00 00 00 00 00 00 r2 += r1 ``` You can see jump table symbols are all different.
1 parent 0864965 commit c3fb2e1

19 files changed

+716
-39
lines changed

llvm/lib/Target/BPF/AsmParser/BPFAsmParser.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ struct BPFOperand : public MCParsedAsmOperand {
234234
.Case("callx", true)
235235
.Case("goto", true)
236236
.Case("gotol", true)
237+
.Case("gotox", true)
237238
.Case("may_goto", true)
238239
.Case("*", true)
239240
.Case("exit", true)

llvm/lib/Target/BPF/BPFAsmPrinter.cpp

Lines changed: 90 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,52 +11,35 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313

14+
#include "BPFAsmPrinter.h"
1415
#include "BPF.h"
1516
#include "BPFInstrInfo.h"
1617
#include "BPFMCInstLower.h"
1718
#include "BTFDebug.h"
1819
#include "MCTargetDesc/BPFInstPrinter.h"
1920
#include "TargetInfo/BPFTargetInfo.h"
21+
#include "llvm/BinaryFormat/ELF.h"
2022
#include "llvm/CodeGen/AsmPrinter.h"
2123
#include "llvm/CodeGen/MachineConstantPool.h"
2224
#include "llvm/CodeGen/MachineInstr.h"
25+
#include "llvm/CodeGen/MachineJumpTableInfo.h"
2326
#include "llvm/CodeGen/MachineModuleInfo.h"
27+
#include "llvm/CodeGen/TargetLowering.h"
2428
#include "llvm/IR/Module.h"
2529
#include "llvm/MC/MCAsmInfo.h"
30+
#include "llvm/MC/MCExpr.h"
2631
#include "llvm/MC/MCInst.h"
2732
#include "llvm/MC/MCStreamer.h"
2833
#include "llvm/MC/MCSymbol.h"
34+
#include "llvm/MC/MCSymbolELF.h"
2935
#include "llvm/MC/TargetRegistry.h"
3036
#include "llvm/Support/Compiler.h"
3137
#include "llvm/Support/raw_ostream.h"
38+
#include "llvm/Target/TargetLoweringObjectFile.h"
3239
using namespace llvm;
3340

3441
#define DEBUG_TYPE "asm-printer"
3542

36-
namespace {
37-
class BPFAsmPrinter : public AsmPrinter {
38-
public:
39-
explicit BPFAsmPrinter(TargetMachine &TM,
40-
std::unique_ptr<MCStreamer> Streamer)
41-
: AsmPrinter(TM, std::move(Streamer), ID), BTF(nullptr) {}
42-
43-
StringRef getPassName() const override { return "BPF Assembly Printer"; }
44-
bool doInitialization(Module &M) override;
45-
void printOperand(const MachineInstr *MI, int OpNum, raw_ostream &O);
46-
bool PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
47-
const char *ExtraCode, raw_ostream &O) override;
48-
bool PrintAsmMemoryOperand(const MachineInstr *MI, unsigned OpNum,
49-
const char *ExtraCode, raw_ostream &O) override;
50-
51-
void emitInstruction(const MachineInstr *MI) override;
52-
53-
static char ID;
54-
55-
private:
56-
BTFDebug *BTF;
57-
};
58-
} // namespace
59-
6043
bool BPFAsmPrinter::doInitialization(Module &M) {
6144
AsmPrinter::doInitialization(M);
6245

@@ -69,6 +52,45 @@ bool BPFAsmPrinter::doInitialization(Module &M) {
6952
return false;
7053
}
7154

55+
const BPFTargetMachine &BPFAsmPrinter::getBTM() const {
56+
return static_cast<const BPFTargetMachine &>(TM);
57+
}
58+
59+
bool BPFAsmPrinter::doFinalization(Module &M) {
60+
// Remove unused globals which are previously used for jump table.
61+
const BPFSubtarget *Subtarget = getBTM().getSubtargetImpl();
62+
if (Subtarget->hasGotox()) {
63+
std::vector<GlobalVariable *> Targets;
64+
for (GlobalVariable &Global : M.globals()) {
65+
if (Global.getLinkage() != GlobalValue::PrivateLinkage)
66+
continue;
67+
if (!Global.isConstant() || !Global.hasInitializer())
68+
continue;
69+
70+
Constant *CV = dyn_cast<Constant>(Global.getInitializer());
71+
if (!CV)
72+
continue;
73+
ConstantArray *CA = dyn_cast<ConstantArray>(CV);
74+
if (!CA)
75+
continue;
76+
77+
for (unsigned i = 1, e = CA->getNumOperands(); i != e; ++i) {
78+
if (!dyn_cast<BlockAddress>(CA->getOperand(i)))
79+
continue;
80+
}
81+
Targets.push_back(&Global);
82+
}
83+
84+
for (GlobalVariable *GV : Targets) {
85+
GV->replaceAllUsesWith(PoisonValue::get(GV->getType()));
86+
GV->dropAllReferences();
87+
GV->eraseFromParent();
88+
}
89+
}
90+
91+
return AsmPrinter::doFinalization(M);
92+
}
93+
7294
void BPFAsmPrinter::printOperand(const MachineInstr *MI, int OpNum,
7395
raw_ostream &O) {
7496
const MachineOperand &MO = MI->getOperand(OpNum);
@@ -150,6 +172,50 @@ void BPFAsmPrinter::emitInstruction(const MachineInstr *MI) {
150172
EmitToStreamer(*OutStreamer, TmpInst);
151173
}
152174

175+
MCSymbol *BPFAsmPrinter::getJTPublicSymbol(unsigned JTI) {
176+
SmallString<60> Name;
177+
raw_svector_ostream(Name)
178+
<< "BPF.JT." << MF->getFunctionNumber() << '.' << JTI;
179+
MCSymbol *S = OutContext.getOrCreateSymbol(Name);
180+
if (auto *ES = static_cast<MCSymbolELF *>(S)) {
181+
ES->setBinding(ELF::STB_GLOBAL);
182+
ES->setType(ELF::STT_OBJECT);
183+
}
184+
return S;
185+
}
186+
187+
void BPFAsmPrinter::emitJumpTableInfo() {
188+
const MachineJumpTableInfo *MJTI = MF->getJumpTableInfo();
189+
if (!MJTI)
190+
return;
191+
192+
const std::vector<MachineJumpTableEntry> &JT = MJTI->getJumpTables();
193+
if (JT.empty())
194+
return;
195+
196+
const TargetLoweringObjectFile &TLOF = getObjFileLowering();
197+
const Function &F = MF->getFunction();
198+
MCSection *JTS = TLOF.getSectionForJumpTable(F, TM);
199+
assert(MJTI->getEntryKind() == MachineJumpTableInfo::EK_BlockAddress);
200+
unsigned EntrySize = MJTI->getEntrySize(getDataLayout());
201+
OutStreamer->switchSection(JTS);
202+
for (unsigned JTI = 0; JTI < JT.size(); JTI++) {
203+
ArrayRef<MachineBasicBlock *> JTBBs = JT[JTI].MBBs;
204+
if (JTBBs.empty())
205+
continue;
206+
207+
MCSymbol *JTStart = getJTPublicSymbol(JTI);
208+
OutStreamer->emitLabel(JTStart);
209+
for (const MachineBasicBlock *MBB : JTBBs) {
210+
const MCExpr *LHS = MCSymbolRefExpr::create(MBB->getSymbol(), OutContext);
211+
OutStreamer->emitValue(LHS, EntrySize);
212+
}
213+
const MCExpr *JTSize =
214+
MCConstantExpr::create(JTBBs.size() * EntrySize, OutContext);
215+
OutStreamer->emitELFSize(JTStart, JTSize);
216+
}
217+
}
218+
153219
char BPFAsmPrinter::ID = 0;
154220

155221
INITIALIZE_PASS(BPFAsmPrinter, "bpf-asm-printer", "BPF Assembly Printer", false,
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
//===-- BPFFrameLowering.h - Define frame lowering for BPF -----*- C++ -*--===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef LLVM_LIB_TARGET_BPF_BPFASMPRINTER_H
10+
#define LLVM_LIB_TARGET_BPF_BPFASMPRINTER_H
11+
12+
#include "BPFTargetMachine.h"
13+
#include "BTFDebug.h"
14+
#include "llvm/CodeGen/AsmPrinter.h"
15+
16+
namespace llvm {
17+
18+
class BPFAsmPrinter : public AsmPrinter {
19+
public:
20+
explicit BPFAsmPrinter(TargetMachine &TM,
21+
std::unique_ptr<MCStreamer> Streamer)
22+
: AsmPrinter(TM, std::move(Streamer), ID), BTF(nullptr), TM(TM) {}
23+
24+
StringRef getPassName() const override { return "BPF Assembly Printer"; }
25+
bool doInitialization(Module &M) override;
26+
bool doFinalization(Module &M) override;
27+
void printOperand(const MachineInstr *MI, int OpNum, raw_ostream &O);
28+
bool PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
29+
const char *ExtraCode, raw_ostream &O) override;
30+
bool PrintAsmMemoryOperand(const MachineInstr *MI, unsigned OpNum,
31+
const char *ExtraCode, raw_ostream &O) override;
32+
33+
void emitInstruction(const MachineInstr *MI) override;
34+
MCSymbol *getJTPublicSymbol(unsigned JTI);
35+
virtual void emitJumpTableInfo() override;
36+
37+
static char ID;
38+
39+
private:
40+
BTFDebug *BTF;
41+
TargetMachine &TM;
42+
43+
const BPFTargetMachine &getBTM() const;
44+
};
45+
46+
} // namespace llvm
47+
48+
#endif /* LLVM_LIB_TARGET_BPF_BPFASMPRINTER_H */

0 commit comments

Comments
 (0)