@@ -341,13 +341,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
341
341
#tmem = #ttng.tensor_memory_encoding <blockM = 128 , blockN = 128 , unpacked = true >
342
342
module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.target = " cuda:100" , " ttg.threads-per-warp" = 32 : i32 } {
343
343
tt.func public @forward_tmem_load (%m: !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >, %t: !ttg.async.token ) -> (!ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >, !ttg.async.token ) {
344
+ %true = arith.constant true
344
345
%result , %token0 = ttng.tmem_load %m [%t ] : !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable > -> tensor <128 x128 xf32 , #blocked >
345
346
// HOIST-IF-LABEL: @forward_tmem_load
346
347
// HOIST-IF-SAME: %[[ARG0:.+]]: !ttg.memdesc<128x128xf32,
347
348
// HOIST-IF-SAME: %[[ARG1:.+]]: !ttg.async.token
348
349
// HOIST-IF-NEXT: tt.return %[[ARG0]], %[[ARG1]]
349
- %result1 , %token1 = ttng.tmem_alloc %result : (tensor <128 x128 xf32 , #blocked >) -> (!ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >, !ttg.async.token )
350
- tt.return %result1 , %token1 : !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >, !ttg.async.token
350
+ %result1 , %token1 = ttng.tmem_alloc : () -> (!ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >, !ttg.async.token )
351
+ %token2 = ttng.tmem_store %result , %result1 [%token1 ], %true : tensor <128 x128 xf32 , #blocked > -> !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >
352
+ tt.return %result1 , %token2 : !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >, !ttg.async.token
351
353
}
352
354
}
353
355
0 commit comments