@@ -7,10 +7,10 @@ tt.func public @fuseLoadWithReshape1(%arg0: !tt.ptr<tensor<256x32xbf16>>, %arg1:
77 %c2_i32 = arith.constant 2 : i32
88 %c1_i64 = arith.constant 1 : i64
99 %c2_i64 = arith.constant 2 : i64
10- %c3_i64 = arith.constant 3 : i64
10+ %c4_i64 = arith.constant 4 : i64
1111 %c1024_i64 = arith.constant 1024 : i64
1212 %cst = arith.constant dense <0.000000e+00 > : tensor <256 x256 xf32 >
13- %0 = tt.make_tensor_ptr %arg1 , [%c2_i64 , %c1_i64 , %c1024_i64 ], [%c3_i64 , %c1024_i64 , %c1_i64 ], [%c2_i32 , %c1_i32 , %c0_i32 ] {order = array<i32 : 2 , 1 , 0 >} : <tensor <1 x32 x256 xbf16 >>
13+ %0 = tt.make_tensor_ptr %arg1 , [%c2_i64 , %c1_i64 , %c1024_i64 ], [%c1024_i64 , %c4_i64 , %c1_i64 ], [%c2_i32 , %c1_i32 , %c0_i32 ] {order = array<i32 : 2 , 1 , 0 >} : <tensor <1 x32 x256 xbf16 >>
1414 %1 = tt.load %arg0 {boundaryCheck = array<i32 : 0 , 1 >} : !tt.ptr <tensor <256 x32 xbf16 >>
1515 %3 = tt.load %0 {boundaryCheck = array<i32 : 1 , 2 >} : !tt.ptr <tensor <1 x32 x256 xbf16 >>
1616 %4 = tt.reshape %3 : tensor <1 x32 x256 xbf16 > -> tensor <32 x256 xbf16 >
@@ -19,12 +19,14 @@ tt.func public @fuseLoadWithReshape1(%arg0: !tt.ptr<tensor<256x32xbf16>>, %arg1:
1919}
2020// CHECK-LABEL: fuseLoadWithReshape1
2121// CHECK-NOT: tt.reshape
22- // CHECK: [[MUL1:%.*]] = arith.muli %c3_i64, %c2_i64 : i64
23- // CHECK: [[ADD1:%.*]] = arith.addi [[MUL1]], %c1024_i64 : i64
24- // CHECK: [[TRUNC:%.*]] = arith.trunci %c3_i64 : i64 to i32
25- // CHECK: [[MUL2:%.*]] = arith.muli [[TRUNC]], %c2_i32 : i32
26- // CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %c0_i32 : i32
27- // CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg1, [%c1_i64, [[ADD1]]], [%c1024_i64, %c1_i64], [%c1_i32, [[ADD2]]] {order = array<i32: 1, 0>} : <tensor<32x256xbf16>>
22+ // CHECK: [[DIV:%.*]] = arith.divui %c1024_i64, %c4_i64 : i64
23+ // CHECK: [[MUL1:%.*]] = arith.muli %c2_i64, [[DIV]] : i64
24+ // CHECK: [[ADD1:%.*]] = arith.addi [[MUL1]], %c1_i64 : i64
25+ // CHECK: [[TRUNC:%.*]] = arith.trunci [[DIV]] : i64 to i32
26+ // CHECK: [[MUL2:%.*]] = arith.muli %c2_i32, [[TRUNC]] : i32
27+ // CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %c1_i32 : i32
28+
29+ // CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg1, [[[ADD1]], %c1024_i64], [%c4_i64, %c1_i64], [[[ADD2]], %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x256xbf16>>
2830// CHECK: [[LOAD_B:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xbf16>>
2931// CHECK: tt.dot {{.*}}, [[LOAD_B]], {{.*}}, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32>
3032
@@ -34,14 +36,14 @@ tt.func public @fuseLoadWithReshape1(%arg0: !tt.ptr<tensor<256x32xbf16>>, %arg1:
3436// COM: where the 'make_tensor_ptr' result is not loop carried.
3537tt.func public @fuseLoadWithReshape2 (%arg0: !tt.ptr <tensor <32 x256 xbf16 >>, %arg1: !tt.ptr <bf16 >) {
3638 %c0_i32 = arith.constant 0 : i32
37- %c1_i32 = arith.constant 1 : i32
38- %c1_i64 = arith.constant 1 : i64
3939 %c32_i32 = arith.constant 32 : i32
4040 %c1024_i32 = arith.constant 1024 : i32
41+ %c32_i64 = arith.constant 32 : i64
42+ %c1_i64 = arith.constant 1 : i64
4143 %c512_i64 = arith.constant 512 : i64
4244 %c1024_i64 = arith.constant 1024 : i64
4345 %cst = arith.constant dense <0.000000e+00 > : tensor <256 x256 xf32 >
44- %0 = tt.make_tensor_ptr %arg1 , [%c512_i64 , %c1024_i64 , %c1_i64 ], [%c512_i64 , %c1_i64 , %c1024_i64 ], [%c1_i32 , %c32_i32 , %c0_i32 ] {order = array<i32 : 2 , 0 , 1 >} : <tensor <1 x256 x32 xbf16 >>
46+ %0 = tt.make_tensor_ptr %arg1 , [%c512_i64 , %c1024_i64 , %c32_i64 ], [%c1024_i64 , %c1_i64 , %c512_i64 ], [%c32_i32 , %c32_i32 , %c0_i32 ] {order = array<i32 : 2 , 0 , 1 >} : <tensor <1 x256 x32 xbf16 >>
4547 %res:2 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args (%arg4 = %cst , %arg5 = %c0_i32 ) -> (tensor <256 x256 xf32 >, i32 ) : i32 {
4648 %1 = tt.load %arg0 {boundaryCheck = array<i32 : 0 , 1 >} : !tt.ptr <tensor <32 x256 xbf16 >>
4749 %3 = tt.load %0 {boundaryCheck = array<i32 : 2 , 1 >} : !tt.ptr <tensor <1 x256 x32 xbf16 >>
@@ -54,19 +56,20 @@ tt.func public @fuseLoadWithReshape2(%arg0: !tt.ptr<tensor<32x256xbf16>>, %arg1:
5456}
5557// CHECK-LABEL: fuseLoadWithReshape2
5658// CHECK-NOT: tt.reshape
57- // CHECK: [[MUL1:%.*]] = arith.muli %c512_i64, %c512_i64 : i64
58- // CHECK: [[ADD1:%.*]] = arith.addi [[MUL1]], %c1024_i64 : i64
59- // CHECK: [[TRUNC:%.*]] = arith.trunci %c512_i64 : i64 to i32
60- // CHECK: [[MUL2:%.*]] = arith.muli [[TRUNC]], %c1_i32 : i32
61- // CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %c32_i32 : i32
62- // CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg1, [[[ADD1]], %c1_i64], [%c1_i64, %c1024_i64], [[[ADD2]], %c0_i32] {order = array<i32: 0, 1>} : <tensor<256x32xbf16>>
59+ // CHECK: [[DIV:%.*]] = arith.divui %c1024_i64, %c512_i64 : i64
60+ // CHECK: [[MUL1:%.*]] = arith.muli %c512_i64, [[DIV]] : i64
61+ // CHECK: [[ADD1:%.*]] = arith.addi [[MUL1]], %c32_i64 : i64
62+ // CHECK: [[TRUNC:%.*]] = arith.trunci [[DIV]] : i64 to i32
63+ // CHECK: [[MUL2:%.*]] = arith.muli %c32_i32, [[TRUNC]] : i32
64+ // CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %c0_i32 : i32
65+ // CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg1, [%c1024_i64, [[ADD1]]], [%c1_i64, %c512_i64], [%c32_i32, [[ADD2]]] {order = array<i32: 0, 1>} : <tensor<256x32xbf16>>
6366// CHECK: scf.for
6467// CHECK: [[LOAD_A:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 1, 0>} : !tt.ptr<tensor<256x32xbf16>>
6568// CHECK: tt.dot [[LOAD_A]], {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32>
6669
6770// -----
6871
69- // COM: tt.load -> tt.reshape -> tt.dot chain, in a loop
72+ // COM: tt.load -> tt.reshape -> tt.dot chain, in a loop
7073// COM: Where the 'make_tensor_ptr' result is loop carried.
7174tt.func public @test_matmul (%a_ptr: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }, %b_ptr: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }, %c_ptr: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }, %M: i32 {tt.divisibility = 16 : i32 }, %N: i32 {tt.divisibility = 16 : i32 }, %K: i32 {tt.divisibility = 16 : i32 }, %stride_am: i32 {tt.divisibility = 16 : i32 }, %stride_bk: i32 {tt.divisibility = 16 : i32 }, %stride_cm: i32 {tt.divisibility = 16 : i32 }) {
7275 %c127_i32 = arith.constant 127 : i32
@@ -118,12 +121,13 @@ tt.func public @test_matmul(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %
118121}
119122// CHECK-LABEL: test_matmul
120123// CHECK-NOT: tt.reshape
121- // CHECK: [[MUL1:%.*]] = arith.muli %c1_i64, %c1_i64 : i64
122- // CHECK: [[ADD1:%.*]] = arith.addi [[MUL1]], %16 : i64
123- // CHECK: [[TRUNC:%.*]] = arith.trunci %c1_i64 : i64 to i32
124- // CHECK: [[MUL2:%.*]] = arith.muli [[TRUNC]], %c0_i32 : i32
125- // CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %c0_i32 : i32
126- // CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg0, [%15, [[ADD1]]], [%17, %c1_i64], [%14, [[ADD2]]] {order = array<i32: 1, 0>} : <tensor<256x32xf32>>
124+ // CHECK: [[DIV:%.*]] = arith.divui %c1_i64, %17 : i64
125+ // CHECK: [[MUL1:%.*]] = arith.muli %c1_i64, [[DIV]] : i64
126+ // CHECK: [[ADD1:%.*]] = arith.addi [[MUL1]], %15 : i64
127+ // CHECK: [[TRUNC:%.*]] = arith.trunci [[DIV]] : i64 to i32
128+ // CHECK: [[MUL2:%.*]] = arith.muli %c0_i32, [[TRUNC]] : i32
129+ // CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %14 : i32
130+ // CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg0, [[[ADD1]], %16], [%17, %c1_i64], [[[ADD2]], %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xf32>>
127131// CHECK: scf.for {{.*}} = %c0_i32 to {{.*}} step %c32_i32 iter_args([[ARG:%.*]] = [[PTR]]
128132// CHECK: [[LOAD_A:%.*]] = tt.load [[ARG]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xf32>>
129133// CHECK: tt.dot [[LOAD_A]], {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xf32> * tensor<32x128xf32> -> tensor<256x128xf32>
0 commit comments