Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions bin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ target_link_libraries(triton-opt PRIVATE
TritonTransforms
TritonGPUTransforms
TritonNvidiaGPUTransforms
TritonIntelLLVMIR
MLIRGPUToROCDLTransforms
${dialect_libs}
${conversion_libs}
Expand Down Expand Up @@ -88,6 +89,7 @@ target_link_libraries(triton-llvm-opt PRIVATE
LLVMSupport
LLVMOption
LLVMCodeGen
TritonIntelLLVMIR
TritonIntelGPUIR
)
export_executable_symbols_for_plugins(triton-llvm-opt)
Expand Down
8 changes: 8 additions & 0 deletions bin/triton-llvm-opt.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/// Trimmed down clone of llvm opt to be able to test triton custom llvm ir
/// passes.
#include "lib/Target/LLVMIR/LLVMPasses.h"
#include "third_party/intel/lib/LLVMIR/LLVMPasses.h"
#include "llvm/CodeGen/CommandFlags.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
Expand Down Expand Up @@ -42,6 +43,11 @@ static cl::opt<bool>
llvm::cl::desc("run pass to break phi struct"),
cl::init(false));

static cl::opt<bool> FreezeMaskedDivRem(
"freeze-masked-div-rem",
llvm::cl::desc("run pass to insert freeze between masked load and div/rem"),
cl::init(false));

namespace {
static std::function<Error(Module *)> makeOptimizingPipeline() {
return [](Module *m) -> Error {
Expand All @@ -62,6 +68,8 @@ static std::function<Error(Module *)> makeOptimizingPipeline() {
llvm::FunctionPassManager fpm;
if (BreakStructPhiNodes)
fpm.addPass(BreakStructPhiNodesPass());
if (FreezeMaskedDivRem)
fpm.addPass(FreezeMaskedDivRemPass());
mpm.addPass(createModuleToFunctionPassAdaptor(std::move(fpm)));
mpm.run(*m, mam);
return Error::success();
Expand Down
84 changes: 84 additions & 0 deletions python/test/regression/test_divide.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# flake8: noqa: F821, F841
import torch
import pytest

import triton
import triton.language as tl

aten = torch.ops.aten


def patch_kernel(template, to_replace):
kernel = triton.JITFunction(template.fn)
for key, value in to_replace.items():
kernel.src = kernel.src.replace(key, value)
return kernel


@pytest.mark.parametrize("float_div", [True, False])
@pytest.mark.parametrize("floor", [True, False])
@pytest.mark.parametrize("trunc", [True, False])
def test_divide(float_div, floor, trunc, device):
# regression test for various division cases

@triton.jit
def divide_kernel(a, b, out_ptr0, out_ptr1, out_ptr2, out_ptr3, out_ptr4, xnumel, XBLOCK: tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(a + (x0), xmask)
tmp2 = tl.load(b + (x0), xmask)
# custom bits
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp1 / tmp3
tmp5 = tl.where((tmp0 < 0) != (tmp2 < 0), tl.where(tmp0 % tmp2 != 0, tmp0 // tmp2 - 1, tmp0 // tmp2),
tmp0 // tmp2)
tmp6 = tmp0 // tmp2
GENERATE_OUTPUTS_HERE

torch.manual_seed(0)

outputs_float_div = "tl.store(out_ptr0 + (x0), tmp4, xmask)\n tl.store(out_ptr3 + (x0), tmp4, xmask)" if float_div else ""
outputs_floor = " tl.store(out_ptr1 + (x0), tmp5, xmask)\n tl.store(out_ptr4 + (x0), tmp5, xmask)" if floor else ""
outputs_trunc = " tl.store(out_ptr2 + (x0), tmp6, xmask)" if trunc else ""

divide_kernel = patch_kernel(divide_kernel,
{"GENERATE_OUTPUTS_HERE": f"{outputs_float_div}\n{outputs_floor}\n{outputs_trunc}"})

def launch_triton(a, b):
output0 = torch.zeros_like(a)
output1 = torch.zeros_like(a)
output2 = torch.zeros_like(a)
output3 = torch.zeros_like(a)
output4 = torch.zeros_like(a)

n_elements = output0.numel()

grid = lambda meta: (triton.cdiv(n_elements, meta['XBLOCK']), )

divide_kernel[grid](a, b, output0, output1, output2, output3, output4, n_elements, XBLOCK=128)

return (output0, output1, output2, output3, output4)

def launch_torch(a, b):
return (
aten.div(a, b, rounding_mode=None) if float_div is True else torch.zeros_like(a),
aten.div(a, b, rounding_mode="floor") if floor is True else torch.zeros_like(a),
aten.div(a, b, rounding_mode="trunc") if trunc is True else torch.zeros_like(a),
a / b if float_div is True else torch.zeros_like(a),
a // b if floor is True else torch.zeros_like(a),
)

a = torch.randint(2**32, 2**40, [100, 100], device=device)
b = torch.randint(-10, -1, [100, 100], device=device)

for iter in range(100):
triton_result = launch_triton(a, b)
torch_result = launch_torch(a, b)

for i in range(5):
torch.testing.assert_close(
triton_result[i], torch_result[i], check_dtype=False, msg=lambda msg:
f"Float: {float_div}, Floor: {floor}, Trunc: {trunc}\nIteration {iter}, {i} failed\n{msg}")
57 changes: 57 additions & 0 deletions test/LLVMIR/freeze-masked-div-rem.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
; RUN: triton-llvm-opt -freeze-masked-div-rem %s | FileCheck %s

define void @phi_div_of_zero_okay(i8 noundef %x, i8 %i, ptr %v) {
; CHECK-LABEL: @phi_div_of_zero_okay(
entry:
%cmp = icmp ult i8 %i, 9
br i1 %cmp, label %if.then, label %if.end

if.then:
%y = load i8, ptr %v, align 8
br label %if.end

if.end:
%yy = phi i8 [ %y, %if.then ], [ 0, %entry ]
; CHECK: [[F0:%.*]] = freeze i8 %yy
; CHECK-NEXT: %z = sdiv i8 %x, [[F0:%.*]]
%z = sdiv i8 %x, %yy
br i1 %cmp, label %if2.then, label %if2.end

if2.then:
store i8 %z, ptr %v, align 8
br label %if2.end

if2.end:
ret void
}

define void @two_phi_div_of_zero_okay(i8 noundef %x, i8 %i, ptr %v) {
; CHECK-LABEL: @two_phi_div_of_zero_okay(
entry:
%cmp = icmp ult i8 %i, 9
br i1 %cmp, label %if.then, label %if.end

if.then:
%y = load i8, ptr %v, align 8
%vv = getelementptr inbounds i64, ptr %v, i64 1
%b = load i8, ptr %vv, align 8
br label %if.end

if.end:
%bb = phi i8 [ %b, %if.then ], [ undef, %entry ]
%yy = phi i8 [ %y, %if.then ], [ 0, %entry ]
; CHECK: [[F0:%.*]] = freeze i8 %yy
; CHECK-NEXT: %z = sdiv i8 %x, [[F0:%.*]]
%z = sdiv i8 %x, %yy
; CHECK: [[F1:%.*]] = freeze i8 %bb
; CHECK-NEXT: %zz = sdiv i8 %x, [[F1:%.*]]
%zz = sdiv i8 %x, %bb
br i1 %cmp, label %if2.then, label %if2.end

if2.then:
store i8 %z, ptr %v, align 8
br label %if2.end

if2.end:
ret void
}
1 change: 1 addition & 0 deletions third_party/intel/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_subdirectory(Analysis)
add_subdirectory(Dialect)
add_subdirectory(GPUToTritonGEN)
add_subdirectory(LLVMIR)
add_subdirectory(Target)
add_subdirectory(TritonAnnotateModule)
add_subdirectory(TritonGENToLLVM)
Expand Down
6 changes: 6 additions & 0 deletions third_party/intel/lib/LLVMIR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
add_triton_library(TritonIntelLLVMIR
LLVMIRFreezeMaskedDivRem.cpp

DEPENDS
LLVMIRIncGen
)
51 changes: 51 additions & 0 deletions third_party/intel/lib/LLVMIR/LLVMIRFreezeMaskedDivRem.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#include "LLVMPasses.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Instructions.h"

using namespace llvm;

static bool processPhiNode(PHINode *PhiNode) {
if (none_of(PhiNode->incoming_values(), [](Use &U) {
Constant *C = dyn_cast<Constant>(&U);
return isa<UndefValue>(U) || C && C->isNullValue();
})) {
return false;
}

bool Changed = false;
BasicBlock *BB = const_cast<BasicBlock *>(PhiNode->getParent());
for (Instruction &I : *BB) {
if (I.getOpcode() == Instruction::SDiv ||
I.getOpcode() == Instruction::SRem) {
const size_t OpIdx = 1;
if (I.getOperand(OpIdx) == PhiNode) {
auto *freezePhi = new FreezeInst(
PhiNode, PhiNode->getName() + ".frozen", I.getIterator());
I.setOperand(OpIdx, freezePhi);
Changed = true;
}
}
}
return Changed;
}

static bool runOnFunction(Function &F) {
bool Changed = false;

for (BasicBlock &BB : F) {
for (PHINode &PhiNode : BB.phis()) {
Changed |= processPhiNode(&PhiNode);
}
}

return Changed;
}

PreservedAnalyses FreezeMaskedDivRemPass::run(Function &F,
FunctionAnalysisManager &FAM) {
const auto b = runOnFunction(F);

return b ? PreservedAnalyses::none() : PreservedAnalyses::all();
}
11 changes: 11 additions & 0 deletions third_party/intel/lib/LLVMIR/LLVMPasses.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#include "llvm/IR/PassManager.h"
#include "llvm/Pass.h"

namespace llvm {

struct FreezeMaskedDivRemPass : PassInfoMixin<FreezeMaskedDivRemPass> {
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
static StringRef name() { return "FreezeMaskedDivRemPass"; }
};

} // namespace llvm
16 changes: 16 additions & 0 deletions third_party/intel/triton_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "intel/include/TritonAnnotateModule/Passes.h"
#include "intel/include/TritonIntelGPUToLLVM/Passes.h"
#include "intel/include/TritonToTritonGPUWarp/Passes.h"
#include "intel/lib/LLVMIR/LLVMPasses.h"

#include "triton/Target/SPIRV/SPIRVTranslation.h"
#include "triton/Tools/Sys/GetEnv.hpp"
Expand Down Expand Up @@ -204,6 +205,21 @@ void init_triton_intel(py::module &&m) {
fpm.addPass(BreakStructPhiNodesPass());
fpm.addPass(InstCombinePass());
});
pb.registerPeepholeEPCallback(
[&](llvm::FunctionPassManager &fpm, llvm::OptimizationLevel level) {
// The Triton masked load pattern can generate instances where the
// mask value causes undefined behavior in sdiv/srem instructions. The
// language allows this UB as the result of those arithmetic
// instructions is never used, and control flow to avoid computation
// of these instructions would negatively affect performance. But,
// LLVM SimplifyCFG aggressively marks code paths with undefined
// behavior as dead. This can result in removal of the mask path and
// incorrect results from legal Triton kernels due to masked elements
// being used in computation. Run a pass to add a freeze instruction
// between masked loads and sdiv/srem to signal to LLVM we consider
// the sdiv/srem operands to be well defined.
fpm.addPass(FreezeMaskedDivRemPass());
});
mpm.addPass(pb.buildPerModuleDefaultPipeline(opt));
mpm.run(*mod, mam);
});
Expand Down