Skip to content

Commit 915cc70

Browse files
authored
[TensorDesc] Add fallback for reduction ops (#6829)
This supports most of the same dtype/kind combinations, with the exception of min/max with float types which are only implemented in the frontend at the moment.
1 parent 22150c4 commit 915cc70

File tree

3 files changed

+209
-123
lines changed

3 files changed

+209
-123
lines changed

lib/Dialect/Triton/Transforms/RewriteTensorDescriptorToPointer.cpp

Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "llvm/ADT/SmallVector.h"
1717
#include "llvm/ADT/SmallVectorExtras.h"
1818
#include "llvm/Support/LogicalResult.h"
19+
#include "llvm/Support/raw_ostream.h"
1920
#include <mlir/Dialect/Arith/IR/Arith.h>
2021
#include <mlir/Dialect/Func/Transforms/FuncConversions.h>
2122
#include <mlir/IR/Builders.h>
@@ -253,8 +254,6 @@ struct RewriteLoadPattern : OpConversionPattern<triton::DescriptorLoadOp> {
253254
ConversionPatternRewriter &rewriter) const override {
254255
auto loc = op.getLoc();
255256
const auto blockShape = op.getDesc().getType().getBlockType().getShape();
256-
const auto rank = blockShape.size();
257-
258257
auto descTy = op.getDesc().getType();
259258
auto desc = unpackDescriptor(descTy, adaptor.getDesc());
260259
auto offsets = castToI64(rewriter, op.getIndices());
@@ -279,7 +278,6 @@ struct RewriteStorePattern : OpConversionPattern<triton::DescriptorStoreOp> {
279278
auto loc = op.getLoc();
280279
auto descTy = op.getDesc().getType();
281280
const auto blockShape = descTy.getBlockType().getShape();
282-
const auto rank = blockShape.size();
283281
auto desc = unpackDescriptor(descTy, adaptor.getDesc());
284282
auto offsets = castToI64(rewriter, op.getIndices());
285283

@@ -360,6 +358,68 @@ struct RewriteScatterPattern
360358
}
361359
};
362360

361+
std::optional<RMWOp> translateReduceKind(DescriptorReduceKind kind,
362+
TensorDescType ty) {
363+
auto scalarTy = ty.getBlockType().getElementType();
364+
switch (kind) {
365+
case DescriptorReduceKind::ADD:
366+
return scalarTy.isInteger() ? RMWOp::ADD : RMWOp::FADD;
367+
case DescriptorReduceKind::MIN:
368+
if (scalarTy.isUnsignedInteger()) {
369+
return RMWOp::UMIN;
370+
} else if (scalarTy.isSignedInteger()) {
371+
return RMWOp::MIN;
372+
}
373+
return {};
374+
case DescriptorReduceKind::MAX:
375+
if (scalarTy.isUnsignedInteger()) {
376+
return RMWOp::UMAX;
377+
} else if (scalarTy.isSignedInteger()) {
378+
return RMWOp::MAX;
379+
}
380+
return {};
381+
case DescriptorReduceKind::AND:
382+
return RMWOp::AND;
383+
case DescriptorReduceKind::OR:
384+
return RMWOp::OR;
385+
case DescriptorReduceKind::XOR:
386+
return RMWOp::XOR;
387+
default:
388+
break;
389+
}
390+
return {};
391+
}
392+
393+
struct RewriteReducePattern : OpConversionPattern<triton::DescriptorReduceOp> {
394+
using OpConversionPattern<triton::DescriptorReduceOp>::OpConversionPattern;
395+
396+
llvm::LogicalResult
397+
matchAndRewrite(triton::DescriptorReduceOp op, OneToNOpAdaptor adaptor,
398+
ConversionPatternRewriter &rewriter) const override {
399+
auto loc = op.getLoc();
400+
auto descTy = op.getDesc().getType();
401+
const auto blockShape = descTy.getBlockType().getShape();
402+
auto desc = unpackDescriptor(descTy, adaptor.getDesc());
403+
auto offsets = castToI64(rewriter, op.getIndices());
404+
auto rmwOp = translateReduceKind(op.getKind(), descTy);
405+
if (!rmwOp) {
406+
std::string msgstring;
407+
llvm::raw_string_ostream msg(msgstring);
408+
msg << "Cannot fallback on descriptor atomic op, unsupported for type "
409+
<< descTy.getBlockType().getElementType();
410+
return op->emitError(msgstring);
411+
}
412+
413+
auto newStore = rewriter.create<triton::AtomicRMWOp>(
414+
loc, descTy.getSignlessBlockType(), *rmwOp,
415+
generatePtr(rewriter, loc, blockShape, desc, offsets), op.getSrc(),
416+
generateMask(rewriter, loc, blockShape, desc, offsets),
417+
MemSemantic::RELEASE, MemSyncScope::GPU);
418+
op.erase();
419+
return success();
420+
}
421+
};
422+
363423
/**
364424
* @brief This implements the pass for converting triton tensor descriptor
365425
* loads/stores into indexed loads/stores.
@@ -428,9 +488,10 @@ class TritonRewriteTensorDescriptorToPointerPass
428488
mlir::scf::populateSCFStructuralTypeConversions(converter, patterns);
429489
triton::populateArithTypeConversions(converter, patterns);
430490

431-
patterns.add<RewriteMakeTensorDesc, RewriteLoadPattern, RewriteStorePattern,
432-
RewriteGatherPattern, RewriteScatterPattern>(converter,
433-
&getContext());
491+
patterns
492+
.add<RewriteMakeTensorDesc, RewriteLoadPattern, RewriteStorePattern,
493+
RewriteGatherPattern, RewriteScatterPattern, RewriteReducePattern>(
494+
converter, &getContext());
434495

435496
ConversionConfig config;
436497
config.buildMaterializations = false;

python/test/unit/cuda/test_tensor_descriptor.py

Lines changed: 1 addition & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -1,124 +1,10 @@
11
import pytest
22
import torch
3-
import numpy as np
43

54
import triton
6-
from triton.compiler.errors import CompilationError
75
import triton.language as tl
8-
from triton._internal_testing import is_interpreter, numpy_random, to_triton, requires_tma, unwrap_tensor, tma_dtypes, to_numpy
6+
from triton._internal_testing import is_interpreter, numpy_random, to_triton, requires_tma, unwrap_tensor, tma_dtypes
97
from triton.tools.tensor_descriptor import TensorDescriptor
10-
from typing import Optional
11-
12-
SUPPORTED_REDUCE_DTYPES = {
13-
"add": {tl.uint32, tl.int32, tl.uint64, tl.float32, tl.float16, tl.bfloat16},
14-
"min": {tl.uint32, tl.int32, tl.uint64, tl.int64, tl.float16, tl.bfloat16},
15-
"max": {tl.uint32, tl.int32, tl.uint64, tl.int64, tl.float16, tl.bfloat16},
16-
"and": {tl.uint32, tl.int32, tl.uint64, tl.int64},
17-
"or": {tl.uint32, tl.int32, tl.uint64, tl.int64},
18-
"xor": {tl.uint32, tl.int32, tl.uint64, tl.int64},
19-
}
20-
21-
22-
def min_op(a, b):
23-
out = np.minimum(to_numpy(a), to_numpy(b))
24-
return unwrap_tensor(to_triton(out, device=a.device))
25-
26-
27-
def max_op(a, b):
28-
out = np.maximum(to_numpy(a), to_numpy(b))
29-
return unwrap_tensor(to_triton(out, device=a.device))
30-
31-
32-
REDUCE_OP = {
33-
"add": lambda a, b: unwrap_tensor(a) + unwrap_tensor(b),
34-
"min": min_op,
35-
"max": max_op,
36-
"and": lambda a, b: torch.bitwise_and(unwrap_tensor(a), unwrap_tensor(b)),
37-
"or": lambda a, b: torch.bitwise_or(unwrap_tensor(a), unwrap_tensor(b)),
38-
"xor": lambda a, b: torch.bitwise_xor(unwrap_tensor(a), unwrap_tensor(b)),
39-
}
40-
41-
42-
@requires_tma
43-
# TODO: interpreter support
44-
# @pytest.mark.interpreter
45-
@pytest.mark.parametrize("kind", ["add", "min", "max", "and", "or", "xor"])
46-
@pytest.mark.parametrize("dtype_str", tma_dtypes)
47-
@pytest.mark.parametrize("num_ctas", [1, 2])
48-
@pytest.mark.parametrize("descriptor", ["host", "device"])
49-
@pytest.mark.parametrize("M_BLOCK,N_BLOCK", [(2, 16), (8, 16), (8, 32), (8, 128), (512, 32), (1, 1024)])
50-
def test_tensor_descriptor_reduce(kind, descriptor, dtype_str, num_ctas, M_BLOCK, N_BLOCK):
51-
52-
@triton.jit(debug=True)
53-
def kernel(out_desc, out_ptr, a_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr, kind: tl.constexpr):
54-
moffset = tl.program_id(0) * M_BLOCK
55-
noffset = tl.program_id(1) * N_BLOCK
56-
57-
midx = moffset + tl.arange(0, M_BLOCK)[:, None]
58-
nidx = noffset + tl.arange(0, N_BLOCK)[None, :]
59-
idx = midx * N + nidx
60-
61-
val = tl.load(a_ptr + idx)
62-
63-
if out_desc is None:
64-
desc = tl.make_tensor_descriptor(
65-
out_ptr,
66-
shape=[M, N],
67-
strides=[N, 1],
68-
block_shape=[M_BLOCK, N_BLOCK],
69-
)
70-
else:
71-
desc = out_desc
72-
73-
assert desc.shape[0] == M
74-
assert desc.shape[1] == N
75-
assert desc.strides[0] == N
76-
assert desc.strides[1] == 1
77-
assert desc.block_shape == [M_BLOCK, N_BLOCK]
78-
if kind == "add":
79-
desc.atomic_add([moffset, noffset], val)
80-
elif kind == "min":
81-
desc.atomic_min([moffset, noffset], val)
82-
elif kind == "max":
83-
desc.atomic_max([moffset, noffset], val)
84-
elif kind == "and":
85-
desc.atomic_and([moffset, noffset], val)
86-
elif kind == "or":
87-
desc.atomic_or([moffset, noffset], val)
88-
else:
89-
tl.static_assert(kind == "xor")
90-
desc.atomic_xor([moffset, noffset], val)
91-
92-
M, N = M_BLOCK * 2, N_BLOCK * 2
93-
rs = np.random.RandomState(seed=17)
94-
inp = to_triton(numpy_random((M, N), dtype_str, rs), device="cuda", dst_type=dtype_str)
95-
out = to_triton(numpy_random((M, N), dtype_str, rs), device="cuda", dst_type=dtype_str)
96-
97-
grid_m = M // M_BLOCK
98-
grid_n = N // N_BLOCK
99-
100-
if descriptor == "host":
101-
out_desc = TensorDescriptor.from_tensor(out, [M_BLOCK, N_BLOCK])
102-
else:
103-
104-
def alloc_fn(size: int, align: int, stream: Optional[int]):
105-
assert size == 128 * (grid_m * grid_n) * num_ctas
106-
assert align == 128
107-
assert stream == 0
108-
return torch.empty(size, dtype=torch.int8, device="cuda")
109-
110-
triton.set_allocator(alloc_fn)
111-
out_desc = None
112-
113-
supported = getattr(tl, dtype_str) in SUPPORTED_REDUCE_DTYPES[kind]
114-
if not supported:
115-
with pytest.raises(CompilationError):
116-
kernel[(grid_m, grid_n)](out_desc, out, inp, M, N, M_BLOCK, N_BLOCK, kind, num_ctas=num_ctas)
117-
return
118-
119-
expect = REDUCE_OP[kind](inp, out)
120-
kernel[(grid_m, grid_n)](out_desc, out, inp, M, N, M_BLOCK, N_BLOCK, kind, num_ctas=num_ctas)
121-
torch.testing.assert_close(expect, unwrap_tensor(out), check_dtype=False)
1228

1239

12410
@requires_tma

python/test/unit/language/test_tensor_descriptor.py

Lines changed: 141 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44

55
import triton
66
import triton.language as tl
7-
from triton._internal_testing import is_interpreter, numpy_random, to_triton, unwrap_tensor, tma_dtypes
7+
from triton._internal_testing import is_interpreter, numpy_random, to_triton, unwrap_tensor, tma_dtypes, to_numpy
88
from triton.tools.mxfp import MXFP4Tensor, MXScaleTensor
99
from typing import Optional
10-
from triton._internal_testing import is_cuda, is_hip
10+
from triton._internal_testing import is_cuda, is_hip, is_hip_cdna3
11+
from triton.tools.tensor_descriptor import TensorDescriptor
12+
from triton import CompilationError
1113

1214

1315
@pytest.mark.interpreter
@@ -1434,3 +1436,140 @@ def alloc_fn(size: int, align: int, steam):
14341436

14351437
ref = torch_scatter_rows(input, idx, y, BLOCK_Y, X, Y)
14361438
torch.testing.assert_close(ref, output, atol=0, rtol=0)
1439+
1440+
1441+
NATIVE_SUPPORTED_REDUCE_DTYPES = {
1442+
"add": {tl.uint32, tl.int32, tl.uint64, tl.float32, tl.float16, tl.bfloat16},
1443+
"min": {tl.uint32, tl.int32, tl.uint64, tl.int64, tl.float16, tl.bfloat16},
1444+
"max": {tl.uint32, tl.int32, tl.uint64, tl.int64, tl.float16, tl.bfloat16},
1445+
"and": {tl.uint32, tl.int32, tl.uint64, tl.int64},
1446+
"or": {tl.uint32, tl.int32, tl.uint64, tl.int64},
1447+
"xor": {tl.uint32, tl.int32, tl.uint64, tl.int64},
1448+
}
1449+
FALLBACK_SUPPORTED_REDUCE_DTYPES = {
1450+
"add": {tl.uint32, tl.int32, tl.uint64, tl.float32, tl.float16, tl.bfloat16},
1451+
"min": {tl.uint32, tl.int32, tl.uint64, tl.int64},
1452+
"max": {tl.uint32, tl.int32, tl.uint64, tl.int64},
1453+
"and": {tl.uint32, tl.int32, tl.uint64, tl.int64},
1454+
"or": {tl.uint32, tl.int32, tl.uint64, tl.int64},
1455+
"xor": {tl.uint32, tl.int32, tl.uint64, tl.int64},
1456+
}
1457+
1458+
1459+
def min_op(a, b):
1460+
out = np.minimum(to_numpy(a), to_numpy(b))
1461+
return unwrap_tensor(to_triton(out, device=a.device))
1462+
1463+
1464+
def max_op(a, b):
1465+
out = np.maximum(to_numpy(a), to_numpy(b))
1466+
return unwrap_tensor(to_triton(out, device=a.device))
1467+
1468+
1469+
REDUCE_OP = {
1470+
"add": lambda a, b: unwrap_tensor(a) + unwrap_tensor(b),
1471+
"min": min_op,
1472+
"max": max_op,
1473+
"and": lambda a, b: torch.bitwise_and(unwrap_tensor(a), unwrap_tensor(b)),
1474+
"or": lambda a, b: torch.bitwise_or(unwrap_tensor(a), unwrap_tensor(b)),
1475+
"xor": lambda a, b: torch.bitwise_xor(unwrap_tensor(a), unwrap_tensor(b)),
1476+
}
1477+
1478+
REDUCE_SKIP_HIP_CDNA3 = [
1479+
("min", "int32", 1, 1024),
1480+
("max", "int32", 1, 1024),
1481+
("add", "bfloat16", 1, 1024),
1482+
]
1483+
1484+
1485+
# TODO: interpreter support
1486+
# @pytest.mark.interpreter
1487+
@pytest.mark.parametrize("kind", ["add", "min", "max", "and", "or", "xor"])
1488+
@pytest.mark.parametrize("dtype_str", tma_dtypes)
1489+
@pytest.mark.parametrize("num_ctas", [1, 2])
1490+
@pytest.mark.parametrize("descriptor", ["host", "device"])
1491+
@pytest.mark.parametrize("M_BLOCK,N_BLOCK", [(2, 16), (8, 16), (8, 32), (8, 128), (512, 32), (1, 1024)])
1492+
def test_tensor_descriptor_reduce(kind, descriptor, dtype_str, num_ctas, M_BLOCK, N_BLOCK):
1493+
is_native = is_cuda() and torch.cuda.get_device_capability()[0] >= 9
1494+
if not is_native:
1495+
if num_ctas != 1:
1496+
pytest.skip("Multi-CTA not supported")
1497+
if descriptor == "host":
1498+
pytest.skip("NYI: Host side tensor descriptor fallback")
1499+
if is_hip_cdna3() and (kind, dtype_str, M_BLOCK, N_BLOCK) in REDUCE_SKIP_HIP_CDNA3:
1500+
pytest.skip("Broken on rocm")
1501+
1502+
@triton.jit(debug=True)
1503+
def kernel(out_desc, out_ptr, a_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr, kind: tl.constexpr):
1504+
moffset = tl.program_id(0) * M_BLOCK
1505+
noffset = tl.program_id(1) * N_BLOCK
1506+
1507+
midx = moffset + tl.arange(0, M_BLOCK)[:, None]
1508+
nidx = noffset + tl.arange(0, N_BLOCK)[None, :]
1509+
idx = midx * N + nidx
1510+
1511+
val = tl.load(a_ptr + idx)
1512+
1513+
if out_desc is None:
1514+
desc = tl.make_tensor_descriptor(
1515+
out_ptr,
1516+
shape=[M, N],
1517+
strides=[N, 1],
1518+
block_shape=[M_BLOCK, N_BLOCK],
1519+
)
1520+
else:
1521+
desc = out_desc
1522+
1523+
assert desc.shape[0] == M
1524+
assert desc.shape[1] == N
1525+
assert desc.strides[0] == N
1526+
assert desc.strides[1] == 1
1527+
assert desc.block_shape == [M_BLOCK, N_BLOCK]
1528+
if kind == "add":
1529+
desc.atomic_add([moffset, noffset], val)
1530+
elif kind == "min":
1531+
desc.atomic_min([moffset, noffset], val)
1532+
elif kind == "max":
1533+
desc.atomic_max([moffset, noffset], val)
1534+
elif kind == "and":
1535+
desc.atomic_and([moffset, noffset], val)
1536+
elif kind == "or":
1537+
desc.atomic_or([moffset, noffset], val)
1538+
else:
1539+
tl.static_assert(kind == "xor")
1540+
desc.atomic_xor([moffset, noffset], val)
1541+
1542+
M, N = M_BLOCK * 2, N_BLOCK * 2
1543+
rs = np.random.RandomState(seed=17)
1544+
inp = to_triton(numpy_random((M, N), dtype_str, rs), device="cuda", dst_type=dtype_str)
1545+
out = to_triton(numpy_random((M, N), dtype_str, rs), device="cuda", dst_type=dtype_str)
1546+
1547+
grid_m = M // M_BLOCK
1548+
grid_n = N // N_BLOCK
1549+
1550+
if descriptor == "host":
1551+
out_desc = TensorDescriptor.from_tensor(out, [M_BLOCK, N_BLOCK])
1552+
else:
1553+
1554+
def alloc_fn(size: int, align: int, stream: Optional[int]):
1555+
assert size == 128 * (grid_m * grid_n) * num_ctas
1556+
assert align == 128
1557+
assert stream == 0
1558+
return torch.empty(size, dtype=torch.int8, device="cuda")
1559+
1560+
triton.set_allocator(alloc_fn)
1561+
out_desc = None
1562+
1563+
dtype = getattr(tl, dtype_str)
1564+
native_supported = dtype in NATIVE_SUPPORTED_REDUCE_DTYPES[kind]
1565+
fallback_supported = dtype in FALLBACK_SUPPORTED_REDUCE_DTYPES[kind]
1566+
supported = native_supported if is_native else fallback_supported
1567+
if not supported:
1568+
exc_type = CompilationError if not native_supported else RuntimeError
1569+
with pytest.raises(exc_type):
1570+
kernel[(grid_m, grid_n)](out_desc, out, inp, M, N, M_BLOCK, N_BLOCK, kind, num_ctas=num_ctas)
1571+
return
1572+
1573+
expect = REDUCE_OP[kind](inp, out)
1574+
kernel[(grid_m, grid_n)](out_desc, out, inp, M, N, M_BLOCK, N_BLOCK, kind, num_ctas=num_ctas)
1575+
torch.testing.assert_close(expect, unwrap_tensor(out), check_dtype=False)

0 commit comments

Comments
 (0)