Skip to content

Commit e7a375a

Browse files
authored
Implements unsplat support for single element reduction. (#305)
This enables support for extracting an element from a single value tensor using `.item()`
1 parent 2483659 commit e7a375a

File tree

3 files changed

+78
-0
lines changed

3 files changed

+78
-0
lines changed

include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1393,6 +1393,24 @@ struct ReduceConverter : public OpConversionPattern<triton::ReduceOp> {
13931393
return success();
13941394
}
13951395

1396+
LogicalResult
1397+
convertToTensorExtract(triton::ReduceOp op,
1398+
typename triton::ReduceOp::Adaptor adaptor,
1399+
ConversionPatternRewriter &rewriter) const {
1400+
assert(llvm::hasSingleElement(op.getSrcs()));
1401+
1402+
auto returnOp = cast<triton::ReduceReturnOp>(*op.getOps().begin());
1403+
assert(llvm::hasSingleElement(returnOp.getResult()));
1404+
assert(cast<BlockArgument>(returnOp.getResult().front()).getArgNumber() ==
1405+
0);
1406+
1407+
auto source = op.getSrcs().front();
1408+
auto zeroIdx =
1409+
rewriter.createOrFold<arith::ConstantIndexOp>(op.getLoc(), 0);
1410+
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, source, zeroIdx);
1411+
return success();
1412+
}
1413+
13961414
public:
13971415
LogicalResult
13981416
matchAndRewrite(triton::ReduceOp op,
@@ -1409,6 +1427,14 @@ struct ReduceConverter : public OpConversionPattern<triton::ReduceOp> {
14091427
"axis is within "
14101428
"operand's rank");
14111429

1430+
// Unsplat is implemented as a single element, rank 1 reduction where
1431+
// single element is yielded immediately. This can be simplified into
1432+
// a single element extract.
1433+
if (llvm::hasSingleElement(op.getOps()) && sourceType.getRank() == 1 &&
1434+
sourceType.getShape()[0] == 1) {
1435+
return convertToTensorExtract(op, adaptor, rewriter);
1436+
}
1437+
14121438
return convertToLinalgReduce(op, adaptor, rewriter);
14131439
}
14141440
};

python/examples/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def with_allocator():
7373
"test_addptr",
7474
"test_transpose",
7575
"test_trans_4d",
76+
"test_unsplat",
7677
"test_arange",
7778
}
7879

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// RUN: triton-shared-opt --triton-to-linalg-experimental %s | FileCheck %s
2+
module {
3+
tt.func public @unsplat_kernel(%arg0: !tt.ptr<i32> {maia.rank = 1 : i32, tt.divisibility = 16 : i32}) attributes {noinline = false} {
4+
%cst = arith.constant dense<42> : tensor<1xi32>
5+
%c42_i32 = arith.constant 42 : i32
6+
%0 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<1x!tt.ptr<i32>>
7+
%1 = tt.load %0 : tensor<1x!tt.ptr<i32>>
8+
%2 = arith.cmpi sgt, %1, %cst : tensor<1xi32>
9+
%3 = "tt.reduce"(%2) <{axis = 0 : i32}> ({
10+
^bb0(%arg1: i1, %arg2: i1):
11+
tt.reduce.return %arg1 : i1
12+
}) : (tensor<1xi1>) -> i1
13+
scf.if %3 {
14+
tt.store %arg0, %c42_i32 : !tt.ptr<i32>
15+
}
16+
tt.return
17+
}
18+
}
19+
20+
// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)>
21+
// CHECK-LABEL: func.func @unsplat_kernel
22+
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xi32> {maia.rank = 1 : i32, tt.divisibility = 16 : i32}, [[PARAM_1_:%.+]]: i32, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32) {
23+
// CHECK-DAG: [[CST_42_:%.+]] = arith.constant 42 : i32
24+
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32
25+
// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : index
26+
// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<1xi32>
27+
// CHECK-NOT: separator of consecutive DAGs
28+
// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_42_]] : i32) outs([[VAR_0_]] : tensor<1xi32>) -> tensor<1xi32>
29+
// CHECK-DAG: [[VAR_2_:%.+]] = linalg.fill ins([[CST_0_]] : i32) outs([[VAR_0_]] : tensor<1xi32>) -> tensor<1xi32>
30+
// CHECK-DAG: [[VAR_cast_:%.+]] = memref.cast [[PARAM_0_]] : memref<*xi32> to memref<?xi32>
31+
// CHECK-NOT: separator of consecutive DAGs
32+
// CHECK-DAG: [[VAR_3_:%.+]] = bufferization.to_tensor [[VAR_cast_]] restrict : memref<?xi32> to tensor<?xi32>
33+
// CHECK-DAG: [[VAR_4_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_2_]] : tensor<1xi32>) outs([[VAR_0_]] : tensor<1xi32>) {
34+
// CHECK: ^bb0([[IN_0_:%.+]]: i32, [[IN_1_:%.+]]: i32):
35+
// CHECK: [[VAR_7_:%.+]] = arith.index_cast [[IN_0_]] : i32 to index
36+
// CHECK: [[VAR_extracted_0_:%.+]] = tensor.extract [[VAR_3_]]{{.}}[[VAR_7_]]{{.}} : tensor<?xi32>
37+
// CHECK: linalg.yield [[VAR_extracted_0_]] : i32
38+
// CHECK: } -> tensor<1xi32>
39+
// CHECK: [[VAR_5_:%.+]] = tensor.empty() : tensor<1xi1>
40+
// CHECK: [[VAR_6_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_4_]], [[VAR_1_]] : tensor<1xi32>, tensor<1xi32>) outs([[VAR_5_]] : tensor<1xi1>) {
41+
// CHECK: ^bb0([[IN_2_:%.+]]: i32, [[IN_3_:%.+]]: i32, [[IN_4_:%.+]]: i1):
42+
// CHECK: [[VAR_7_1_:%.+]] = arith.cmpi sgt, [[IN_2_]], [[IN_3_]] : i32
43+
// CHECK: linalg.yield [[VAR_7_1_]] : i1
44+
// CHECK: } -> tensor<1xi1>
45+
// CHECK: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_6_]]{{.}}[[CST_0_1_]]{{.}} : tensor<1xi1>
46+
// CHECK: scf.if [[VAR_extracted_]] {
47+
// CHECK: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[CST_0_1_]]{{.}}, sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1], offset: ?>>
48+
// CHECK: affine.store [[CST_42_]], [[VAR_reinterpret_cast_]][0] : memref<1xi32, strided<[1], offset: ?>>
49+
// CHECK: }
50+
// CHECK: return
51+
// CHECK: }

0 commit comments

Comments
 (0)