@@ -1165,3 +1165,39 @@ module attributes {"ttg.num-warps" = 4 : i32} {
11651165 tt.return %result : !ttg.memdesc <2 x2 xf16 , #shared , #smem , mutable >
11661166 }
11671167}
1168+
1169+ // -----
1170+
1171+ #blocked = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [2 , 16 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
1172+ #blocked1 = #ttg.blocked <{sizePerThread = [1 , 128 ], threadsPerWarp = [32 , 1 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
1173+ #shared = #ttg.nvmma_shared <{swizzlingByteWidth = 128 , transposed = false , elementBitWidth = 16 }>
1174+ #tmem = #ttng.tensor_memory_encoding <blockM = 128 , blockN = 128 , colStride = 1 >
1175+
1176+ module attributes {" ttg.num-warps" = 4 : i32 , " ttg.num-ctas" = 1 : i32 } {
1177+ // CHECK-LABEL: @tc_gen5_mma_alloc_block_arg
1178+ tt.func @tc_gen5_mma_alloc_block_arg (%lb : index , %ub : index , %step : index ,
1179+ %A_ptr: tensor <128 x128 x!tt.ptr <f16 >, #blocked1 > {tt.divisibility = dense <[16 , 16 ]> : tensor <2 xi32 >, tt.contiguity = dense <[1 , 16 ]> : tensor <2 xi32 >},
1180+ %B_ptr: tensor <128 x128 x!tt.ptr <f16 >, #blocked1 > {tt.divisibility = dense <[16 , 16 ]> : tensor <2 xi32 >, tt.contiguity = dense <[1 , 16 ]> : tensor <2 xi32 >},
1181+ %acc_init : tensor <128 x128 xf32 , #blocked1 >) -> () {
1182+ %true = arith.constant true
1183+ %acc_tm = ttng.tmem_alloc %acc_init : (tensor <128 x128 xf32 , #blocked1 >) -> !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >
1184+ %zero = arith.constant dense <0.0 > : tensor <128 x128 xf16 , #blocked1 >
1185+ // CHECK: ttng.tmem_alloc
1186+ // CHECK: scf.for
1187+ scf.for %iv = %lb to %ub step %step iter_args (%A = %zero , %B = %zero ) -> (tensor <128 x128 xf16 , #blocked1 >, tensor <128 x128 xf16 , #blocked1 >) : index {
1188+ // Ensure this doesn't crash.
1189+ %A_sh = ttg.local_alloc %A : (tensor <128 x128 xf16 , #blocked1 >) -> !ttg.memdesc <128 x128 xf16 , #shared , #ttg.shared_memory >
1190+ %B_sh = ttg.local_alloc %B : (tensor <128 x128 xf16 , #blocked1 >) -> !ttg.memdesc <128 x128 xf16 , #shared , #ttg.shared_memory >
1191+ ttng.tmem_store %acc_init , %acc_tm , %true : tensor <128 x128 xf32 , #blocked1 > -> !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >
1192+ // CHECK: ttng.tc_gen5_mma
1193+ ttng.tc_gen5_mma %A_sh , %B_sh , %acc_tm , %true , %true : !ttg.memdesc <128 x128 xf16 , #shared , #ttg.shared_memory >, !ttg.memdesc <128 x128 xf16 , #shared , #ttg.shared_memory >, !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >
1194+ // CHECK: ttng.tmem_load
1195+ %acc_res = ttng.tmem_load %acc_tm : !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable > -> tensor <128 x128 xf32 , #blocked1 >
1196+ " use" (%acc_res ) : (tensor <128 x128 xf32 , #blocked1 >) -> ()
1197+ %A_next = tt.load %A_ptr : tensor <128 x128 x!tt.ptr <f16 >, #blocked1 >
1198+ %B_next = tt.load %B_ptr : tensor <128 x128 x!tt.ptr <f16 >, #blocked1 >
1199+ scf.yield %A_next , %B_next : tensor <128 x128 xf16 , #blocked1 >, tensor <128 x128 xf16 , #blocked1 >
1200+ }
1201+ tt.return
1202+ }
1203+ }
0 commit comments