@@ -198,3 +198,38 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
198198 tt.return
199199 }
200200} // end module
201+
202+ // -----
203+
204+ // CHECK-NOT: #triton_gpu.shared<{{.*}} order = [2, 0, 1]
205+ // CHECK: #triton_gpu.shared<{{.*}} order = [2, 1, 0]
206+ // CHECK-NOT: #triton_gpu.shared<{{.*}} order = [2, 0, 1]
207+
208+ // CHECK-LABEL: tt.func public @slowest_dim_is_batch
209+ #blocked = #triton_gpu.blocked <{sizePerThread = [1 , 1 , 2 ], threadsPerWarp = [4 , 1 , 16 ], warpsPerCTA = [4 , 1 , 1 ], order = [2 , 1 , 0 ]}>
210+ #blocked1 = #triton_gpu.blocked <{sizePerThread = [1 , 1 , 8 ], threadsPerWarp = [16 , 1 , 4 ], warpsPerCTA = [4 , 1 , 1 ], order = [2 , 0 , 1 ]}>
211+ #blocked2 = #triton_gpu.blocked <{sizePerThread = [1 , 2 ], threadsPerWarp = [1 , 64 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ]}>
212+ #blocked5 = #triton_gpu.blocked <{sizePerThread = [1 , 1 , 2 ], threadsPerWarp = [16 , 1 , 4 ], warpsPerCTA = [4 , 1 , 1 ], order = [2 , 0 , 1 ]}>
213+ module attributes {" triton_gpu.num-ctas" = 1 : i32 , " triton_gpu.num-warps" = 4 : i32 , triton_gpu.target = " hip:gfx90a" , " triton_gpu.threads-per-warp" = 64 : i32 } {
214+ tt.func public @slowest_dim_is_batch (%arg0: tensor <1 x512 x!tt.ptr <f32 >, #blocked2 >, %arg1: tensor <64 x8 x32 x!tt.ptr <f32 >, #blocked1 >, %arg2: tensor <64 x1 x32 x!tt.ptr <f32 >, #blocked >) attributes {noinline = false } {
215+ %cst = arith.constant dense <0.000000e+00 > : tensor <64 x1 x32 xf32 , #blocked >
216+ %cst_0 = arith.constant dense <512 > : tensor <1 x512 xi32 , #blocked2 >
217+ %cst_1 = arith.constant dense <128 > : tensor <64 x8 x32 xi32 , #blocked1 >
218+ %c1_i32 = arith.constant 1 : i32
219+ %c5_i32 = arith.constant 2 : i32
220+ %c0_i32 = arith.constant 0 : i32
221+ %33:3 = scf.for %arg7 = %c0_i32 to %c5_i32 step %c1_i32 iter_args (%arg8 = %cst , %arg9 = %arg0 , %arg10 = %arg1 ) -> (tensor <64 x1 x32 xf32 , #blocked >, tensor <1 x512 x!tt.ptr <f32 >, #blocked2 >, tensor <64 x8 x32 x!tt.ptr <f32 >, #blocked1 >) : i32 {
222+ %39 = tt.load %arg9 : tensor <1 x512 x!tt.ptr <f32 >, #blocked2 >
223+ %40 = tt.load %arg10 : tensor <64 x8 x32 x!tt.ptr <f32 >, #blocked1 >
224+ %41 = tt.reshape %39 {allow_reorder = true } : tensor <1 x512 xf32 , #blocked2 > -> tensor <64 x1 x8 xf32 , #blocked5 >
225+ %43 = triton_gpu.convert_layout %41 : tensor <64 x1 x8 xf32 , #blocked5 > -> tensor <64 x1 x8 xf32 , #triton_gpu.dot_op <{opIdx = 0 , parent = #blocked }>>
226+ %44 = triton_gpu.convert_layout %40 : tensor <64 x8 x32 xf32 , #blocked1 > -> tensor <64 x8 x32 xf32 , #triton_gpu.dot_op <{opIdx = 1 , parent = #blocked }>>
227+ %45 = tt.dot %43 , %44 , %arg8 : tensor <64 x1 x8 xf32 , #triton_gpu.dot_op <{opIdx = 0 , parent = #blocked }>> * tensor <64 x8 x32 xf32 , #triton_gpu.dot_op <{opIdx = 1 , parent = #blocked }>> -> tensor <64 x1 x32 xf32 , #blocked >
228+ %46 = tt.addptr %arg9 , %cst_0 : tensor <1 x512 x!tt.ptr <f32 >, #blocked2 >, tensor <1 x512 xi32 , #blocked2 >
229+ %47 = tt.addptr %arg10 , %cst_1 : tensor <64 x8 x32 x!tt.ptr <f32 >, #blocked1 >, tensor <64 x8 x32 xi32 , #blocked1 >
230+ scf.yield %45 , %46 , %47 : tensor <64 x1 x32 xf32 , #blocked >, tensor <1 x512 x!tt.ptr <f32 >, #blocked2 >, tensor <64 x8 x32 x!tt.ptr <f32 >, #blocked1 >
231+ }
232+ tt.store %arg2 , %33#0 : tensor <64 x1 x32 x!tt.ptr <f32 >, #blocked >
233+ tt.return
234+ }
235+ }
0 commit comments