@@ -6,20 +6,22 @@ module {
6
6
%c64_i32 = arith.constant 64 : i32
7
7
%c8_i32 = arith.constant 8 : i32
8
8
%0 = arith.extsi %arg2 : i32 to i64
9
- %desc = tt.make_tensor_descriptor %arg0 , [%arg1 , %arg2 ], [%0 , %c1_i64 ] : <f32 >, <tensor <16 x128 xf32 >>
10
- %load = tt.descriptor_load %desc [%c8_i32 , %c64_i32 ] : !tt.tensordesc <tensor <16 x128 xf32 >> -> tensor <16 x128 xf32 >
9
+ %desc1 = tt.make_tensor_descriptor %arg0 , [%arg1 , %arg2 ], [%0 , %c1_i64 ] : <f32 >, <tensor <16 x128 xf32 >>
10
+ %load1 = tt.descriptor_load %desc1 [%c8_i32 , %c64_i32 ] : !tt.tensordesc <tensor <16 x128 xf32 >> -> tensor <16 x128 xf32 >
11
11
tt.return
12
12
}
13
13
// CHECK: tt.func public @test_load([[PARAM_0:%.+]]: !tt.ptr<f32>, [[PARAM_1:%.+]]: i32, [[PARAM_2:%.+]]: i32) {
14
14
// CHECK-NOT: tt.make_tensor_descriptor
15
15
// CHECK-NOT: tt.descriptor_load
16
+ // CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32
16
17
// CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64
17
18
// CHECK-DAG: [[CST_64_i32:%.+]] = arith.constant 64 : i32
18
19
// CHECK-DAG: [[CST_8_i32:%.+]] = arith.constant 8 : i32
19
20
// CHECK-DAG: [[EXTSI_PARAM_1:%.+]] = arith.extsi [[PARAM_1]] : i32 to i64
20
21
// CHECK-DAG: [[EXTSI_PARAM_2:%.+]] = arith.extsi [[PARAM_2]] : i32 to i64
21
- // CHECK: [[TENSOR_PTR:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2]]], {{\[}}[[EXTSI_PARAM_2]], [[CST_1_i64]]], {{\[}}[[CST_8_i32]], [[CST_64_i32]]] {{.*}} : <tensor<16x128xf32>>
22
- // CHECK: [[LOAD:%.+]] = tt.load [[TENSOR_PTR]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<16x128xf32>>
22
+ // CHECK: [[TENSOR_PTR:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2]]], {{\[}}[[EXTSI_PARAM_2]], [[CST_1_i64]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : <tensor<16x128xf32>>
23
+ // CHECK: [[TENSOR_PTR1:%.+]] = tt.advance [[TENSOR_PTR]], {{\[}}[[CST_8_i32]], [[CST_64_i32]]] : <tensor<16x128xf32>>
24
+ // CHECK: [[LOAD1:%.+]] = tt.load [[TENSOR_PTR1]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<16x128xf32>>
23
25
// CHECK: tt.return
24
26
// CHECK: }
25
27
@@ -29,21 +31,23 @@ module {
29
31
%c8_i32 = arith.constant 8 : i32
30
32
%cst = arith.constant dense <1.000000e+00 > : tensor <16 x128 xf32 >
31
33
%0 = arith.extsi %arg2 : i32 to i64
32
- %desc = tt.make_tensor_descriptor %arg0 , [%arg1 , %arg2 ], [%0 , %c1_i64 ] : <f32 >, <tensor <16 x128 xf32 >>
33
- tt.descriptor_store %desc [%c8_i32 , %c64_i32 ], %cst : !tt.tensordesc <tensor <16 x128 xf32 >>, tensor <16 x128 xf32 >
34
+ %desc1 = tt.make_tensor_descriptor %arg0 , [%arg1 , %arg2 ], [%0 , %c1_i64 ] : <f32 >, <tensor <16 x128 xf32 >>
35
+ tt.descriptor_store %desc1 [%c8_i32 , %c64_i32 ], %cst : !tt.tensordesc <tensor <16 x128 xf32 >>, tensor <16 x128 xf32 >
34
36
tt.return
35
37
}
36
38
// CHECK: tt.func public @test_store([[PARAM_0:%.+]]: !tt.ptr<f32>, [[PARAM_1:%.+]]: i32, [[PARAM_2:%.+]]: i32) {
37
39
// CHECK-NOT: tt.make_tensor_descriptor
38
40
// CHECK-NOT: tt.descriptor_store
41
+ // CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32
39
42
// CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64
40
43
// CHECK-DAG: [[CST_64_i32:%.+]] = arith.constant 64 : i32
41
44
// CHECK-DAG: [[CST_8_i32:%.+]] = arith.constant 8 : i32
42
45
// CHECK-DAG: [[CST:%.+]] = arith.constant dense<1.000000e+00> : tensor<16x128xf32>
43
46
// CHECK-DAG: [[EXTSI_PARAM_1:%.+]] = arith.extsi [[PARAM_1]] : i32 to i64
44
47
// CHECK-DAG: [[EXTSI_PARAM_2:%.+]] = arith.extsi [[PARAM_2]] : i32 to i64
45
- // CHECK: [[TENSOR_PTR:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2]]], {{\[}}[[EXTSI_PARAM_2]], [[CST_1_i64]]], {{\[}}[[CST_8_i32]], [[CST_64_i32]]] {{.*}} : <tensor<16x128xf32>>
46
- // CHECK: tt.store [[TENSOR_PTR]], [[CST]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<16x128xf32>>
48
+ // CHECK: [[TENSOR_PTR:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2]]], {{\[}}[[EXTSI_PARAM_2]], [[CST_1_i64]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : <tensor<16x128xf32>>
49
+ // CHECK: [[TENSOR_PTR1:%.+]] = tt.advance [[TENSOR_PTR]], {{\[}}[[CST_8_i32]], [[CST_64_i32]]] : <tensor<16x128xf32>>
50
+ // CHECK: tt.store [[TENSOR_PTR1]], [[CST]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<16x128xf32>>
47
51
// CHECK: tt.return
48
52
// CHECK: }
49
53
}
0 commit comments