@@ -353,3 +353,26 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
353353 tt.return %result1 , %token1 : !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >, !ttg.async.token
354354 }
355355}
356+
357+ // -----
358+
359+ #blocked = #ttg.blocked <{sizePerThread = [1 , 128 ], threadsPerWarp = [32 , 1 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
360+ #tmem = #ttng.tensor_memory_encoding <blockM = 128 , blockN = 128 , unpacked = true >
361+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.target = " cuda:100" , " ttg.threads-per-warp" = 32 : i32 } {
362+ // CHECK-LABEL: @sink_multiple_tmem_load
363+ tt.func public @sink_multiple_tmem_load (%m: !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >, %t: !ttg.async.token ) -> (tensor <128 x128 xf32 , #blocked >, tensor <128 x128 xf32 , #blocked >) {
364+ %cst = arith.constant dense <0.000000e+00 > : tensor <128 x128 xf32 , #blocked >
365+ %c0_i32 = arith.constant 0 : i32
366+ %c1_i32 = arith.constant 1 : i32
367+ %c2_i32 = arith.constant 2 : i32
368+ %res:2 = scf.for %i = %c0_i32 to %c2_i32 step %c1_i32 iter_args (%init0 = %cst , %init1 = %cst ) -> (tensor <128 x128 xf32 , #blocked >, tensor <128 x128 xf32 , #blocked >) : i32 {
369+ // Any order is fine, just make sure we don't reorder them in an infinite loop.
370+ // CHECK-COUNT-2: ttng.tmem_load
371+ // CHECK: scf.yield
372+ %l0 , %token_1 = ttng.tmem_load %m [%t ] : !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable > -> tensor <128 x128 xf32 , #blocked >
373+ %l1 , %token_2 = ttng.tmem_load %m [%t ] : !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable > -> tensor <128 x128 xf32 , #blocked >
374+ scf.yield %l0 , %l1 : tensor <128 x128 xf32 , #blocked >, tensor <128 x128 xf32 , #blocked >
375+ } {tt.scheduled_max_stage = 3 : i32 }
376+ tt.return %res#0 , %res#1 : tensor <128 x128 xf32 , #blocked >, tensor <128 x128 xf32 , #blocked >
377+ }
378+ }
0 commit comments