@@ -62,7 +62,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
6262#tmem1 = #ttng.tensor_memory_encoding <blockM = 64 , blockN = 128 , unpacked = true >
6363module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.shared = 65536 : i32 , ttg.target = " cuda:100" , " ttg.threads-per-warp" = 32 : i32 } {
6464 // CHECK: ttg.tensor_memory_size = 512
65- // CHECK: alloc_tensor_memory
65+ // CHECK: alloc_tensor_memory_re_use
6666 tt.func public @alloc_tensor_memory_re_use (%arg0: !tt.ptr <f16 >, %arg1: !tt.ptr <f16 >, %arg2: !tt.ptr <f16 >) {
6767 %true = arith.constant true
6868 %c1 = arith.constant 1 : i32
@@ -113,6 +113,50 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
113113
114114// -----
115115
116+ #blocked = #ttg.blocked <{sizePerThread = [1 , 32 ], threadsPerWarp = [32 , 1 ], warpsPerCTA = [4 , 1 ], order = [0 , 1 ]}>
117+ #tmem = #ttng.tensor_memory_encoding <blockM = 128 , blockN = 64 , unpacked = true >
118+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.shared = 65536 : i32 , ttg.target = " cuda:100" , " ttg.threads-per-warp" = 32 : i32 } {
119+ // CHECK: ttg.tensor_memory_size = 128
120+ // CHECK: alloc_tensor_memory_re_use_liverange_end_collision
121+ tt.func public @alloc_tensor_memory_re_use_liverange_end_collision (
122+ %arg0: !tt.ptr <f16 >, %arg1: !tt.ptr <f16 >, %arg2: !tt.ptr <f16 >,
123+ %lb: index , %ub: index , %step: index ) {
124+ %true = arith.constant true
125+ %c1 = arith.constant 1 : i32
126+ %c0 = arith.constant 0 : i32
127+ %cst = arith.constant dense <0.000000e+00 > : tensor <128 x64 xf32 , #blocked >
128+ %cst0 = arith.constant dense <0.000000e+00 > : tensor <128 x64 xf32 , #blocked >
129+ %cst1 = arith.constant dense <0.000000e+00 > : tensor <128 x64 xf32 , #blocked >
130+ %cst2 = arith.constant dense <0.000000e+00 > : tensor <128 x64 xf32 , #blocked >
131+
132+ // CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32}
133+ %a = ttng.tmem_alloc %cst0 : (tensor <128 x64 xf32 , #blocked >) -> !ttg.memdesc <128 x64 xf32 , #tmem , #ttng.tensor_memory , mutable >
134+
135+ // CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 64 : i32, tensor_memory_row_offset = 0 : i32}
136+ %b = ttng.tmem_alloc %cst : (tensor <128 x64 xf32 , #blocked >) -> !ttg.memdesc <128 x64 xf32 , #tmem , #ttng.tensor_memory , mutable >
137+
138+ scf.for %i = %lb to %ub step %step {
139+ ttng.tmem_store %cst2 , %a , %true : tensor <128 x64 xf32 , #blocked > -> !ttg.memdesc <128 x64 xf32 , #tmem , #ttng.tensor_memory , mutable >
140+ ttng.tmem_store %cst2 , %b , %true : tensor <128 x64 xf32 , #blocked > -> !ttg.memdesc <128 x64 xf32 , #tmem , #ttng.tensor_memory , mutable >
141+ scf.yield
142+ }
143+ // Liveranges of both allocations end at the same time, at the boundary of the loop. Make sure we can handle this case.
144+
145+ // CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32}
146+ %c = ttng.tmem_alloc %cst0 : (tensor <128 x64 xf32 , #blocked >) -> !ttg.memdesc <128 x64 xf32 , #tmem , #ttng.tensor_memory , mutable >
147+
148+ // CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 64 : i32, tensor_memory_row_offset = 0 : i32}
149+ %d = ttng.tmem_alloc %cst : (tensor <128 x64 xf32 , #blocked >) -> !ttg.memdesc <128 x64 xf32 , #tmem , #ttng.tensor_memory , mutable >
150+
151+ ttng.tmem_store %cst2 , %c , %true : tensor <128 x64 xf32 , #blocked > -> !ttg.memdesc <128 x64 xf32 , #tmem , #ttng.tensor_memory , mutable >
152+ ttng.tmem_store %cst2 , %d , %true : tensor <128 x64 xf32 , #blocked > -> !ttg.memdesc <128 x64 xf32 , #tmem , #ttng.tensor_memory , mutable >
153+
154+ tt.return
155+ }
156+ }
157+
158+ // -----
159+
116160#blocked = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [2 , 16 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ], CTAsPerCGA = [2 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
117161#tmem = #ttng.tensor_memory_encoding <blockM = 128 , blockN = 128 , unpacked = true , CTASplitM = 2 >
118162#tmem1 = #ttng.tensor_memory_encoding <blockM = 128 , blockN = 64 , unpacked = true , CTASplitN = 2 >
0 commit comments