Skip to content

Commit d39ee1f

Browse files
authored
[IR] Add poison value to triton IR and use in frontend in place of undef (#4896)
`LLVM::UndefOp` only supports LLVM native types, so instead this uses the closest equivalent in mlir which is `ub::PoisonOp`. As the name implies, this lowers to poison in llvm instead of undef but I'm not sure the difference matters to us as we don't have a UB policy anyway.
1 parent ddb7098 commit d39ee1f

File tree

8 files changed

+36
-17
lines changed

8 files changed

+36
-17
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
210210
MLIRSCFToControlFlow
211211
MLIRIndexToLLVM
212212
MLIRGPUToROCDLTransforms
213+
MLIRUBToLLVM
213214

214215
# LLVM
215216
LLVMPasses

include/triton/Dialect/Triton/IR/TritonDialect.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ def Triton_Dialect : Dialect {
2828
"arith::ArithDialect",
2929
"math::MathDialect",
3030
"scf::SCFDialect",
31-
"cf::ControlFlowDialect"
31+
"cf::ControlFlowDialect",
32+
"ub::UBDialect"
3233
];
3334

3435
let extraClassDeclaration = [{

lib/Dialect/Triton/IR/Dialect.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "triton/Dialect/Triton/IR/Types.h"
33

44
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
5+
#include "mlir/Dialect/UB/IR/UBOps.h"
56
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
67
#include "llvm/ADT/StringSwitch.h"
78
#include "llvm/ADT/TypeSwitch.h"

python/src/ir.cc

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
77
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
88
#include "mlir/Dialect/Index/IR/IndexDialect.h"
9-
#include "mlir/Dialect/Index/IR/IndexOps.h"
109
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1110
#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
11+
#include "mlir/Dialect/UB/IR/UBOps.h"
1212
#include "mlir/IR/Builders.h"
1313
#include "mlir/IR/BuiltinOps.h"
1414
#include "mlir/IR/Diagnostics.h"
@@ -22,12 +22,11 @@
2222
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
2323
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
2424
#include "mlir/Transforms/LocationSnapshot.h"
25-
#include "mlir/Transforms/Passes.h"
2625

27-
#include "triton/Analysis/Allocation.h"
2826
#include "triton/Dialect/Triton/IR/Dialect.h"
2927
#include "triton/Dialect/Triton/IR/Types.h"
3028
#include "triton/Dialect/Triton/IR/Utility.h"
29+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
3130
#include "triton/Tools/Sys/GetEnv.hpp"
3231
#include "llvm/Support/SourceMgr.h"
3332

@@ -225,7 +224,8 @@ void init_triton_ir(py::module &&m) {
225224
registry.insert<TritonDialect, ::mlir::triton::gpu::TritonGPUDialect,
226225
math::MathDialect, arith::ArithDialect, index::IndexDialect,
227226
scf::SCFDialect, ::mlir::gpu::GPUDialect,
228-
cf::ControlFlowDialect, LLVM::LLVMDialect>();
227+
cf::ControlFlowDialect, LLVM::LLVMDialect,
228+
mlir::ub::UBDialect>();
229229
mlir::LLVM::registerInlinerInterface(registry);
230230
registerBuiltinDialectTranslation(registry);
231231
registerLLVMDialectTranslation(registry);
@@ -1529,10 +1529,9 @@ void init_triton_ir(py::module &&m) {
15291529
[](TritonOpBuilder &self, Value &condition) {
15301530
self.create<LLVM::AssumeOp>(condition);
15311531
})
1532-
// Undef
1533-
.def("create_undef",
1532+
.def("create_poison",
15341533
[](TritonOpBuilder &self, Type &type) -> Value {
1535-
return self.create<LLVM::UndefOp>(type);
1534+
return self.create<ub::PoisonOp>(type);
15361535
})
15371536
.def("create_histogram",
15381537
[](TritonOpBuilder &self, Value operand, int numBins) -> Value {

python/test/unit/language/test_core.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4882,6 +4882,25 @@ def kernel(Semaphore, Out, total: tl.constexpr):
48824882
assert out.item() >= 0
48834883

48844884

4885+
@triton.jit
4886+
def return_poison(x):
4887+
a = False
4888+
if a:
4889+
return x
4890+
4891+
4892+
def test_poison_return(device):
4893+
4894+
@triton.jit
4895+
def kernel(Out):
4896+
tl.store(Out, return_poison(0))
4897+
4898+
a = torch.empty((), device=device, dtype=torch.int32)
4899+
h = kernel[(1, )](a)
4900+
assert "ub.poison" in h.asm["ttir"], h.asm["ttir"]
4901+
assert "poison" in h.asm["llir"], h.asm["llir"]
4902+
4903+
48854904
# -----------------------
48864905
# test extra
48874906
# -----------------------

python/triton/compiler/code_generator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ def visit_FunctionDef(self, node):
448448
self.prototype.ret_types = list(self.ret_type) if isinstance(self.ret_type, tuple) else [self.ret_type]
449449
self.fn.reset_type(self.prototype.to_ir(self.builder))
450450
self.builder.ret([
451-
self.builder.create_undef(ty.to_ir(self.builder))
451+
self.builder.create_poison(ty.to_ir(self.builder))
452452
for ty in self.prototype.ret_types
453453
if self.ret_type is not None
454454
])
@@ -954,7 +954,7 @@ def visit_For(self, node):
954954
ub = self.builder.create_int_cast(ub, iv_ir_type, iv_is_signed)
955955
step = self.builder.create_int_cast(step, iv_ir_type, iv_is_signed)
956956
# Create placeholder for the loop induction variable
957-
iv = self.builder.create_undef(iv_ir_type)
957+
iv = self.builder.create_poison(iv_ir_type)
958958
self.set_value(node.target.id, language.core.tensor(iv, iv_type))
959959

960960
with enter_sub_region(self) as sr:

third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
99
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
1010
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
11+
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
1112
#include "mlir/Dialect/Index/IR/IndexDialect.h"
1213
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1314
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
@@ -221,6 +222,7 @@ struct ConvertTritonAMDGPUToLLVM
221222
patterns);
222223
mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns,
223224
targetInfo, commonBenefit);
225+
mlir::ub::populateUBToLLVMConversionPatterns(typeConverter, patterns);
224226
if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) {
225227
return signalPassFailure();
226228
}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,25 @@
11
#include "Dialect/NVGPU/IR/Dialect.h"
22
#include "TritonNVIDIAGPUToLLVM/Passes.h"
33
#include "TritonNVIDIAGPUToLLVM/Utility.h"
4-
#include "mlir/Analysis/DataFlowFramework.h"
54
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
65
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
76
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
8-
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
97
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
108
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
11-
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
9+
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
1210
#include "mlir/Dialect/Index/IR/IndexDialect.h"
13-
#include "mlir/Dialect/Index/IR/IndexOps.h"
1411
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1512
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
1613
#include "mlir/Pass/Pass.h"
17-
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1814
#include "triton/Analysis/Allocation.h"
1915
#include "triton/Analysis/AxisInfo.h"
2016
#include "triton/Analysis/Membar.h"
17+
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
2118
#include "triton/Dialect/Triton/IR/Dialect.h"
2219
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
2320
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
2421

2522
#include "PatternTritonGPUOpToLLVM.h"
26-
#include "Utility.h"
2723
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
2824
#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h"
2925

@@ -36,7 +32,6 @@ namespace triton {
3632

3733
using namespace mlir;
3834
using namespace mlir::triton::NVIDIA;
39-
namespace ttng = mlir::triton::nvidia_gpu;
4035

4136
namespace {
4237

@@ -168,6 +163,7 @@ struct ConvertTritonGPUToLLVM
168163
mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns);
169164
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
170165
patterns);
166+
mlir::ub::populateUBToLLVMConversionPatterns(typeConverter, patterns);
171167
mlir::triton::populateViewOpToLLVMPatterns(typeConverter, patterns,
172168
benefit);
173169
mlir::triton::populateAssertOpToLLVMPattern(typeConverter, patterns,

0 commit comments

Comments
 (0)