@@ -6,11 +6,11 @@ tt.func public @fuseLoadWithReshape1(%arg0: !tt.ptr<tensor<256x32xbf16>>, %arg1:
66 %c1_i32 = arith.constant 1 : i32
77 %c2_i32 = arith.constant 2 : i32
88 %c1_i64 = arith.constant 1 : i64
9- %c2_i64 = arith.constant 2 : i64
109 %c4_i64 = arith.constant 4 : i64
10+ %c64_i64 = arith.constant 4 : i64
1111 %c1024_i64 = arith.constant 1024 : i64
1212 %cst = arith.constant dense <0.000000e+00 > : tensor <256 x256 xf32 >
13- %0 = tt.make_tensor_ptr %arg1 , [%c2_i64 , %c1_i64 , %c1024_i64 ], [%c1024_i64 , %c4_i64 , %c1_i64 ], [%c2_i32 , %c1_i32 , %c0_i32 ] {order = array<i32 : 2 , 1 , 0 >} : <tensor <1 x32 x256 xbf16 >>
13+ %0 = tt.make_tensor_ptr %arg1 , [%c1_i64 , %c64_i64 , %c1024_i64 ], [%c1024_i64 , %c4_i64 , %c1_i64 ], [%c2_i32 , %c1_i32 , %c0_i32 ] {order = array<i32 : 2 , 1 , 0 >} : <tensor <1 x32 x256 xbf16 >>
1414 %1 = tt.load %arg0 {boundaryCheck = array<i32 : 0 , 1 >} : !tt.ptr <tensor <256 x32 xbf16 >>
1515 %3 = tt.load %0 {boundaryCheck = array<i32 : 1 , 2 >} : !tt.ptr <tensor <1 x32 x256 xbf16 >>
1616 %4 = tt.reshape %3 : tensor <1 x32 x256 xbf16 > -> tensor <32 x256 xbf16 >
@@ -20,16 +20,17 @@ tt.func public @fuseLoadWithReshape1(%arg0: !tt.ptr<tensor<256x32xbf16>>, %arg1:
2020// CHECK-LABEL: fuseLoadWithReshape1
2121// CHECK-NOT: tt.reshape
2222// CHECK: [[DIV:%.*]] = arith.divui %c1024_i64, %c4_i64 : i64
23- // CHECK: [[MUL1:%.*]] = arith.muli %c2_i64 , [[DIV]] : i64
24- // CHECK: [[ADD1:%.*]] = arith.addi [[MUL1]], %c1_i64 : i64
23+ // CHECK: [[MUL1:%.*]] = arith.muli %c1_i64 , [[DIV]] : i64
24+ // CHECK: [[ADD1:%.*]] = arith.addi [[MUL1]], %c4_i64_0 : i64
2525// CHECK: [[TRUNC:%.*]] = arith.trunci [[DIV]] : i64 to i32
2626// CHECK: [[MUL2:%.*]] = arith.muli %c2_i32, [[TRUNC]] : i32
2727// CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %c1_i32 : i32
2828// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg1, [[[ADD1]], %c1024_i64], [%c4_i64, %c1_i64], [[[ADD2]], %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x256xbf16>>
29- // CHECK: [[TRUNC:%.*]] = arith.trunci %c1_i64 : i64 to i32
30- // CHECK: [[COND:%.*]] = arith.cmpi ult, [[ADD2]], [[TRUNC]] : i32
29+ // CHECK: [[ADD3:%.*]] = arith.addi %c1_i32, %c32_i32 : i32
30+ // CHECK: [[TRUNC:%.*]] = arith.trunci %c4_i64_0 : i64 to i32
31+ // CHECK: [[COND:%.*]] = arith.cmpi ult, [[ADD3]], [[TRUNC]] : i32
3132// CHECK: [[IF_RES:%.*]] = scf.if [[COND]] -> (tensor<32x256xbf16>) {
32- // CHECK: [[LOAD_B:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xbf16>>
33+ // CHECK: [[LOAD_B:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<32x256xbf16>>
3334// CHECK: scf.yield [[LOAD_B]] : tensor<32x256xbf16>
3435// CHECK: } else {
3536// CHECK: [[ZERO:%.*]] = arith.constant dense<0.000000e+00> : tensor<32x256xbf16>
@@ -71,7 +72,7 @@ tt.func public @fuseLoadWithReshape2(%arg0: !tt.ptr<tensor<32x256xbf16>>, %arg1:
7172// CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %c0_i32 : i32
7273// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg1, [%c1024_i64, [[ADD1]]], [%c1_i64, %c512_i64], [%c32_i32, [[ADD2]]] {order = array<i32: 0, 1>} : <tensor<256x32xbf16>>
7374// CHECK: scf.for
74- // CHECK: [[LOAD_A:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 0>} : !tt.ptr<tensor<256x32xbf16>>
75+ // CHECK: [[LOAD_A:%.*]] = tt.load [[PTR]] : !tt.ptr<tensor<256x32xbf16>>
7576// CHECK: tt.dot [[LOAD_A]], {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32>
7677
7778// -----
@@ -106,7 +107,7 @@ tt.func public @fuseLoadWithReshape3(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16
106107 %15 = arith.extsi %M : i32 to i64
107108 %16 = arith.extsi %K : i32 to i64
108109 %17 = arith.extsi %stride_am : i32 to i64
109- %18 = tt.make_tensor_ptr %a_ptr , [%c1_i64 , %15 , %16 ], [%c1_i64 , %17 , %c1_i64 ], [%c0_i32 , %14 , %c0_i32 ] {order = array<i32 : 2 , 1 , 0 >} : <tensor <1 x256 x32 xf32 >>
110+ %18 = tt.make_tensor_ptr %a_ptr , [%c1_i64 , %15 , %16 ], [%c1_i64 , %17 , %c1_i64 ], [%c0_i32 , %c128_i32 , %c0_i32 ] {order = array<i32 : 2 , 1 , 0 >} : <tensor <1 x256 x32 xf32 >>
110111 %19 = arith.muli %13 , %c128_i32 : i32
111112 %20 = arith.extsi %N : i32 to i64
112113 %21 = arith.extsi %stride_bk : i32 to i64
@@ -134,13 +135,15 @@ tt.func public @fuseLoadWithReshape3(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16
134135// CHECK: [[ADD1:%.*]] = arith.addi [[MUL1]], %15 : i64
135136// CHECK: [[TRUNC:%.*]] = arith.trunci [[DIV]] : i64 to i32
136137// CHECK: [[MUL2:%.*]] = arith.muli %c0_i32, [[TRUNC]] : i32
137- // CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %14 : i32
138+ // CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %c128_i32 : i32
138139// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg0, [[[ADD1]], %16], [%17, %c1_i64], [[[ADD2]], %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xf32>>
140+ // CHECK: [[CST_256:%.*]] = arith.constant 256 : i32
141+ // CHECK: [[ADD3:%.*]] = arith.addi %c128_i32, [[CST_256]] : i32
139142// CHECK: [[TRUNC:%.*]] = arith.trunci [[EXT_M]] : i64 to i32
140- // CHECK: [[COND:%.*]] = arith.cmpi ult, [[ADD2 ]], [[TRUNC]] : i32
143+ // CHECK: [[COND:%.*]] = arith.cmpi ult, [[ADD3 ]], [[TRUNC]] : i32
141144// CHECK: scf.for {{.*}} = %c0_i32 to {{.*}} step %c32_i32 iter_args([[ARG:%.*]] = [[PTR]]
142145// CHECK: [[IF_RES:%.*]] = scf.if [[COND]] -> (tensor<256x32xf32>) {
143- // CHECK: [[LOAD_A:%.*]] = tt.load [[ARG]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xf32>>
146+ // CHECK: [[LOAD_A:%.*]] = tt.load [[ARG]] {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<256x32xf32>>
144147// CHECK: scf.yield [[LOAD_A]] : tensor<256x32xf32>
145148// CHECK: } else {
146149// CHECK: [[ZERO:%.*]] = arith.constant dense<0.000000e+00> : tensor<256x32xf32>
@@ -153,7 +156,7 @@ tt.func public @fuseLoadWithReshape3(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16
153156
154157// COM: tt.load -> tt.reshape -> tt.dot chain, in 2 loops.
155158// COM: Where the block ptr used by the loads in the 2 loops is created by the same make_tensor_ptr operation.
156- tt.func public @fuseLoadWithTrans4 (%arg0: i32 , %arg1: !tt.ptr <f16 >, %arg2: !tt.ptr <f16 >) {
159+ tt.func public @fuseLoadWithReshape4 (%arg0: i32 , %arg1: !tt.ptr <f16 >, %arg2: !tt.ptr <f16 >) {
157160 %c0_i32 = arith.constant 0 : i32
158161 %c1_i32 = arith.constant 1 : i32
159162 %c2_i32 = arith.constant 2 : i32
@@ -185,7 +188,7 @@ tt.func public @fuseLoadWithTrans4(%arg0: i32, %arg1: !tt.ptr<f16>, %arg2: !tt.p
185188 tt.return
186189
187190}
188- // CHECK-LABEL: fuseLoadWithTrans4
191+ // CHECK-LABEL: fuseLoadWithReshape4
189192// CHECK-NOT: tt.reshape
190193// CHECK: [[DIV1:%.*]] = arith.divui %c256_i64, %c64_i64 : i64
191194// CHECK: [[MUL11:%.*]] = arith.muli %c1_i64, [[DIV1]] : i64
0 commit comments