Skip to content

Commit 2b7ca9b

Browse files
authored
[triton-raise-block-ptr]: Fix tt.advance offset computation (#3287)
Resolve #3284. --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 7ad67a0 commit 2b7ca9b

File tree

8 files changed

+92
-86
lines changed

8 files changed

+92
-86
lines changed

test/Triton/Intel/RaiseToBlockPointers/addptr_dim1.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ module {
8585
// CHECK-DAG: [[VAR_3_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_0_i64]], [[CST_1_i64]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : <tensor<1x256xbf16>>
8686
// CHECK-DAG: [[VAR_4_:%.+]] = tt.addptr [[VAR_2_]], [[VAR_1_]] : tensor<1x256x!tt.ptr<bf16>>, tensor<1x256xi32>
8787
// CHECK-DAG: [[VAR_5_:%.+]] = tt.load [[VAR_3_]] : !tt.ptr<tensor<1x256xbf16>>
88-
// CHECK-DAG: [[VAR_6_:%.+]] = tt.advance [[VAR_3_]], {{\[}}[[CST_0_i32]], [[PARAM_1_]]] : <tensor<1x256xbf16>>
88+
// CHECK-DAG: [[VAR_6_:%.+]] = tt.advance [[VAR_3_]], {{\[}}[[PARAM_1_]], [[CST_0_i32]]] : <tensor<1x256xbf16>>
8989
// CHECK: tt.store [[VAR_6_]], [[VAR_5_]] : !tt.ptr<tensor<1x256xbf16>>
9090
// CHECK: [[VAR_7_:%.+]]:2 = scf.for [[VAR_arg2_:%.+]] = {{.*}} iter_args([[VAR_arg3_:%.+]] = [[CST_0_]], [[VAR_arg4_:%.+]] = [[VAR_4_]]) -> (tensor<4x256xbf16>, tensor<1x256x!tt.ptr<bf16>>) {
9191
// CHECK: [[VAR_9_:%.+]] = tt.broadcast [[VAR_arg4_]] : tensor<1x256x!tt.ptr<bf16>> -> tensor<4x256x!tt.ptr<bf16>>

test/Triton/Intel/RaiseToBlockPointers/addptr_for_accumulation.mlir

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,20 +64,18 @@ module {
6464
// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : i32
6565
// CHECK-DAG: [[CST_5_:%.+]] = arith.constant 5 : i64
6666
// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i64
67-
// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index
6867
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i64
6968
// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32
70-
// CHECK: [[VAR_1_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_]], [[CST_0_]]], {{\[}}[[CST_1_]], [[CST_5_]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {{.*}} : <tensor<4x256xbf16>>
71-
// CHECK: [[VAR_2_:%.+]] = tt.load [[VAR_1_]] : !tt.ptr<tensor<4x256xbf16>>
72-
// CHECK: [[VAR_3_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_]], [[CST_0_]]], {{\[}}[[CST_1_]], [[CST_5_]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {{.*}} : <tensor<4x256xbf16>>
73-
// CHECK: [[VAR_4_:%.+]]:2 = scf.for [[VAR_arg5_:%.+]] = {{.*}} iter_args([[VAR_arg6_:%.+]] = [[VAR_2_]], [[VAR_arg7_:%.+]] = [[VAR_3_]]) -> (tensor<4x256xbf16>, !tt.ptr<tensor<4x256xbf16>>) {
74-
// CHECK: [[VAR_8_:%.+]] = tt.load [[VAR_arg7_]] : !tt.ptr<tensor<4x256xbf16>>
75-
// CHECK-DAG: [[VAR_9_:%.+]] = arith.addf [[VAR_arg6_]], [[VAR_8_]] : tensor<4x256xbf16>
76-
// CHECK-DAG: [[VAR_10_:%.+]] = tt.advance [[VAR_arg7_]], {{\[}}[[CST_0_i32]], [[CST_3_]]] : <tensor<4x256xbf16>>
77-
// CHECK: scf.yield [[VAR_9_]], [[VAR_10_]] : tensor<4x256xbf16>, !tt.ptr<tensor<4x256xbf16>>
69+
// CHECK: [[VAR_0_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_]], [[CST_0_]]], {{\[}}[[CST_1_]], [[CST_5_]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {{.*}} : <tensor<4x256xbf16>>
70+
// CHECK: [[VAR_1_:%.+]] = tt.load [[VAR_0_]] : !tt.ptr<tensor<4x256xbf16>>
71+
// CHECK: [[VAR_2_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_]], [[CST_0_]]], {{\[}}[[CST_1_]], [[CST_5_]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {{.*}} : <tensor<4x256xbf16>>
72+
// CHECK: [[VAR_3_:%.+]]:2 = scf.for [[VAR_arg5_:%.+]] = {{.*}} iter_args([[VAR_arg6_:%.+]] = [[VAR_1_]], [[VAR_arg7_:%.+]] = [[VAR_2_]]) -> (tensor<4x256xbf16>, !tt.ptr<tensor<4x256xbf16>>) {
73+
// CHECK: [[VAR_5_:%.+]] = tt.load [[VAR_arg7_]] : !tt.ptr<tensor<4x256xbf16>>
74+
// CHECK-DAG: [[VAR_6_:%.+]] = arith.addf [[VAR_arg6_]], [[VAR_5_]] : tensor<4x256xbf16>
75+
// CHECK-DAG: [[VAR_7_:%.+]] = tt.advance [[VAR_arg7_]], {{\[}}[[CST_3_]], [[CST_0_i32]]] : <tensor<4x256xbf16>>
76+
// CHECK: scf.yield [[VAR_6_]], [[VAR_7_]] : tensor<4x256xbf16>, !tt.ptr<tensor<4x256xbf16>>
7877
// CHECK: }
79-
// COM: to sizes: [4, 256], strides: [1, [[CST_5_]]{{.}}, offsets: {{.}}[[VAR_5_]], 0], shape: [0, 0], order: [] : <bf16> to tensor<4x256x!tt.ptr<bf16>>
80-
// CHECK: [[VAR_6_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_]], [[CST_0_]]], {{\[}}[[CST_1_]], [[CST_5_]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {{.*}} : <tensor<4x256xbf16>>
81-
// CHECK: tt.store [[VAR_6_]], [[VAR_4_]]#0 : !tt.ptr<tensor<4x256xbf16>>
78+
// CHECK: [[VAR_4_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_]], [[CST_0_]]], {{\[}}[[CST_1_]], [[CST_5_]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {{.*}} : <tensor<4x256xbf16>>
79+
// CHECK: tt.store [[VAR_4_]], [[VAR_3_]]#0 : !tt.ptr<tensor<4x256xbf16>>
8280
// CHECK: tt.return
8381
// CHECK: }

test/Triton/Intel/RaiseToBlockPointers/kernel-03-matrix-multiplication.mlir

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ module {
106106
// CHECK-DAG: [[CST_64_i32:%.+]] = arith.constant 64 : i32
107107
// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32
108108
// CHECK-DAG: [[CST_1_i32:%.+]] = arith.constant 1 : i32
109+
// CHECK: [[VAR_17_:%.+]] = arith.muli {{.*}}, [[CST_128_i32]] : i32
110+
// CHECK: [[VAR_18_:%.+]] = arith.muli {{.*}}, [[CST_256_i32]] : i32
109111
// CHECK: [[VAR_20_:%.+]] = arith.extsi [[PARAM_6_]] : i32 to i64
110112
// CHECK: [[VAR_21_:%.+]] = arith.extsi [[PARAM_7_]] : i32 to i64
111113
// CHECK: [[VAR_22_:%.+]] = arith.divui {{.*}}, [[PARAM_6_]] : i32
@@ -122,12 +124,20 @@ module {
122124
// CHECK-DAG: [[VAR_41_:%.+]] = tt.load [[VAR_arg15_]] : !tt.ptr<tensor<64x256xbf16>>
123125
// CHECK: [[VAR_42_:%.+]] = tt.dot [[VAR_40_]], [[VAR_41_]], [[VAR_cst_]], inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x256xbf16> -> tensor<128x256xf32>
124126
// CHECK-DAG: [[VAR_43_:%.+]] = arith.addf [[VAR_arg13_]], [[VAR_42_]] : tensor<128x256xf32>
125-
// CHECK-DAG: [[VAR_44_:%.+]] = tt.advance [[VAR_arg14_]], {{\[}}[[CST_0_i32]], [[VAR_29_]]] : <tensor<128x64xbf16>>
126-
// CHECK-DAG: [[VAR_45_:%.+]] = tt.advance [[VAR_arg15_]], {{\[}}[[CST_0_i32]], [[VAR_30_]]] : <tensor<64x256xbf16>>
127-
// CHECK: scf.yield [[VAR_43_]], [[VAR_44_]], [[VAR_45_]] : tensor<128x256xf32>, !tt.ptr<tensor<128x64xbf16>>, !tt.ptr<tensor<64x256xbf16>>
127+
// CHECK-DAG: [[VAR_44_:%.+]] = arith.divui [[VAR_29_]], [[PARAM_6_]] : i32
128+
// CHECK-DAG: [[VAR_45_:%.+]] = tt.advance [[VAR_arg14_]], {{\[}}[[VAR_44_]], [[CST_0_i32]]] : <tensor<128x64xbf16>>
129+
// CHECK-DAG: [[VAR_46_:%.+]] = arith.divui [[VAR_30_]], [[PARAM_8_]] : i32
130+
// CHECK-DAG: [[VAR_47_:%.+]] = tt.advance [[VAR_arg15_]], {{\[}}[[VAR_46_]], [[CST_0_i32]]] : <tensor<64x256xbf16>>
131+
// CHECK: scf.yield [[VAR_43_]], [[VAR_45_]], [[VAR_47_]] : tensor<128x256xf32>, !tt.ptr<tensor<128x64xbf16>>, !tt.ptr<tensor<64x256xbf16>>
128132
// CHECK: }
129133
// CHECK-DAG: [[VAR_32_:%.+]] = arith.truncf [[VAR_31_]]#0 : tensor<128x256xf32> to tensor<128x256xbf16>
130-
// CHECK: [[VAR_39_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}{{.*}}], {{\[}}{{.*}}] {{.*}} : <tensor<128x256xbf16>>
134+
// CHECK-DAG: [[VAR_33_:%.+]] = arith.muli [[VAR_17_]], [[PARAM_10_]] : i32
135+
// CHECK-DAG: [[VAR_34_:%.+]] = arith.extsi [[PARAM_10_]] : i32 to i64
136+
// CHECK-DAG: [[VAR_35_:%.+]] = arith.muli [[VAR_18_]], [[PARAM_11_]] : i32
137+
// CHECK-DAG: [[VAR_36_:%.+]] = arith.extsi [[PARAM_11_]] : i32 to i64
138+
// CHECK-DAG: [[VAR_37_:%.+]] = arith.divui [[VAR_33_]], [[PARAM_10_]] : i32
139+
// CHECK-DAG: [[VAR_38_:%.+]] = arith.divui [[VAR_35_]], [[PARAM_11_]] : i32
140+
// CHECK: [[VAR_39_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_34_]], [[VAR_36_]]], {{\[}}[[VAR_37_]], [[VAR_38_]]] {{.*}} : <tensor<128x256xbf16>>
131141
// CHECK: tt.store [[VAR_39_]], [[VAR_32_]] : !tt.ptr<tensor<128x256xbf16>>
132142
// CHECK: tt.return
133143
// CHECK: }

test/Triton/Intel/RaiseToBlockPointers/raise-block-pointer.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ tt.func public @wrap_side_by_side_masked(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32
342342
// CHECK: [[VAR_4_:%.+]]:2 = scf.for [[VAR_arg5_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg6_:%.+]] = [[VAR_2_]], [[VAR_arg7_:%.+]] = [[VAR_3_]]) -> (tensor<4x256xbf16>, !tt.ptr<tensor<4x256xbf16>>) {
343343
// CHECK: [[VAR_5_:%.+]] = tt.load [[VAR_arg7_]] : !tt.ptr<tensor<4x256xbf16>>
344344
// CHECK: [[VAR_6_:%.+]] = arith.addf [[VAR_arg6_]], [[VAR_5_]] : tensor<4x256xbf16>
345-
// CHECK: [[VAR_7_:%.+]] = tt.advance [[VAR_arg7_]], {{\[}}[[CST_0_i32]], [[CST_3_i32]]{{\]}} : <tensor<4x256xbf16>>
345+
// CHECK: [[VAR_7_:%.+]] = tt.advance [[VAR_arg7_]], {{\[}}[[CST_3_i32]], [[CST_0_i32]]{{\]}} : <tensor<4x256xbf16>>
346346
// CHECK: scf.yield [[VAR_6_]], [[VAR_7_]] : tensor<4x256xbf16>, !tt.ptr<tensor<4x256xbf16>>
347347
// CHECK: }
348348
// CHECK: [[VAR_5_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_1_i64]], [[CST_5_i64]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {order = array<i32>} : <tensor<4x256xbf16>>

test/Triton/Intel/RaiseToBlockPointers/wraparound_side_by_side.mlir

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,18 @@ module {
7979
// CHECK-DAG: [[VAR_14_:%.+]] = arith.extsi [[PARAM_6_]] : i32 to i64
8080
// CHECK-DAG: [[VAR_15_:%.+]] = arith.extsi [[PARAM_7_]] : i32 to i64
8181
// CHECK: [[VAR_16_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_14_]], [[VAR_15_]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : <tensor<4x4xf32>>
82-
// CHECK: [[VAR_21_:%.+]] = arith.cmpi slt, [[VAR_13_]], {{.*}} : tensor<4x1xi32>
83-
// CHECK: [[VAR_22_:%.+]] = tt.broadcast [[VAR_21_]] : tensor<4x1xi1> -> tensor<4x4xi1>
84-
// CHECK-DAG: [[VAR_23_:%.+]] = arith.muli [[PARAM_4_]], [[CST_4_i32]] : i32
85-
// CHECK-DAG: [[VAR_24_:%.+]] = arith.muli [[PARAM_5_]], [[CST_4_i32]] : i32
86-
// CHECK-DAG: [[VAR_25_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = {{.*}} iter_args([[VAR_arg9_:%.+]] = [[VAR_12_]], [[VAR_arg10_:%.+]] = [[VAR_16_]]) -> (!tt.ptr<tensor<4x4xf32>>, !tt.ptr<tensor<4x4xf32>>) : i32 {
87-
// CHECK: [[VAR_26_:%.+]] = tt.load [[VAR_arg9_]], [[VAR_22_]], [[CST_]] : !tt.ptr<tensor<4x4xf32>>
88-
// CHECK: tt.store [[VAR_arg10_]], [[VAR_26_]] : !tt.ptr<tensor<4x4xf32>>
89-
// CHECK-DAG: [[VAR_27_:%.+]] = tt.advance [[VAR_arg9_]], {{\[}}[[CST_0_i32]], [[VAR_23_]]] : <tensor<4x4xf32>>
90-
// CHECK-DAG: [[VAR_28_:%.+]] = tt.advance [[VAR_arg10_]], {{\[}}[[CST_0_i32]], [[VAR_24_]]] : <tensor<4x4xf32>>
91-
// CHECK: scf.yield [[VAR_27_]], [[VAR_28_]] : !tt.ptr<tensor<4x4xf32>>, !tt.ptr<tensor<4x4xf32>>
82+
// CHECK: [[VAR_17_:%.+]] = arith.cmpi slt, [[VAR_13_]], {{.*}} : tensor<4x1xi32>
83+
// CHECK: [[VAR_18_:%.+]] = tt.broadcast [[VAR_17_]] : tensor<4x1xi1> -> tensor<4x4xi1>
84+
// CHECK-DAG: [[VAR_19_:%.+]] = arith.muli [[PARAM_4_]], [[CST_4_i32]] : i32
85+
// CHECK-DAG: [[VAR_20_:%.+]] = arith.muli [[PARAM_5_]], [[CST_4_i32]] : i32
86+
// CHECK-DAG: [[VAR_21_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = {{.*}} iter_args([[VAR_arg9_:%.+]] = [[VAR_12_]], [[VAR_arg10_:%.+]] = [[VAR_16_]]) -> (!tt.ptr<tensor<4x4xf32>>, !tt.ptr<tensor<4x4xf32>>) : i32 {
87+
// CHECK: [[VAR_22_:%.+]] = tt.load [[VAR_arg9_]], [[VAR_18_]], [[CST_]] : !tt.ptr<tensor<4x4xf32>>
88+
// CHECK: tt.store [[VAR_arg10_]], [[VAR_22_]] : !tt.ptr<tensor<4x4xf32>>
89+
// CHECK: [[VAR_23_:%.+]] = arith.divui [[VAR_19_]], [[PARAM_4_]] : i32
90+
// CHECK: [[VAR_24_:%.+]] = tt.advance [[VAR_arg9_]], {{\[}}[[VAR_23_]], [[CST_0_i32]]] : <tensor<4x4xf32>>
91+
// CHECK: [[VAR_25_:%.+]] = arith.divui [[VAR_20_]], [[PARAM_6_]] : i32
92+
// CHECK: [[VAR_26_:%.+]] = tt.advance [[VAR_arg10_]], {{\[}}[[VAR_25_]], [[CST_0_i32]]] : <tensor<4x4xf32>>
93+
// CHECK: scf.yield [[VAR_24_]], [[VAR_26_]] : !tt.ptr<tensor<4x4xf32>>, !tt.ptr<tensor<4x4xf32>>
9294
// CHECK: }
9395
// CHECK: tt.return
9496
// CHECK: }

test/Triton/Intel/RaiseToBlockPointers/wraparound_stacked.mlir

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,11 @@ module {
8383
// CHECK: [[VAR_20_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = {{.*}} iter_args([[VAR_arg9_:%.+]] = [[VAR_12_]], [[VAR_arg10_:%.+]] = [[VAR_16_]]) -> (!tt.ptr<tensor<4x4xf32>>, !tt.ptr<tensor<4x4xf32>>) : i32 {
8484
// CHECK: [[VAR_21_:%.+]] = tt.load [[VAR_arg9_]], [[VAR_18_]], [[CST_]] : !tt.ptr<tensor<4x4xf32>>
8585
// CHECK: tt.store [[VAR_arg10_]], [[VAR_21_]] : !tt.ptr<tensor<4x4xf32>>
86-
// CHECK-DAG: [[VAR_22_:%.+]] = tt.advance [[VAR_arg9_]], {{\[}}[[CST_0_i32]], [[VAR_19_]]] : <tensor<4x4xf32>>
87-
// CHECK-DAG: [[VAR_23_:%.+]] = tt.advance [[VAR_arg10_]], {{\[}}[[CST_0_i32]], [[VAR_19_]]] : <tensor<4x4xf32>>
88-
// CHECK: scf.yield [[VAR_22_]], [[VAR_23_]] : !tt.ptr<tensor<4x4xf32>>, !tt.ptr<tensor<4x4xf32>>
86+
// CHECK: [[VAR_22_:%.+]] = arith.divui [[VAR_19_]], [[PARAM_4_]] : i32
87+
// CHECK: [[VAR_23_:%.+]] = tt.advance [[VAR_arg9_]], {{\[}}[[VAR_22_]], [[CST_0_i32]]] : <tensor<4x4xf32>>
88+
// CHECK: [[VAR_24_:%.+]] = arith.divui [[VAR_19_]], [[PARAM_6_]] : i32
89+
// CHECK-DAG: [[VAR_25_:%.+]] = tt.advance [[VAR_arg10_]], {{\[}}[[VAR_24_]], [[CST_0_i32]]] : <tensor<4x4xf32>>
90+
// CHECK: scf.yield [[VAR_23_]], [[VAR_25_]] : !tt.ptr<tensor<4x4xf32>>, !tt.ptr<tensor<4x4xf32>>
8991
// CHECK: }
9092
// CHECK: tt.return
9193
// CHECK: }

test/Triton/Intel/RaiseToBlockPointers/wraparound_unsupported_add_offset.mlir

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,20 +91,21 @@ module {
9191
// CHECK-DAG: [[VAR_16_:%.+]] = tt.expand_dims [[VAR_0_]] {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32>
9292
// CHECK-DAG: [[VAR_17_:%.+]] = arith.extsi [[arg6_]] : i32 to i64
9393
// CHECK-DAG: [[VAR_18_:%.+]] = arith.extsi [[arg7_]] : i32 to i64
94-
// CHECK: [[VAR_25_:%.+]] = tt.make_tensor_ptr [[arg1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_17_]], [[VAR_18_]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : <tensor<4x4xf32>>
95-
// CHECK: [[VAR_26_:%.+]] = arith.cmpi slt, [[VAR_16_]], [[VAR_cst_0_]] : tensor<4x1xi32>
96-
// CHECK-DAG: [[VAR_27_:%.+]] = tt.broadcast [[VAR_26_]] : tensor<4x1xi1> -> tensor<4x4xi1>
97-
// CHECK-DAG: [[VAR_28_:%.+]] = arith.muli [[arg4_]], [[CST_4_i32]] : i32
94+
// CHECK: [[VAR_19_:%.+]] = tt.make_tensor_ptr [[arg1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_17_]], [[VAR_18_]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : <tensor<4x4xf32>>
95+
// CHECK: [[VAR_20_:%.+]] = arith.cmpi slt, [[VAR_16_]], [[VAR_cst_0_]] : tensor<4x1xi32>
96+
// CHECK-DAG: [[VAR_21_:%.+]] = tt.broadcast [[VAR_20_]] : tensor<4x1xi1> -> tensor<4x4xi1>
97+
// CHECK-DAG: [[VAR_22_:%.+]] = arith.muli [[arg4_]], [[CST_4_i32]] : i32
9898
// CHECK-NOT: separator of consecutive DAGs
99-
// CHECK-DAG: [[VAR_29_:%.+]] = tt.splat [[VAR_28_]] : i32 -> tensor<4x4xi32>
100-
// CHECK-DAG: [[VAR_30_:%.+]] = arith.muli [[arg5_]], [[CST_4_i32]] : i32
99+
// CHECK-DAG: [[VAR_23_:%.+]] = tt.splat [[VAR_22_]] : i32 -> tensor<4x4xi32>
100+
// CHECK-DAG: [[VAR_24_:%.+]] = arith.muli [[arg5_]], [[CST_4_i32]] : i32
101101
// CHECK-NOT: separator of consecutive DAGs
102-
// CHECK-DAG: [[VAR_31_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = [[CST_0_i32]] to [[CST_2_i32]] step [[CST_1_i32]] iter_args([[VAR_arg9_:%.+]] = [[VAR_15_]], [[VAR_arg10_:%.+]] = [[VAR_25_]]) -> (tensor<4x4x!tt.ptr<f32>>, !tt.ptr<tensor<4x4xf32>>)
103-
// CHECK: [[VAR_32_:%.+]] = tt.load [[VAR_arg9_]], [[VAR_27_]], [[VAR_cst_]] : tensor<4x4x!tt.ptr<f32>>
104-
// CHECK: tt.store [[VAR_arg10_]], [[VAR_32_]] : !tt.ptr<tensor<4x4xf32>>
105-
// CHECK-DAG: [[VAR_33_:%.+]] = tt.addptr [[VAR_arg9_]], [[VAR_29_]] : tensor<4x4x!tt.ptr<f32>>, tensor<4x4xi32>
106-
// CHECK-DAG: [[VAR_34_:%.+]] = tt.advance [[VAR_arg10_]], {{\[}}[[CST_0_i32]], [[VAR_30_]]] : <tensor<4x4xf32>>
107-
// CHECK: scf.yield [[VAR_33_]], [[VAR_34_]] : tensor<4x4x!tt.ptr<f32>>, !tt.ptr<tensor<4x4xf32>>
102+
// CHECK-DAG: [[VAR_25_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = [[CST_0_i32]] to [[CST_2_i32]] step [[CST_1_i32]] iter_args([[VAR_arg9_:%.+]] = [[VAR_15_]], [[VAR_arg10_:%.+]] = [[VAR_19_]]) -> (tensor<4x4x!tt.ptr<f32>>, !tt.ptr<tensor<4x4xf32>>)
103+
// CHECK: [[VAR_26_:%.+]] = tt.load [[VAR_arg9_]], [[VAR_21_]], [[VAR_cst_]] : tensor<4x4x!tt.ptr<f32>>
104+
// CHECK: tt.store [[VAR_arg10_]], [[VAR_26_]] : !tt.ptr<tensor<4x4xf32>>
105+
// CHECK-DAG: [[VAR_27_:%.+]] = tt.addptr [[VAR_arg9_]], [[VAR_23_]] : tensor<4x4x!tt.ptr<f32>>, tensor<4x4xi32>
106+
// CHECK-DAG: [[VAR_28_:%.+]] = arith.divui [[VAR_24_]], [[arg6_]] : i32
107+
// CHECK-DAG: [[VAR_29_:%.+]] = tt.advance [[VAR_arg10_]], {{\[}}[[VAR_28_]], [[CST_0_i32]]] : <tensor<4x4xf32>>
108+
// CHECK: scf.yield [[VAR_27_]], [[VAR_29_]] : tensor<4x4x!tt.ptr<f32>>, !tt.ptr<tensor<4x4xf32>>
108109
// CHECK: }
109110
// CHECK: tt.return
110111
// CHECK: }

0 commit comments

Comments
 (0)