Skip to content

Commit e7e6fa3

Browse files
authored
Improve getFinalValue and remove no longer needed helper functions (intel#3282)
Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 2247beb commit e7e6fa3

File tree

7 files changed

+212
-262
lines changed

7 files changed

+212
-262
lines changed

test/Triton/Intel/RaiseToBlockPointers/addptr_mul_value_const.mlir

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,15 @@ module {
3838
// CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64
3939
// CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64
4040
// CHECK-DAG: [[VAR_0_:%.+]] = tt.get_program_id x : i32
41-
// CHECK-DAG: [[VAR_1_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index
42-
// CHECK: [[VAR_2_:%.+]] = arith.muli [[PARAM_2_]], [[CST_2048_i32]] : i32
43-
// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[VAR_1_]] : index to i64
44-
// CHECK-DAG: [[VAR_4_:%.+]] = arith.addi [[VAR_0_]], [[VAR_2_]] : i32
45-
// CHECK-DAG: [[VAR_5_:%.+]] = arith.addi [[VAR_3_]], [[CST_1_i64]] : i64
46-
// CHECK-DAG: [[VAR_6_:%.+]] = arith.trunci [[VAR_5_]] : i64 to i32
47-
// CHECK-DAG: [[VAR_7_:%.+]] = arith.divui [[VAR_4_]], [[VAR_6_]] : i32
48-
// CHECK-DAG: [[VAR_8_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]]], {{\[}}[[VAR_5_]]], {{\[}}[[VAR_7_]]] {{.*}} : <tensor<1024xbf16>>
49-
// CHECK-DAG: [[VAR_9_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[VAR_0_]]] {{.*}} : <tensor<1024xbf16>>
50-
// CHECK: [[VAR_10_:%.+]] = tt.load [[VAR_8_]] : !tt.ptr<tensor<1024xbf16>>
51-
// CHECK: tt.store [[VAR_9_]], [[VAR_10_]] : !tt.ptr<tensor<1024xbf16>>
41+
// CHECK: [[VAR_1_:%.+]] = arith.muli [[PARAM_2_]], [[CST_2048_i32]] : i32
42+
// CHECK-DAG: [[VAR_2_:%.+]] = arith.extsi [[PARAM_2_]] : i32 to i64
43+
// CHECK-DAG: [[VAR_3_:%.+]] = arith.addi [[VAR_0_]], [[VAR_1_]] : i32
44+
// CHECK-DAG: [[VAR_4_:%.+]] = arith.addi [[VAR_2_]], [[CST_1_i64]] : i64
45+
// CHECK-DAG: [[VAR_5_:%.+]] = arith.trunci [[VAR_4_]] : i64 to i32
46+
// CHECK-DAG: [[VAR_6_:%.+]] = arith.divui [[VAR_3_]], [[VAR_5_]] : i32
47+
// CHECK-DAG: [[VAR_7_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]]], {{\[}}[[VAR_4_]]], {{\[}}[[VAR_6_]]] {{.*}} : <tensor<1024xbf16>>
48+
// CHECK-DAG: [[VAR_8_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[VAR_0_]]] {{.*}} : <tensor<1024xbf16>>
49+
// CHECK: [[VAR_9_:%.+]] = tt.load [[VAR_7_]] : !tt.ptr<tensor<1024xbf16>>
50+
// CHECK: tt.store [[VAR_8_]], [[VAR_9_]] : !tt.ptr<tensor<1024xbf16>>
5251
// CHECK: tt.return
5352
// CHECK: }

test/Triton/Intel/RaiseToBlockPointers/kernel-03-matrix-multiplication.mlir

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -106,24 +106,28 @@ module {
106106
// CHECK-DAG: [[CST_64_i32:%.+]] = arith.constant 64 : i32
107107
// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32
108108
// CHECK-DAG: [[CST_1_i32:%.+]] = arith.constant 1 : i32
109-
// CHECK: [[VAR_28_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}{{.*}}], {{\[}}{{.*}}] {{.*}} : <tensor<128x64xbf16>>
110-
// CHECK: [[VAR_31_:%.+]] = arith.index_cast [[PARAM_9_]] : i32 to index
111-
// CHECK: [[VAR_32_:%.+]] = arith.index_cast [[VAR_31_]] : index to i64
112-
// CHECK: [[VAR_38_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}{{.*}}], {{\[}}{{.*}}] {{.*}} : <tensor<64x256xbf16>>
113-
// CHECK-DAG: [[VAR_39_:%.+]] = arith.muli [[PARAM_7_]], [[CST_64_i32]] : i32
114-
// CHECK-DAG: [[VAR_40_:%.+]] = arith.muli [[PARAM_8_]], [[CST_64_i32]] : i32
115-
// CHECK: [[VAR_41_:%.+]]:3 = scf.for {{.*}} iter_args([[VAR_arg13_:%.+]] = [[VAR_cst_]], [[VAR_arg14_:%.+]] = [[VAR_28_]], [[VAR_arg15_:%.+]] = [[VAR_38_]]) -> (tensor<128x256xf32>, !tt.ptr<tensor<128x64xbf16>>, !tt.ptr<tensor<64x256xbf16>>) : i32 {
116-
// CHECK-DAG: [[VAR_54_:%.+]] = tt.load [[VAR_arg14_]] : !tt.ptr<tensor<128x64xbf16>>
117-
// CHECK-DAG: [[VAR_55_:%.+]] = tt.load [[VAR_arg15_]] : !tt.ptr<tensor<64x256xbf16>>
118-
// CHECK: [[VAR_56_:%.+]] = tt.dot [[VAR_54_]], [[VAR_55_]], [[VAR_cst_]], inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x256xbf16> -> tensor<128x256xf32>
119-
// CHECK-DAG: [[VAR_57_:%.+]] = arith.addf [[VAR_arg13_]], [[VAR_56_]] : tensor<128x256xf32>
120-
// CHECK-DAG: [[VAR_58_:%.+]] = tt.advance [[VAR_arg14_]], {{\[}}[[CST_0_i32]], [[VAR_39_]]] : <tensor<128x64xbf16>>
121-
// CHECK-DAG: [[VAR_59_:%.+]] = tt.advance [[VAR_arg15_]], {{\[}}[[CST_0_i32]], [[VAR_40_]]] : <tensor<64x256xbf16>>
122-
// CHECK: scf.yield [[VAR_57_]], [[VAR_58_]], [[VAR_59_]] : tensor<128x256xf32>, !tt.ptr<tensor<128x64xbf16>>, !tt.ptr<tensor<64x256xbf16>>
109+
// CHECK: [[VAR_20_:%.+]] = arith.extsi [[PARAM_6_]] : i32 to i64
110+
// CHECK: [[VAR_21_:%.+]] = arith.extsi [[PARAM_7_]] : i32 to i64
111+
// CHECK: [[VAR_22_:%.+]] = arith.divui {{.*}}, [[PARAM_6_]] : i32
112+
// CHECK: [[VAR_23_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_20_]], [[VAR_21_]]], {{\[}}[[VAR_22_]], [[CST_0_i32]]] {{.*}} : <tensor<128x64xbf16>>
113+
// CHECK: [[VAR_24_:%.+]] = arith.extsi [[PARAM_8_]] : i32 to i64
114+
// CHECK: [[VAR_25_:%.+]] = arith.muli {{.*}}, [[PARAM_9_]] : i32
115+
// CHECK: [[VAR_26_:%.+]] = arith.extsi [[PARAM_9_]] : i32 to i64
116+
// CHECK: [[VAR_27_:%.+]] = arith.divui [[VAR_25_]], [[PARAM_9_]] : i32
117+
// CHECK: [[VAR_28_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_24_]], [[VAR_26_]]], {{\[}}[[CST_0_i32]], [[VAR_27_]]] {{.*}} : <tensor<64x256xbf16>>
118+
// CHECK-DAG: [[VAR_29_:%.+]] = arith.muli [[PARAM_7_]], [[CST_64_i32]] : i32
119+
// CHECK-DAG: [[VAR_30_:%.+]] = arith.muli [[PARAM_8_]], [[CST_64_i32]] : i32
120+
// CHECK: [[VAR_31_:%.+]]:3 = scf.for {{.*}} iter_args([[VAR_arg13_:%.+]] = [[VAR_cst_]], [[VAR_arg14_:%.+]] = [[VAR_23_]], [[VAR_arg15_:%.+]] = [[VAR_28_]]) -> (tensor<128x256xf32>, !tt.ptr<tensor<128x64xbf16>>, !tt.ptr<tensor<64x256xbf16>>) : i32 {
121+
// CHECK-DAG: [[VAR_40_:%.+]] = tt.load [[VAR_arg14_]] : !tt.ptr<tensor<128x64xbf16>>
122+
// CHECK-DAG: [[VAR_41_:%.+]] = tt.load [[VAR_arg15_]] : !tt.ptr<tensor<64x256xbf16>>
123+
// CHECK: [[VAR_42_:%.+]] = tt.dot [[VAR_40_]], [[VAR_41_]], [[VAR_cst_]], inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x256xbf16> -> tensor<128x256xf32>
124+
// CHECK-DAG: [[VAR_43_:%.+]] = arith.addf [[VAR_arg13_]], [[VAR_42_]] : tensor<128x256xf32>
125+
// CHECK-DAG: [[VAR_44_:%.+]] = tt.advance [[VAR_arg14_]], {{\[}}[[CST_0_i32]], [[VAR_29_]]] : <tensor<128x64xbf16>>
126+
// CHECK-DAG: [[VAR_45_:%.+]] = tt.advance [[VAR_arg15_]], {{\[}}[[CST_0_i32]], [[VAR_30_]]] : <tensor<64x256xbf16>>
127+
// CHECK: scf.yield [[VAR_43_]], [[VAR_44_]], [[VAR_45_]] : tensor<128x256xf32>, !tt.ptr<tensor<128x64xbf16>>, !tt.ptr<tensor<64x256xbf16>>
123128
// CHECK: }
124-
// CHECK-DAG: [[VAR_42_:%.+]] = arith.truncf [[VAR_41_]]#0 : tensor<128x256xf32> to tensor<128x256xbf16>
125-
// CHECK-DAG: [[VAR_43_:%.+]] = arith.index_cast [[PARAM_10_]] : i32 to index
126-
// CHECK: [[VAR_53_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}{{.*}}], {{\[}}{{.*}}] {{.*}} : <tensor<128x256xbf16>>
127-
// CHECK: tt.store [[VAR_53_]], [[VAR_42_]] : !tt.ptr<tensor<128x256xbf16>>
129+
// CHECK-DAG: [[VAR_32_:%.+]] = arith.truncf [[VAR_31_]]#0 : tensor<128x256xf32> to tensor<128x256xbf16>
130+
// CHECK: [[VAR_39_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}{{.*}}], {{\[}}{{.*}}] {{.*}} : <tensor<128x256xbf16>>
131+
// CHECK: tt.store [[VAR_39_]], [[VAR_32_]] : !tt.ptr<tensor<128x256xbf16>>
128132
// CHECK: tt.return
129133
// CHECK: }

0 commit comments

Comments
 (0)