|
2 | 2 |
|
3 | 3 | #blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> |
4 | 4 | #blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> |
| 5 | +#blocked2 = #ttg.blocked<{sizePerThread = [128, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> |
5 | 6 | #linear = #ttg.linear<{register = [[0, 1], [0, 2], [32, 0], [64, 0], [0, 4]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[0, 0], [0, 0]], block = []}> |
6 | 7 | #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> |
7 | 8 | #shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}> |
@@ -166,4 +167,49 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { |
166 | 167 | } {tt.num_stages = 2 : i64, tt.scheduled_max_stage = 1 : i32, tt.warp_specialize, ttg.partition.stages = [0 : i32, 1 : i32, 0 : i32], ttg.warp_specialize.tag = 0 : i32} |
167 | 168 | tt.return |
168 | 169 | } |
| 170 | + |
| 171 | + // FUNC-LABEL: @local_alloc_default_partition |
| 172 | + // CHECK: @local_alloc_default_partition |
| 173 | + tt.func @local_alloc_default_partition(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: !tt.tensordesc<tensor<128x128xf16, #shared>>, %arg4: !tt.tensordesc<tensor<128x128xf16, #shared>>) { |
| 174 | + %true = arith.constant true |
| 175 | + %c0_i32 = arith.constant 0 : i32 |
| 176 | + %c1_i32 = arith.constant 1 : i32 |
| 177 | + %c128_i32 = arith.constant 128 : i32 |
| 178 | + // CHECK: [[AREF_LHS_TRANS:%.*]] = nvws.aref.create {{.*}} : <[!ttg.memdesc<1x128x128xf16, #shared1, #smem, mutable>]> |
| 179 | + // CHECK: [[AREF_RHS:%.*]] = nvws.aref.create {{.*}} : <[!ttg.memdesc<1x128x128xf16, #shared, #smem, mutable>]> |
| 180 | + // CHECK: [[AREF_LHS:%.*]] = nvws.aref.create {{.*}} : <[!ttg.memdesc<1x128x128xf16, #shared, #smem, mutable>]> |
| 181 | + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> |
| 182 | + %result, %token = ttng.tmem_alloc : () -> (!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) |
| 183 | + %0 = ttng.tmem_store %cst, %result[%token], %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> |
| 184 | + |
| 185 | + %1 = scf.for %arg5 = %c0_i32 to %arg0 step %c1_i32 iter_args(%arg6 = %0) -> (!ttg.async.token) : i32 { |
| 186 | + %2 = arith.muli %arg5, %c128_i32 {loop.cluster = 1 : i32, loop.stage = 0 : i32} : i32 |
| 187 | + // CHECK: [[AREF_LHS_PUT_BUF:%.*]], {{.*}} = nvws.aref.put.enter [[AREF_LHS]] {{.*}}ttg.partition = array<i32: 2>} |
| 188 | + // CHECK: nvws.descriptor_load {{.*}} 32768 [[AREF_LHS_PUT_BUF]] {{.*}}ttg.partition = array<i32: 2>} |
| 189 | + |
| 190 | + // CHECK: [[AREF_LHS_GET_BUF:%.*]], {{.*}} = nvws.aref.get.enter [[AREF_LHS]] {{.*}}ttg.partition = array<i32: 0>} |
| 191 | + // CHECK: [[TMA_RES_REG:%.*]] = ttg.local_load [[AREF_LHS_GET_BUF]] {{.*}}ttg.partition = array<i32: 0>} |
| 192 | + |
| 193 | + // CHECK: [[AREF_LHS_TRANS_PUT_BUF:%.*]], {{.*}} = nvws.aref.put.enter [[AREF_LHS_TRANS]] {{.*}}ttg.partition = array<i32: 0>} |
| 194 | + // CHECK: ttg.local_store [[TMA_RES_REG]], [[AREF_LHS_TRANS_PUT_BUF]] {{.*}}ttg.partition = array<i32: 0>} |
| 195 | + |
| 196 | + // CHECK: [[AREF_LHS_TRANS_GET_BUF:%.*]], {{.*}} = nvws.aref.get.enter [[AREF_LHS_TRANS]] {{.*}}ttg.partition = array<i32: 1>} |
| 197 | + // CHECK: [[LHS:%.*]] = ttg.memdesc_trans [[AREF_LHS_TRANS_GET_BUF]] {{.*}}ttg.partition = array<i32: 1>} |
| 198 | + |
| 199 | + %3 = tt.descriptor_load %arg3[%arg1, %2] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array<i32: 2>} : !tt.tensordesc<tensor<128x128xf16, #shared>> -> tensor<128x128xf16, #blocked2> |
| 200 | + %5 = ttg.local_alloc %3 {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array<i32: 0>} : (tensor<128x128xf16, #blocked2>) -> !ttg.memdesc<128x128xf16, #shared1, #smem> |
| 201 | + %lhs_trans = ttg.memdesc_trans %5 {loop.cluster = 0 : i32, loop.stage = 1 : i32, order = array<i32: 1, 0>, ttg.partition = array<i32: 1>} : !ttg.memdesc<128x128xf16, #shared1, #smem> -> !ttg.memdesc<128x128xf16, #shared, #smem> |
| 202 | + |
| 203 | + %4 = tt.descriptor_load %arg4[%arg2, %2] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array<i32: 2>} : !tt.tensordesc<tensor<128x128xf16, #shared>> -> tensor<128x128xf16, #blocked1> |
| 204 | + %6 = ttg.local_alloc %4 {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array<i32: 2>} : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #smem> |
| 205 | + %7 = ttg.memdesc_trans %6 {loop.cluster = 0 : i32, loop.stage = 1 : i32, order = array<i32: 1, 0>, ttg.partition = array<i32: 1>} : !ttg.memdesc<128x128xf16, #shared, #smem> -> !ttg.memdesc<128x128xf16, #shared1, #smem> |
| 206 | + |
| 207 | + // CHECK: ttng.tc_gen5_mma [[LHS]] |
| 208 | + %8 = ttng.tc_gen5_mma %lhs_trans, %7, %result[%arg6], %true, %true {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array<i32: 1>} : !ttg.memdesc<128x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf16, #shared1, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> |
| 209 | + scf.yield %8 : !ttg.async.token |
| 210 | + } {tt.num_stages = 2 : i32, tt.scheduled_max_stage = 1 : i32, tt.warp_specialize, ttg.partition.stages = [0 : i32, 1 : i32, 0 : i32], ttg.warp_specialize.tag = 0 : i32} |
| 211 | + %result_0, %token_1 = ttng.tmem_load %result[%1] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked> |
| 212 | + "use"(%result_0) : (tensor<128x128xf32, #blocked>) -> () |
| 213 | + tt.return |
| 214 | + } |
169 | 215 | } |
0 commit comments