Skip to content

Commit c5b53c2

Browse files
author
Yonghong Song
committed
[RFC][BPF] Support Jump Table
NOTE: We probably need cpu v5 or other flags to enable this feature. We can add it later when necessary. - Generate all jump tables in a single section named .jumptables. - Represent each jump table as a symbol: - value points to an offset within .jumptables; - size encodes jump table size in bytes. - Indirect jump is a gotox instruction: - dst register is an index within the table; - accompanied by a R_BPF_64_64 relocation pointing to a jump table symbol. clang -S: .LJTI0_0: .reloc 0, FK_SecRel_8, .BPF.JT.0.0 gotox r1 goto LBB0_2 LBB0_4: ... .section .jumptables,"",@progbits .L0_0_set_4 = ((LBB0_4-.LBPF.JX.0.0)>>3)-1 .L0_0_set_2 = ((LBB0_2-.LBPF.JX.0.0)>>3)-1 ... .BPF.JT.0.0: .long .L0_0_set_4 .long .L0_0_set_2 ... llvm-readelf -r --sections --symbols: Section Headers: [Nr] Name Type Address Off Size ES Flg Lk Inf Al ... [ 4] .jumptables PROGBITS 0000000000000000 000118 000100 00 0 0 1 ... Relocation section '.rel.text' at offset 0x2a8 contains 2 entries: Offset Info Type Symbol's Value Symbol's Name 0000000000000010 0000000300000001 R_BPF_64_64 0000000000000000 .BPF.JT.0.0 ... Symbol table '.symtab' contains 6 entries: Num: Value Size Type Bind Vis Ndx Name ... 2: 0000000000000000 112 FUNC GLOBAL DEFAULT 2 foo 3: 0000000000000000 128 NOTYPE GLOBAL DEFAULT 4 .BPF.JT.0.0 ... llvm-objdump -Sdr: 0000000000000000 <foo>: ... 2: gotox r1 0000000000000010: R_BPF_64_64 .BPF.JT.0.0 An option -bpf-min-jump-table-entries is implemented to control the minimum number of entries to use a jump table on BPF. The default value 4, but it can be changed with the following clang option clang ... -mllvm -bpf-min-jump-table-entries=6 where the number of jump table cases needs to be >= 6 in order to use jump table.
1 parent 09fb20e commit c5b53c2

15 files changed

+283
-5
lines changed

llvm/include/llvm/CodeGen/AsmPrinter.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,8 @@ class LLVM_ABI AsmPrinter : public MachineFunctionPass {
663663
MCSymbol *GetExternalSymbolSymbol(const Twine &Sym) const;
664664

665665
/// Return the symbol for the specified jump table entry.
666-
MCSymbol *GetJTISymbol(unsigned JTID, bool isLinkerPrivate = false) const;
666+
virtual MCSymbol *GetJTISymbol(unsigned JTID,
667+
bool isLinkerPrivate = false) const;
667668

668669
/// Return the symbol for the specified jump table .set
669670
/// FIXME: privatize to AsmPrinter.

llvm/lib/Target/BPF/AsmParser/BPFAsmParser.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ struct BPFOperand : public MCParsedAsmOperand {
234234
.Case("callx", true)
235235
.Case("goto", true)
236236
.Case("gotol", true)
237+
.Case("gotox", true)
237238
.Case("may_goto", true)
238239
.Case("*", true)
239240
.Case("exit", true)

llvm/lib/Target/BPF/BPFAsmPrinter.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,24 @@
1717
#include "BTFDebug.h"
1818
#include "MCTargetDesc/BPFInstPrinter.h"
1919
#include "TargetInfo/BPFTargetInfo.h"
20+
#include "llvm/BinaryFormat/ELF.h"
2021
#include "llvm/CodeGen/AsmPrinter.h"
2122
#include "llvm/CodeGen/MachineConstantPool.h"
2223
#include "llvm/CodeGen/MachineInstr.h"
24+
#include "llvm/CodeGen/MachineJumpTableInfo.h"
2325
#include "llvm/CodeGen/MachineModuleInfo.h"
26+
#include "llvm/CodeGen/TargetLowering.h"
2427
#include "llvm/IR/Module.h"
2528
#include "llvm/MC/MCAsmInfo.h"
29+
#include "llvm/MC/MCExpr.h"
2630
#include "llvm/MC/MCInst.h"
2731
#include "llvm/MC/MCStreamer.h"
2832
#include "llvm/MC/MCSymbol.h"
33+
#include "llvm/MC/MCSymbolELF.h"
2934
#include "llvm/MC/TargetRegistry.h"
3035
#include "llvm/Support/Compiler.h"
3136
#include "llvm/Support/raw_ostream.h"
37+
#include "llvm/Target/TargetLoweringObjectFile.h"
3238
using namespace llvm;
3339

3440
#define DEBUG_TYPE "asm-printer"
@@ -49,6 +55,9 @@ class BPFAsmPrinter : public AsmPrinter {
4955
const char *ExtraCode, raw_ostream &O) override;
5056

5157
void emitInstruction(const MachineInstr *MI) override;
58+
virtual MCSymbol *GetJTISymbol(unsigned JTID,
59+
bool isLinkerPrivate = false) const override;
60+
virtual void emitJumpTableInfo() override;
5261

5362
static char ID;
5463

@@ -150,6 +159,74 @@ void BPFAsmPrinter::emitInstruction(const MachineInstr *MI) {
150159
EmitToStreamer(*OutStreamer, TmpInst);
151160
}
152161

162+
MCSymbol *BPFAsmPrinter::GetJTISymbol(unsigned JTID,
163+
bool isLinkerPrivate) const {
164+
SmallString<60> Name;
165+
raw_svector_ostream(Name)
166+
<< "BPF.JT." << MF->getFunctionNumber() << '.' << JTID;
167+
MCSymbol *S = OutContext.getOrCreateSymbol(Name);
168+
if (auto *ES = dyn_cast<MCSymbolELF>(S))
169+
ES->setBinding(ELF::STB_GLOBAL);
170+
return S;
171+
}
172+
173+
void BPFAsmPrinter::emitJumpTableInfo() {
174+
const MachineJumpTableInfo *MJTI = MF->getJumpTableInfo();
175+
if (!MJTI)
176+
return;
177+
178+
const std::vector<MachineJumpTableEntry> &JT = MJTI->getJumpTables();
179+
if (JT.empty())
180+
return;
181+
182+
const TargetLoweringObjectFile &TLOF = getObjFileLowering();
183+
const Function &F = MF->getFunction();
184+
MCSection *JTS = TLOF.getSectionForJumpTable(F, TM);
185+
assert(MJTI->getEntryKind() == MachineJumpTableInfo::EK_LabelDifference32);
186+
unsigned EntrySize = MJTI->getEntrySize(getDataLayout());
187+
OutStreamer->switchSection(JTS);
188+
for (unsigned JTI = 0; JTI < JT.size(); JTI++) {
189+
ArrayRef<MachineBasicBlock *> JTBBs = JT[JTI].MBBs;
190+
if (JTBBs.empty())
191+
continue;
192+
193+
SmallPtrSet<const MachineBasicBlock *, 16> EmittedSets;
194+
const TargetLowering *TLI = MF->getSubtarget().getTargetLowering();
195+
const MCExpr *Base = TLI->getPICJumpTableRelocBaseExpr(MF, JTI, OutContext);
196+
for (const MachineBasicBlock *MBB : JTBBs) {
197+
if (!EmittedSets.insert(MBB).second)
198+
continue;
199+
200+
// Offset from gotox to target basic block expressed in number
201+
// of instructions, e.g.:
202+
//
203+
// .L0_0_set_4 = ((LBB0_4 - .LBPF.JX.0.0) >> 3) - 1
204+
const MCExpr *LHS = MCSymbolRefExpr::create(MBB->getSymbol(), OutContext);
205+
OutStreamer->emitAssignment(
206+
GetJTSetSymbol(JTI, MBB->getNumber()),
207+
MCBinaryExpr::createSub(
208+
MCBinaryExpr::createAShr(
209+
MCBinaryExpr::createSub(LHS, Base, OutContext),
210+
MCConstantExpr::create(3, OutContext), OutContext),
211+
MCConstantExpr::create(1, OutContext), OutContext));
212+
}
213+
// BPF.JT.0.0:
214+
// .long .L0_0_set_4
215+
// .long .L0_0_set_2
216+
// ...
217+
// .size BPF.JT.0.0, 128
218+
MCSymbol *JTStart = GetJTISymbol(JTI);
219+
OutStreamer->emitLabel(JTStart);
220+
for (const MachineBasicBlock *MBB : JTBBs) {
221+
MCSymbol *SetSymbol = GetJTSetSymbol(JTI, MBB->getNumber());
222+
const MCExpr *V = MCSymbolRefExpr::create(SetSymbol, OutContext);
223+
OutStreamer->emitValue(V, EntrySize);
224+
}
225+
const MCExpr *JTSize = MCConstantExpr::create(JTBBs.size() * 4, OutContext);
226+
OutStreamer->emitELFSize(JTStart, JTSize);
227+
}
228+
}
229+
153230
char BPFAsmPrinter::ID = 0;
154231

155232
INITIALIZE_PASS(BPFAsmPrinter, "bpf-asm-printer", "BPF Assembly Printer", false,

llvm/lib/Target/BPF/BPFISelLowering.cpp

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@
1818
#include "llvm/CodeGen/MachineFrameInfo.h"
1919
#include "llvm/CodeGen/MachineFunction.h"
2020
#include "llvm/CodeGen/MachineInstrBuilder.h"
21+
#include "llvm/CodeGen/MachineJumpTableInfo.h"
2122
#include "llvm/CodeGen/MachineRegisterInfo.h"
2223
#include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
2324
#include "llvm/CodeGen/ValueTypes.h"
2425
#include "llvm/IR/DIBuilder.h"
2526
#include "llvm/IR/DiagnosticInfo.h"
2627
#include "llvm/IR/DiagnosticPrinter.h"
2728
#include "llvm/IR/Module.h"
29+
#include "llvm/MC/MCAsmInfo.h"
2830
#include "llvm/Support/Debug.h"
2931
#include "llvm/Support/ErrorHandling.h"
3032
#include "llvm/Support/MathExtras.h"
@@ -38,6 +40,10 @@ static cl::opt<bool> BPFExpandMemcpyInOrder("bpf-expand-memcpy-in-order",
3840
cl::Hidden, cl::init(false),
3941
cl::desc("Expand memcpy into load/store pairs in order"));
4042

43+
static cl::opt<unsigned> BPFMinimumJumpTableEntries(
44+
"bpf-min-jump-table-entries", cl::init(4), cl::Hidden,
45+
cl::desc("Set minimum number of entries to use a jump table on BPF"));
46+
4147
static void fail(const SDLoc &DL, SelectionDAG &DAG, const Twine &Msg,
4248
SDValue Val = {}) {
4349
std::string Str;
@@ -66,8 +72,7 @@ BPFTargetLowering::BPFTargetLowering(const TargetMachine &TM,
6672
setStackPointerRegisterToSaveRestore(BPF::R11);
6773

6874
setOperationAction(ISD::BR_CC, MVT::i64, Custom);
69-
setOperationAction(ISD::BR_JT, MVT::Other, Expand);
70-
setOperationAction(ISD::BRIND, MVT::Other, Expand);
75+
setOperationAction(ISD::BR_JT, MVT::Other, Custom);
7176
setOperationAction(ISD::BRCOND, MVT::Other, Expand);
7277

7378
setOperationAction(ISD::TRAP, MVT::Other, Custom);
@@ -159,6 +164,7 @@ BPFTargetLowering::BPFTargetLowering(const TargetMachine &TM,
159164

160165
setBooleanContents(ZeroOrOneBooleanContent);
161166
setMaxAtomicSizeInBitsSupported(64);
167+
setMinimumJumpTableEntries(BPFMinimumJumpTableEntries);
162168

163169
// Function alignments
164170
setMinFunctionAlignment(Align(8));
@@ -332,6 +338,8 @@ SDValue BPFTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
332338
return LowerATOMIC_LOAD_STORE(Op, DAG);
333339
case ISD::TRAP:
334340
return LowerTRAP(Op, DAG);
341+
case ISD::BR_JT:
342+
return LowerBR_JT(Op, DAG);
335343
}
336344
}
337345

@@ -780,6 +788,16 @@ SDValue BPFTargetLowering::LowerTRAP(SDValue Op, SelectionDAG &DAG) const {
780788
return LowerCall(CLI, InVals);
781789
}
782790

791+
SDValue BPFTargetLowering::LowerBR_JT(SDValue Op, SelectionDAG &DAG) const {
792+
SDValue Chain = Op->getOperand(0);
793+
SDValue Table = Op->getOperand(1);
794+
SDValue Index = Op->getOperand(2);
795+
JumpTableSDNode *JT = cast<JumpTableSDNode>(Table);
796+
SDLoc DL(Op);
797+
SDValue TargetJT = DAG.getTargetJumpTable(JT->getIndex(), MVT::i32);
798+
return DAG.getNode(BPFISD::BPF_BR_JT, DL, MVT::Other, Chain, TargetJT, Index);
799+
}
800+
783801
const char *BPFTargetLowering::getTargetNodeName(unsigned Opcode) const {
784802
switch ((BPFISD::NodeType)Opcode) {
785803
case BPFISD::FIRST_NUMBER:
@@ -796,6 +814,8 @@ const char *BPFTargetLowering::getTargetNodeName(unsigned Opcode) const {
796814
return "BPFISD::Wrapper";
797815
case BPFISD::MEMCPY:
798816
return "BPFISD::MEMCPY";
817+
case BPFISD::BPF_BR_JT:
818+
return "BPFISD::BPF_BR_JT";
799819
}
800820
return nullptr;
801821
}
@@ -1069,3 +1089,21 @@ bool BPFTargetLowering::isLegalAddressingMode(const DataLayout &DL,
10691089

10701090
return true;
10711091
}
1092+
1093+
MCSymbol *BPFTargetLowering::getJXAnchorSymbol(const MachineFunction *MF,
1094+
unsigned JTI) {
1095+
const MCAsmInfo *MAI = MF->getContext().getAsmInfo();
1096+
SmallString<60> Name;
1097+
raw_svector_ostream(Name) << MAI->getPrivateGlobalPrefix() << "BPF.JX."
1098+
<< MF->getFunctionNumber() << '.' << JTI;
1099+
return MF->getContext().getOrCreateSymbol(Name);
1100+
}
1101+
1102+
unsigned BPFTargetLowering::getJumpTableEncoding() const {
1103+
return MachineJumpTableInfo::EK_LabelDifference32;
1104+
}
1105+
1106+
const MCExpr *BPFTargetLowering::getPICJumpTableRelocBaseExpr(
1107+
const MachineFunction *MF, unsigned JTI, MCContext &Ctx) const {
1108+
return MCSymbolRefExpr::create(getJXAnchorSymbol(MF, JTI), Ctx);
1109+
}

llvm/lib/Target/BPF/BPFISelLowering.h

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ enum NodeType : unsigned {
2828
SELECT_CC,
2929
BR_CC,
3030
Wrapper,
31-
MEMCPY
31+
MEMCPY,
32+
BPF_BR_JT,
3233
};
3334
}
3435

@@ -66,6 +67,31 @@ class BPFTargetLowering : public TargetLowering {
6667

6768
MVT getScalarShiftAmountTy(const DataLayout &, EVT) const override;
6869

70+
// Always emit EK_LabelDifference32, computed as difference between
71+
// JX instruction location and target basic block label.
72+
virtual unsigned getJumpTableEncoding() const override;
73+
74+
// This is a label for JX instructions, used for jump table offsets
75+
// computation, e.g.:
76+
//
77+
// .LBPF.JX.0.0: <------- this is the anchor
78+
// .reloc 0, FK_SecRel_8, BPF.JT.0.0
79+
// gotox r1
80+
// ...
81+
// .section .jumptables,"",@progbits
82+
// .L0_0_set_7 = ((LBB0_7-.LBPF.JX.0.0)>>3)-1
83+
// ...
84+
// BPF.JT.0.0: <------- JT definition
85+
// .long .L0_0_set_7
86+
// ...
87+
static MCSymbol *getJXAnchorSymbol(const MachineFunction *MF, unsigned JTI);
88+
89+
// Refers to a symbol returned by getJXAnchorSymbol(), used by
90+
// AsmPrinter::emitJumpTableInfo() to define the .L0_0_set_7 etc above.
91+
virtual const MCExpr *
92+
getPICJumpTableRelocBaseExpr(const MachineFunction *MF, unsigned JTI,
93+
MCContext &Ctx) const override;
94+
6995
private:
7096
// Control Instruction Selection Features
7197
bool HasAlu32;
@@ -81,6 +107,7 @@ class BPFTargetLowering : public TargetLowering {
81107
SDValue LowerConstantPool(SDValue Op, SelectionDAG &DAG) const;
82108
SDValue LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const;
83109
SDValue LowerTRAP(SDValue Op, SelectionDAG &DAG) const;
110+
SDValue LowerBR_JT(SDValue Op, SelectionDAG &DAG) const;
84111

85112
template <class NodeTy>
86113
SDValue getAddr(NodeTy *N, SelectionDAG &DAG, unsigned Flags = 0) const;

llvm/lib/Target/BPF/BPFInstrInfo.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,10 @@ bool BPFInstrInfo::analyzeBranch(MachineBasicBlock &MBB,
181181
if (!isUnpredicatedTerminator(*I))
182182
break;
183183

184+
// If a JX insn, we're done.
185+
if (I->getOpcode() == BPF::JX)
186+
break;
187+
184188
// A terminator that isn't a branch can't easily be handled
185189
// by this analysis.
186190
if (!I->isBranch())
@@ -259,3 +263,11 @@ unsigned BPFInstrInfo::removeBranch(MachineBasicBlock &MBB,
259263

260264
return Count;
261265
}
266+
267+
int BPFInstrInfo::getJumpTableIndex(const MachineInstr &MI) const {
268+
if (MI.getOpcode() != BPF::JX)
269+
return -1;
270+
const MachineOperand &MO = MI.getOperand(1);
271+
assert(MO.isJTI() && "JX operand #0 should be isJTI");
272+
return MO.getIndex();
273+
}

llvm/lib/Target/BPF/BPFInstrInfo.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ class BPFInstrInfo : public BPFGenInstrInfo {
5858
MachineBasicBlock *FBB, ArrayRef<MachineOperand> Cond,
5959
const DebugLoc &DL,
6060
int *BytesAdded = nullptr) const override;
61+
62+
int getJumpTableIndex(const MachineInstr &MI) const override;
63+
6164
private:
6265
void expandMEMCPY(MachineBasicBlock::iterator) const;
6366

llvm/lib/Target/BPF/BPFInstrInfo.td

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ def SDT_BPFMEMCPY : SDTypeProfile<0, 4, [SDTCisVT<0, i64>,
3131
SDTCisVT<1, i64>,
3232
SDTCisVT<2, i64>,
3333
SDTCisVT<3, i64>]>;
34+
def SDT_BPFBrJt : SDTypeProfile<0, 2, [SDTCisVT<0, i32>, // jump table
35+
SDTCisVT<1, i64>]>; // index
3436

3537
def BPFcall : SDNode<"BPFISD::CALL", SDT_BPFCall,
3638
[SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
@@ -49,6 +51,9 @@ def BPFWrapper : SDNode<"BPFISD::Wrapper", SDT_BPFWrapper>;
4951
def BPFmemcpy : SDNode<"BPFISD::MEMCPY", SDT_BPFMEMCPY,
5052
[SDNPHasChain, SDNPInGlue, SDNPOutGlue,
5153
SDNPMayStore, SDNPMayLoad]>;
54+
def BPFBrJt : SDNode<"BPFISD::BPF_BR_JT", SDT_BPFBrJt,
55+
[SDNPHasChain]>;
56+
5257
def BPFIsLittleEndian : Predicate<"Subtarget->isLittleEndian()">;
5358
def BPFIsBigEndian : Predicate<"!Subtarget->isLittleEndian()">;
5459
def BPFHasALU32 : Predicate<"Subtarget->getHasAlu32()">;
@@ -183,6 +188,15 @@ class TYPE_LD_ST<bits<3> mode, bits<2> size,
183188
let Inst{60-59} = size;
184189
}
185190

191+
// For indirect jump
192+
class TYPE_IND_JMP<bits<4> op, bits<1> srctype,
193+
dag outs, dag ins, string asmstr, list<dag> pattern>
194+
: InstBPF<outs, ins, asmstr, pattern> {
195+
196+
let Inst{63-60} = op;
197+
let Inst{59} = srctype;
198+
}
199+
186200
// jump instructions
187201
class JMP_RR<BPFJumpOp Opc, string OpcodeStr, PatLeaf Cond>
188202
: TYPE_ALU_JMP<Opc.Value, BPF_X.Value,
@@ -216,6 +230,18 @@ class JMP_RI<BPFJumpOp Opc, string OpcodeStr, PatLeaf Cond>
216230
let BPFClass = BPF_JMP;
217231
}
218232

233+
class JMP_IND<BPFJumpOp Opc, string OpcodeStr, list<dag> Pattern>
234+
: TYPE_ALU_JMP<Opc.Value, BPF_X.Value,
235+
(outs),
236+
(ins GPR:$dst, i32imm:$jt),
237+
!strconcat(OpcodeStr, " $dst"),
238+
Pattern> {
239+
bits<4> dst;
240+
241+
let Inst{51-48} = dst;
242+
let BPFClass = BPF_JMP;
243+
}
244+
219245
class JMP_JCOND<BPFJumpOp Opc, string OpcodeStr, list<dag> Pattern>
220246
: TYPE_ALU_JMP<Opc.Value, BPF_K.Value,
221247
(outs),
@@ -281,6 +307,10 @@ defm JSLT : J<BPF_JSLT, "s<", BPF_CC_LT, BPF_CC_LT_32>;
281307
defm JSLE : J<BPF_JSLE, "s<=", BPF_CC_LE, BPF_CC_LE_32>;
282308
defm JSET : J<BPF_JSET, "&", NoCond, NoCond>;
283309
def JCOND : JMP_JCOND<BPF_JCOND, "may_goto", []>;
310+
311+
let isIndirectBranch = 1 in {
312+
def JX : JMP_IND<BPF_JA, "gotox", [(BPFBrJt tjumptable:$jt, i64:$dst)]>;
313+
}
284314
}
285315

286316
// ALU instructions

0 commit comments

Comments
 (0)