|
13 | 13 | // CHECK-DAG: [[ACC_TMEM:#.*]] = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
|
14 | 14 | #acc_tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
|
15 | 15 |
|
| 16 | +#lhs_layout = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> |
| 17 | +#lhs_tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = false> |
| 18 | + |
16 | 19 | #fp4_padded_shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8, fp4Padded = true, CTAsPerCGA = [1, 1, 1], CTASplitNum = [1, 1, 1], CTAOrder = [2, 1, 0]}>
|
17 | 20 |
|
18 | 21 | module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} {
|
@@ -1247,6 +1250,55 @@ tt.func @local_alloc_into_mma(
|
1247 | 1250 | tt.return
|
1248 | 1251 | }
|
1249 | 1252 |
|
| 1253 | +// CHECK-LABEL: @shmem_sink_iterator_invalidation |
| 1254 | +// CHECK-SAME: [[A_DESC:%arg[0-9]+]]: !tt.tensordesc |
| 1255 | +// CHECK-SAME: [[B_DESC:%arg[0-9]+]]: !tt.tensordesc |
| 1256 | +tt.func @shmem_sink_iterator_invalidation( |
| 1257 | + %k_tiles: i32, |
| 1258 | + %off_m: i32, |
| 1259 | + %off_n: i32, |
| 1260 | + %a_desc: !tt.tensordesc<tensor<128x64xf16, #shared>>, |
| 1261 | + %b_desc: !tt.tensordesc<tensor<128x64xf16, #shared>> |
| 1262 | +) { |
| 1263 | + %true = arith.constant true |
| 1264 | + %c0_i32 = arith.constant 0 : i32 |
| 1265 | + %c1_i32 = arith.constant 1 : i32 |
| 1266 | + |
| 1267 | + %BLOCK_K = arith.constant 64 : i32 |
| 1268 | + %zero = arith.constant dense<0.0> : tensor<128x128xf32, #acc_layout> |
| 1269 | + |
| 1270 | + %result = scf.for %k = %c0_i32 to %k_tiles step %c1_i32 |
| 1271 | + iter_args(%acc = %zero) -> tensor<128x128xf32, #acc_layout> : i32 { |
| 1272 | + %off_k = arith.muli %k, %BLOCK_K : i32 |
| 1273 | + |
| 1274 | + // CHECK: async_tma_copy_global_to_local [[B_DESC]] |
| 1275 | + %b_reg = tt.descriptor_load %b_desc[%off_n, %off_k] : !tt.tensordesc<tensor<128x64xf16, #shared>> -> tensor<128x64xf16, #oper_layout> |
| 1276 | + // CHECK: wait_barrier [[B_EMPTY:%[0-9]+]] |
| 1277 | + // CHECK: async_tma_copy_global_to_local [[A_DESC]][{{.*}}] [[B_DEST:%[0-9]+]], [[B_BAR:%[0-9]+]] |
| 1278 | + %a_reg = tt.descriptor_load %a_desc[%off_m, %off_k] : !tt.tensordesc<tensor<128x64xf16, #shared>> -> tensor<128x64xf16, #oper_layout> |
| 1279 | + |
| 1280 | + %a_shared = ttg.local_alloc %a_reg : (tensor<128x64xf16, #oper_layout>) -> !ttg.memdesc<128x64xf16, #shared, #smem> |
| 1281 | + // CHECK: wait_barrier [[B_BAR]] |
| 1282 | + // CHECK-NEXT: [[B:%.*]] = ttg.local_load [[B_DEST]] |
| 1283 | + // CHECK-NEXT: arrive_barrier [[B_EMPTY]] |
| 1284 | + // CHECK-NEXT: memdesc_trans |
| 1285 | + %a = ttg.local_load %a_shared : !ttg.memdesc<128x64xf16, #shared, #smem> -> tensor<128x64xf16, #lhs_layout> |
| 1286 | + %b_shared = ttg.local_alloc %b_reg : (tensor<128x64xf16, #oper_layout>) -> !ttg.memdesc<128x64xf16, #shared, #smem> |
| 1287 | + %b_T_shared = ttg.memdesc_trans %b_shared {order = array<i32: 1, 0>} : !ttg.memdesc<128x64xf16, #shared, #smem> -> !ttg.memdesc<64x128xf16, #shared_trans, #smem> |
| 1288 | + %c_tmem, %c_tok = ttng.tmem_alloc %acc : (tensor<128x128xf32, #acc_layout>) -> (!ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) |
| 1289 | + %a_tmem = ttng.tmem_alloc %a : (tensor<128x64xf16, #lhs_layout>) -> !ttg.memdesc<128x64xf16, #lhs_tmem, #ttng.tensor_memory> |
| 1290 | + %mma_tok = ttng.tc_gen5_mma %a_tmem, %b_T_shared, %c_tmem[%c_tok], %true, %true : !ttg.memdesc<128x64xf16, #lhs_tmem, #ttng.tensor_memory>, !ttg.memdesc<64x128xf16, #shared_trans, #smem>, !ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable> |
| 1291 | + |
| 1292 | + %c, %load_tok = ttng.tmem_load %c_tmem[%mma_tok] : !ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #acc_layout> |
| 1293 | + |
| 1294 | + scf.yield %c : tensor<128x128xf32, #acc_layout> |
| 1295 | + |
| 1296 | + } {tt.warp_specialize, tt.num_stages = 2 : i32} |
| 1297 | + |
| 1298 | + "use"(%result) : (tensor<128x128xf32, #acc_layout>) -> () |
| 1299 | + tt.return |
| 1300 | +} |
| 1301 | + |
1250 | 1302 | }
|
1251 | 1303 |
|
1252 | 1304 | // -----
|
|
0 commit comments