@@ -90,3 +90,19 @@ tt.func @tma_scatter(%arg0: !tt.tensordesc<tensor<1x128xbf16, #nvmma_128>>, %arg
9090}
9191
9292}
93+
94+ // -----
95+
96+ #blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [8 , 1 ], order = [1 , 0 ]}>
97+ #shared = #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [1 , 2 , 0 ]}>
98+ // CHECK: #[[$SHARED:.+]] = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
99+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 8 : i32 , ttg.target = " cuda:90" , " ttg.threads-per-warp" = 32 : i32 } {
100+ // CHECK-LABLE: @rank_reducing_load
101+ tt.func public @rank_reducing_load (%arg0: !tt.tensordesc <tensor <1 x256 x32 xf32 , #shared >>) -> tensor <256 x32 xf32 , #blocked > {
102+ %c32_i32 = arith.constant 32 : i32
103+ // CHECK: %[[A:.+]] = ttg.local_alloc : () -> !ttg.memdesc<256x32xf32, #[[$SHARED]], #smem, mutable>
104+ // CHECK: tng.async_tma_copy_global_to_local %{{.+}}[%{{.+}}, %{{.+}}, %{{.+}}] %[[A]],
105+ %l = tt.descriptor_load %arg0 [%c32_i32 , %c32_i32 , %c32_i32 ] : !tt.tensordesc <tensor <1 x256 x32 xf32 , #shared >> -> tensor <256 x32 xf32 , #blocked >
106+ tt.return %l : tensor <256 x32 xf32 , #blocked >
107+ }
108+ }
0 commit comments