Skip to content

Commit 0b68388

Browse files
authored
Fix descriptor type being lost in the frontend after control flow (triton-lang#5086)
Fix descriptor type being lost in the frontend after control flow After control flow every IR value was being wrapped in a tensor, this adds two new methods `_value._flatten_ir()` and `_value._unflatten_ir` which convert an arbitrary value to/from a list of `ir.value`.
1 parent 4480f86 commit 0b68388

File tree

7 files changed

+228
-58
lines changed

7 files changed

+228
-58
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def TTG_UpcastMXFPOp : TTG_Op<"upcast_mxfp", [Pure, DeclareOpInterfaceMethods<In
280280
def TTG_GlobalScratchAllocOp : TTG_Op<"global_scratch_alloc", [MemoryEffects<[MemAlloc<GlobalMemory>]>]> {
281281
let summary = "allocate a global memory buffer";
282282
let description = [{
283-
This operation allocates a buffer in global memory.
283+
This operation allocates a buffer in global memory that is private to the current program.
284284
}];
285285
let arguments = (
286286
ins

python/src/ir.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,16 @@ void init_triton_ir(py::module &&m) {
249249
.def("is_integer",
250250
[](Type &self, unsigned width) { return self.isInteger(width); })
251251
.def("is_fp16", &Type::isF16)
252+
.def("__eq__",
253+
[](Type &self, py::object &other) {
254+
Type *other_ty = py::cast<Type *>(other);
255+
return (other_ty != nullptr) && (*other_ty == self);
256+
})
257+
.def("__ne__",
258+
[](Type &self, py::object &other) {
259+
Type *other_ty = py::cast<Type *>(other);
260+
return (other_ty == nullptr) || (*other_ty != self);
261+
})
252262
.def("__str__", [](Type &self) {
253263
std::string str;
254264
llvm::raw_string_ostream os(str);

python/test/unit/hopper/test_experimental_tma.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,3 +460,77 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
460460
assert "tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned" in kernel.asm["ptx"]
461461
if BLOCK_M >= 64 and BLOCK_N >= 64:
462462
assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm["ptx"]
463+
464+
465+
@triton.jit
466+
def kernel_make_tensor_desciptor_loop_carried(a_ptr, M, N, MBLOCK: tl.constexpr, NBLOCK: tl.constexpr):
467+
# Test that descriptors work with
468+
pid = tl.program_id(0)
469+
moffset = MBLOCK * pid
470+
471+
a_desc = tl._experimental_make_tensor_descriptor(
472+
a_ptr,
473+
shape=[M, N],
474+
strides=[N, 1],
475+
block_shape=[MBLOCK, NBLOCK],
476+
)
477+
478+
for i in range(0, N, NBLOCK):
479+
assert isinstance(a_desc, tl._experimental_tensor_descriptor)
480+
if i % (3 * NBLOCK) == 0:
481+
a_desc = tl._experimental_make_tensor_descriptor(
482+
a_ptr,
483+
shape=[M, N],
484+
strides=[N, 1],
485+
block_shape=[MBLOCK, NBLOCK],
486+
)
487+
assert isinstance(a_desc, tl._experimental_tensor_descriptor)
488+
assert isinstance(a_desc, tl._experimental_tensor_descriptor)
489+
a = a_desc.load([moffset, i])
490+
a_desc.store([moffset, i], a + 10)
491+
492+
n = 0
493+
while n < N:
494+
assert isinstance(a_desc, tl._experimental_tensor_descriptor)
495+
if n % (3 * NBLOCK) == 0:
496+
assert isinstance(a_desc, tl._experimental_tensor_descriptor)
497+
a_desc = tl._experimental_make_tensor_descriptor(
498+
a_ptr,
499+
shape=[M, N],
500+
strides=[N, 1],
501+
block_shape=[MBLOCK, NBLOCK],
502+
)
503+
assert isinstance(a_desc, tl._experimental_tensor_descriptor)
504+
a = a_desc.load([moffset, n])
505+
a_desc.store([moffset, n], a + 5)
506+
507+
n += NBLOCK
508+
509+
510+
@requires_tma
511+
def test_experimental_make_tensor_descriptor_loop_carried():
512+
device = "cuda"
513+
M, N = 8192, 8192
514+
torch.manual_seed(42)
515+
A = torch.randn((M, N), dtype=torch.float32, device=device)
516+
MBLOCK, NBLOCK = 8, 128
517+
grid = (triton.cdiv(M, MBLOCK), )
518+
519+
def alloc_fn(size: int, align: int, stream: Optional[int]):
520+
assert size == 128 * grid[0]
521+
assert align == 128
522+
assert stream == 0
523+
return torch.empty(size, dtype=torch.int8, device="cuda")
524+
525+
triton.set_allocator(alloc_fn)
526+
527+
ref_out = A + 15
528+
kernel = kernel_make_tensor_desciptor_loop_carried[grid](
529+
A,
530+
M,
531+
N,
532+
MBLOCK,
533+
NBLOCK,
534+
)
535+
torch.testing.assert_close(ref_out, A)
536+
assert "tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned" in kernel.asm["ptx"]

python/triton/_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from typing import Tuple, List, Any
2+
3+
# Poor man's PyTree
4+
5+
6+
def list_list_flatten(x: List[List[Any]]) -> Tuple[List[int], List[Any]]:
7+
spec = []
8+
flat = []
9+
for l in x:
10+
spec.append(len(l))
11+
flat.extend(l)
12+
return spec, flat
13+
14+
15+
def list_list_unflatten(spec: List[int], flat: List[Any]) -> List[List[Any]]:
16+
ret = []
17+
idx = 0
18+
for size in spec:
19+
ret.append(flat[idx:idx + size])
20+
idx += size
21+
assert idx == len(flat)
22+
return ret

0 commit comments

Comments
 (0)