Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/lib/Target/BPF/AsmParser/BPFAsmParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ struct BPFOperand : public MCParsedAsmOperand {
.Case("callx", true)
.Case("goto", true)
.Case("gotol", true)
.Case("gotox", true)
.Case("may_goto", true)
.Case("*", true)
.Case("exit", true)
Expand Down
118 changes: 94 additions & 24 deletions llvm/lib/Target/BPF/BPFAsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,52 +11,35 @@
//
//===----------------------------------------------------------------------===//

#include "BPFAsmPrinter.h"
#include "BPF.h"
#include "BPFInstrInfo.h"
#include "BPFMCInstLower.h"
#include "BTFDebug.h"
#include "MCTargetDesc/BPFInstPrinter.h"
#include "TargetInfo/BPFTargetInfo.h"
#include "llvm/BinaryFormat/ELF.h"
#include "llvm/CodeGen/AsmPrinter.h"
#include "llvm/CodeGen/MachineConstantPool.h"
#include "llvm/CodeGen/MachineInstr.h"
#include "llvm/CodeGen/MachineJumpTableInfo.h"
#include "llvm/CodeGen/MachineModuleInfo.h"
#include "llvm/CodeGen/TargetLowering.h"
#include "llvm/IR/Module.h"
#include "llvm/MC/MCAsmInfo.h"
#include "llvm/MC/MCExpr.h"
#include "llvm/MC/MCInst.h"
#include "llvm/MC/MCStreamer.h"
#include "llvm/MC/MCSymbol.h"
#include "llvm/MC/MCSymbolELF.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Target/TargetLoweringObjectFile.h"
using namespace llvm;

#define DEBUG_TYPE "asm-printer"

namespace {
class BPFAsmPrinter : public AsmPrinter {
public:
explicit BPFAsmPrinter(TargetMachine &TM,
std::unique_ptr<MCStreamer> Streamer)
: AsmPrinter(TM, std::move(Streamer), ID), BTF(nullptr) {}

StringRef getPassName() const override { return "BPF Assembly Printer"; }
bool doInitialization(Module &M) override;
void printOperand(const MachineInstr *MI, int OpNum, raw_ostream &O);
bool PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
const char *ExtraCode, raw_ostream &O) override;
bool PrintAsmMemoryOperand(const MachineInstr *MI, unsigned OpNum,
const char *ExtraCode, raw_ostream &O) override;

void emitInstruction(const MachineInstr *MI) override;

static char ID;

private:
BTFDebug *BTF;
};
} // namespace

bool BPFAsmPrinter::doInitialization(Module &M) {
AsmPrinter::doInitialization(M);

Expand Down Expand Up @@ -150,6 +133,93 @@ void BPFAsmPrinter::emitInstruction(const MachineInstr *MI) {
EmitToStreamer(*OutStreamer, TmpInst);
}

// This is a label for JX instructions, used for jump table offsets
// computation, e.g.:
//
// .LBPF.JX.0.0: <------- this is the anchor
// .reloc 0, FK_SecRel_8, BPF.JT.0.0
// gotox r1
// ...
// .section .jumptables,"",@progbits
// .L0_0_set_7 = ((LBB0_7-.LBPF.JX.0.0)>>3)-1
// ...
// BPF.JT.0.0: <------- JT definition
// .long .L0_0_set_7
// ...
MCSymbol *BPFAsmPrinter::getJXAnchorSymbol(unsigned JTI) {
const MCAsmInfo *MAI = MF->getContext().getAsmInfo();
SmallString<60> Name;
raw_svector_ostream(Name) << MAI->getPrivateGlobalPrefix() << "BPF.JX."
<< MF->getFunctionNumber() << '.' << JTI;
return MF->getContext().getOrCreateSymbol(Name);
}

MCSymbol *BPFAsmPrinter::getJTPublicSymbol(unsigned JTI) {
SmallString<60> Name;
raw_svector_ostream(Name)
<< "BPF.JT." << MF->getFunctionNumber() << '.' << JTI;
MCSymbol *S = OutContext.getOrCreateSymbol(Name);
if (auto *ES = dyn_cast<MCSymbolELF>(S))
ES->setBinding(ELF::STB_GLOBAL);
return S;
}

void BPFAsmPrinter::emitJumpTableInfo() {
const MachineJumpTableInfo *MJTI = MF->getJumpTableInfo();
if (!MJTI)
return;

const std::vector<MachineJumpTableEntry> &JT = MJTI->getJumpTables();
if (JT.empty())
return;

const TargetLoweringObjectFile &TLOF = getObjFileLowering();
const Function &F = MF->getFunction();
MCSection *JTS = TLOF.getSectionForJumpTable(F, TM);
assert(MJTI->getEntryKind() == MachineJumpTableInfo::EK_LabelDifference32);
unsigned EntrySize = MJTI->getEntrySize(getDataLayout());
OutStreamer->switchSection(JTS);
for (unsigned JTI = 0; JTI < JT.size(); JTI++) {
ArrayRef<MachineBasicBlock *> JTBBs = JT[JTI].MBBs;
if (JTBBs.empty())
continue;

SmallPtrSet<const MachineBasicBlock *, 16> EmittedSets;
auto *Base = MCSymbolRefExpr::create(getJXAnchorSymbol(JTI), OutContext);
for (const MachineBasicBlock *MBB : JTBBs) {
if (!EmittedSets.insert(MBB).second)
continue;

// Offset from gotox to target basic block expressed in number
// of instructions, e.g.:
//
// .L0_0_set_4 = ((LBB0_4 - .LBPF.JX.0.0) >> 3) - 1
const MCExpr *LHS = MCSymbolRefExpr::create(MBB->getSymbol(), OutContext);
OutStreamer->emitAssignment(
GetJTSetSymbol(JTI, MBB->getNumber()),
MCBinaryExpr::createSub(
MCBinaryExpr::createAShr(
MCBinaryExpr::createSub(LHS, Base, OutContext),
MCConstantExpr::create(3, OutContext), OutContext),
MCConstantExpr::create(1, OutContext), OutContext));
}
// BPF.JT.0.0:
// .long .L0_0_set_4
// .long .L0_0_set_2
// ...
// .size BPF.JT.0.0, 128
MCSymbol *JTStart = getJTPublicSymbol(JTI);
OutStreamer->emitLabel(JTStart);
for (const MachineBasicBlock *MBB : JTBBs) {
MCSymbol *SetSymbol = GetJTSetSymbol(JTI, MBB->getNumber());
const MCExpr *V = MCSymbolRefExpr::create(SetSymbol, OutContext);
OutStreamer->emitValue(V, EntrySize);
}
const MCExpr *JTSize = MCConstantExpr::create(JTBBs.size() * 4, OutContext);
OutStreamer->emitELFSize(JTStart, JTSize);
}
}

char BPFAsmPrinter::ID = 0;

INITIALIZE_PASS(BPFAsmPrinter, "bpf-asm-printer", "BPF Assembly Printer", false,
Expand Down
44 changes: 44 additions & 0 deletions llvm/lib/Target/BPF/BPFAsmPrinter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
//===-- BPFFrameLowering.h - Define frame lowering for BPF -----*- C++ -*--===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_LIB_TARGET_BPF_BPFASMPRINTER_H
#define LLVM_LIB_TARGET_BPF_BPFASMPRINTER_H

#include "BTFDebug.h"
#include "llvm/CodeGen/AsmPrinter.h"

namespace llvm {

class BPFAsmPrinter : public AsmPrinter {
public:
explicit BPFAsmPrinter(TargetMachine &TM,
std::unique_ptr<MCStreamer> Streamer)
: AsmPrinter(TM, std::move(Streamer), ID), BTF(nullptr) {}

StringRef getPassName() const override { return "BPF Assembly Printer"; }
bool doInitialization(Module &M) override;
void printOperand(const MachineInstr *MI, int OpNum, raw_ostream &O);
bool PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
const char *ExtraCode, raw_ostream &O) override;
bool PrintAsmMemoryOperand(const MachineInstr *MI, unsigned OpNum,
const char *ExtraCode, raw_ostream &O) override;

void emitInstruction(const MachineInstr *MI) override;
MCSymbol *getJTPublicSymbol(unsigned JTI);
MCSymbol *getJXAnchorSymbol(unsigned JTI);
virtual void emitJumpTableInfo() override;

static char ID;

private:
BTFDebug *BTF;
};

} // namespace llvm

#endif /* LLVM_LIB_TARGET_BPF_BPFASMPRINTER_H */
30 changes: 28 additions & 2 deletions llvm/lib/Target/BPF/BPFISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
#include "llvm/CodeGen/MachineFrameInfo.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
#include "llvm/CodeGen/MachineJumpTableInfo.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
#include "llvm/CodeGen/ValueTypes.h"
#include "llvm/IR/DIBuilder.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/DiagnosticPrinter.h"
#include "llvm/IR/Module.h"
#include "llvm/MC/MCAsmInfo.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/MathExtras.h"
Expand All @@ -38,6 +40,10 @@ static cl::opt<bool> BPFExpandMemcpyInOrder("bpf-expand-memcpy-in-order",
cl::Hidden, cl::init(false),
cl::desc("Expand memcpy into load/store pairs in order"));

static cl::opt<unsigned> BPFMinimumJumpTableEntries(
"bpf-min-jump-table-entries", cl::init(4), cl::Hidden,
cl::desc("Set minimum number of entries to use a jump table on BPF"));

static void fail(const SDLoc &DL, SelectionDAG &DAG, const Twine &Msg,
SDValue Val = {}) {
std::string Str;
Expand Down Expand Up @@ -66,9 +72,10 @@ BPFTargetLowering::BPFTargetLowering(const TargetMachine &TM,
setStackPointerRegisterToSaveRestore(BPF::R11);

setOperationAction(ISD::BR_CC, MVT::i64, Custom);
setOperationAction(ISD::BR_JT, MVT::Other, Expand);
setOperationAction(ISD::BRIND, MVT::Other, Expand);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this does remove restriction to not produce indirect jumps?

Is there a way to control if we want to generate indirect jumps "in general" vs., say, "only for large switches"? (Or even only for a particular switch?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this does remove restriction to not produce indirect jumps?
Yes, we do not want to expand 'brind', rather we will do pattern matching with 'brind'.

Is there a way to control if we want to generate indirect jumps "in general" vs., say, "only for large switches"? (Or even only for a particular switch?)

Good point. Let me do some experiments with a flag for this. I am not sure whether I could do 'only for a particular switch', but I will do some investigation. Hopefully can find a s solution for that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added an option to control how many cases in a switch statement to use jump table. The default is 4 cases. But you can change it with additional clang option, e.g., the minimum number of cases must be 6, then

clang ... -mllvm -bpf-min-jump-table-entries=6

I checked other targets, there are no control for a specific switch. So I think we do not need them for now.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks!

setOperationAction(ISD::BRCOND, MVT::Other, Expand);
LegalizeAction IndirectBrAction = STI.hasGotox() ? Custom : Expand;
setOperationAction(ISD::BR_JT, MVT::Other, IndirectBrAction);
setOperationAction(ISD::BRIND, MVT::Other, IndirectBrAction);

setOperationAction(ISD::TRAP, MVT::Other, Custom);

Expand Down Expand Up @@ -159,6 +166,7 @@ BPFTargetLowering::BPFTargetLowering(const TargetMachine &TM,

setBooleanContents(ZeroOrOneBooleanContent);
setMaxAtomicSizeInBitsSupported(64);
setMinimumJumpTableEntries(BPFMinimumJumpTableEntries);

// Function alignments
setMinFunctionAlignment(Align(8));
Expand Down Expand Up @@ -332,6 +340,8 @@ SDValue BPFTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
return LowerATOMIC_LOAD_STORE(Op, DAG);
case ISD::TRAP:
return LowerTRAP(Op, DAG);
case ISD::BR_JT:
return LowerBR_JT(Op, DAG);
}
}

Expand Down Expand Up @@ -780,6 +790,16 @@ SDValue BPFTargetLowering::LowerTRAP(SDValue Op, SelectionDAG &DAG) const {
return LowerCall(CLI, InVals);
}

SDValue BPFTargetLowering::LowerBR_JT(SDValue Op, SelectionDAG &DAG) const {
SDValue Chain = Op->getOperand(0);
SDValue Table = Op->getOperand(1);
SDValue Index = Op->getOperand(2);
JumpTableSDNode *JT = cast<JumpTableSDNode>(Table);
SDLoc DL(Op);
SDValue TargetJT = DAG.getTargetJumpTable(JT->getIndex(), MVT::i32);
return DAG.getNode(BPFISD::BPF_BR_JT, DL, MVT::Other, Chain, TargetJT, Index);
}

const char *BPFTargetLowering::getTargetNodeName(unsigned Opcode) const {
switch ((BPFISD::NodeType)Opcode) {
case BPFISD::FIRST_NUMBER:
Expand All @@ -796,6 +816,8 @@ const char *BPFTargetLowering::getTargetNodeName(unsigned Opcode) const {
return "BPFISD::Wrapper";
case BPFISD::MEMCPY:
return "BPFISD::MEMCPY";
case BPFISD::BPF_BR_JT:
return "BPFISD::BPF_BR_JT";
}
return nullptr;
}
Expand Down Expand Up @@ -1069,3 +1091,7 @@ bool BPFTargetLowering::isLegalAddressingMode(const DataLayout &DL,

return true;
}

unsigned BPFTargetLowering::getJumpTableEncoding() const {
return MachineJumpTableInfo::EK_LabelDifference32;
}
8 changes: 7 additions & 1 deletion llvm/lib/Target/BPF/BPFISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ enum NodeType : unsigned {
SELECT_CC,
BR_CC,
Wrapper,
MEMCPY
MEMCPY,
BPF_BR_JT,
};
}

Expand Down Expand Up @@ -66,6 +67,10 @@ class BPFTargetLowering : public TargetLowering {

MVT getScalarShiftAmountTy(const DataLayout &, EVT) const override;

// Always emit EK_LabelDifference32, computed as difference between
// JX instruction location and target basic block label.
virtual unsigned getJumpTableEncoding() const override;

private:
// Control Instruction Selection Features
bool HasAlu32;
Expand All @@ -81,6 +86,7 @@ class BPFTargetLowering : public TargetLowering {
SDValue LowerConstantPool(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerTRAP(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerBR_JT(SDValue Op, SelectionDAG &DAG) const;

template <class NodeTy>
SDValue getAddr(NodeTy *N, SelectionDAG &DAG, unsigned Flags = 0) const;
Expand Down
13 changes: 13 additions & 0 deletions llvm/lib/Target/BPF/BPFInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,11 @@ bool BPFInstrInfo::analyzeBranch(MachineBasicBlock &MBB,
if (!isUnpredicatedTerminator(*I))
break;

// From base method doc: ... returning true if it cannot be understood ...
// Indirect branch has multiple destinations and no true/false concepts.
if (I->isIndirectBranch())
return true;

// A terminator that isn't a branch can't easily be handled
// by this analysis.
if (!I->isBranch())
Expand Down Expand Up @@ -259,3 +264,11 @@ unsigned BPFInstrInfo::removeBranch(MachineBasicBlock &MBB,

return Count;
}

int BPFInstrInfo::getJumpTableIndex(const MachineInstr &MI) const {
if (MI.getOpcode() != BPF::JX)
return -1;
const MachineOperand &MO = MI.getOperand(1);
assert(MO.isJTI() && "JX operand #0 should be isJTI");
return MO.getIndex();
}
3 changes: 3 additions & 0 deletions llvm/lib/Target/BPF/BPFInstrInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ class BPFInstrInfo : public BPFGenInstrInfo {
MachineBasicBlock *FBB, ArrayRef<MachineOperand> Cond,
const DebugLoc &DL,
int *BytesAdded = nullptr) const override;

int getJumpTableIndex(const MachineInstr &MI) const override;

private:
void expandMEMCPY(MachineBasicBlock::iterator) const;

Expand Down
Loading