@@ -35,11 +35,11 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
3535 %c1 = arith.constant 1 : i32
3636 %cst = arith.constant dense <0.000000e+00 > : tensor <256 x256 xf32 , #mma >
3737 %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args (%arg1 = %cst ) -> (tensor <256 x256 xf32 , #mma >) : i32 {
38- %4 = tt.load %A_ptr : tensor <256 x128 x!tt.ptr <f16 >, #blocked >
3938 %1 = triton_gpu.local_load %A_LDS : !tt.memdesc <256 x128 xf16 , #shared , #triton_gpu.shared_memory , mutable > -> tensor <256 x128 xf16 , #dotOp0 >
40- %5 = tt.load %B_ptr : tensor <128 x256 x!tt.ptr <f16 >, #blocked1 >
4139 %2 = triton_gpu.local_load %B_LDS : !tt.memdesc <128 x256 xf16 , #shared1 , #triton_gpu.shared_memory , mutable > -> tensor <128 x256 xf16 , #dotOp1 >
4240 %3 = tt.dot %1 , %2 , %arg1 : tensor <256 x128 xf16 , #dotOp0 > * tensor <128 x256 xf16 , #dotOp1 > -> tensor <256 x256 xf32 , #mma >
41+ %4 = tt.load %A_ptr : tensor <256 x128 x!tt.ptr <f16 >, #blocked >
42+ %5 = tt.load %B_ptr : tensor <128 x256 x!tt.ptr <f16 >, #blocked1 >
4343 triton_gpu.local_store %4 , %A_LDS : tensor <256 x128 xf16 , #blocked > -> !tt.memdesc <256 x128 xf16 , #shared , #triton_gpu.shared_memory , mutable >
4444 triton_gpu.local_store %5 , %B_LDS : tensor <128 x256 xf16 , #blocked1 > -> !tt.memdesc <128 x256 xf16 , #shared1 , #triton_gpu.shared_memory , mutable >
4545 scf.yield %3 : tensor <256 x256 xf32 , #mma >
@@ -74,11 +74,11 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
7474 %c1 = arith.constant 1 : i32
7575 %cst = arith.constant dense <0.000000e+00 > : tensor <256 x256 xf32 , #mma >
7676 %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args (%arg1 = %cst ) -> (tensor <256 x256 xf32 , #mma >) : i32 {
77- %4 = tt.load %A_ptr : tensor <256 x64 x!tt.ptr <f16 >, #blocked >
7877 %1 = triton_gpu.local_load %A_LDS : !tt.memdesc <256 x64 xf16 , #shared , #triton_gpu.shared_memory , mutable > -> tensor <256 x64 xf16 , #dotOp0 >
79- %5 = tt.load %B_ptr : tensor <64 x256 x!tt.ptr <f16 >, #blocked1 >
8078 %2 = triton_gpu.local_load %B_LDS : !tt.memdesc <64 x256 xf16 , #shared1 , #triton_gpu.shared_memory , mutable > -> tensor <64 x256 xf16 , #dotOp1 >
8179 %3 = tt.dot %1 , %2 , %arg1 : tensor <256 x64 xf16 , #dotOp0 > * tensor <64 x256 xf16 , #dotOp1 > -> tensor <256 x256 xf32 , #mma >
80+ %4 = tt.load %A_ptr : tensor <256 x64 x!tt.ptr <f16 >, #blocked >
81+ %5 = tt.load %B_ptr : tensor <64 x256 x!tt.ptr <f16 >, #blocked1 >
8282 triton_gpu.local_store %4 , %A_LDS : tensor <256 x64 xf16 , #blocked > -> !tt.memdesc <256 x64 xf16 , #shared , #triton_gpu.shared_memory , mutable >
8383 triton_gpu.local_store %5 , %B_LDS : tensor <64 x256 xf16 , #blocked1 > -> !tt.memdesc <64 x256 xf16 , #shared1 , #triton_gpu.shared_memory , mutable >
8484 scf.yield %3 : tensor <256 x256 xf32 , #mma >
@@ -101,9 +101,9 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
101101// Should NOT apply: tile size 256x64x128 with single dot
102102// CHECK-LABEL: sink_2nd_load_256x64x128
103103// CHECK: %[[tileA:.*]] = tt.load
104- // CHECK-NEXT: local_load
105104// CHECK-NEXT: %[[tileB:.*]] = tt.load
106105// CHECK-NEXT: local_load
106+ // CHECK-NEXT: local_load
107107// CHECK-NEXT: tt.dot
108108// CHECK-NEXT: triton_gpu.local_store %[[tileA]]
109109// CHECK-NEXT: triton_gpu.local_store %[[tileB]]
@@ -113,11 +113,11 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
113113 %c1 = arith.constant 1 : i32
114114 %cst = arith.constant dense <0.000000e+00 > : tensor <256 x64 xf32 , #mma >
115115 %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args (%arg1 = %cst ) -> (tensor <256 x64 xf32 , #mma >) : i32 {
116- %4 = tt.load %A_ptr : tensor <256 x128 x!tt.ptr <f16 >, #blocked >
117116 %1 = triton_gpu.local_load %A_LDS : !tt.memdesc <256 x128 xf16 , #shared , #triton_gpu.shared_memory , mutable > -> tensor <256 x128 xf16 , #dotOp0 >
118- %5 = tt.load %B_ptr : tensor <128 x64 x!tt.ptr <f16 >, #blocked1 >
119117 %2 = triton_gpu.local_load %B_LDS : !tt.memdesc <128 x64 xf16 , #shared1 , #triton_gpu.shared_memory , mutable > -> tensor <128 x64 xf16 , #dotOp1 >
120118 %3 = tt.dot %1 , %2 , %arg1 : tensor <256 x128 xf16 , #dotOp0 > * tensor <128 x64 xf16 , #dotOp1 > -> tensor <256 x64 xf32 , #mma >
119+ %4 = tt.load %A_ptr : tensor <256 x128 x!tt.ptr <f16 >, #blocked >
120+ %5 = tt.load %B_ptr : tensor <128 x64 x!tt.ptr <f16 >, #blocked1 >
121121 triton_gpu.local_store %4 , %A_LDS : tensor <256 x128 xf16 , #blocked > -> !tt.memdesc <256 x128 xf16 , #shared , #triton_gpu.shared_memory , mutable >
122122 triton_gpu.local_store %5 , %B_LDS : tensor <128 x64 xf16 , #blocked1 > -> !tt.memdesc <128 x64 xf16 , #shared1 , #triton_gpu.shared_memory , mutable >
123123 scf.yield %3 : tensor <256 x64 xf32 , #mma >
@@ -140,9 +140,9 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
140140// Should NOT apply: tile size 256x256x32 with single dot
141141// CHECK-LABEL: sink_2nd_load_256x256x32
142142// CHECK: %[[tileA:.*]] = tt.load
143- // CHECK-NEXT: local_load
144143// CHECK-NEXT: %[[tileB:.*]] = tt.load
145144// CHECK-NEXT: local_load
145+ // CHECK-NEXT: local_load
146146// CHECK-NEXT: tt.dot
147147// CHECK-NEXT: triton_gpu.local_store %[[tileA]]
148148// CHECK-NEXT: triton_gpu.local_store %[[tileB]]
@@ -152,11 +152,11 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
152152 %c1 = arith.constant 1 : i32
153153 %cst = arith.constant dense <0.000000e+00 > : tensor <256 x256 xf32 , #mma >
154154 %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args (%arg1 = %cst ) -> (tensor <256 x256 xf32 , #mma >) : i32 {
155- %4 = tt.load %A_ptr : tensor <256 x32 x!tt.ptr <f16 >, #blocked >
156155 %1 = triton_gpu.local_load %A_LDS : !tt.memdesc <256 x32 xf16 , #shared , #triton_gpu.shared_memory , mutable > -> tensor <256 x32 xf16 , #dotOp0 >
157- %5 = tt.load %B_ptr : tensor <32 x256 x!tt.ptr <f16 >, #blocked1 >
158156 %2 = triton_gpu.local_load %B_LDS : !tt.memdesc <32 x256 xf16 , #shared1 , #triton_gpu.shared_memory , mutable > -> tensor <32 x256 xf16 , #dotOp1 >
159157 %3 = tt.dot %1 , %2 , %arg1 : tensor <256 x32 xf16 , #dotOp0 > * tensor <32 x256 xf16 , #dotOp1 > -> tensor <256 x256 xf32 , #mma >
158+ %4 = tt.load %A_ptr : tensor <256 x32 x!tt.ptr <f16 >, #blocked >
159+ %5 = tt.load %B_ptr : tensor <32 x256 x!tt.ptr <f16 >, #blocked1 >
160160 triton_gpu.local_store %4 , %A_LDS : tensor <256 x32 xf16 , #blocked > -> !tt.memdesc <256 x32 xf16 , #shared , #triton_gpu.shared_memory , mutable >
161161 triton_gpu.local_store %5 , %B_LDS : tensor <32 x256 xf16 , #blocked1 > -> !tt.memdesc <32 x256 xf16 , #shared1 , #triton_gpu.shared_memory , mutable >
162162 scf.yield %3 : tensor <256 x256 xf32 , #mma >
@@ -181,9 +181,9 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
181181// Should NOT apply: the 2nd load has a user before the dot
182182// CHECK-LABEL: sink_2nd_load_128x128x128_user_before_dot
183183// CHECK: %[[tileA:.*]] = tt.load
184- // CHECK-NEXT: local_load
185184// CHECK-NEXT: %[[tileB:.*]] = tt.load
186185// CHECK-NEXT: local_load
186+ // CHECK-NEXT: local_load
187187// CHECK-NEXT: tt.store
188188// CHECK-NEXT: tt.dot
189189// CHECK-NEXT: triton_gpu.local_store %[[tileA]]
@@ -193,10 +193,10 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
193193 %c1 = arith.constant 1 : i32
194194 %cst = arith.constant dense <0.000000e+00 > : tensor <128 x128 xf32 , #mma >
195195 %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args (%arg1 = %cst ) -> (tensor <128 x128 xf32 , #mma >) : i32 {
196- %4 = tt.load %A_ptr : tensor <128 x128 x!tt.ptr <f16 >, #blocked >
197196 %1 = triton_gpu.local_load %A_LDS : !tt.memdesc <128 x128 xf16 , #shared , #triton_gpu.shared_memory , mutable > -> tensor <128 x128 xf16 , #dotOp0 >
198- %5 = tt.load %B_ptr : tensor <128 x128 x!tt.ptr <i64 >, #blocked >
199197 %2 = triton_gpu.local_load %B_LDS : !tt.memdesc <128 x128 xf16 , #shared1 , #triton_gpu.shared_memory , mutable > -> tensor <128 x128 xf16 , #dotOp1 >
198+ %4 = tt.load %A_ptr : tensor <128 x128 x!tt.ptr <f16 >, #blocked >
199+ %5 = tt.load %B_ptr : tensor <128 x128 x!tt.ptr <i64 >, #blocked >
200200 tt.store %B_ptr , %5 : tensor <128 x128 x!tt.ptr <i64 >, #blocked >
201201 %3 = tt.dot %1 , %2 , %arg1 : tensor <128 x128 xf16 , #dotOp0 > * tensor <128 x128 xf16 , #dotOp1 > -> tensor <128 x128 xf32 , #mma >
202202 triton_gpu.local_store %4 , %A_LDS : tensor <128 x128 xf16 , #blocked > -> !tt.memdesc <128 x128 xf16 , #shared , #triton_gpu.shared_memory , mutable >
@@ -213,12 +213,12 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
213213// Category 3: two dots in the for loop. Make sure the optimization is not applied
214214// should NOT apply: two dots
215215// CHECK-LABEL: sink_2nd_load_256x256x64_two_dot
216- // CHECK: tt.load
217- // CHECK-NEXT: tt.load
218- // CHECK-NEXT: triton_gpu.local_load
216+ // CHECK: triton_gpu.local_load
219217// CHECK-NEXT: triton_gpu.local_load
220218// CHECK-NEXT: tt.dot
221219// CHECK-NEXT: tt.dot
220+ // CHECK-NEXT: tt.load
221+ // CHECK-NEXT: tt.load
222222// CHECK-NEXT: triton_gpu.local_store
223223// CHECK-NEXT: triton_gpu.local_store
224224#blocked = #triton_gpu.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [1 , 1 ], order = [1 , 0 ]}>
0 commit comments