Skip to content

Commit 7752755

Browse files
author
Yonghong Song
committed
[BPF] Add jump table support with switch statements and computed goto
NOTE 1: We probably need cpu v5 or other flags to enable this feature. We can add it later when necessary. Let us use cpu v4 for now. NOTE 2: An option -bpf-min-jump-table-entries is implemented to control the minimum number of entries to use a jump table on BPF. The default value 5 and this is to make it easy to test. Eventually we will increase min jump table entries to be 13. This patch adds jump table support. A new insn 'gotox <reg>' is added to allow goto through a register. The register represents the address in the current section. Example 1 (switch statement): ============================= Code: struct simple_ctx { int x; int y; int z; }; int ret_user, ret_user2; void bar(void); int foo(struct simple_ctx *ctx, struct simple_ctx *ctx2) { switch (ctx->x) { case 1: ret_user = 18; break; case 20: ret_user = 6; break; case 16: ret_user = 9; break; case 6: ret_user = 16; break; case 8: ret_user = 14; break; case 30: ret_user = 2; break; default: ret_user = 1; break; } bar(); switch (ctx2->x) { case 0: ret_user2 = 8; break; case 31: ret_user2 = 5; break; case 13: ret_user2 = 8; break; case 1: ret_user2 = 3; break; case 11: ret_user2 = 4; break; default: ret_user2 = 29; break; } return 0; } Run: clang --target=bpf -mcpu=v4 -O2 -S test.c The assembly code: ... # %bb.1: # %entry w1 = 18 r2 <<= 2 r3 = .LJTI0_0 ll r3 += r2 r2 = *(s32 *)(r3 + 0) gotox r2 LBB0_2: # %sw.bb1 w1 = 6 goto LBB0_8 ... # %bb.9: # %sw.epilog w1 = 8 r2 <<= 2 r3 = .LJTI0_1 ll r3 += r2 r2 = *(s32 *)(r3 + 0) gotox r2 LBB0_10: # %sw.bb8 w1 = 5 goto LBB0_14 ... .section .rodata,"a",@progbits .p2align 2, 0x0 .LJTI0_0: .long LBB0_8-.LJTI0_0 .long LBB0_7-.LJTI0_0 ... .long LBB0_6-.LJTI0_0 .LJTI0_1: .long LBB0_14-.LJTI0_1 .long LBB0_11-.LJTI0_1 ... Although we do have labels .LJTI0_0 and .LJTI0_1, but since they have prefix '.L' so they won't appear in the .o file like other symbols. Run: llvm-objdump -Sr test.o ... 5: 67 02 00 00 02 00 00 00 r2 <<= 0x2 6: 18 03 00 00 00 00 00 00 00 00 00 00 00 00 00 00 r3 = 0x0 ll 0000000000000030: R_BPF_64_64 .rodata 8: 0f 23 00 00 00 00 00 00 r3 += r2 ... 29: 67 02 00 00 02 00 00 00 r2 <<= 0x2 30: 18 03 00 00 78 00 00 00 00 00 00 00 00 00 00 00 r3 = 0x78 ll 00000000000000f0: R_BPF_64_64 .rodata 32: 0f 23 00 00 00 00 00 00 r3 += r2 The size of jump table is not obvious. The libbpf needs to check all relocations against .rodata section in order to get precise size in order to construct bpf maps. Example 2 (Simple computed goto): ================================= Code: int bar(int a) { __label__ l1, l2; void * volatile tgt; int ret = 0; if (a) tgt = &&l1; // synthetic jump table generated here else tgt = &&l2; // another synthetic jump table goto *tgt; l1: ret += 1; l2: ret += 2; return ret; } Compile: clang --target=bpf -mcpu=v4 -O2 -c test1.c Objdump: llvm-objdump -Sr test1.o 0: 18 02 00 00 50 00 00 00 00 00 00 00 00 00 00 00 r2 = 0x50 ll 0000000000000000: R_BPF_64_64 .text 2: 16 01 02 00 00 00 00 00 if w1 == 0x0 goto +0x2 <bar+0x28> 3: 18 02 00 00 40 00 00 00 00 00 00 00 00 00 00 00 r2 = 0x40 ll 0000000000000018: R_BPF_64_64 .text 5: 7b 2a f8 ff 00 00 00 00 *(u64 *)(r10 - 0x8) = r2 6: 79 a1 f8 ff 00 00 00 00 r1 = *(u64 *)(r10 - 0x8) 7: 0d 01 00 00 00 00 00 00 gotox r1 8: b4 00 00 00 03 00 00 00 w0 = 0x3 9: 05 00 01 00 00 00 00 00 goto +0x1 <bar+0x58> 10: b4 00 00 00 02 00 00 00 w0 = 0x2 11: 95 00 00 00 00 00 00 00 exit For this case, no need for jump table. Verifier is able to detect precise gotox addresses for different branches. Example 3 (More complicated computed goto): =========================================== Code: int foo(int a, int b) { __label__ l1, l2, l3, l4; void *jt1[] = {[0]=&&l1, [1]=&&l2}; void *jt2[] = {[0]=&&l3, [1]=&&l4}; int ret = 0; goto *jt1[a % 2]; l1: ret += 1; l2: ret += 3; goto *jt2[b % 2]; l3: ret += 5; l4: ret += 7; return ret; } Compile: clang --target=bpf -mcpu=v4 -O2 -S test2.c Asm code: ... r3 = (s32)r2 r3 <<= 3 r2 = .L__const.foo.jt2 ll r2 += r3 r1 = (s32)r1 r1 <<= 3 r3 = .L__const.foo.jt1 ll r3 += r1 w0 = 0 r1 = *(u64 *)(r3 + 0) gotox r1 .Ltmp0: # Block address taken LBB0_1: # %l1 # =>This Inner Loop Header: Depth=1 w0 += 1 w0 += 3 r1 = *(u64 *)(r2 + 0) gotox r1 .Ltmp1: # Block address taken LBB0_2: # %l2 ... .type .L__const.foo.jt1,@object # @__const.foo.jt1 .section .rodata,"a",@progbits .p2align 3, 0x0 .L__const.foo.jt1: .quad .Ltmp0 .quad .Ltmp1 .size .L__const.foo.jt1, 16 .type .L__const.foo.jt2,@object # @__const.foo.jt2 .p2align 3, 0x0 .L__const.foo.jt2: .quad .Ltmp2 .quad .Ltmp3 .size .L__const.foo.jt2, 16 Similar to switch statement case, for the binary, the symbols .L__const.foo.jt* will not show up in the symbol table and jump table will be in .rodata section. With more libbpf work (dealing with .rodata sections etc.), everything should work fine. But we could do better by - Replacing symbols like .L<...> with symbols appearing in symbol table. - Add jump tables to .jumptables section instead of .rodata section. This should make things easier for libbpf and for users. Next two patches try to implement the above.
1 parent 58c3aff commit 7752755

File tree

7 files changed

+114
-2
lines changed

7 files changed

+114
-2
lines changed

llvm/lib/Target/BPF/AsmParser/BPFAsmParser.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ struct BPFOperand : public MCParsedAsmOperand {
234234
.Case("callx", true)
235235
.Case("goto", true)
236236
.Case("gotol", true)
237+
.Case("gotox", true)
237238
.Case("may_goto", true)
238239
.Case("*", true)
239240
.Case("exit", true)

llvm/lib/Target/BPF/BPFISelLowering.cpp

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "llvm/CodeGen/MachineFrameInfo.h"
1919
#include "llvm/CodeGen/MachineFunction.h"
2020
#include "llvm/CodeGen/MachineInstrBuilder.h"
21+
#include "llvm/CodeGen/MachineJumpTableInfo.h"
2122
#include "llvm/CodeGen/MachineRegisterInfo.h"
2223
#include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
2324
#include "llvm/CodeGen/ValueTypes.h"
@@ -38,6 +39,10 @@ static cl::opt<bool> BPFExpandMemcpyInOrder("bpf-expand-memcpy-in-order",
3839
cl::Hidden, cl::init(false),
3940
cl::desc("Expand memcpy into load/store pairs in order"));
4041

42+
static cl::opt<unsigned> BPFMinimumJumpTableEntries(
43+
"bpf-min-jump-table-entries", cl::init(5), cl::Hidden,
44+
cl::desc("Set minimum number of entries to use a jump table on BPF"));
45+
4146
static void fail(const SDLoc &DL, SelectionDAG &DAG, const Twine &Msg,
4247
SDValue Val = {}) {
4348
std::string Str;
@@ -67,12 +72,13 @@ BPFTargetLowering::BPFTargetLowering(const TargetMachine &TM,
6772

6873
setOperationAction(ISD::BR_CC, MVT::i64, Custom);
6974
setOperationAction(ISD::BR_JT, MVT::Other, Expand);
70-
setOperationAction(ISD::BRIND, MVT::Other, Expand);
7175
setOperationAction(ISD::BRCOND, MVT::Other, Expand);
7276

7377
setOperationAction(ISD::TRAP, MVT::Other, Custom);
7478

75-
setOperationAction({ISD::GlobalAddress, ISD::ConstantPool}, MVT::i64, Custom);
79+
setOperationAction({ISD::GlobalAddress, ISD::ConstantPool, ISD::JumpTable,
80+
ISD::BlockAddress},
81+
MVT::i64, Custom);
7682

7783
setOperationAction(ISD::DYNAMIC_STACKALLOC, MVT::i64, Custom);
7884
setOperationAction(ISD::STACKSAVE, MVT::Other, Expand);
@@ -159,6 +165,7 @@ BPFTargetLowering::BPFTargetLowering(const TargetMachine &TM,
159165

160166
setBooleanContents(ZeroOrOneBooleanContent);
161167
setMaxAtomicSizeInBitsSupported(64);
168+
setMinimumJumpTableEntries(BPFMinimumJumpTableEntries);
162169

163170
// Function alignments
164171
setMinFunctionAlignment(Align(8));
@@ -246,6 +253,10 @@ bool BPFTargetLowering::isZExtFree(SDValue Val, EVT VT2) const {
246253
return TargetLoweringBase::isZExtFree(Val, VT2);
247254
}
248255

256+
unsigned BPFTargetLowering::getJumpTableEncoding() const {
257+
return MachineJumpTableInfo::EK_LabelDifference32;
258+
}
259+
249260
BPFTargetLowering::ConstraintType
250261
BPFTargetLowering::getConstraintType(StringRef Constraint) const {
251262
if (Constraint.size() == 1) {
@@ -316,10 +327,14 @@ SDValue BPFTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
316327
report_fatal_error("unimplemented opcode: " + Twine(Op.getOpcode()));
317328
case ISD::BR_CC:
318329
return LowerBR_CC(Op, DAG);
330+
case ISD::JumpTable:
331+
return LowerJumpTable(Op, DAG);
319332
case ISD::GlobalAddress:
320333
return LowerGlobalAddress(Op, DAG);
321334
case ISD::ConstantPool:
322335
return LowerConstantPool(Op, DAG);
336+
case ISD::BlockAddress:
337+
return LowerBlockAddress(Op, DAG);
323338
case ISD::SELECT_CC:
324339
return LowerSELECT_CC(Op, DAG);
325340
case ISD::SDIV:
@@ -780,6 +795,11 @@ SDValue BPFTargetLowering::LowerTRAP(SDValue Op, SelectionDAG &DAG) const {
780795
return LowerCall(CLI, InVals);
781796
}
782797

798+
SDValue BPFTargetLowering::LowerJumpTable(SDValue Op, SelectionDAG &DAG) const {
799+
JumpTableSDNode *N = cast<JumpTableSDNode>(Op);
800+
return getAddr(N, DAG);
801+
}
802+
783803
const char *BPFTargetLowering::getTargetNodeName(unsigned Opcode) const {
784804
switch ((BPFISD::NodeType)Opcode) {
785805
case BPFISD::FIRST_NUMBER:
@@ -811,6 +831,17 @@ static SDValue getTargetNode(ConstantPoolSDNode *N, const SDLoc &DL, EVT Ty,
811831
N->getOffset(), Flags);
812832
}
813833

834+
static SDValue getTargetNode(BlockAddressSDNode *N, const SDLoc &DL, EVT Ty,
835+
SelectionDAG &DAG, unsigned Flags) {
836+
return DAG.getTargetBlockAddress(N->getBlockAddress(), Ty, N->getOffset(),
837+
Flags);
838+
}
839+
840+
static SDValue getTargetNode(JumpTableSDNode *N, const SDLoc &DL, EVT Ty,
841+
SelectionDAG &DAG, unsigned Flags) {
842+
return DAG.getTargetJumpTable(N->getIndex(), Ty, Flags);
843+
}
844+
814845
template <class NodeTy>
815846
SDValue BPFTargetLowering::getAddr(NodeTy *N, SelectionDAG &DAG,
816847
unsigned Flags) const {
@@ -837,6 +868,12 @@ SDValue BPFTargetLowering::LowerConstantPool(SDValue Op,
837868
return getAddr(N, DAG);
838869
}
839870

871+
SDValue BPFTargetLowering::LowerBlockAddress(SDValue Op,
872+
SelectionDAG &DAG) const {
873+
BlockAddressSDNode *N = cast<BlockAddressSDNode>(Op);
874+
return getAddr(N, DAG);
875+
}
876+
840877
unsigned
841878
BPFTargetLowering::EmitSubregExt(MachineInstr &MI, MachineBasicBlock *BB,
842879
unsigned Reg, bool isSigned) const {

llvm/lib/Target/BPF/BPFISelLowering.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ class BPFTargetLowering : public TargetLowering {
6666

6767
MVT getScalarShiftAmountTy(const DataLayout &, EVT) const override;
6868

69+
unsigned getJumpTableEncoding() const override;
70+
6971
private:
7072
// Control Instruction Selection Features
7173
bool HasAlu32;
@@ -81,6 +83,8 @@ class BPFTargetLowering : public TargetLowering {
8183
SDValue LowerConstantPool(SDValue Op, SelectionDAG &DAG) const;
8284
SDValue LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const;
8385
SDValue LowerTRAP(SDValue Op, SelectionDAG &DAG) const;
86+
SDValue LowerBlockAddress(SDValue Op, SelectionDAG &DAG) const;
87+
SDValue LowerJumpTable(SDValue Op, SelectionDAG &DAG) const;
8488

8589
template <class NodeTy>
8690
SDValue getAddr(NodeTy *N, SelectionDAG &DAG, unsigned Flags = 0) const;

llvm/lib/Target/BPF/BPFInstrInfo.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,11 @@ bool BPFInstrInfo::analyzeBranch(MachineBasicBlock &MBB,
181181
if (!isUnpredicatedTerminator(*I))
182182
break;
183183

184+
// From base method doc: ... returning true if it cannot be understood ...
185+
// Indirect branch has multiple destinations and no true/false concepts.
186+
if (I->isIndirectBranch())
187+
return true;
188+
184189
// A terminator that isn't a branch can't easily be handled
185190
// by this analysis.
186191
if (!I->isBranch())
@@ -259,3 +264,40 @@ unsigned BPFInstrInfo::removeBranch(MachineBasicBlock &MBB,
259264

260265
return Count;
261266
}
267+
268+
int BPFInstrInfo::getJumpTableIndex(const MachineInstr &MI) const {
269+
// The pattern looks like:
270+
// %0 = LD_imm64 %jump-table.0 ; load jump-table address
271+
// %1 = ADD_rr %0, $another_reg ; address + offset
272+
// %2 = LDD %1, 0 ; load the actual label
273+
// JX %2
274+
const MachineFunction &MF = *MI.getParent()->getParent();
275+
const MachineRegisterInfo &MRI = MF.getRegInfo();
276+
277+
Register Reg = MI.getOperand(0).getReg();
278+
if (!Reg.isVirtual())
279+
return -1;
280+
MachineInstr *Ldd = MRI.getUniqueVRegDef(Reg);
281+
if (Ldd == nullptr || Ldd->getOpcode() != BPF::LDD)
282+
return -1;
283+
284+
Reg = Ldd->getOperand(1).getReg();
285+
if (!Reg.isVirtual())
286+
return -1;
287+
MachineInstr *Add = MRI.getUniqueVRegDef(Reg);
288+
if (Add == nullptr || Add->getOpcode() != BPF::ADD_rr)
289+
return -1;
290+
291+
Reg = Add->getOperand(1).getReg();
292+
if (!Reg.isVirtual())
293+
return -1;
294+
MachineInstr *LDimm64 = MRI.getUniqueVRegDef(Reg);
295+
if (LDimm64 == nullptr || LDimm64->getOpcode() != BPF::LD_imm64)
296+
return -1;
297+
298+
const MachineOperand &MO = LDimm64->getOperand(1);
299+
if (!MO.isJTI())
300+
return -1;
301+
302+
return MO.getIndex();
303+
}

llvm/lib/Target/BPF/BPFInstrInfo.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ class BPFInstrInfo : public BPFGenInstrInfo {
5858
MachineBasicBlock *FBB, ArrayRef<MachineOperand> Cond,
5959
const DebugLoc &DL,
6060
int *BytesAdded = nullptr) const override;
61+
62+
int getJumpTableIndex(const MachineInstr &MI) const override;
63+
6164
private:
6265
void expandMEMCPY(MachineBasicBlock::iterator) const;
6366

llvm/lib/Target/BPF/BPFInstrInfo.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,18 @@ class JMP_RI<BPFJumpOp Opc, string OpcodeStr, PatLeaf Cond>
216216
let BPFClass = BPF_JMP;
217217
}
218218

219+
class JMP_IND<BPFJumpOp Opc, string OpcodeStr, list<dag> Pattern>
220+
: TYPE_ALU_JMP<Opc.Value, BPF_X.Value,
221+
(outs),
222+
(ins GPR:$dst),
223+
!strconcat(OpcodeStr, " $dst"),
224+
Pattern> {
225+
bits<4> dst;
226+
227+
let Inst{51-48} = dst;
228+
let BPFClass = BPF_JMP;
229+
}
230+
219231
class JMP_JCOND<BPFJumpOp Opc, string OpcodeStr, list<dag> Pattern>
220232
: TYPE_ALU_JMP<Opc.Value, BPF_K.Value,
221233
(outs),
@@ -281,6 +293,10 @@ defm JSLT : J<BPF_JSLT, "s<", BPF_CC_LT, BPF_CC_LT_32>;
281293
defm JSLE : J<BPF_JSLE, "s<=", BPF_CC_LE, BPF_CC_LE_32>;
282294
defm JSET : J<BPF_JSET, "&", NoCond, NoCond>;
283295
def JCOND : JMP_JCOND<BPF_JCOND, "may_goto", []>;
296+
297+
let isIndirectBranch = 1 in {
298+
def JX : JMP_IND<BPF_JA, "gotox", [(brind i64:$dst)]>;
299+
}
284300
}
285301

286302
// ALU instructions
@@ -851,6 +867,8 @@ let usesCustomInserter = 1, isCodeGenOnly = 1 in {
851867
// load 64-bit global addr into register
852868
def : Pat<(BPFWrapper tglobaladdr:$in), (LD_imm64 tglobaladdr:$in)>;
853869
def : Pat<(BPFWrapper tconstpool:$in), (LD_imm64 tconstpool:$in)>;
870+
def : Pat<(BPFWrapper tblockaddress:$in), (LD_imm64 tblockaddress:$in)>;
871+
def : Pat<(BPFWrapper tjumptable:$in), (LD_imm64 tjumptable:$in)>;
854872

855873
// 0xffffFFFF doesn't fit into simm32, optimize common case
856874
def : Pat<(i64 (and (i64 GPR:$src), 0xffffFFFF)),

llvm/lib/Target/BPF/BPFMCInstLower.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,13 @@ void BPFMCInstLower::Lower(const MachineInstr *MI, MCInst &OutMI) const {
7777
case MachineOperand::MO_ConstantPoolIndex:
7878
MCOp = LowerSymbolOperand(MO, Printer.GetCPISymbol(MO.getIndex()));
7979
break;
80+
case MachineOperand::MO_JumpTableIndex:
81+
MCOp = LowerSymbolOperand(MO, Printer.GetJTISymbol(MO.getIndex()));
82+
break;
83+
case MachineOperand::MO_BlockAddress:
84+
MCOp = LowerSymbolOperand(
85+
MO, Printer.GetBlockAddressSymbol(MO.getBlockAddress()));
86+
break;
8087
}
8188

8289
OutMI.addOperand(MCOp);

0 commit comments

Comments
 (0)