Skip to content

Commit a94dee0

Browse files
authored
Add extra attribute in PrintOp to propagate signness info (intel#4363)
The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** -------- This PR aims to address triton-lang/triton#4248, to correctly `device_print` the value if it is an signed integer. The signness info is lost when lowering TTGIR to LLIR (e.g. `i32` is always **signless** in MLIR), but the **lowered** data type is currently being used for constructing the format specifier in the `PrintOpToLLVM` implementation (triton-lang/triton#4248 (comment)), so a negative value is printed out as an unsigned int, thus confusing users. A minimal reproducer is ```python import torch import triton import triton.language as tl @triton.jit def print_kernel(ptr): value = tl.load(ptr) tl.device_print("value in kernel from device_print", value) print_kernel[(1,)](torch.tensor(10, dtype=torch.int32).cuda()) print_kernel[(1,)](torch.tensor(-10, dtype=torch.int32).cuda()) print_kernel[(1,)](torch.tensor((1 << 31) + 1000, dtype=torch.uint32).cuda()) ``` Currently, it prints ``` pid (0, 0, 0) idx () value in kernel from device_print: 10 ... pid (0, 0, 0) idx () value in kernel from device_print: 4294967286 ... pid (0, 0, 0) idx () value in kernel from device_print: 2147484648 ``` (always as unsigned int) This PR adds extra `isSigned` attribute in the `PrintOp` to indicate if each operand in the `PrintOp` should be printed as signed or not. With this, the program above now prints correctly ``` pid (0, 0, 0) idx () value in kernel from device_print: 10 ... pid (0, 0, 0) idx () value in kernel from device_print: -10 ... pid (0, 0, 0) idx () value in kernel from device_print: 2147484648 ``` Extra LIT tests and python unit tests are added as well; also manually verified that they failed without the fix and passing now by running ``` $ pytest python/test/unit/language/test_subprocess.py $ cd python/build/cmake.linux-x86_64-cpython-3.10; lit test ``` **Alternative considered**: adds `uint32` in the triton MLIR data type definition and then rely on the triton op data type to determine the format specifier, to retain the original signness info; as in commit triton-lang/triton@f7a7407. However, as PR reviewer pointed out, that means adding a new data type in the Triton IR just for this purpose, which is overkill and introduces unnecessary maintenance overhead and thus less ideal. -------- Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. - [X] I am not making a trivial change, such as fixing a typo in a comment. - [X] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [X] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [X] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [ ] I have not added any `lit` tests. - [X] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent 9c7f9f8 commit a94dee0

File tree

11 files changed

+78
-28
lines changed

11 files changed

+78
-28
lines changed

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -820,8 +820,14 @@ def TT_HistogramOp : TT_Op<"histogram", [Pure]> {
820820
//
821821
// Print Op
822822
//
823-
def TT_PrintOp : TT_Op<"print", [MemoryEffects<[MemWrite<GlobalMemory>]>]>,
824-
Arguments<(ins StrAttr:$prefix, BoolAttr:$hex, Variadic<AnyTypeOf<[TT_Type]>>:$args)> {
823+
def TT_PrintOp : TT_Op<"print", [SameVariadicOperandSize, MemoryEffects<[MemWrite<GlobalMemory>]>]> {
824+
let arguments = (
825+
ins
826+
StrAttr:$prefix,
827+
BoolAttr:$hex,
828+
Variadic<AnyTypeOf<[TT_Type]>>:$args,
829+
DenseI32ArrayAttr:$isSigned
830+
);
825831
let summary = "Device-side print, as in CUDA for debugging";
826832
let description = [{
827833
`tt.print` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed.

lib/Conversion/TritonGPUToLLVM/PrintOpToLLVM.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,10 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern<triton::PrintOp> {
4444
return success();
4545
}
4646

47+
assert(op.getNumOperands() == op.getIsSigned().size());
48+
4749
for (size_t i = 0; i < op.getNumOperands(); i++) {
50+
bool isSigned = op.getIsSigned()[i] > 0;
4851
// Elements of the tensor that are resident in this GPU thread.
4952
auto elems = unpackLLElements(loc, adaptor.getOperands()[i], rewriter);
5053

@@ -76,7 +79,7 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern<triton::PrintOp> {
7679
if (!elems.empty()) {
7780
printTensor(op.getPrefix(), /*operand=*/i,
7881
/*numOperands=*/op.getNumOperands(), elems, pid, indices,
79-
dimWidths, op.getHex(), rewriter);
82+
dimWidths, op.getHex(), rewriter, isSigned);
8083
}
8184
}
8285
rewriter.eraseOp(op);
@@ -87,7 +90,7 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern<triton::PrintOp> {
8790
ArrayRef<Value> elems, std::array<Value, 3> pid,
8891
ArrayRef<SmallVector<Value>> indices,
8992
ArrayRef<int> dimWidths, bool hex,
90-
ConversionPatternRewriter &rewriter) const {
93+
ConversionPatternRewriter &rewriter, bool isSigned) const {
9194
assert(!elems.empty());
9295
assert(elems.size() == indices.size());
9396
assert(dimWidths.size() == indices.front().size());
@@ -151,7 +154,8 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern<triton::PrintOp> {
151154
}
152155

153156
auto elem = elems[i];
154-
os << getFormatSubstr(elem, hex);
157+
158+
os << getFormatSubstr(elem, hex, /*width=*/std::nullopt, isSigned);
155159
printfOperands.push_back(elem);
156160

157161
// It's the same format string each iteration, but it's a lot easier if we
@@ -169,8 +173,10 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern<triton::PrintOp> {
169173
}
170174

171175
std::string getFormatSubstr(Value value, bool hex = false,
172-
std::optional<int> width = std::nullopt) const {
176+
std::optional<int> width = std::nullopt,
177+
bool isSigned = false) const {
173178
Type type = value.getType();
179+
// If the `value` is a pointer, just return %p.
174180
if (isa<LLVM::LLVMPointerType>(type)) {
175181
return "%p";
176182
}
@@ -192,21 +198,16 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern<triton::PrintOp> {
192198
prefix += std::to_string(*width);
193199
} else if (hex) {
194200
prefix += "0";
195-
prefix += std::to_string(value.getType().getIntOrFloatBitWidth() / 4);
201+
prefix += std::to_string(type.getIntOrFloatBitWidth() / 4);
196202
}
197203

198204
if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) {
199205
return prefix + "f";
200-
} else if (type.isSignedInteger()) {
201-
if (type.getIntOrFloatBitWidth() == 64)
202-
return prefix + "lli";
203-
else
204-
return prefix + "i";
205-
} else if (type.isUnsignedInteger() || type.isSignlessInteger()) {
206+
} else if (type.isInteger()) {
206207
if (type.getIntOrFloatBitWidth() == 64)
207-
return prefix + "llu";
208+
return prefix + (isSigned ? "lli" : "llu");
208209
else
209-
return prefix + "u";
210+
return prefix + (isSigned ? "i" : "u");
210211
}
211212
assert(false && "not supported type");
212213
return "";

python/src/ir.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1520,11 +1520,11 @@ void init_triton_ir(py::module &&m) {
15201520
})
15211521
.def("create_print",
15221522
[](TritonOpBuilder &self, const std::string &prefix, bool hex,
1523-
const std::vector<Value> &values) -> void {
1524-
self.create<PrintOp>(
1525-
StringAttr::get(self.getBuilder().getContext(),
1526-
llvm::StringRef(prefix)),
1527-
hex, values);
1523+
const std::vector<Value> &values,
1524+
const std::vector<int32_t> &isSigned) -> void {
1525+
auto prefixAttr = StringAttr::get(self.getBuilder().getContext(),
1526+
llvm::StringRef(prefix));
1527+
self.create<PrintOp>(prefixAttr, hex, values, isSigned);
15281528
})
15291529
.def("create_assert",
15301530
[](TritonOpBuilder &self, Value &condition,

python/test/unit/language/print_helper.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,12 @@ def test_print(func: str, data_type: str, device: str):
104104
elif func == "device_print_scalar":
105105
scalar = torch.tensor(42, dtype=x.dtype, device="cuda")
106106
kernel_device_print_scalar[(1, )](scalar, num_warps=num_warps)
107+
elif func == "device_print_negative":
108+
x = -x
109+
kernel_device_print[(1, )](x, y, num_warps=num_warps, BLOCK=N)
110+
elif func == "device_print_uint":
111+
x = torch.arange((1 << 31), (1 << 31) + N, device=device).to(getattr(torch, data_type))
112+
kernel_device_print[(1, )](x, y, num_warps=num_warps, BLOCK=N)
107113
elif func == "print":
108114
kernel_print[(1, )](x, y, num_warps=num_warps, BLOCK=N)
109115
elif func == "device_print_large":

python/test/unit/language/test_subprocess.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ def is_interpreter():
3838
("device_print_hex", "int32"),
3939
("device_print_hex", "int64"),
4040
("device_print_pointer", "int32"),
41+
("device_print_negative", "int32"),
42+
("device_print_uint", "uint32"),
4143
])
4244
def test_print(func_type: str, data_type: str, device: str):
4345
proc = subprocess.run(
@@ -62,9 +64,10 @@ def test_print(func_type: str, data_type: str, device: str):
6264
# Format is
6365
# pid (<x>, <y>, <z>) idx (<i1>, <i2>, ...) <prefix> (operand <n>) <elem>
6466
expected_lines = Counter()
65-
if func_type == "print" or func_type == "device_print":
67+
if func_type in ("print", "device_print", "device_print_uint"):
6668
for i in range(N):
67-
line = f"pid (0, 0, 0) idx ({i:3}) x: {i}"
69+
offset = (1 << 31) if data_type == "uint32" else 0
70+
line = f"pid (0, 0, 0) idx ({i:3}) x: {i + offset}"
6871
if data_type.startswith("float"):
6972
line += ".000000"
7073
expected_lines[line] = 1
@@ -73,6 +76,10 @@ def test_print(func_type: str, data_type: str, device: str):
7376
if data_type.startswith("float"):
7477
line += ".000000"
7578
expected_lines[line] = N
79+
elif func_type == "device_print_negative":
80+
for i in range(N):
81+
line = f"pid (0, 0, 0) idx ({i:3}) x: {-i}"
82+
expected_lines[line] = 1
7683
elif func_type == "device_print_hex":
7784
for i in range(N):
7885
line = f"pid (0, 0, 0) idx ({i:3}) x: 0x"

python/triton/language/semantic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1535,7 +1535,8 @@ def device_print(prefix: str, args: List[tl.tensor], hex: bool, builder: ir.buil
15351535
prefix = " " + prefix
15361536

15371537
new_args = [arg.handle for arg in args]
1538-
return tl.tensor(builder.create_print(prefix, hex, new_args), tl.void)
1538+
is_signed = [arg.dtype in (tl.int1, tl.int8, tl.int16, tl.int32, tl.int64) for arg in args]
1539+
return tl.tensor(builder.create_print(prefix, hex, new_args, is_signed), tl.void)
15391540

15401541

15411542
def device_assert(cond: tl.tensor, msg: str, file_name: str, func_name, lineno: int, builder: ir.builder) -> tl.tensor:

python/triton/runtime/interpreter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,10 @@ def create_extern_elementwise(self, libName, libPath, symbol, argList, retType,
628628
def create_inline_asm(self, inlineAsm, constraints, values, type, isPure, pack):
629629
raise NotImplementedError("inline_asm not supported in interpreter mode")
630630

631-
def create_print(self, prefix, hex, values):
631+
def create_print(self, prefix, hex, values, isSigned):
632+
# NOTE: the `isSigned` variable is not really used here; because Signness is already known
633+
# by `values` themselves in python interpreter, thus not really needed here;
634+
# it is only used for triton PrintOpToLLVM to correctly construct the format specifier.
632635
# Interpreter's device_print function has a different format than Triton's device_print
633636
msg = f"({self.grid_idx[0]}, {self.grid_idx[1]}, {self.grid_idx[2]})"
634637
if prefix:

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1639,7 +1639,33 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
16391639
// CHECK-LABEL: print_ptr
16401640
// CHECK: llvm.call @vprintf(%{{.*}}, %{{.*}}) : (!llvm.ptr, !llvm.ptr) -> i32
16411641
tt.func @print_ptr(%arg0 : tensor<256x!tt.ptr<i32>, #blocked0>) {
1642-
tt.print "ptr: " {hex = false} : %arg0 : tensor<256x!tt.ptr<i32>, #blocked0>
1642+
tt.print "ptr: " {hex = false, isSigned = array<i32: 0>} : %arg0 : tensor<256x!tt.ptr<i32>, #blocked0>
1643+
tt.return
1644+
}
1645+
}
1646+
1647+
// -----
1648+
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
1649+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
1650+
// Test that %u format specifier is used if isSigned is false
1651+
// CHECK: llvm.mlir.global internal constant @printfFormat_0("{{.*}}int32 tensor: %u{{.*}}")
1652+
// CHECK-LABEL: print_int32_tensor_issigned_off
1653+
// CHECK: llvm.call @vprintf(%{{.*}}, %{{.*}}) : (!llvm.ptr, !llvm.ptr) -> i32
1654+
tt.func @print_int32_tensor_issigned_off(%arg0 : i32) {
1655+
tt.print "int32 tensor: " {hex = false, isSigned = array<i32: 0>} : %arg0 : i32
1656+
tt.return
1657+
}
1658+
}
1659+
1660+
// -----
1661+
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
1662+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
1663+
// Test that %i format specifier is used if isSigned is true
1664+
// CHECK: llvm.mlir.global internal constant @printfFormat_0("{{.*}}int32 tensor: %i{{.*}}")
1665+
// CHECK-LABEL: print_int32_tensor_issigned_on
1666+
// CHECK: llvm.call @vprintf(%{{.*}}, %{{.*}}) : (!llvm.ptr, !llvm.ptr) -> i32
1667+
tt.func @print_int32_tensor_issigned_on(%arg0 : i32) {
1668+
tt.print "int32 tensor: " {hex = false, isSigned = array<i32: 1>} : %arg0 : i32
16431669
tt.return
16441670
}
16451671
}

test/Triton/ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ tt.func @dot_ops_infer(%ptr: !tt.ptr<f32>, %v : f32) {
186186
// CHECK-LABEL: @print_no_arg
187187
tt.func @print_no_arg(%arg0: !tt.ptr<f32>) {
188188
// CHECK: tt.print "test"
189-
tt.print "test" { hex = false }
189+
tt.print "test" { hex = false, isSigned = array<i32: 0>}
190190
%0 = tt.load %arg0 : !tt.ptr<f32>
191191
tt.store %arg0, %0 : !tt.ptr<f32>
192192
tt.return

test/TritonGPU/accelerate-matmul.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 :
119119
// CHECK: tt.dot {{.*}} -> tensor<2x16x16xf32, #[[MMA1]]>
120120
%11 = tt.dot %8, %9, %10, inputPrecision = tf32 : tensor<2x16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked3}>> * tensor<2x16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked3}>> -> tensor<2x16x16xf32, #blocked3>
121121
%12 = triton_gpu.convert_layout %11 : tensor<2x16x16xf32, #blocked3> -> tensor<2x16x16xf32, #blocked>
122-
tt.print ": " {hex = false} : %12 : tensor<2x16x16xf32, #blocked>
122+
tt.print ": " {hex = false, isSigned = array<i32: 0>} : %12 : tensor<2x16x16xf32, #blocked>
123123
tt.return
124124
}
125125
}

0 commit comments

Comments
 (0)