Skip to content

Commit ff742b4

Browse files
[tensor-descriptor]: Extend support when tensor descriptor created in control flow (#4152)
Enhance layout propagation and tensor descriptor lowering to support cases where descriptors or pointers are created within control flow constructs. --------- Signed-off-by: Tiotto, Ettore <[email protected]> Co-authored-by: Whitney Tsang <[email protected]>
1 parent 664df8a commit ff742b4

File tree

4 files changed

+287
-130
lines changed

4 files changed

+287
-130
lines changed

test/Triton/Intel/TensorDescToBlockPointer/basic.mlir

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,22 @@ module {
66
%c64_i32 = arith.constant 64 : i32
77
%c8_i32 = arith.constant 8 : i32
88
%0 = arith.extsi %arg2 : i32 to i64
9-
%desc = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : <f32>, <tensor<16x128xf32>>
10-
%load = tt.descriptor_load %desc[%c8_i32, %c64_i32] : !tt.tensordesc<tensor<16x128xf32>> -> tensor<16x128xf32>
9+
%desc1 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : <f32>, <tensor<16x128xf32>>
10+
%load1 = tt.descriptor_load %desc1[%c8_i32, %c64_i32] : !tt.tensordesc<tensor<16x128xf32>> -> tensor<16x128xf32>
1111
tt.return
1212
}
1313
// CHECK: tt.func public @test_load([[PARAM_0:%.+]]: !tt.ptr<f32>, [[PARAM_1:%.+]]: i32, [[PARAM_2:%.+]]: i32) {
1414
// CHECK-NOT: tt.make_tensor_descriptor
1515
// CHECK-NOT: tt.descriptor_load
16+
// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32
1617
// CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64
1718
// CHECK-DAG: [[CST_64_i32:%.+]] = arith.constant 64 : i32
1819
// CHECK-DAG: [[CST_8_i32:%.+]] = arith.constant 8 : i32
1920
// CHECK-DAG: [[EXTSI_PARAM_1:%.+]] = arith.extsi [[PARAM_1]] : i32 to i64
2021
// CHECK-DAG: [[EXTSI_PARAM_2:%.+]] = arith.extsi [[PARAM_2]] : i32 to i64
21-
// CHECK: [[TENSOR_PTR:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2]]], {{\[}}[[EXTSI_PARAM_2]], [[CST_1_i64]]], {{\[}}[[CST_8_i32]], [[CST_64_i32]]] {{.*}} : <tensor<16x128xf32>>
22-
// CHECK: [[LOAD:%.+]] = tt.load [[TENSOR_PTR]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<16x128xf32>>
22+
// CHECK: [[TENSOR_PTR:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2]]], {{\[}}[[EXTSI_PARAM_2]], [[CST_1_i64]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : <tensor<16x128xf32>>
23+
// CHECK: [[TENSOR_PTR1:%.+]] = tt.advance [[TENSOR_PTR]], {{\[}}[[CST_8_i32]], [[CST_64_i32]]] : <tensor<16x128xf32>>
24+
// CHECK: [[LOAD1:%.+]] = tt.load [[TENSOR_PTR1]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<16x128xf32>>
2325
// CHECK: tt.return
2426
// CHECK: }
2527

@@ -29,21 +31,23 @@ module {
2931
%c8_i32 = arith.constant 8 : i32
3032
%cst = arith.constant dense<1.000000e+00> : tensor<16x128xf32>
3133
%0 = arith.extsi %arg2 : i32 to i64
32-
%desc = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : <f32>, <tensor<16x128xf32>>
33-
tt.descriptor_store %desc[%c8_i32, %c64_i32], %cst : !tt.tensordesc<tensor<16x128xf32>>, tensor<16x128xf32>
34+
%desc1 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : <f32>, <tensor<16x128xf32>>
35+
tt.descriptor_store %desc1[%c8_i32, %c64_i32], %cst : !tt.tensordesc<tensor<16x128xf32>>, tensor<16x128xf32>
3436
tt.return
3537
}
3638
// CHECK: tt.func public @test_store([[PARAM_0:%.+]]: !tt.ptr<f32>, [[PARAM_1:%.+]]: i32, [[PARAM_2:%.+]]: i32) {
3739
// CHECK-NOT: tt.make_tensor_descriptor
3840
// CHECK-NOT: tt.descriptor_store
41+
// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32
3942
// CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64
4043
// CHECK-DAG: [[CST_64_i32:%.+]] = arith.constant 64 : i32
4144
// CHECK-DAG: [[CST_8_i32:%.+]] = arith.constant 8 : i32
4245
// CHECK-DAG: [[CST:%.+]] = arith.constant dense<1.000000e+00> : tensor<16x128xf32>
4346
// CHECK-DAG: [[EXTSI_PARAM_1:%.+]] = arith.extsi [[PARAM_1]] : i32 to i64
4447
// CHECK-DAG: [[EXTSI_PARAM_2:%.+]] = arith.extsi [[PARAM_2]] : i32 to i64
45-
// CHECK: [[TENSOR_PTR:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2]]], {{\[}}[[EXTSI_PARAM_2]], [[CST_1_i64]]], {{\[}}[[CST_8_i32]], [[CST_64_i32]]] {{.*}} : <tensor<16x128xf32>>
46-
// CHECK: tt.store [[TENSOR_PTR]], [[CST]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<16x128xf32>>
48+
// CHECK: [[TENSOR_PTR:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2]]], {{\[}}[[EXTSI_PARAM_2]], [[CST_1_i64]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : <tensor<16x128xf32>>
49+
// CHECK: [[TENSOR_PTR1:%.+]] = tt.advance [[TENSOR_PTR]], {{\[}}[[CST_8_i32]], [[CST_64_i32]]] : <tensor<16x128xf32>>
50+
// CHECK: tt.store [[TENSOR_PTR1]], [[CST]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<16x128xf32>>
4751
// CHECK: tt.return
4852
// CHECK: }
4953
}

0 commit comments

Comments
 (0)