Skip to content

Commit cb466ef

Browse files
authored
MakeRangeOp needs to consider the start of the range (#304)
The conversion for MakeRangeOp needs to consider the start of the range. It currently assumes all ranges are constructed from 0 to the size. --------- Authored-by: Daniel Donenfeld <[email protected]>
1 parent a0fa823 commit cb466ef

File tree

6 files changed

+25
-6
lines changed

6 files changed

+25
-6
lines changed

include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,12 @@ struct MakeRangeConverter : public OpConversionPattern<triton::MakeRangeOp> {
817817
Value index = nestedBuilder.create<linalg::IndexOp>(loc, 0);
818818
Value res = nestedBuilder.create<arith::IndexCastOp>(
819819
loc, type.getElementType(), index);
820+
if (op.getStart()) {
821+
auto start = rewriter.create<mlir::arith::ConstantIntOp>(
822+
op.getLoc(), op.getStart(),
823+
type.getElementType().getIntOrFloatBitWidth());
824+
res = nestedBuilder.create<arith::AddIOp>(loc, res, start);
825+
}
820826
nestedBuilder.create<linalg::YieldOp>(loc, res);
821827
});
822828

python/examples/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def with_allocator():
7373
"test_addptr",
7474
"test_transpose",
7575
"test_trans_4d",
76+
"test_arange",
7677
}
7778

7879

test/Conversion/StructuredToMemref/use_end_chain.mlir

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ module {
4141
// CHECK-LABEL: func.func @kernel
4242
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xbf16>, [[PARAM_1_:%.+]]: memref<*xbf16>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) {
4343
// CHECK-DAG: [[CST_6_:%.+]] = arith.constant 6 : i32
44+
// CHECK-DAG: [[CST_512_:%.+]] = arith.constant 512 : i32
45+
// CHECK-DAG: [[CST_1024_:%.+]] = arith.constant 1024 : i32
4446
// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<256x128xi32>
4547
// CHECK-NOT: separator of consecutive DAGs
4648
// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_6_]] : i32) outs([[VAR_0_]] : tensor<256x128xi32>) -> tensor<256x128xi32>
@@ -49,7 +51,8 @@ module {
4951
// CHECK: ^bb0([[IN_0_:%.+]]: i32):
5052
// CHECK: [[VAR_13_:%.+]] = linalg.index 0 : index
5153
// CHECK: [[VAR_14_:%.+]] = arith.index_cast [[VAR_13_]] : index to i32
52-
// CHECK: linalg.yield [[VAR_14_]] : i32
54+
// CHECK: [[VAL_24:%.+]] = arith.addi [[VAR_14_]], [[CST_512_]] : i32
55+
// CHECK: linalg.yield [[VAL_24]] : i32
5356
// CHECK: } -> tensor<256xi32>
5457
// CHECK: [[VAR_expanded_:%.+]] = tensor.expand_shape [[VAR_3_]] {{.}}[0, 1]{{.}} output_shape [256, 1] : tensor<256xi32> into tensor<256x1xi32>
5558
// CHECK: [[VAR_4_:%.+]] = linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "parallel"]} ins([[VAR_expanded_]] : tensor<256x1xi32>) outs([[VAR_0_]] : tensor<256x128xi32>) attrs = {broadcastDims = array<i64: 1>} {
@@ -61,7 +64,8 @@ module {
6164
// CHECK: ^bb0([[IN_3_:%.+]]: i32):
6265
// CHECK: [[VAR_13_1_:%.+]] = linalg.index 0 : index
6366
// CHECK: [[VAR_14_1_:%.+]] = arith.index_cast [[VAR_13_1_]] : index to i32
64-
// CHECK: linalg.yield [[VAR_14_1_]] : i32
67+
// CHECK: [[VAL_25:%.+]] = arith.addi [[VAR_14_1_]], [[CST_1024_]] : i32
68+
// CHECK: linalg.yield [[VAL_25]] : i32
6569
// CHECK: } -> tensor<128xi32>
6670
// CHECK: [[VAR_expanded_0_:%.+]] = tensor.expand_shape [[VAR_6_]] {{.}}[0, 1]{{.}} output_shape [1, 128] : tensor<128xi32> into tensor<1x128xi32>
6771
// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map3, #map2], iterator_types = ["parallel", "parallel"]} ins([[VAR_expanded_0_]] : tensor<1x128xi32>) outs([[VAR_0_]] : tensor<256x128xi32>) attrs = {broadcastDims = array<i64: 0>} {

test/Conversion/StructuredToMemref/use_mid_chain.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,14 @@ module {
4141
// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<(d0, d1) -> (d0, d1)>
4242
// CHECK-LABEL: func.func @kernel
4343
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xbf16>, [[PARAM_1_:%.+]]: memref<*xbf16>, [[PARAM_2_:%.+]]: memref<*xi32>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) {
44+
// CHECK-DAG: [[VAL_25:%.+]] = arith.constant 512 : i32
4445
// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<256xi32>
4546
// CHECK: [[VAR_1_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_0_]] : tensor<256xi32>) {
4647
// CHECK: ^bb0([[IN_0_:%.+]]: i32):
4748
// CHECK: [[VAR_5_:%.+]] = linalg.index 0 : index
4849
// CHECK: [[VAR_6_:%.+]] = arith.index_cast [[VAR_5_]] : index to i32
49-
// CHECK: linalg.yield [[VAR_6_]] : i32
50+
// CHECK: [[VAL_24:%.+]] = arith.addi [[VAR_6_]], [[VAL_25]] : i32
51+
// CHECK: linalg.yield [[VAL_24]] : i32
5052
// CHECK: } -> tensor<256xi32>
5153
// CHECK-DAG: [[VAR_expanded_:%.+]] = tensor.expand_shape [[VAR_1_]] {{.}}[0, 1]{{.}} output_shape [256, 1] : tensor<256xi32> into tensor<256x1xi32>
5254
// CHECK-DAG: [[VAR_2_:%.+]] = tensor.empty() : tensor<256x128xi32>

test/Conversion/TritonToLinalg/use_end_chain.mlir

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,17 @@ module {
3737
// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) {
3838
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 6 : index
3939
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 6 : i32
40+
// CHECK-DAG: %[[CST_512_:.*]] = arith.constant 512 : i32
41+
// CHECK-DAG: %[[CST_1024_:.*]] = arith.constant 1024 : i32
4042
// CHECK: %[[VAL_30:.*]] = tensor.empty() : tensor<256x128xi32>
4143
// CHECK: %[[VAL_31:.*]] = linalg.fill ins(%[[VAL_7]] : i32) outs(%[[VAL_30]] : tensor<256x128xi32>) -> tensor<256x128xi32>
4244
// CHECK: %[[VAL_8:.*]] = tensor.empty() : tensor<256xi32>
4345
// CHECK: %[[VAL_9:.*]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%[[VAL_8]] : tensor<256xi32>) {
4446
// CHECK: ^bb0(%[[VAL_10:.*]]: i32):
4547
// CHECK: %[[VAL_11:.*]] = linalg.index 0 : index
4648
// CHECK: %[[VAL_12:.*]] = arith.index_cast %[[VAL_11]] : index to i32
47-
// CHECK: linalg.yield %[[VAL_12]] : i32
49+
// CHECK: %[[VAL_55:.*]] = arith.addi %[[VAL_12]], %[[CST_512_]] : i32
50+
// CHECK: linalg.yield %[[VAL_55]] : i32
4851
// CHECK: } -> tensor<256xi32>
4952
// CHECK: %[[VAL_13:.*]] = tensor.expand_shape %[[VAL_14:.*]] {{\[\[}}0, 1]] output_shape [256, 1] : tensor<256xi32> into tensor<256x1xi32>
5053
// CHECK: %[[VAL_15:.*]] = tensor.empty() : tensor<256x128xi32>
@@ -57,7 +60,8 @@ module {
5760
// CHECK: ^bb0(%[[VAL_21:.*]]: i32):
5861
// CHECK: %[[VAL_22:.*]] = linalg.index 0 : index
5962
// CHECK: %[[VAL_23:.*]] = arith.index_cast %[[VAL_22]] : index to i32
60-
// CHECK: linalg.yield %[[VAL_23]] : i32
63+
// CHECK: %[[VAL_56:.*]] = arith.addi %[[VAL_23]], %[[CST_1024_]] : i32
64+
// CHECK: linalg.yield %[[VAL_56]] : i32
6165
// CHECK: } -> tensor<128xi32>
6266
// CHECK: %[[VAL_24:.*]] = tensor.expand_shape %[[VAL_25:.*]] {{\[\[}}0, 1]] output_shape [1, 128] : tensor<128xi32> into tensor<1x128xi32>
6367
// CHECK: %[[VAL_26:.*]] = tensor.empty() : tensor<256x128xi32>

test/Conversion/TritonToLinalg/use_mid_chain.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,14 @@ module {
3838
// CHECK-LABEL: func.func @kernel(
3939
// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: memref<*xi32>, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32) {
4040
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 6 : index
41+
// CHECK-DAG: %[[VAL_25:.*]] = arith.constant 512 : i32
4142
// CHECK: %[[VAL_8:.*]] = tensor.empty() : tensor<256xi32>
4243
// CHECK: %[[VAL_9:.*]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%[[VAL_8]] : tensor<256xi32>) {
4344
// CHECK: ^bb0(%[[VAL_10:.*]]: i32):
4445
// CHECK: %[[VAL_11:.*]] = linalg.index 0 : index
4546
// CHECK: %[[VAL_12:.*]] = arith.index_cast %[[VAL_11]] : index to i32
46-
// CHECK: linalg.yield %[[VAL_12]] : i32
47+
// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_12]], %[[VAL_25]] : i32
48+
// CHECK: linalg.yield %[[VAL_24]] : i32
4749
// CHECK: } -> tensor<256xi32>
4850
// CHECK: %[[VAL_13:.*]] = tensor.expand_shape %[[VAL_14:.*]] {{\[\[}}0, 1]] output_shape [256, 1] : tensor<256xi32> into tensor<256x1xi32>
4951
// CHECK: %[[VAL_15:.*]] = tensor.empty() : tensor<256x128xi32>

0 commit comments

Comments
 (0)