Skip to content

Commit e15146b

Browse files
authored
update transpose injection when convert tt.reduce to linalg.reduce (#312)
when convert tt.reduce to linalg.reduce, we will insert transpose so that the axis becomes 0. if reduce axis==0, the other dimensions can be colllapsed. this will make the reduction easier and more efficient to vectorize. The pseudocode is as follows: before ``` ... = tt.reduce axis=[x] ``` after, if x!=0 ``` %v = linalg.transpose permutation [x, ...] ... = linalg.reduce %v dimensions=[0] ``` test case is `python/examples/test_reduce.py::test_reduce_max`
1 parent 859ede9 commit e15146b

File tree

6 files changed

+90
-37
lines changed

6 files changed

+90
-37
lines changed

.clang-format

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
triton/.clang-format

include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
#include "triton-shared/Analysis/MaskAnalysis.h"
1212
#include "triton-shared/Analysis/OpFoldResultUtils.h"
1313
#include "triton-shared/Analysis/PtrAnalysis.h"
14+
#include "triton-shared/Conversion/TritonArithToLinalg/ConversionTools.h"
1415
#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h"
1516
#include "triton-shared/Utils/Utils.h"
16-
#include "triton-shared/Conversion/TritonArithToLinalg/ConversionTools.h"
1717

1818
#include "triton/Dialect/Triton/IR/Dialect.h"
1919

@@ -1284,11 +1284,22 @@ struct ReduceConverter : public OpConversionPattern<triton::ReduceOp> {
12841284

12851285
auto rop = reductionOps.front();
12861286
auto axis = op.getAxis();
1287-
auto isVectorReduce = sourceType.getRank() == 1;
1288-
1289-
if (axis == sourceType.getRank() - 1 && !isVectorReduce) {
1290-
source = getTransposedValue(source, op.getLoc(), rewriter);
1291-
axis = sourceType.getRank() - 2;
1287+
auto rank = sourceType.getRank();
1288+
auto isVectorReduce = (rank == 1);
1289+
1290+
// if it is not a vector reduce, we can transpose the source
1291+
// so that the reduction axis is the first dimension.
1292+
if (!isVectorReduce && axis != 0) {
1293+
SmallVector<int32_t> order;
1294+
order.reserve(rank);
1295+
order.push_back(axis);
1296+
for (int i = 0; i < rank; ++i) {
1297+
if (i != axis) {
1298+
order.push_back(i);
1299+
}
1300+
}
1301+
source = getTransposedValue(source, op.getLoc(), rewriter, order);
1302+
axis = 0;
12921303
}
12931304

12941305
bool convertToF32Precision = requiresF32Conversion(resType, rop);
@@ -1334,7 +1345,7 @@ struct ReduceConverter : public OpConversionPattern<triton::ReduceOp> {
13341345
})
13351346
.getResult(0);
13361347

1337-
if (sourceType.getRank() == 1) {
1348+
if (isVectorReduce) {
13381349
finalResult =
13391350
rewriter.create<tensor::ExtractOp>(loc, constantType, finalResult);
13401351
}

python/examples/test_reduce.py

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
2-
2+
import pytest
3+
import math
34
import triton
45
from triton.backends.compiler import GPUTarget
56
import triton.language as tl
@@ -35,27 +36,61 @@ def test(device):
3536
x = torch.rand([n_cols, n_rows], device=device, dtype=torch.float32)
3637
output = torch.empty([n_cols], device=device, dtype=x.dtype)
3738
BLOCK_SIZE = n_rows
38-
grid = lambda meta: (n_cols,)
39+
grid = lambda meta: (n_cols, )
3940

4041
reduce_kernel_2d[grid](x, output, x.stride(0), n_rows, BLOCK_SIZE=BLOCK_SIZE)
4142
ans = torch.sum(x, dim=1)
4243
torch.testing.assert_close(output, ans, rtol=0.001, atol=1e-5)
4344

4445
# TODO: need to check some conditions otherwise the code below does not make any difference for the test
4546
src = triton.compiler.ASTSource(
46-
fn=reduce_kernel_2d,
47-
signature={"x_ptr": "*fp32",
48-
"output_ptr": "*fp32",
49-
"stride": "i32",
50-
"n_elements": "i32",
51-
"BLOCK_SIZE": "constexpr"},
52-
constexprs={"BLOCK_SIZE": 32}
53-
)
54-
ret = triton.compile(
55-
src,
56-
target=GPUTarget(device, 0, 0)
57-
)
47+
fn=reduce_kernel_2d, signature={
48+
"x_ptr": "*fp32", "output_ptr": "*fp32", "stride": "i32", "n_elements": "i32", "BLOCK_SIZE": "constexpr"
49+
}, constexprs={"BLOCK_SIZE": 32})
50+
ret = triton.compile(src, target=GPUTarget(device, 0, 0))
5851
print(ret.asm["ttir"])
5952
print(ret.asm["ttsharedir"])
6053
print(ret.asm["llir"])
6154
print(ret.asm["obj"])
55+
56+
57+
@pytest.mark.interpreter
58+
@pytest.mark.parametrize("dtype_str", ["int32", "float32"])
59+
@pytest.mark.parametrize("shape", [(128, 2, 4), (64, 2, 4), (32, 2, 4), (2, 4, 32), (2, 4, 2)])
60+
@pytest.mark.parametrize("axis", [0, 1, 2])
61+
def test_reduce_max(dtype_str, shape, axis, device):
62+
63+
@triton.jit
64+
def kernel(
65+
In,
66+
Out,
67+
in_shape1: tl.constexpr,
68+
in_shape2: tl.constexpr,
69+
in_shape3: tl.constexpr,
70+
ou_shape1: tl.constexpr,
71+
ou_shape2: tl.constexpr,
72+
axis: tl.constexpr,
73+
):
74+
in_desc = tl.make_tensor_descriptor(
75+
base=In,
76+
shape=[in_shape1 * in_shape2 * in_shape3],
77+
strides=[1],
78+
block_shape=[in_shape1 * in_shape2 * in_shape3],
79+
)
80+
out_desc = tl.make_tensor_descriptor(
81+
base=Out,
82+
shape=[ou_shape1 * ou_shape2],
83+
strides=[1],
84+
block_shape=[ou_shape1 * ou_shape2],
85+
)
86+
val = in_desc.load([0]).reshape(in_shape1, in_shape2, in_shape3)
87+
output = tl.max(val, axis=axis)
88+
out_desc.store([0], output.reshape(out_desc.block_shape))
89+
90+
input = torch.arange(math.prod(shape), dtype=getattr(torch, dtype_str),
91+
device="cpu").reshape(shape).to(device=device)
92+
expected, indices = torch.max(input, dim=axis)
93+
actual = torch.zeros(expected.shape, dtype=getattr(torch, dtype_str), device=device)
94+
kernel[(1, )](input, actual, *shape, *expected.shape, axis=axis)
95+
96+
assert torch.equal(expected, actual)

test/Conversion/StructuredToMemref/reducesum_middle_dim.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,11 @@ module {
5353
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<32x256x16xbf16>
5454
// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<32x256x16xbf16, strided<[256, 1, 1]>> to memref<32x256x16xbf16>
5555
// CHECK-DAG: [[VAR_0_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32x256x16xbf16>
56+
// CHECK: [[VAL_7:%.+]] = tensor.empty() : tensor<256x32x16xbf16>
57+
// CHECK: [[VAL_8:%.+]] = linalg.transpose ins([[VAR_0_]] : tensor<32x256x16xbf16>) outs([[VAL_7]] : tensor<256x32x16xbf16>) permutation = [1, 0, 2]
5658
// CHECK-DAG: [[VAR_1_:%.+]] = tensor.empty() : tensor<32x16xbf16>
5759
// CHECK: [[VAR_2_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : bf16) outs([[VAR_1_]] : tensor<32x16xbf16>) -> tensor<32x16xbf16>
58-
// CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[VAR_0_]] : tensor<32x256x16xbf16>) outs([[VAR_2_]] : tensor<32x16xbf16>) dimensions = [1]
60+
// CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[VAL_8]] : tensor<256x32x16xbf16>) outs([[VAR_2_]] : tensor<32x16xbf16>) dimensions = [0]
5961
// CHECK: ([[in_:.+]]: bf16, [[init_:.+]]: bf16) {
6062
// CHECK: [[VAR_3_:%.+]] = arith.addf [[in_]], [[init_]] : bf16
6163
// CHECK: linalg.yield [[VAR_3_]] : bf16

test/Conversion/TritonArithToLinalg/reducesum_middle_dim.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,11 @@ module {
132132
// CHECK: linalg.yield [[VAR_29_6_]] : !tt.ptr<bf16>
133133
// CHECK: } -> tensor<32x256x16x!tt.ptr<bf16>>
134134
// CHECK-DAG: [[LOAD_VAR_25_MEM_:%.+]] = tt.load [[VAR_25_]] : tensor<32x256x16x!tt.ptr<bf16>>
135+
// CHECK: [[VAL_72:%.+]] = tensor.empty() : tensor<256x32x16xbf16>
136+
// CHECK: [[VAL_73:%.+]] = linalg.transpose ins([[LOAD_VAR_25_MEM_]] : tensor<32x256x16xbf16>) outs([[VAL_72]] : tensor<256x32x16xbf16>) permutation = [1, 0, 2]
135137
// CHECK-DAG: [[VAR_27_:%.+]] = tensor.empty() : tensor<32x16xbf16>
136138
// CHECK: [[VAR_28_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : bf16) outs([[VAR_27_]] : tensor<32x16xbf16>) -> tensor<32x16xbf16>
137-
// CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[LOAD_VAR_25_MEM_]] : tensor<32x256x16xbf16>) outs([[VAR_28_]] : tensor<32x16xbf16>) dimensions = [1]
139+
// CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[VAL_73]] : tensor<256x32x16xbf16>) outs([[VAR_28_]] : tensor<32x16xbf16>) dimensions = [0]
138140
// CHECK: ([[in_]]: bf16, [[in_]]it: bf16) {
139141
// CHECK: [[VAR_29_7_:%.+]] = arith.addf [[in_]], [[in_]]it : bf16
140142
// CHECK: linalg.yield [[VAR_29_7_]] : bf16

test/Conversion/TritonToLinalg/reducesum_middle_dim.mlir

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,22 @@ module {
3939
}
4040
}
4141
// CHECK-LABEL: func.func @kernel(
42-
// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: memref<32x16xbf16>, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32) {
43-
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 256 : index
44-
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 0.000000e+00 : bf16
45-
// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [32, 256, 16], strides: {{\[}}%[[VAL_6]], 1, 1] : memref<*xbf16> to memref<32x256x16xbf16, strided<[?, 1, 1]>>
46-
// CHECK: %[[VAL_9:.*]] = memref.alloc() : memref<32x256x16xbf16>
47-
// CHECK: memref.copy %[[VAL_8]], %[[VAL_9]] : memref<32x256x16xbf16, strided<[?, 1, 1]>> to memref<32x256x16xbf16>
48-
// CHECK: %[[VAL_10:.*]] = bufferization.to_tensor %[[VAL_9]] restrict writable : memref<32x256x16xbf16>
49-
// CHECK: %[[VAL_11:.*]] = tensor.empty() : tensor<32x16xbf16>
50-
// CHECK: %[[VAL_12:.*]] = linalg.fill ins(%[[VAL_7]] : bf16) outs(%[[VAL_11]] : tensor<32x16xbf16>) -> tensor<32x16xbf16>
51-
// CHECK: %[[VAL_13:.*]] = linalg.reduce ins(%[[VAL_10]] : tensor<32x256x16xbf16>) outs(%[[VAL_12]] : tensor<32x16xbf16>) dimensions = [1]
52-
// CHECK: (%[[VAL_14:.*]]: bf16, %[[VAL_15:.*]]: bf16) {
53-
// CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_14]], %[[VAL_15]] : bf16
54-
// CHECK: linalg.yield %[[VAL_16]] : bf16
42+
// CHECK-SAME: %[[ARG0:.*]]: memref<*xbf16>, %[[ARG1:.*]]: memref<*xbf16>, %[[ARG2:.*]]: memref<32x16xbf16>, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32, %[[ARG6:.*]]: i32, %[[ARG7:.*]]: i32, %[[ARG8:.*]]: i32) {
43+
// CHECK-DAG: %[[VAL_0:.*]] = arith.constant 0.000000e+00 : bf16
44+
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 256 : index
45+
// CHECK: %[[VAL_2:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [32, 256, 16], strides: {{\[}}%[[VAL_1]], 1, 1] : memref<*xbf16> to memref<32x256x16xbf16, strided<[?, 1, 1]>>
46+
// CHECK: %[[VAL_3:.*]] = memref.alloc() : memref<32x256x16xbf16>
47+
// CHECK: memref.copy %[[VAL_2]], %[[VAL_3]] : memref<32x256x16xbf16, strided<[?, 1, 1]>> to memref<32x256x16xbf16>
48+
// CHECK: %[[VAL_4:.*]] = bufferization.to_tensor %[[VAL_3]] restrict writable : memref<32x256x16xbf16> to tensor<32x256x16xbf16>
49+
// CHECK: %[[VAL_5:.*]] = tensor.empty() : tensor<256x32x16xbf16>
50+
// CHECK: %[[VAL_6:.*]] = linalg.transpose ins(%[[VAL_4]] : tensor<32x256x16xbf16>) outs(%[[VAL_5]] : tensor<256x32x16xbf16>) permutation = [1, 0, 2]
51+
// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<32x16xbf16>
52+
// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%[[VAL_0]] : bf16) outs(%[[VAL_7]] : tensor<32x16xbf16>) -> tensor<32x16xbf16>
53+
// CHECK: %[[VAL_9:.*]] = linalg.reduce ins(%[[VAL_6]] : tensor<256x32x16xbf16>) outs(%[[VAL_8]] : tensor<32x16xbf16>) dimensions = [0]
54+
// CHECK: (%[[VAL_10:.*]]: bf16, %[[VAL_11:.*]]: bf16) {
55+
// CHECK: %[[VAL_12:.*]] = arith.addf %[[VAL_10]], %[[VAL_11]] : bf16
56+
// CHECK: linalg.yield %[[VAL_12]] : bf16
5557
// CHECK: }
56-
// CHECK: bufferization.materialize_in_destination %[[VAL_13]] in writable %[[VAL_2]]
58+
// CHECK: bufferization.materialize_in_destination %[[VAL_9]] in writable %[[ARG2]] : (tensor<32x16xbf16>, memref<32x16xbf16>) -> ()
5759
// CHECK: return
5860
// CHECK: }

0 commit comments

Comments
 (0)