@@ -35,11 +35,11 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
3535 %c1 = arith.constant 1 : i32
3636 %cst = arith.constant dense <0.000000e+00 > : tensor <256 x256 xf32 , #mma >
3737 %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args (%arg1 = %cst ) -> (tensor <256 x256 xf32 , #mma >) : i32 {
38+ %4 = tt.load %A_ptr : tensor <256 x128 x!tt.ptr <f16 >, #blocked >
3839 %1 = triton_gpu.local_load %A_LDS : !tt.memdesc <256 x128 xf16 , #shared , #triton_gpu.shared_memory , mutable > -> tensor <256 x128 xf16 , #dotOp0 >
40+ %5 = tt.load %B_ptr : tensor <128 x256 x!tt.ptr <f16 >, #blocked1 >
3941 %2 = triton_gpu.local_load %B_LDS : !tt.memdesc <128 x256 xf16 , #shared1 , #triton_gpu.shared_memory , mutable > -> tensor <128 x256 xf16 , #dotOp1 >
4042 %3 = tt.dot %1 , %2 , %arg1 : tensor <256 x128 xf16 , #dotOp0 > * tensor <128 x256 xf16 , #dotOp1 > -> tensor <256 x256 xf32 , #mma >
41- %4 = tt.load %A_ptr : tensor <256 x128 x!tt.ptr <f16 >, #blocked >
42- %5 = tt.load %B_ptr : tensor <128 x256 x!tt.ptr <f16 >, #blocked1 >
4343 triton_gpu.local_store %4 , %A_LDS : tensor <256 x128 xf16 , #blocked > -> !tt.memdesc <256 x128 xf16 , #shared , #triton_gpu.shared_memory , mutable >
4444 triton_gpu.local_store %5 , %B_LDS : tensor <128 x256 xf16 , #blocked1 > -> !tt.memdesc <128 x256 xf16 , #shared1 , #triton_gpu.shared_memory , mutable >
4545 scf.yield %3 : tensor <256 x256 xf32 , #mma >
@@ -64,11 +64,11 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
6464 %c1 = arith.constant 1 : i32
6565 %cst = arith.constant dense <0.000000e+00 > : tensor <256 x256 xf32 , #mma >
6666 %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args (%arg1 = %cst ) -> (tensor <256 x256 xf32 , #mma >) : i32 {
67+ %4 = tt.load %A_ptr : tensor <256 x64 x!tt.ptr <f16 >, #blocked >
6768 %1 = triton_gpu.local_load %A_LDS : !tt.memdesc <256 x64 xf16 , #shared , #triton_gpu.shared_memory , mutable > -> tensor <256 x64 xf16 , #dotOp0 >
69+ %5 = tt.load %B_ptr : tensor <64 x256 x!tt.ptr <f16 >, #blocked1 >
6870 %2 = triton_gpu.local_load %B_LDS : !tt.memdesc <64 x256 xf16 , #shared1 , #triton_gpu.shared_memory , mutable > -> tensor <64 x256 xf16 , #dotOp1 >
6971 %3 = tt.dot %1 , %2 , %arg1 : tensor <256 x64 xf16 , #dotOp0 > * tensor <64 x256 xf16 , #dotOp1 > -> tensor <256 x256 xf32 , #mma >
70- %4 = tt.load %A_ptr : tensor <256 x64 x!tt.ptr <f16 >, #blocked >
71- %5 = tt.load %B_ptr : tensor <64 x256 x!tt.ptr <f16 >, #blocked1 >
7272 triton_gpu.local_store %4 , %A_LDS : tensor <256 x64 xf16 , #blocked > -> !tt.memdesc <256 x64 xf16 , #shared , #triton_gpu.shared_memory , mutable >
7373 triton_gpu.local_store %5 , %B_LDS : tensor <64 x256 xf16 , #blocked1 > -> !tt.memdesc <64 x256 xf16 , #shared1 , #triton_gpu.shared_memory , mutable >
7474 scf.yield %3 : tensor <256 x256 xf32 , #mma >
@@ -81,8 +81,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
8181// Should NOT apply: tile size 256x64x128 with single dot
8282// CHECK-LABEL: sink_2nd_load_256x64x128
8383// CHECK: %[[tileA:.*]] = tt.load
84- // CHECK-NEXT: %[[tileB:.*]] = tt.load
8584// CHECK-NEXT: local_load
85+ // CHECK-NEXT: %[[tileB:.*]] = tt.load
8686// CHECK-NEXT: local_load
8787// CHECK-NEXT: tt.dot
8888// CHECK-NEXT: triton_gpu.local_store %[[tileA]]
@@ -93,11 +93,11 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
9393 %c1 = arith.constant 1 : i32
9494 %cst = arith.constant dense <0.000000e+00 > : tensor <256 x64 xf32 , #mma >
9595 %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args (%arg1 = %cst ) -> (tensor <256 x64 xf32 , #mma >) : i32 {
96+ %4 = tt.load %A_ptr : tensor <256 x128 x!tt.ptr <f16 >, #blocked >
9697 %1 = triton_gpu.local_load %A_LDS : !tt.memdesc <256 x128 xf16 , #shared , #triton_gpu.shared_memory , mutable > -> tensor <256 x128 xf16 , #dotOp0 >
98+ %5 = tt.load %B_ptr : tensor <128 x64 x!tt.ptr <f16 >, #blocked1 >
9799 %2 = triton_gpu.local_load %B_LDS : !tt.memdesc <128 x64 xf16 , #shared1 , #triton_gpu.shared_memory , mutable > -> tensor <128 x64 xf16 , #dotOp1 >
98100 %3 = tt.dot %1 , %2 , %arg1 : tensor <256 x128 xf16 , #dotOp0 > * tensor <128 x64 xf16 , #dotOp1 > -> tensor <256 x64 xf32 , #mma >
99- %4 = tt.load %A_ptr : tensor <256 x128 x!tt.ptr <f16 >, #blocked >
100- %5 = tt.load %B_ptr : tensor <128 x64 x!tt.ptr <f16 >, #blocked1 >
101101 triton_gpu.local_store %4 , %A_LDS : tensor <256 x128 xf16 , #blocked > -> !tt.memdesc <256 x128 xf16 , #shared , #triton_gpu.shared_memory , mutable >
102102 triton_gpu.local_store %5 , %B_LDS : tensor <128 x64 xf16 , #blocked1 > -> !tt.memdesc <128 x64 xf16 , #shared1 , #triton_gpu.shared_memory , mutable >
103103 scf.yield %3 : tensor <256 x64 xf32 , #mma >
@@ -110,8 +110,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
110110// Should NOT apply: tile size 256x256x32 with single dot
111111// CHECK-LABEL: sink_2nd_load_256x256x32
112112// CHECK: %[[tileA:.*]] = tt.load
113- // CHECK-NEXT: %[[tileB:.*]] = tt.load
114113// CHECK-NEXT: local_load
114+ // CHECK-NEXT: %[[tileB:.*]] = tt.load
115115// CHECK-NEXT: local_load
116116// CHECK-NEXT: tt.dot
117117// CHECK-NEXT: triton_gpu.local_store %[[tileA]]
@@ -122,11 +122,11 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
122122 %c1 = arith.constant 1 : i32
123123 %cst = arith.constant dense <0.000000e+00 > : tensor <256 x256 xf32 , #mma >
124124 %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args (%arg1 = %cst ) -> (tensor <256 x256 xf32 , #mma >) : i32 {
125+ %4 = tt.load %A_ptr : tensor <256 x32 x!tt.ptr <f16 >, #blocked >
125126 %1 = triton_gpu.local_load %A_LDS : !tt.memdesc <256 x32 xf16 , #shared , #triton_gpu.shared_memory , mutable > -> tensor <256 x32 xf16 , #dotOp0 >
127+ %5 = tt.load %B_ptr : tensor <32 x256 x!tt.ptr <f16 >, #blocked1 >
126128 %2 = triton_gpu.local_load %B_LDS : !tt.memdesc <32 x256 xf16 , #shared1 , #triton_gpu.shared_memory , mutable > -> tensor <32 x256 xf16 , #dotOp1 >
127129 %3 = tt.dot %1 , %2 , %arg1 : tensor <256 x32 xf16 , #dotOp0 > * tensor <32 x256 xf16 , #dotOp1 > -> tensor <256 x256 xf32 , #mma >
128- %4 = tt.load %A_ptr : tensor <256 x32 x!tt.ptr <f16 >, #blocked >
129- %5 = tt.load %B_ptr : tensor <32 x256 x!tt.ptr <f16 >, #blocked1 >
130130 triton_gpu.local_store %4 , %A_LDS : tensor <256 x32 xf16 , #blocked > -> !tt.memdesc <256 x32 xf16 , #shared , #triton_gpu.shared_memory , mutable >
131131 triton_gpu.local_store %5 , %B_LDS : tensor <32 x256 xf16 , #blocked1 > -> !tt.memdesc <32 x256 xf16 , #shared1 , #triton_gpu.shared_memory , mutable >
132132 scf.yield %3 : tensor <256 x256 xf32 , #mma >
@@ -142,8 +142,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
142142// Should NOT apply: the 2nd load has a user before the dot
143143// CHECK-LABEL: sink_2nd_load_128x128x128_user_before_dot
144144// CHECK: %[[tileA:.*]] = tt.load
145- // CHECK-NEXT: %[[tileB:.*]] = tt.load
146145// CHECK-NEXT: local_load
146+ // CHECK-NEXT: %[[tileB:.*]] = tt.load
147147// CHECK-NEXT: local_load
148148// CHECK-NEXT: tt.store
149149// CHECK-NEXT: tt.dot
@@ -154,10 +154,10 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
154154 %c1 = arith.constant 1 : i32
155155 %cst = arith.constant dense <0.000000e+00 > : tensor <128 x128 xf32 , #mma >
156156 %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args (%arg1 = %cst ) -> (tensor <128 x128 xf32 , #mma >) : i32 {
157- %1 = triton_gpu.local_load %A_LDS : !tt.memdesc <128 x128 xf16 , #shared , #triton_gpu.shared_memory , mutable > -> tensor <128 x128 xf16 , #dotOp0 >
158- %2 = triton_gpu.local_load %B_LDS : !tt.memdesc <128 x128 xf16 , #shared1 , #triton_gpu.shared_memory , mutable > -> tensor <128 x128 xf16 , #dotOp1 >
159157 %4 = tt.load %A_ptr : tensor <128 x128 x!tt.ptr <f16 >, #blocked >
158+ %1 = triton_gpu.local_load %A_LDS : !tt.memdesc <128 x128 xf16 , #shared , #triton_gpu.shared_memory , mutable > -> tensor <128 x128 xf16 , #dotOp0 >
160159 %5 = tt.load %B_ptr : tensor <128 x128 x!tt.ptr <i64 >, #blocked >
160+ %2 = triton_gpu.local_load %B_LDS : !tt.memdesc <128 x128 xf16 , #shared1 , #triton_gpu.shared_memory , mutable > -> tensor <128 x128 xf16 , #dotOp1 >
161161 tt.store %B_ptr , %5 : tensor <128 x128 x!tt.ptr <i64 >, #blocked >
162162 %3 = tt.dot %1 , %2 , %arg1 : tensor <128 x128 xf16 , #dotOp0 > * tensor <128 x128 xf16 , #dotOp1 > -> tensor <128 x128 xf32 , #mma >
163163 triton_gpu.local_store %4 , %A_LDS : tensor <128 x128 xf16 , #blocked > -> !tt.memdesc <128 x128 xf16 , #shared , #triton_gpu.shared_memory , mutable >
@@ -174,12 +174,12 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
174174// Category 3: two dots in the for loop. Make sure the optimization is not applied
175175// should NOT apply: two dots
176176// CHECK-LABEL: sink_2nd_load_256x256x64_two_dot
177- // CHECK: triton_gpu.local_load
177+ // CHECK: tt.load
178+ // CHECK-NEXT: tt.load
179+ // CHECK-NEXT: triton_gpu.local_load
178180// CHECK-NEXT: triton_gpu.local_load
179181// CHECK-NEXT: tt.dot
180182// CHECK-NEXT: tt.dot
181- // CHECK-NEXT: tt.load
182- // CHECK-NEXT: tt.load
183183// CHECK-NEXT: triton_gpu.local_store
184184// CHECK-NEXT: triton_gpu.local_store
185185#blocked = #triton_gpu.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [1 , 1 ], order = [1 , 0 ]}>
0 commit comments