Skip to content

Commit df6a2fa

Browse files
committed
LLVM freeze instruction between mask and div 5/5
1 parent f9cfd50 commit df6a2fa

File tree

4 files changed

+46
-70
lines changed

4 files changed

+46
-70
lines changed

third_party/intel/lib/LLVMIR/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@ add_triton_library(TritonIntelLLVMIR
33

44
DEPENDS
55
LLVMIRIncGen
6-
)
6+
)
Lines changed: 34 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,80 +1,57 @@
11
#include "LLVMPasses.h"
2-
#include "llvm/IR/Instructions.h"
3-
#include "llvm/Analysis/ValueTracking.h"
42
#include "llvm/Analysis/TargetTransformInfo.h"
3+
#include "llvm/Analysis/ValueTracking.h"
54
#include "llvm/IR/Dominators.h"
5+
#include "llvm/IR/Instructions.h"
66

77
using namespace llvm;
88

9-
static bool processPhiNode(PHINode *phiNode, BasicBlock& BB) {
10-
llvm::errs() << "YOLO: " << *phiNode << "\n";
11-
12-
const auto phiHasNullValue = any_of(phiNode->incoming_values(), [](Use& U) {
13-
if (Constant *C = dyn_cast<Constant>(&U)) {
14-
return C->isNullValue();
15-
}
16-
return false;
17-
});
9+
static bool processPhiNode(PHINode *PhiNode, BasicBlock &BB) {
10+
if (!any_of(PhiNode->incoming_values(), [](Use &U) {
11+
if (Constant *C = dyn_cast<Constant>(&U)) {
12+
return C->isNullValue();
13+
}
14+
return false;
15+
})) {
16+
return false;
17+
}
1818

1919
bool Changed = false;
20-
if (phiHasNullValue) {
21-
for (Instruction &I : BB) {
22-
if (I.getOpcode() == Instruction::SDiv || I.getOpcode() == Instruction::SRem) {
23-
const size_t OpIdx = 1; // I.getOpcode() == Instruction::SRem ? 0 : 1;
24-
if (I.getOperand(OpIdx) == phiNode) {
25-
auto *freezePhi = new FreezeInst(phiNode, phiNode->getName() + ".frozen", I.getIterator());
26-
I.setOperand(OpIdx, freezePhi);
27-
Changed = true;
28-
}
20+
for (Instruction &I : BB) {
21+
if (I.getOpcode() == Instruction::SDiv ||
22+
I.getOpcode() == Instruction::SRem) {
23+
const size_t OpIdx = 1;
24+
if (I.getOperand(OpIdx) == PhiNode) {
25+
auto *freezePhi = new FreezeInst(
26+
PhiNode, PhiNode->getName() + ".frozen", I.getIterator());
27+
I.setOperand(OpIdx, freezePhi);
28+
Changed = true;
2929
}
3030
}
31-
#if 0
32-
auto FindUse = llvm::find_if(phiNode->users(), [](auto *U) {
33-
auto *Use = cast<Instruction>(U);
34-
llvm::errs() << "User: " << *Use << "\n";
35-
return (Use->getOpcode() == Instruction::SDiv || Use->getOpcode() == Instruction::SRem);
36-
});
37-
if (FindUse == phiNode->user_end()) {
38-
llvm::errs() << "no div :(\n";
39-
return false;
40-
}
41-
auto *Use = cast<Instruction>(*FindUse);
42-
assert(Use->isIntDivRem());
43-
const size_t OpIdx = Use->getOpcode() == Instruction::SRem ? 0 : 1;
44-
if (Use->getOperand(OpIdx) == phiNode) {
45-
llvm::errs() << "Got our user! " << *Use << "\n";
46-
llvm::errs() << "Operand 1: " << *Use->getOperand(1) << "\n";
47-
auto *freezePhi = new FreezeInst(phiNode, phiNode->getName() + ".frozen", Use->getIterator());
48-
Use->setOperand(OpIdx, freezePhi);
49-
Changed = true;
50-
}
51-
#endif
5231
}
53-
return Changed;
32+
return Changed;
5433
}
5534

56-
static bool runOnFunction(Function& F, const TargetTransformInfo &TTI,
57-
const DominatorTree &DT) {
58-
bool Changed = false;
35+
static bool runOnFunction(Function &F) {
36+
bool Changed = false;
5937

60-
SmallVector<PHINode *> PhiNodes;
61-
for (BasicBlock &BB : F) {
62-
for (Instruction &inst : BB) {
63-
if (PHINode *phiNode = dyn_cast<PHINode>(&inst)) {
64-
Changed |= processPhiNode(phiNode, BB);
38+
SmallVector<PHINode *> PhiNodes;
39+
for (BasicBlock &BB : F) {
40+
for (Instruction &I : BB) {
41+
if (PHINode *PhiNode = dyn_cast<PHINode>(&I)) {
42+
Changed |= processPhiNode(PhiNode, BB);
6543
continue;
6644
}
6745
break;
6846
}
6947
}
7048

71-
return Changed;
49+
return Changed;
7250
}
7351

74-
PreservedAnalyses FreezeMaskedDivRemPass::run(Function &F, FunctionAnalysisManager &FAM) {
75-
TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
76-
DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
77-
const auto b = runOnFunction(F, TTI, DT);
52+
PreservedAnalyses FreezeMaskedDivRemPass::run(Function &F,
53+
FunctionAnalysisManager &FAM) {
54+
const auto b = runOnFunction(F);
7855

79-
return b ? PreservedAnalyses::none() : PreservedAnalyses::all();
80-
}
56+
return b ? PreservedAnalyses::none() : PreservedAnalyses::all();
57+
}

third_party/intel/lib/LLVMIR/LLVMPasses.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ struct FreezeMaskedDivRemPass : PassInfoMixin<FreezeMaskedDivRemPass> {
88
static StringRef name() { return "FreezeMaskedDivRemPass"; }
99
};
1010

11-
}
11+
} // namespace llvm

third_party/intel/triton_xpu.cc

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -208,22 +208,21 @@ void init_triton_intel(py::module &&m) {
208208
fpm.addPass(BreakStructPhiNodesPass());
209209
fpm.addPass(InstCombinePass());
210210
});
211-
#if 1
212211
pb.registerPeepholeEPCallback(
213212
[&](llvm::FunctionPassManager &fpm, llvm::OptimizationLevel level) {
214213
// The Triton masked load pattern can generate instances where the
215-
// mask false path appears to cause undefined behavior during
216-
// computation. Even though the result of that behavior will never be
217-
// used, LLVM can choose to optimize away the false path resulting in
218-
// an incorrect result for the kernel. Adding `DivRemPairsPass`
219-
// introduces freeze instructions which prevent UB from leaking into
220-
// div/rem instructions.
221-
// fpm.addPass(DivRemPairsPass());
214+
// mask value causes undefined behavior in sdiv/srem instructions. The
215+
// language allows this UB as the result of those arithmetic
216+
// instructions is never used, and control flow to avoid computation
217+
// of these instructions would negatively affect performance. But,
218+
// LLVM SimplifyCFG aggressively marks code paths with undefined
219+
// behavior as dead. This can result in removal of the mask path and
220+
// incorrect results from legal Triton kernels due to masked elements
221+
// being used in computation. Run a pass to add a freeze instruction
222+
// between masked loads and sdiv/srem to signal to LLVM we consider
223+
// the sdiv/srem operands to be well defined.
222224
fpm.addPass(FreezeMaskedDivRemPass());
223225
});
224-
#else
225-
mpm.addPass(createModuleToFunctionPassAdaptor(FreezeMaskedDivRemPass()));
226-
#endif
227226
mpm.addPass(pb.buildPerModuleDefaultPipeline(opt));
228227
mpm.run(*mod, mam);
229228
});

0 commit comments

Comments
 (0)