@@ -336,26 +336,26 @@ tt.func public @wrap_side_by_side_masked(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32
336336}
337337
338338
339- // CHECK: tt.func @test_addptr_for_accumulation([[PARAM_0_:%.+]]: !tt.ptr<bf16>, [[PARAM_1_:%.+]]: !tt.ptr<bf16>, [[PARAM_2_:%.+]]: !tt.ptr<bf16>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32) {
340- // CHECK-DAG: [[CST_3_i32:%.+]] = arith.constant 3 : i32
341- // CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index
342- // CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index
343- // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
344- // CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64
345- // CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64
346- // CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32
347- // CHECK-DAG: [[CST_5_i64:%.+]] = arith.constant 5 : i64
339+ // CHECK: tt.func @test_addptr_for_accumulation([[PARAM_0_:%.+]]: !tt.ptr<bf16>, [[PARAM_1_:%.+]]: !tt.ptr<bf16>, [[PARAM_2_:%.+]]: !tt.ptr<bf16>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32) {
340+ // CHECK-DAG: [[CST_3_i32:%.+]] = arith.constant 3 : i32
341+ // CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index
342+ // CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index
343+ // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
344+ // CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64
345+ // CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64
346+ // CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32
347+ // CHECK-DAG: [[CST_5_i64:%.+]] = arith.constant 5 : i64
348348// CHECK: [[VAR_1_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_1_i64]], [[CST_5_i64]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {order = array<i32>} : <tensor<4x256xbf16>>
349349// CHECK: [[VAR_2_:%.+]] = tt.load [[VAR_1_]] : !tt.ptr<tensor<4x256xbf16>>
350- // CHECK: [[VAR_4_ :%.+]]:2 = scf.for [[VAR_arg5_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg6_:%.+]] = [[VAR_2_ ]], [[VAR_arg7_:%.+]] = [[PARAM_3_]]) -> ( tensor<4x256xbf16>, i32) {
351- // CHECK: [[VAR_7_ :%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_1_i64]], [[CST_5_i64]]], {{\[}}[[ VAR_arg7_]], [[CST_0_i32]]] {order = array<i32>} : <tensor<4x256xbf16>>
352- // CHECK: [[VAR_8_ :%.+]] = tt.load [[VAR_7_ ]] : !tt.ptr<tensor<4x256xbf16>>
353- // CHECK: [[VAR_9_ :%.+]] = arith.addf [[VAR_arg6_]], [[VAR_8_ ]] : tensor<4x256xbf16>
354- // CHECK: [[VAR_10_ :%.+]] = arith.addi [[VAR_arg7_]], [[ CST_3_i32]] : i32
355- // CHECK: scf.yield [[VAR_9_ ]], [[VAR_10_ ]] : tensor<4x256xbf16>, i32
350+ // CHECK: [[VAR_3_ :%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_1_i64]], [[CST_5_i64]]], {{\[}}[[PARAM_3_ ]], [[CST_0_i32]]] {order = array<i32>} : < tensor<4x256xbf16>>
351+ // CHECK: [[VAR_4_ :%.+]]:2 = scf.for [[VAR_arg5_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg6_:%.+]] = [[VAR_2_]], [[ VAR_arg7_:%.+]] = [[VAR_3_]]) -> (tensor<4x256xbf16>, !tt.ptr <tensor<4x256xbf16>>) {
352+ // CHECK: [[VAR_5_ :%.+]] = tt.load [[VAR_arg7_ ]] : !tt.ptr<tensor<4x256xbf16>>
353+ // CHECK: [[VAR_6_ :%.+]] = arith.addf [[VAR_arg6_]], [[VAR_5_ ]] : tensor<4x256xbf16>
354+ // CHECK: [[VAR_7_ :%.+]] = tt.advance [[VAR_arg7_]], {{\[}}[[CST_0_i32]], [[ CST_3_i32]]{{\]}} : <tensor<4x256xbf16>>
355+ // CHECK: scf.yield [[VAR_6_ ]], [[VAR_7_ ]] : tensor<4x256xbf16>, !tt.ptr<tensor<4x256xbf16>>
356356// CHECK: }
357- // CHECK: [[VAR_6_ :%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_1_i64]], [[CST_5_i64]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {order = array<i32>} : <tensor<4x256xbf16>>
358- // CHECK: tt.store [[VAR_6_ ]], [[VAR_4_]]#0 : !tt.ptr<tensor<4x256xbf16>>
357+ // CHECK: [[VAR_5_ :%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_1_i64]], [[CST_5_i64]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {order = array<i32>} : <tensor<4x256xbf16>>
358+ // CHECK: tt.store [[VAR_5_ ]], [[VAR_4_]]#0 : !tt.ptr<tensor<4x256xbf16>>
359359// CHECK: tt.return
360360// CHECK: }
361361module {
@@ -505,22 +505,22 @@ module {
505505// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index
506506// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index
507507// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index
508- // CHECK-DAG: [[CST_1024_i32 :%.+]] = arith.constant 1024 : i32
508+ // CHECK-DAG: [[CST_1024 :%.+]] = arith.constant 1024 : i32
509509// CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64
510- // CHECK: [[VAR_0_ :%.+]]:5 = scf.for [[VAR_arg2_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg3_:%.+]] = [[CST_1_]], [[VAR_arg4_:%.+]] = [[CST_2_]], [[VAR_arg5_:%.+]] = [[CST_3_]], [[VAR_arg6_:%.+]] = [[CST_1024_i32]], [[VAR_arg7_:%.+]] = [[CST_1024_i32]]) -> (index, index, index, i32, i32) {
511- // CHECK: [[VAR_1_ :%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[VAR_arg7_ ]]] {order = array<i32>} : <tensor<256xbf16>>
512- // CHECK: [[VAR_2_ :%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[ VAR_arg6_]]] {order = array<i32>} : <tensor<256xbf16>>
513- // CHECK: [[VAR_3_:%.+]] = tt.load [[VAR_2_ ]] : !tt.ptr<tensor<256xbf16>>
514- // CHECK: tt.store [[VAR_1_ ]], [[VAR_3_]] : !tt.ptr<tensor<256xbf16>>
515- // CHECK: [[VAR_4_:%.+]] = arith.addi [[VAR_arg6_ ]], [[ CST_3_i32]] : i32
510+ // CHECK-DAG : [[VAR_1_ :%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[CST_1024]]] {order = array<i32>} : <tensor<256xbf16>>
511+ // CHECK-DAG : [[VAR_2_ :%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[CST_1024 ]]] {order = array<i32>} : <tensor<256xbf16>>
512+ // CHECK: [[VAR_0_:%.+]]:5 = scf.for [[VAR_arg2_ :%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg3_:%.+]] = [[CST_1_]], [[VAR_arg4_:%.+]] = [[VAR_1_]], [[VAR_arg5_:%.+]] = [[CST_2_]], [[ VAR_arg6_:%.+]] = [[VAR_2_]], [[VAR_arg7_:%.+]] = [[CST_3_]]) -> (index, !tt.ptr<tensor<256xbf16>>, index, !tt.ptr <tensor<256xbf16>>, index) {
513+ // CHECK: [[VAR_3_:%.+]] = tt.load [[VAR_arg4_ ]] : !tt.ptr<tensor<256xbf16>>
514+ // CHECK: tt.store [[VAR_arg6_ ]], [[VAR_3_]] : !tt.ptr<tensor<256xbf16>>
515+ // CHECK: [[VAR_4_:%.+]] = tt.advance [[VAR_arg4_ ]], {{\[}}[[ CST_3_i32]]{{\]}} : <tensor<256xbf16>>
516516// CHECK: [[VAR_5_:%.+]] = arith.addi [[VAR_arg3_]], [[CST_3_]] : index
517- // CHECK: [[VAR_6_:%.+]] = arith.addi [[VAR_arg4_ ]], [[CST_3_]] : index
518- // CHECK: [[VAR_7_:%.+]] = arith.addi [[VAR_arg5_ ]], [[CST_3_]] : index
517+ // CHECK: [[VAR_6_:%.+]] = arith.addi [[VAR_arg5_ ]], [[CST_3_]] : index
518+ // CHECK: [[VAR_7_:%.+]] = arith.addi [[VAR_arg7_ ]], [[CST_3_]] : index
519519// CHECK: [[VAR_8_:%.+]] = arith.addi [[VAR_5_]], [[VAR_6_]] : index
520520// CHECK: [[VAR_9_:%.+]] = arith.addi [[VAR_8_]], [[VAR_7_]] : index
521521// CHECK: [[VAR_10_:%.+]] = arith.index_cast [[VAR_9_]] : index to i32
522- // CHECK: [[VAR_11_:%.+]] = arith.addi [[VAR_arg7_ ]], [[ VAR_10_]] : i32
523- // CHECK: scf.yield [[VAR_5_]], [[VAR_6_ ]], [[VAR_7_ ]], [[VAR_4_ ]], [[VAR_11_ ]] : index, index , index, i32, i32
522+ // CHECK: [[VAR_11_:%.+]] = tt.advance [[VAR_arg6_ ]], {{\[}}[[ VAR_10_]]{{\]}} : <tensor<256xbf16>>
523+ // CHECK: scf.yield [[VAR_5_]], [[VAR_4_ ]], [[VAR_6_ ]], [[VAR_11_ ]], [[VAR_7_ ]] : index, !tt.ptr<tensor<256xbf16>> , index, !tt.ptr<tensor<256xbf16>>, index
524524// CHECK: }
525525// CHECK: tt.return
526526// CHECK: }
@@ -568,12 +568,12 @@ module {
568568// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index
569569// CHECK-DAG: [[CST_1024_i32:%.+]] = arith.constant 1024 : i32
570570// CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64
571- // CHECK: [[VAR_0_:%.+]] = scf.for [[VAR_arg1_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg2_:%.+]] = [[ CST_1024_i32]]) -> ( i32) {
572- // CHECK: [[VAR_1_:%.+]] = arith.addi [[VAR_arg2_]], [[CST_3_i32 ]] : i32
573- // CHECK: [[VAR_2_:%.+]] = tt.make_tensor_ptr [[PARAM_0_ ]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[VAR_1_]]] {order = array<i32> } : <tensor<256xbf16>>
571+ // CHECK: [[VAR_0_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[ CST_1024_i32]]] {order = array< i32>} : <tensor<256xbf16>>
572+ // CHECK: [[VAR_1_:%.+]] = scf.for [[VAR_arg1_:%.+]] = [[CST_0_ ]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg2_:%.+]] = [[VAR_0_]]) -> (!tt.ptr<tensor<256xbf16>>) {
573+ // CHECK: [[VAR_2_:%.+]] = tt.advance [[VAR_arg2_ ]], {{\[}}[[CST_3_i32]] {{\]} } : <tensor<256xbf16>>
574574// CHECK: [[VAR_3_:%.+]] = tt.load [[VAR_2_]] : !tt.ptr<tensor<256xbf16>>
575575// CHECK: tt.store [[VAR_2_]], [[VAR_3_]] : !tt.ptr<tensor<256xbf16>>
576- // CHECK: scf.yield [[VAR_1_ ]] : i32
576+ // CHECK: scf.yield [[VAR_2_ ]] : !tt.ptr<tensor<256xbf16>>
577577// CHECK: }
578578// CHECK: tt.return
579579// CHECK: }
@@ -609,13 +609,12 @@ module {
609609// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index
610610// CHECK-DAG: [[CST_1024_i32:%.+]] = arith.constant 1024 : i32
611611// CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64
612- // CHECK: [[VAR_0_:%.+]] = scf.for [[VAR_arg1_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg2_:%.+]] = [[CST_1024_i32]]) -> (i32) {
613-
614- // CHECK: [[VAR_2_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[VAR_arg2_]]] {order = array<i32>} : <tensor<256xbf16>>
615- // CHECK: [[VAR_3_:%.+]] = tt.load [[VAR_2_]] : !tt.ptr<tensor<256xbf16>>
616- // CHECK: tt.store [[VAR_2_]], [[VAR_3_]] : !tt.ptr<tensor<256xbf16>>
617- // CHECK: [[VAR_3_:%.+]] = arith.addi [[VAR_arg2_]], [[CST_3_i32]] : i32
618- // CHECK: scf.yield [[VAR_3_]] : i32
612+ // CHECK: [[VAR_0_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[CST_1024_i32]]] {order = array<i32>} : <tensor<256xbf16>>
613+ // CHECK: [[VAR_1_:%.+]] = scf.for [[VAR_arg1_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg2_:%.+]] = [[VAR_0_]]) -> (!tt.ptr<tensor<256xbf16>>) {
614+ // CHECK: [[VAR_2_:%.+]] = tt.load [[VAR_arg2_]] : !tt.ptr<tensor<256xbf16>>
615+ // CHECK: tt.store [[VAR_arg2_]], [[VAR_2_]] : !tt.ptr<tensor<256xbf16>>
616+ // CHECK: [[VAR_3_:%.+]] = tt.advance [[VAR_arg2_]], {{\[}}[[CST_3_i32]]{{\]}} : <tensor<256xbf16>>
617+ // CHECK: scf.yield [[VAR_3_]] : !tt.ptr<tensor<256xbf16>>
619618// CHECK: }
620619// CHECK: tt.return
621620// CHECK: }
@@ -645,11 +644,13 @@ module {
645644// CHECK: tt.func @matmul_kernel
646645// CHECK: tt.make_tensor_ptr %arg0
647646// CHECK: tt.make_tensor_ptr %arg1
648- // CHECK: tt.dot
649- // CHECK: %[[VAL_1:.*]] = arith.addi
650- // CHECK: %[[VAL_2:.*]] = arith.divui
651- // CHECK: tt.make_tensor_ptr %arg0, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}%[[VAL_2]], %[[VAL_1]]]
652- // CHECK: tt.make_tensor_ptr %arg1
647+ // CHECK: scf.for
648+ // CHECK: [[LOAD1:%.*]] = tt.load [[ARG10:%.*]], {{.*}}, {{.*}} : !tt.ptr<tensor<64x32xf16>>
649+ // CHECK: [[LOAD2:%.*]] = tt.load [[ARG11:%.*]], {{.*}}, {{.*}} : !tt.ptr<tensor<32x128xf16>>
650+ // CHECK: [[DOT:%.*]] = tt.dot [[LOAD1]], [[LOAD2]]
651+ // CHECK: [[ADV1:%.*]] = tt.advance [[ARG10]], {{.*}} : <tensor<64x32xf16>>
652+ // CHECK: [[ADV2:%.*]] = tt.advance [[ARG11]], {{.*}} : <tensor<32x128xf16>>
653+ // CHECK: scf.yield [[DOT]], [[ADV1]], [[ADV2]]
653654module {
654655 tt.func @matmul_kernel (%arg0: !tt.ptr <f16 >, %arg1: !tt.ptr <f16 > , %arg2: !tt.ptr <f16 >, %arg3: i32 , %arg4: i32 , %arg5: i32 , %arg6: i32 , %arg7: i32 ) -> tensor <64 x128 xf16 > {
655656 %cst = arith.constant dense <0.000000e+00 > : tensor <64 x128 xf32 >
0 commit comments