@@ -2685,3 +2685,21 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
26852685 tt.return
26862686 }
26872687}
2688+
2689+ // -----
2690+
2691+ #blocked = #triton_gpu.blocked <{sizePerThread = [1 , 1 , 1 , 1 , 4 ], threadsPerWarp = [2 , 1 , 16 , 1 , 1 ], warpsPerCTA = [1 , 1 , 2 , 2 , 1 ], order = [4 , 0 , 1 , 2 , 3 ]}>
2692+ #blocked2 = #triton_gpu.blocked <{sizePerThread = [1 , 1 , 1 , 1 , 4 ], threadsPerWarp = [1 , 1 , 32 , 1 , 1 ], warpsPerCTA = [1 , 1 , 1 , 1 , 4 ], order = [4 , 3 , 2 , 1 , 0 ]}>
2693+ #blocked1 = #triton_gpu.blocked <{sizePerThread = [1 , 1 , 1 , 1 , 4 ], threadsPerWarp = [2 , 1 , 16 , 1 , 1 ], warpsPerCTA = [1 , 2 , 2 , 1 , 1 ], order = [4 , 0 , 3 , 2 , 1 ]}>
2694+ #shared = #triton_gpu.shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [4 , 0 , 1 , 2 , 3 ], hasLeadingOffset = false }>
2695+ module attributes {" triton_gpu.num-ctas" = 1 : i32 , " triton_gpu.num-warps" = 4 : i32 , triton_gpu.target = " cuda:100" , " triton_gpu.threads-per-warp" = 32 : i32 } {
2696+ // CHECK-LABEL: lift_convert_to_local_load
2697+ // CHECK-NOT: convert_layout
2698+ // CHECK: tt.return
2699+ tt.func public @lift_convert_to_local_load (%arg0 : !triton_gpu.memdesc <2 x1 x32 x4 x4 xi8 , #shared , #triton_gpu.shared_memory , mutable >) -> tensor <2 x4 x32 x1 x4 xi8 , #blocked2 > {
2700+ %1 = triton_gpu.local_load %arg0 : !triton_gpu.memdesc <2 x1 x32 x4 x4 xi8 , #shared , #triton_gpu.shared_memory , mutable > -> tensor <2 x1 x32 x4 x4 xi8 , #blocked >
2701+ %2 = tt.trans %1 {order = array<i32 : 0 , 3 , 2 , 1 , 4 >} : tensor <2 x1 x32 x4 x4 xi8 , #blocked > -> tensor <2 x4 x32 x1 x4 xi8 , #blocked1 >
2702+ %3 = triton_gpu.convert_layout %2 : tensor <2 x4 x32 x1 x4 xi8 , #blocked1 > -> tensor <2 x4 x32 x1 x4 xi8 , #blocked2 >
2703+ tt.return %3 : tensor <2 x4 x32 x1 x4 xi8 , #blocked2 >
2704+ }
2705+ }
0 commit comments