Skip to content

Commit 3d7c478

Browse files
author
Yonghong Song
committed
[BPF] Add jump table support with switch statements and computed goto
NOTE 1: We probably need cpu v5 or other flags to enable this feature. We can add it later when necessary. NOTE 2: An option -bpf-min-jump-table-entries is implemented to control the minimum number of entries to use a jump table on BPF. The default value 5 and this is to make it easy to test. Eventually we will increase min jump table entries to be 13. This patch adds jump table support. A new insn 'gotox <reg>' is added to allow goto through a register. The register represents the address in the current section. Example 1 (switch statement): ============================= Code: struct simple_ctx { int x; int y; int z; }; int ret_user, ret_user2; void bar(void); int foo(struct simple_ctx *ctx, struct simple_ctx *ctx2) { switch (ctx->x) { case 1: ret_user = 18; break; case 20: ret_user = 6; break; case 16: ret_user = 9; break; case 6: ret_user = 16; break; case 8: ret_user = 14; break; case 30: ret_user = 2; break; default: ret_user = 1; break; } bar(); switch (ctx2->x) { case 0: ret_user2 = 8; break; case 31: ret_user2 = 5; break; case 13: ret_user2 = 8; break; case 1: ret_user2 = 3; break; case 11: ret_user2 = 4; break; default: ret_user2 = 29; break; } return 0; } Run: clang --target=bpf -O2 -S test.c The assembly code: ... # %bb.1: # %entry r1 <<= 3 r2 = .LJTI0_0 ll r2 += r1 r1 = *(u64 *)(r2 + 0) gotox r1 goto LBB0_8 LBB0_2: w1 = 18 goto LBB0_9 ... # %bb.10: # %sw.epilog r1 <<= 3 r2 = .LJTI0_1 ll r2 += r1 r1 = *(u64 *)(r2 + 0) gotox r1 goto LBB0_15 ... .section .rodata,"a",@progbits .p2align 3, 0x0 .LJTI0_0: .quad LBB0_2 .quad LBB0_8 .quad LBB0_8 .quad LBB0_8 .quad LBB0_7 ... .LJTI0_1: .quad LBB0_11 .quad LBB0_13 Although we do have labels .LJTI0_0 and .LJTI0_1, but since they have prefix '.L' so they won't appear in the .o file like Run: llvm-objdump -Sr test.o ... 4: 67 01 00 00 03 00 00 00 r1 <<= 0x3 5: 18 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 r2 = 0x0 ll 0000000000000028: R_BPF_64_64 .rodata 7: 0f 12 00 00 00 00 00 00 r2 += r1 ... 30: 67 01 00 00 03 00 00 00 r1 <<= 0x3 31: 18 02 00 00 f0 00 00 00 00 00 00 00 00 00 00 00 r2 = 0xf0 ll 00000000000000f8: R_BPF_64_64 .rodata The size of jump table is not obvious. The libbpf needs to check all relocations against .rodata section in order to get precise size in order to construct bpf maps. Example 2 (Simple computed goto): ================================= Code: int bar(int a) { __label__ l1, l2; void * volatile tgt; int ret = 0; if (a) tgt = &&l1; // synthetic jump table generated here else tgt = &&l2; // another synthetic jump table goto *tgt; l1: ret += 1; l2: ret += 2; return ret; } Compile: clang --target=bpf -O2 -c test1.c Objdump: llvm-objdump -Sr test1.o 0: 18 02 00 00 50 00 00 00 00 00 00 00 00 00 00 00 r2 = 0x50 ll 0000000000000000: R_BPF_64_64 .text 2: 16 01 02 00 00 00 00 00 if w1 == 0x0 goto +0x2 <bar+0x28> 3: 18 02 00 00 40 00 00 00 00 00 00 00 00 00 00 00 r2 = 0x40 ll 0000000000000018: R_BPF_64_64 .text 5: 7b 2a f8 ff 00 00 00 00 *(u64 *)(r10 - 0x8) = r2 6: 79 a1 f8 ff 00 00 00 00 r1 = *(u64 *)(r10 - 0x8) 7: 0d 01 00 00 00 00 00 00 gotox r1 8: b4 00 00 00 03 00 00 00 w0 = 0x3 9: 05 00 01 00 00 00 00 00 goto +0x1 <bar+0x58> 10: b4 00 00 00 02 00 00 00 w0 = 0x2 11: 95 00 00 00 00 00 00 00 exit For this case, no need for jump table. Verifier is able to detect precise gotox addresses for different branches. Example 3 (More complicated computed goto): =========================================== Code: int foo(int a, int b) { __label__ l1, l2, l3, l4; void *jt1[] = {[0]=&&l1, [1]=&&l2}; void *jt2[] = {[0]=&&l3, [1]=&&l4}; int ret = 0; goto *jt1[a % 2]; l1: ret += 1; l2: ret += 3; goto *jt2[b % 2]; l3: ret += 5; l4: ret += 7; return ret; } Compile: clang --target=bpf -O2 -S test2.c Asm code: # %bb.1: # %entry r1 <<= 3 r2 = .LJTI0_0 ll r2 += r1 r1 = *(u64 *)(r2 + 0) gotox r1 goto LBB0_8 LBB0_2: w1 = 18 goto LBB0_9 ... # %bb.10: # %sw.epilog r1 <<= 3 r2 = .LJTI0_1 ll r2 += r1 r1 = *(u64 *)(r2 + 0) gotox r1 goto LBB0_15 LBB0_11: w1 = 8 goto LBB0_16 ... .section .rodata,"a",@progbits .p2align 3, 0x0 .LJTI0_0: .quad LBB0_2 .quad LBB0_8 .quad LBB0_8 ... .quad LBB0_7 .LJTI0_1: .quad LBB0_11 .quad LBB0_13 .quad LBB0_15 ... Similar to switch statement case, for the binary, the symbols .LJTI0_* will not show up in the symbol table and jump table will be in .rodata section. With more libbpf work (dealing with .rodata sections etc.), everything should work fine. But we could do better by - Replacing symbols like .L<...> with symbols appearing in symbol table. - Add jump tables to .jumptables section instead of .rodata section. This should make things easier for libbpf and for users. Next two patches try to implement the above.
1 parent 58c3aff commit 3d7c478

File tree

7 files changed

+106
-2
lines changed

7 files changed

+106
-2
lines changed

llvm/lib/Target/BPF/AsmParser/BPFAsmParser.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ struct BPFOperand : public MCParsedAsmOperand {
234234
.Case("callx", true)
235235
.Case("goto", true)
236236
.Case("gotol", true)
237+
.Case("gotox", true)
237238
.Case("may_goto", true)
238239
.Case("*", true)
239240
.Case("exit", true)

llvm/lib/Target/BPF/BPFISelLowering.cpp

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ static cl::opt<bool> BPFExpandMemcpyInOrder("bpf-expand-memcpy-in-order",
3838
cl::Hidden, cl::init(false),
3939
cl::desc("Expand memcpy into load/store pairs in order"));
4040

41+
static cl::opt<unsigned> BPFMinimumJumpTableEntries(
42+
"bpf-min-jump-table-entries", cl::init(5), cl::Hidden,
43+
cl::desc("Set minimum number of entries to use a jump table on BPF"));
44+
4145
static void fail(const SDLoc &DL, SelectionDAG &DAG, const Twine &Msg,
4246
SDValue Val = {}) {
4347
std::string Str;
@@ -67,12 +71,13 @@ BPFTargetLowering::BPFTargetLowering(const TargetMachine &TM,
6771

6872
setOperationAction(ISD::BR_CC, MVT::i64, Custom);
6973
setOperationAction(ISD::BR_JT, MVT::Other, Expand);
70-
setOperationAction(ISD::BRIND, MVT::Other, Expand);
7174
setOperationAction(ISD::BRCOND, MVT::Other, Expand);
7275

7376
setOperationAction(ISD::TRAP, MVT::Other, Custom);
7477

75-
setOperationAction({ISD::GlobalAddress, ISD::ConstantPool}, MVT::i64, Custom);
78+
setOperationAction({ISD::GlobalAddress, ISD::ConstantPool, ISD::JumpTable,
79+
ISD::BlockAddress},
80+
MVT::i64, Custom);
7681

7782
setOperationAction(ISD::DYNAMIC_STACKALLOC, MVT::i64, Custom);
7883
setOperationAction(ISD::STACKSAVE, MVT::Other, Expand);
@@ -159,6 +164,7 @@ BPFTargetLowering::BPFTargetLowering(const TargetMachine &TM,
159164

160165
setBooleanContents(ZeroOrOneBooleanContent);
161166
setMaxAtomicSizeInBitsSupported(64);
167+
setMinimumJumpTableEntries(BPFMinimumJumpTableEntries);
162168

163169
// Function alignments
164170
setMinFunctionAlignment(Align(8));
@@ -316,10 +322,14 @@ SDValue BPFTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
316322
report_fatal_error("unimplemented opcode: " + Twine(Op.getOpcode()));
317323
case ISD::BR_CC:
318324
return LowerBR_CC(Op, DAG);
325+
case ISD::JumpTable:
326+
return LowerJumpTable(Op, DAG);
319327
case ISD::GlobalAddress:
320328
return LowerGlobalAddress(Op, DAG);
321329
case ISD::ConstantPool:
322330
return LowerConstantPool(Op, DAG);
331+
case ISD::BlockAddress:
332+
return LowerBlockAddress(Op, DAG);
323333
case ISD::SELECT_CC:
324334
return LowerSELECT_CC(Op, DAG);
325335
case ISD::SDIV:
@@ -780,6 +790,11 @@ SDValue BPFTargetLowering::LowerTRAP(SDValue Op, SelectionDAG &DAG) const {
780790
return LowerCall(CLI, InVals);
781791
}
782792

793+
SDValue BPFTargetLowering::LowerJumpTable(SDValue Op, SelectionDAG &DAG) const {
794+
JumpTableSDNode *N = cast<JumpTableSDNode>(Op);
795+
return getAddr(N, DAG);
796+
}
797+
783798
const char *BPFTargetLowering::getTargetNodeName(unsigned Opcode) const {
784799
switch ((BPFISD::NodeType)Opcode) {
785800
case BPFISD::FIRST_NUMBER:
@@ -811,6 +826,17 @@ static SDValue getTargetNode(ConstantPoolSDNode *N, const SDLoc &DL, EVT Ty,
811826
N->getOffset(), Flags);
812827
}
813828

829+
static SDValue getTargetNode(BlockAddressSDNode *N, const SDLoc &DL, EVT Ty,
830+
SelectionDAG &DAG, unsigned Flags) {
831+
return DAG.getTargetBlockAddress(N->getBlockAddress(), Ty, N->getOffset(),
832+
Flags);
833+
}
834+
835+
static SDValue getTargetNode(JumpTableSDNode *N, const SDLoc &DL, EVT Ty,
836+
SelectionDAG &DAG, unsigned Flags) {
837+
return DAG.getTargetJumpTable(N->getIndex(), Ty, Flags);
838+
}
839+
814840
template <class NodeTy>
815841
SDValue BPFTargetLowering::getAddr(NodeTy *N, SelectionDAG &DAG,
816842
unsigned Flags) const {
@@ -837,6 +863,12 @@ SDValue BPFTargetLowering::LowerConstantPool(SDValue Op,
837863
return getAddr(N, DAG);
838864
}
839865

866+
SDValue BPFTargetLowering::LowerBlockAddress(SDValue Op,
867+
SelectionDAG &DAG) const {
868+
BlockAddressSDNode *N = cast<BlockAddressSDNode>(Op);
869+
return getAddr(N, DAG);
870+
}
871+
840872
unsigned
841873
BPFTargetLowering::EmitSubregExt(MachineInstr &MI, MachineBasicBlock *BB,
842874
unsigned Reg, bool isSigned) const {

llvm/lib/Target/BPF/BPFISelLowering.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ class BPFTargetLowering : public TargetLowering {
8181
SDValue LowerConstantPool(SDValue Op, SelectionDAG &DAG) const;
8282
SDValue LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const;
8383
SDValue LowerTRAP(SDValue Op, SelectionDAG &DAG) const;
84+
SDValue LowerBlockAddress(SDValue Op, SelectionDAG &DAG) const;
85+
SDValue LowerJumpTable(SDValue Op, SelectionDAG &DAG) const;
8486

8587
template <class NodeTy>
8688
SDValue getAddr(NodeTy *N, SelectionDAG &DAG, unsigned Flags = 0) const;

llvm/lib/Target/BPF/BPFInstrInfo.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,10 @@ bool BPFInstrInfo::analyzeBranch(MachineBasicBlock &MBB,
181181
if (!isUnpredicatedTerminator(*I))
182182
break;
183183

184+
// If a JX insn, we're done.
185+
if (I->getOpcode() == BPF::JX)
186+
break;
187+
184188
// A terminator that isn't a branch can't easily be handled
185189
// by this analysis.
186190
if (!I->isBranch())
@@ -259,3 +263,40 @@ unsigned BPFInstrInfo::removeBranch(MachineBasicBlock &MBB,
259263

260264
return Count;
261265
}
266+
267+
int BPFInstrInfo::getJumpTableIndex(const MachineInstr &MI) const {
268+
// The pattern looks like:
269+
// %0 = LD_imm64 %jump-table.0 ; load jump-table address
270+
// %1 = ADD_rr %0, $another_reg ; address + offset
271+
// %2 = LDD %1, 0 ; load the actual label
272+
// JX %2
273+
const MachineFunction &MF = *MI.getParent()->getParent();
274+
const MachineRegisterInfo &MRI = MF.getRegInfo();
275+
276+
Register Reg = MI.getOperand(0).getReg();
277+
if (!Reg.isVirtual())
278+
return -1;
279+
MachineInstr *Ldd = MRI.getUniqueVRegDef(Reg);
280+
if (Ldd == nullptr || Ldd->getOpcode() != BPF::LDD)
281+
return -1;
282+
283+
Reg = Ldd->getOperand(1).getReg();
284+
if (!Reg.isVirtual())
285+
return -1;
286+
MachineInstr *Add = MRI.getUniqueVRegDef(Reg);
287+
if (Add == nullptr || Add->getOpcode() != BPF::ADD_rr)
288+
return -1;
289+
290+
Reg = Add->getOperand(1).getReg();
291+
if (!Reg.isVirtual())
292+
return -1;
293+
MachineInstr *LDimm64 = MRI.getUniqueVRegDef(Reg);
294+
if (LDimm64 == nullptr || LDimm64->getOpcode() != BPF::LD_imm64)
295+
return -1;
296+
297+
const MachineOperand &MO = LDimm64->getOperand(1);
298+
if (!MO.isJTI())
299+
return -1;
300+
301+
return MO.getIndex();
302+
}

llvm/lib/Target/BPF/BPFInstrInfo.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ class BPFInstrInfo : public BPFGenInstrInfo {
5858
MachineBasicBlock *FBB, ArrayRef<MachineOperand> Cond,
5959
const DebugLoc &DL,
6060
int *BytesAdded = nullptr) const override;
61+
62+
int getJumpTableIndex(const MachineInstr &MI) const override;
63+
6164
private:
6265
void expandMEMCPY(MachineBasicBlock::iterator) const;
6366

llvm/lib/Target/BPF/BPFInstrInfo.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,18 @@ class JMP_RI<BPFJumpOp Opc, string OpcodeStr, PatLeaf Cond>
216216
let BPFClass = BPF_JMP;
217217
}
218218

219+
class JMP_IND<BPFJumpOp Opc, string OpcodeStr, list<dag> Pattern>
220+
: TYPE_ALU_JMP<Opc.Value, BPF_X.Value,
221+
(outs),
222+
(ins GPR:$dst),
223+
!strconcat(OpcodeStr, " $dst"),
224+
Pattern> {
225+
bits<4> dst;
226+
227+
let Inst{51-48} = dst;
228+
let BPFClass = BPF_JMP;
229+
}
230+
219231
class JMP_JCOND<BPFJumpOp Opc, string OpcodeStr, list<dag> Pattern>
220232
: TYPE_ALU_JMP<Opc.Value, BPF_K.Value,
221233
(outs),
@@ -281,6 +293,10 @@ defm JSLT : J<BPF_JSLT, "s<", BPF_CC_LT, BPF_CC_LT_32>;
281293
defm JSLE : J<BPF_JSLE, "s<=", BPF_CC_LE, BPF_CC_LE_32>;
282294
defm JSET : J<BPF_JSET, "&", NoCond, NoCond>;
283295
def JCOND : JMP_JCOND<BPF_JCOND, "may_goto", []>;
296+
297+
let isIndirectBranch = 1 in {
298+
def JX : JMP_IND<BPF_JA, "gotox", [(brind i64:$dst)]>;
299+
}
284300
}
285301

286302
// ALU instructions
@@ -851,6 +867,8 @@ let usesCustomInserter = 1, isCodeGenOnly = 1 in {
851867
// load 64-bit global addr into register
852868
def : Pat<(BPFWrapper tglobaladdr:$in), (LD_imm64 tglobaladdr:$in)>;
853869
def : Pat<(BPFWrapper tconstpool:$in), (LD_imm64 tconstpool:$in)>;
870+
def : Pat<(BPFWrapper tblockaddress:$in), (LD_imm64 tblockaddress:$in)>;
871+
def : Pat<(BPFWrapper tjumptable:$in), (LD_imm64 tjumptable:$in)>;
854872

855873
// 0xffffFFFF doesn't fit into simm32, optimize common case
856874
def : Pat<(i64 (and (i64 GPR:$src), 0xffffFFFF)),

llvm/lib/Target/BPF/BPFMCInstLower.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,13 @@ void BPFMCInstLower::Lower(const MachineInstr *MI, MCInst &OutMI) const {
7777
case MachineOperand::MO_ConstantPoolIndex:
7878
MCOp = LowerSymbolOperand(MO, Printer.GetCPISymbol(MO.getIndex()));
7979
break;
80+
case MachineOperand::MO_JumpTableIndex:
81+
MCOp = LowerSymbolOperand(MO, Printer.GetJTISymbol(MO.getIndex()));
82+
break;
83+
case MachineOperand::MO_BlockAddress:
84+
MCOp = LowerSymbolOperand(
85+
MO, Printer.GetBlockAddressSymbol(MO.getBlockAddress()));
86+
break;
8087
}
8188

8289
OutMI.addOperand(MCOp);

0 commit comments

Comments
 (0)