Skip to content

Commit a30f5cd

Browse files
authored
[triton-raise-block-ptr]: Avoid generating unnecessary tt.make_tensor_tr in a loop. (#3172)
This PR simplifies the code generation login in the `triton-raise-block-pointer` transformation by: - reuse block ptrs created outside of the loop (rather than rematerializing them in the loop) - use `tt.advance` to do ptr arithmetic on block pointers --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 44e076a commit a30f5cd

File tree

2 files changed

+219
-295
lines changed

2 files changed

+219
-295
lines changed

test/Triton/raise-block-pointer.mlir

Lines changed: 45 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -336,26 +336,26 @@ tt.func public @wrap_side_by_side_masked(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32
336336
}
337337

338338

339-
// CHECK: tt.func @test_addptr_for_accumulation([[PARAM_0_:%.+]]: !tt.ptr<bf16>, [[PARAM_1_:%.+]]: !tt.ptr<bf16>, [[PARAM_2_:%.+]]: !tt.ptr<bf16>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32) {
340-
// CHECK-DAG: [[CST_3_i32:%.+]] = arith.constant 3 : i32
341-
// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index
342-
// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index
343-
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
344-
// CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64
345-
// CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64
346-
// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32
347-
// CHECK-DAG: [[CST_5_i64:%.+]] = arith.constant 5 : i64
339+
// CHECK: tt.func @test_addptr_for_accumulation([[PARAM_0_:%.+]]: !tt.ptr<bf16>, [[PARAM_1_:%.+]]: !tt.ptr<bf16>, [[PARAM_2_:%.+]]: !tt.ptr<bf16>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32) {
340+
// CHECK-DAG: [[CST_3_i32:%.+]] = arith.constant 3 : i32
341+
// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index
342+
// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index
343+
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
344+
// CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64
345+
// CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64
346+
// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32
347+
// CHECK-DAG: [[CST_5_i64:%.+]] = arith.constant 5 : i64
348348
// CHECK: [[VAR_1_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_1_i64]], [[CST_5_i64]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {order = array<i32>} : <tensor<4x256xbf16>>
349349
// CHECK: [[VAR_2_:%.+]] = tt.load [[VAR_1_]] : !tt.ptr<tensor<4x256xbf16>>
350-
// CHECK: [[VAR_4_:%.+]]:2 = scf.for [[VAR_arg5_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg6_:%.+]] = [[VAR_2_]], [[VAR_arg7_:%.+]] = [[PARAM_3_]]) -> (tensor<4x256xbf16>, i32) {
351-
// CHECK: [[VAR_7_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_1_i64]], [[CST_5_i64]]], {{\[}}[[VAR_arg7_]], [[CST_0_i32]]] {order = array<i32>} : <tensor<4x256xbf16>>
352-
// CHECK: [[VAR_8_:%.+]] = tt.load [[VAR_7_]] : !tt.ptr<tensor<4x256xbf16>>
353-
// CHECK: [[VAR_9_:%.+]] = arith.addf [[VAR_arg6_]], [[VAR_8_]] : tensor<4x256xbf16>
354-
// CHECK: [[VAR_10_:%.+]] = arith.addi [[VAR_arg7_]], [[CST_3_i32]] : i32
355-
// CHECK: scf.yield [[VAR_9_]], [[VAR_10_]] : tensor<4x256xbf16>, i32
350+
// CHECK: [[VAR_3_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_1_i64]], [[CST_5_i64]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {order = array<i32>} : <tensor<4x256xbf16>>
351+
// CHECK: [[VAR_4_:%.+]]:2 = scf.for [[VAR_arg5_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg6_:%.+]] = [[VAR_2_]], [[VAR_arg7_:%.+]] = [[VAR_3_]]) -> (tensor<4x256xbf16>, !tt.ptr<tensor<4x256xbf16>>) {
352+
// CHECK: [[VAR_5_:%.+]] = tt.load [[VAR_arg7_]] : !tt.ptr<tensor<4x256xbf16>>
353+
// CHECK: [[VAR_6_:%.+]] = arith.addf [[VAR_arg6_]], [[VAR_5_]] : tensor<4x256xbf16>
354+
// CHECK: [[VAR_7_:%.+]] = tt.advance [[VAR_arg7_]], {{\[}}[[CST_0_i32]], [[CST_3_i32]]{{\]}} : <tensor<4x256xbf16>>
355+
// CHECK: scf.yield [[VAR_6_]], [[VAR_7_]] : tensor<4x256xbf16>, !tt.ptr<tensor<4x256xbf16>>
356356
// CHECK: }
357-
// CHECK: [[VAR_6_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_1_i64]], [[CST_5_i64]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {order = array<i32>} : <tensor<4x256xbf16>>
358-
// CHECK: tt.store [[VAR_6_]], [[VAR_4_]]#0 : !tt.ptr<tensor<4x256xbf16>>
357+
// CHECK: [[VAR_5_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_1_i64]], [[CST_5_i64]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {order = array<i32>} : <tensor<4x256xbf16>>
358+
// CHECK: tt.store [[VAR_5_]], [[VAR_4_]]#0 : !tt.ptr<tensor<4x256xbf16>>
359359
// CHECK: tt.return
360360
// CHECK: }
361361
module {
@@ -505,22 +505,22 @@ module {
505505
// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index
506506
// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index
507507
// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index
508-
// CHECK-DAG: [[CST_1024_i32:%.+]] = arith.constant 1024 : i32
508+
// CHECK-DAG: [[CST_1024:%.+]] = arith.constant 1024 : i32
509509
// CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64
510-
// CHECK: [[VAR_0_:%.+]]:5 = scf.for [[VAR_arg2_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg3_:%.+]] = [[CST_1_]], [[VAR_arg4_:%.+]] = [[CST_2_]], [[VAR_arg5_:%.+]] = [[CST_3_]], [[VAR_arg6_:%.+]] = [[CST_1024_i32]], [[VAR_arg7_:%.+]] = [[CST_1024_i32]]) -> (index, index, index, i32, i32) {
511-
// CHECK: [[VAR_1_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[VAR_arg7_]]] {order = array<i32>} : <tensor<256xbf16>>
512-
// CHECK: [[VAR_2_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[VAR_arg6_]]] {order = array<i32>} : <tensor<256xbf16>>
513-
// CHECK: [[VAR_3_:%.+]] = tt.load [[VAR_2_]] : !tt.ptr<tensor<256xbf16>>
514-
// CHECK: tt.store [[VAR_1_]], [[VAR_3_]] : !tt.ptr<tensor<256xbf16>>
515-
// CHECK: [[VAR_4_:%.+]] = arith.addi [[VAR_arg6_]], [[CST_3_i32]] : i32
510+
// CHECK-DAG: [[VAR_1_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[CST_1024]]] {order = array<i32>} : <tensor<256xbf16>>
511+
// CHECK-DAG: [[VAR_2_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[CST_1024]]] {order = array<i32>} : <tensor<256xbf16>>
512+
// CHECK: [[VAR_0_:%.+]]:5 = scf.for [[VAR_arg2_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg3_:%.+]] = [[CST_1_]], [[VAR_arg4_:%.+]] = [[VAR_1_]], [[VAR_arg5_:%.+]] = [[CST_2_]], [[VAR_arg6_:%.+]] = [[VAR_2_]], [[VAR_arg7_:%.+]] = [[CST_3_]]) -> (index, !tt.ptr<tensor<256xbf16>>, index, !tt.ptr<tensor<256xbf16>>, index) {
513+
// CHECK: [[VAR_3_:%.+]] = tt.load [[VAR_arg4_]] : !tt.ptr<tensor<256xbf16>>
514+
// CHECK: tt.store [[VAR_arg6_]], [[VAR_3_]] : !tt.ptr<tensor<256xbf16>>
515+
// CHECK: [[VAR_4_:%.+]] = tt.advance [[VAR_arg4_]], {{\[}}[[CST_3_i32]]{{\]}} : <tensor<256xbf16>>
516516
// CHECK: [[VAR_5_:%.+]] = arith.addi [[VAR_arg3_]], [[CST_3_]] : index
517-
// CHECK: [[VAR_6_:%.+]] = arith.addi [[VAR_arg4_]], [[CST_3_]] : index
518-
// CHECK: [[VAR_7_:%.+]] = arith.addi [[VAR_arg5_]], [[CST_3_]] : index
517+
// CHECK: [[VAR_6_:%.+]] = arith.addi [[VAR_arg5_]], [[CST_3_]] : index
518+
// CHECK: [[VAR_7_:%.+]] = arith.addi [[VAR_arg7_]], [[CST_3_]] : index
519519
// CHECK: [[VAR_8_:%.+]] = arith.addi [[VAR_5_]], [[VAR_6_]] : index
520520
// CHECK: [[VAR_9_:%.+]] = arith.addi [[VAR_8_]], [[VAR_7_]] : index
521521
// CHECK: [[VAR_10_:%.+]] = arith.index_cast [[VAR_9_]] : index to i32
522-
// CHECK: [[VAR_11_:%.+]] = arith.addi [[VAR_arg7_]], [[VAR_10_]] : i32
523-
// CHECK: scf.yield [[VAR_5_]], [[VAR_6_]], [[VAR_7_]], [[VAR_4_]], [[VAR_11_]] : index, index, index, i32, i32
522+
// CHECK: [[VAR_11_:%.+]] = tt.advance [[VAR_arg6_]], {{\[}}[[VAR_10_]]{{\]}} : <tensor<256xbf16>>
523+
// CHECK: scf.yield [[VAR_5_]], [[VAR_4_]], [[VAR_6_]], [[VAR_11_]], [[VAR_7_]] : index, !tt.ptr<tensor<256xbf16>>, index, !tt.ptr<tensor<256xbf16>>, index
524524
// CHECK: }
525525
// CHECK: tt.return
526526
// CHECK: }
@@ -568,12 +568,12 @@ module {
568568
// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index
569569
// CHECK-DAG: [[CST_1024_i32:%.+]] = arith.constant 1024 : i32
570570
// CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64
571-
// CHECK: [[VAR_0_:%.+]] = scf.for [[VAR_arg1_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg2_:%.+]] = [[CST_1024_i32]]) -> (i32) {
572-
// CHECK: [[VAR_1_:%.+]] = arith.addi [[VAR_arg2_]], [[CST_3_i32]] : i32
573-
// CHECK: [[VAR_2_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[VAR_1_]]] {order = array<i32>} : <tensor<256xbf16>>
571+
// CHECK: [[VAR_0_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[CST_1024_i32]]] {order = array<i32>} : <tensor<256xbf16>>
572+
// CHECK: [[VAR_1_:%.+]] = scf.for [[VAR_arg1_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg2_:%.+]] = [[VAR_0_]]) -> (!tt.ptr<tensor<256xbf16>>) {
573+
// CHECK: [[VAR_2_:%.+]] = tt.advance [[VAR_arg2_]], {{\[}}[[CST_3_i32]]{{\]}} : <tensor<256xbf16>>
574574
// CHECK: [[VAR_3_:%.+]] = tt.load [[VAR_2_]] : !tt.ptr<tensor<256xbf16>>
575575
// CHECK: tt.store [[VAR_2_]], [[VAR_3_]] : !tt.ptr<tensor<256xbf16>>
576-
// CHECK: scf.yield [[VAR_1_]] : i32
576+
// CHECK: scf.yield [[VAR_2_]] : !tt.ptr<tensor<256xbf16>>
577577
// CHECK: }
578578
// CHECK: tt.return
579579
// CHECK: }
@@ -609,13 +609,12 @@ module {
609609
// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index
610610
// CHECK-DAG: [[CST_1024_i32:%.+]] = arith.constant 1024 : i32
611611
// CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64
612-
// CHECK: [[VAR_0_:%.+]] = scf.for [[VAR_arg1_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg2_:%.+]] = [[CST_1024_i32]]) -> (i32) {
613-
614-
// CHECK: [[VAR_2_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[VAR_arg2_]]] {order = array<i32>} : <tensor<256xbf16>>
615-
// CHECK: [[VAR_3_:%.+]] = tt.load [[VAR_2_]] : !tt.ptr<tensor<256xbf16>>
616-
// CHECK: tt.store [[VAR_2_]], [[VAR_3_]] : !tt.ptr<tensor<256xbf16>>
617-
// CHECK: [[VAR_3_:%.+]] = arith.addi [[VAR_arg2_]], [[CST_3_i32]] : i32
618-
// CHECK: scf.yield [[VAR_3_]] : i32
612+
// CHECK: [[VAR_0_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[CST_1024_i32]]] {order = array<i32>} : <tensor<256xbf16>>
613+
// CHECK: [[VAR_1_:%.+]] = scf.for [[VAR_arg1_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg2_:%.+]] = [[VAR_0_]]) -> (!tt.ptr<tensor<256xbf16>>) {
614+
// CHECK: [[VAR_2_:%.+]] = tt.load [[VAR_arg2_]] : !tt.ptr<tensor<256xbf16>>
615+
// CHECK: tt.store [[VAR_arg2_]], [[VAR_2_]] : !tt.ptr<tensor<256xbf16>>
616+
// CHECK: [[VAR_3_:%.+]] = tt.advance [[VAR_arg2_]], {{\[}}[[CST_3_i32]]{{\]}} : <tensor<256xbf16>>
617+
// CHECK: scf.yield [[VAR_3_]] : !tt.ptr<tensor<256xbf16>>
619618
// CHECK: }
620619
// CHECK: tt.return
621620
// CHECK: }
@@ -645,11 +644,13 @@ module {
645644
// CHECK: tt.func @matmul_kernel
646645
// CHECK: tt.make_tensor_ptr %arg0
647646
// CHECK: tt.make_tensor_ptr %arg1
648-
// CHECK: tt.dot
649-
// CHECK: %[[VAL_1:.*]] = arith.addi
650-
// CHECK: %[[VAL_2:.*]] = arith.divui
651-
// CHECK: tt.make_tensor_ptr %arg0, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}%[[VAL_2]], %[[VAL_1]]]
652-
// CHECK: tt.make_tensor_ptr %arg1
647+
// CHECK: scf.for
648+
// CHECK: [[LOAD1:%.*]] = tt.load [[ARG10:%.*]], {{.*}}, {{.*}} : !tt.ptr<tensor<64x32xf16>>
649+
// CHECK: [[LOAD2:%.*]] = tt.load [[ARG11:%.*]], {{.*}}, {{.*}} : !tt.ptr<tensor<32x128xf16>>
650+
// CHECK: [[DOT:%.*]] = tt.dot [[LOAD1]], [[LOAD2]]
651+
// CHECK: [[ADV1:%.*]] = tt.advance [[ARG10]], {{.*}} : <tensor<64x32xf16>>
652+
// CHECK: [[ADV2:%.*]] = tt.advance [[ARG11]], {{.*}} : <tensor<32x128xf16>>
653+
// CHECK: scf.yield [[DOT]], [[ADV1]], [[ADV2]]
653654
module {
654655
tt.func @matmul_kernel(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16> , %arg2: !tt.ptr<f16>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) -> tensor<64x128xf16> {
655656
%cst = arith.constant dense<0.000000e+00> : tensor<64x128xf32>

0 commit comments

Comments
 (0)