Skip to content

Commit 78c13a5

Browse files
authored
Insert freeze between masked loads and sdiv/srem instructions (#2775)
Close #2726 From the code comments: The Triton masked load pattern can generate instances where the mask value causes undefined behavior in sdiv/srem instructions. The language allows this UB as the result of those arithmetic instructions is never used, and control flow to avoid computation of these instructions would negatively affect performance. But, LLVM SimplifyCFG aggressively marks code paths with undefined behavior as dead. This can result in removal of the mask path and incorrect results from legal Triton kernels due to masked elements being used in computation. Run a pass to add a freeze instruction between masked loads and sdiv/srem to signal to LLVM we consider the sdiv/srem operands to be well defined. The strategy here is to basically invalidate the assumptions under which SimplifyCFG can remove UB for sdiv/srem. The rationale is that, unlike C/C++, Triton explicitly allows UB in sdiv/srem instructions (likely because the hardware Triton is targeting allows that). Inserting a `freeze` instruction both signals that we expect the behavior of sdiv/srem to be well defined and hides the constant 0 in the phi from SimplifyCFG's UB optimizations. The pass needs to run after every instance of `InstCombine` because the LLVM optimization that removes UB only occurs if the sdiv/srem are in the same BB as the phi, which can happen after any `InstCombine`. Note that the directory structure for this pass is a little different than `BreakStructPhiNodesPass` because we are already using those directories in `third_party` for MLIR code. If we want to change that, I can open an issue but let's do it separately from this PR. ---------
1 parent 02346d9 commit 78c13a5

File tree

9 files changed

+236
-0
lines changed

9 files changed

+236
-0
lines changed

bin/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ target_link_libraries(triton-opt PRIVATE
1313
TritonTransforms
1414
TritonGPUTransforms
1515
TritonNvidiaGPUTransforms
16+
TritonIntelLLVMIR
1617
MLIRGPUToROCDLTransforms
1718
${dialect_libs}
1819
${conversion_libs}
@@ -88,6 +89,7 @@ target_link_libraries(triton-llvm-opt PRIVATE
8889
LLVMSupport
8990
LLVMOption
9091
LLVMCodeGen
92+
TritonIntelLLVMIR
9193
TritonIntelGPUIR
9294
)
9395
export_executable_symbols_for_plugins(triton-llvm-opt)

bin/triton-llvm-opt.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
/// Trimmed down clone of llvm opt to be able to test triton custom llvm ir
22
/// passes.
33
#include "lib/Target/LLVMIR/LLVMPasses.h"
4+
#include "third_party/intel/lib/LLVMIR/LLVMPasses.h"
45
#include "llvm/CodeGen/CommandFlags.h"
56
#include "llvm/IR/Constants.h"
67
#include "llvm/IR/DataLayout.h"
@@ -42,6 +43,11 @@ static cl::opt<bool>
4243
llvm::cl::desc("run pass to break phi struct"),
4344
cl::init(false));
4445

46+
static cl::opt<bool> FreezeMaskedDivRem(
47+
"freeze-masked-div-rem",
48+
llvm::cl::desc("run pass to insert freeze between masked load and div/rem"),
49+
cl::init(false));
50+
4551
namespace {
4652
static std::function<Error(Module *)> makeOptimizingPipeline() {
4753
return [](Module *m) -> Error {
@@ -62,6 +68,8 @@ static std::function<Error(Module *)> makeOptimizingPipeline() {
6268
llvm::FunctionPassManager fpm;
6369
if (BreakStructPhiNodes)
6470
fpm.addPass(BreakStructPhiNodesPass());
71+
if (FreezeMaskedDivRem)
72+
fpm.addPass(FreezeMaskedDivRemPass());
6573
mpm.addPass(createModuleToFunctionPassAdaptor(std::move(fpm)));
6674
mpm.run(*m, mam);
6775
return Error::success();
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# flake8: noqa: F821, F841
2+
import torch
3+
import pytest
4+
5+
import triton
6+
import triton.language as tl
7+
8+
aten = torch.ops.aten
9+
10+
11+
def patch_kernel(template, to_replace):
12+
kernel = triton.JITFunction(template.fn)
13+
for key, value in to_replace.items():
14+
kernel.src = kernel.src.replace(key, value)
15+
return kernel
16+
17+
18+
@pytest.mark.parametrize("float_div", [True, False])
19+
@pytest.mark.parametrize("floor", [True, False])
20+
@pytest.mark.parametrize("trunc", [True, False])
21+
def test_divide(float_div, floor, trunc, device):
22+
# regression test for various division cases
23+
24+
@triton.jit
25+
def divide_kernel(a, b, out_ptr0, out_ptr1, out_ptr2, out_ptr3, out_ptr4, xnumel, XBLOCK: tl.constexpr):
26+
xoffset = tl.program_id(0) * XBLOCK
27+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
28+
xmask = xindex < xnumel
29+
x0 = xindex
30+
tmp0 = tl.load(a + (x0), xmask)
31+
tmp2 = tl.load(b + (x0), xmask)
32+
# custom bits
33+
tmp1 = tmp0.to(tl.float32)
34+
tmp3 = tmp2.to(tl.float32)
35+
tmp4 = tmp1 / tmp3
36+
tmp5 = tl.where((tmp0 < 0) != (tmp2 < 0), tl.where(tmp0 % tmp2 != 0, tmp0 // tmp2 - 1, tmp0 // tmp2),
37+
tmp0 // tmp2)
38+
tmp6 = tmp0 // tmp2
39+
GENERATE_OUTPUTS_HERE
40+
41+
torch.manual_seed(0)
42+
43+
outputs_float_div = "tl.store(out_ptr0 + (x0), tmp4, xmask)\n tl.store(out_ptr3 + (x0), tmp4, xmask)" if float_div else ""
44+
outputs_floor = " tl.store(out_ptr1 + (x0), tmp5, xmask)\n tl.store(out_ptr4 + (x0), tmp5, xmask)" if floor else ""
45+
outputs_trunc = " tl.store(out_ptr2 + (x0), tmp6, xmask)" if trunc else ""
46+
47+
divide_kernel = patch_kernel(divide_kernel,
48+
{"GENERATE_OUTPUTS_HERE": f"{outputs_float_div}\n{outputs_floor}\n{outputs_trunc}"})
49+
50+
def launch_triton(a, b):
51+
output0 = torch.zeros_like(a)
52+
output1 = torch.zeros_like(a)
53+
output2 = torch.zeros_like(a)
54+
output3 = torch.zeros_like(a)
55+
output4 = torch.zeros_like(a)
56+
57+
n_elements = output0.numel()
58+
59+
grid = lambda meta: (triton.cdiv(n_elements, meta['XBLOCK']), )
60+
61+
divide_kernel[grid](a, b, output0, output1, output2, output3, output4, n_elements, XBLOCK=128)
62+
63+
return (output0, output1, output2, output3, output4)
64+
65+
def launch_torch(a, b):
66+
return (
67+
aten.div(a, b, rounding_mode=None) if float_div is True else torch.zeros_like(a),
68+
aten.div(a, b, rounding_mode="floor") if floor is True else torch.zeros_like(a),
69+
aten.div(a, b, rounding_mode="trunc") if trunc is True else torch.zeros_like(a),
70+
a / b if float_div is True else torch.zeros_like(a),
71+
a // b if floor is True else torch.zeros_like(a),
72+
)
73+
74+
a = torch.randint(2**32, 2**40, [100, 100], device=device)
75+
b = torch.randint(-10, -1, [100, 100], device=device)
76+
77+
for iter in range(100):
78+
triton_result = launch_triton(a, b)
79+
torch_result = launch_torch(a, b)
80+
81+
for i in range(5):
82+
torch.testing.assert_close(
83+
triton_result[i], torch_result[i], check_dtype=False, msg=lambda msg:
84+
f"Float: {float_div}, Floor: {floor}, Trunc: {trunc}\nIteration {iter}, {i} failed\n{msg}")
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
; RUN: triton-llvm-opt -freeze-masked-div-rem %s | FileCheck %s
2+
3+
define void @phi_div_of_zero_okay(i8 noundef %x, i8 %i, ptr %v) {
4+
; CHECK-LABEL: @phi_div_of_zero_okay(
5+
entry:
6+
%cmp = icmp ult i8 %i, 9
7+
br i1 %cmp, label %if.then, label %if.end
8+
9+
if.then:
10+
%y = load i8, ptr %v, align 8
11+
br label %if.end
12+
13+
if.end:
14+
%yy = phi i8 [ %y, %if.then ], [ 0, %entry ]
15+
; CHECK: [[F0:%.*]] = freeze i8 %yy
16+
; CHECK-NEXT: %z = sdiv i8 %x, [[F0:%.*]]
17+
%z = sdiv i8 %x, %yy
18+
br i1 %cmp, label %if2.then, label %if2.end
19+
20+
if2.then:
21+
store i8 %z, ptr %v, align 8
22+
br label %if2.end
23+
24+
if2.end:
25+
ret void
26+
}
27+
28+
define void @two_phi_div_of_zero_okay(i8 noundef %x, i8 %i, ptr %v) {
29+
; CHECK-LABEL: @two_phi_div_of_zero_okay(
30+
entry:
31+
%cmp = icmp ult i8 %i, 9
32+
br i1 %cmp, label %if.then, label %if.end
33+
34+
if.then:
35+
%y = load i8, ptr %v, align 8
36+
%vv = getelementptr inbounds i64, ptr %v, i64 1
37+
%b = load i8, ptr %vv, align 8
38+
br label %if.end
39+
40+
if.end:
41+
%bb = phi i8 [ %b, %if.then ], [ undef, %entry ]
42+
%yy = phi i8 [ %y, %if.then ], [ 0, %entry ]
43+
; CHECK: [[F0:%.*]] = freeze i8 %yy
44+
; CHECK-NEXT: %z = sdiv i8 %x, [[F0:%.*]]
45+
%z = sdiv i8 %x, %yy
46+
; CHECK: [[F1:%.*]] = freeze i8 %bb
47+
; CHECK-NEXT: %zz = sdiv i8 %x, [[F1:%.*]]
48+
%zz = sdiv i8 %x, %bb
49+
br i1 %cmp, label %if2.then, label %if2.end
50+
51+
if2.then:
52+
store i8 %z, ptr %v, align 8
53+
br label %if2.end
54+
55+
if2.end:
56+
ret void
57+
}

third_party/intel/lib/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
add_subdirectory(Analysis)
22
add_subdirectory(Dialect)
33
add_subdirectory(GPUToTritonGEN)
4+
add_subdirectory(LLVMIR)
45
add_subdirectory(Target)
56
add_subdirectory(TritonAnnotateModule)
67
add_subdirectory(TritonGENToLLVM)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
add_triton_library(TritonIntelLLVMIR
2+
LLVMIRFreezeMaskedDivRem.cpp
3+
4+
DEPENDS
5+
LLVMIRIncGen
6+
)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#include "LLVMPasses.h"
2+
#include "llvm/Analysis/TargetTransformInfo.h"
3+
#include "llvm/Analysis/ValueTracking.h"
4+
#include "llvm/IR/Dominators.h"
5+
#include "llvm/IR/Instructions.h"
6+
7+
using namespace llvm;
8+
9+
static bool processPhiNode(PHINode *PhiNode) {
10+
if (none_of(PhiNode->incoming_values(), [](Use &U) {
11+
Constant *C = dyn_cast<Constant>(&U);
12+
return isa<UndefValue>(U) || C && C->isNullValue();
13+
})) {
14+
return false;
15+
}
16+
17+
bool Changed = false;
18+
BasicBlock *BB = const_cast<BasicBlock *>(PhiNode->getParent());
19+
for (Instruction &I : *BB) {
20+
if (I.getOpcode() == Instruction::SDiv ||
21+
I.getOpcode() == Instruction::SRem) {
22+
const size_t OpIdx = 1;
23+
if (I.getOperand(OpIdx) == PhiNode) {
24+
auto *freezePhi = new FreezeInst(
25+
PhiNode, PhiNode->getName() + ".frozen", I.getIterator());
26+
I.setOperand(OpIdx, freezePhi);
27+
Changed = true;
28+
}
29+
}
30+
}
31+
return Changed;
32+
}
33+
34+
static bool runOnFunction(Function &F) {
35+
bool Changed = false;
36+
37+
for (BasicBlock &BB : F) {
38+
for (PHINode &PhiNode : BB.phis()) {
39+
Changed |= processPhiNode(&PhiNode);
40+
}
41+
}
42+
43+
return Changed;
44+
}
45+
46+
PreservedAnalyses FreezeMaskedDivRemPass::run(Function &F,
47+
FunctionAnalysisManager &FAM) {
48+
const auto b = runOnFunction(F);
49+
50+
return b ? PreservedAnalyses::none() : PreservedAnalyses::all();
51+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#include "llvm/IR/PassManager.h"
2+
#include "llvm/Pass.h"
3+
4+
namespace llvm {
5+
6+
struct FreezeMaskedDivRemPass : PassInfoMixin<FreezeMaskedDivRemPass> {
7+
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
8+
static StringRef name() { return "FreezeMaskedDivRemPass"; }
9+
};
10+
11+
} // namespace llvm

third_party/intel/triton_xpu.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "intel/include/TritonAnnotateModule/Passes.h"
1717
#include "intel/include/TritonIntelGPUToLLVM/Passes.h"
1818
#include "intel/include/TritonToTritonGPUWarp/Passes.h"
19+
#include "intel/lib/LLVMIR/LLVMPasses.h"
1920

2021
#include "triton/Target/SPIRV/SPIRVTranslation.h"
2122
#include "triton/Tools/Sys/GetEnv.hpp"
@@ -204,6 +205,21 @@ void init_triton_intel(py::module &&m) {
204205
fpm.addPass(BreakStructPhiNodesPass());
205206
fpm.addPass(InstCombinePass());
206207
});
208+
pb.registerPeepholeEPCallback(
209+
[&](llvm::FunctionPassManager &fpm, llvm::OptimizationLevel level) {
210+
// The Triton masked load pattern can generate instances where the
211+
// mask value causes undefined behavior in sdiv/srem instructions. The
212+
// language allows this UB as the result of those arithmetic
213+
// instructions is never used, and control flow to avoid computation
214+
// of these instructions would negatively affect performance. But,
215+
// LLVM SimplifyCFG aggressively marks code paths with undefined
216+
// behavior as dead. This can result in removal of the mask path and
217+
// incorrect results from legal Triton kernels due to masked elements
218+
// being used in computation. Run a pass to add a freeze instruction
219+
// between masked loads and sdiv/srem to signal to LLVM we consider
220+
// the sdiv/srem operands to be well defined.
221+
fpm.addPass(FreezeMaskedDivRemPass());
222+
});
207223
mpm.addPass(pb.buildPerModuleDefaultPipeline(opt));
208224
mpm.run(*mod, mam);
209225
});

0 commit comments

Comments
 (0)