|
4 | 4 | #indices_layout = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
5 | 5 | #acc_layout = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
|
6 | 6 | #oper_layout = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
|
| 7 | +#b_layout = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> |
7 | 8 | #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
|
8 | 9 | #smem = #ttg.shared_memory
|
9 | 10 | #acc_tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
|
@@ -31,11 +32,16 @@ tt.func @matmul_change_desc_in_prologue(
|
31 | 32 | // CHECK-SAME: num_warps(1)
|
32 | 33 | // BASE-NOT: tt.make_tensor_descriptor
|
33 | 34 | // PIPELINE-NOT: tt.experimental_tensormap_create
|
| 35 | + // PIPELINE-COUNT-1: tc_gen5_mma |
| 36 | + // PIPELINE-NOT: tc_gen5_mma |
34 | 37 | // CHECK-LABEL: partition1
|
35 | 38 | // CHECK-SAME: num_warps(2)
|
36 | 39 | // BASE-COUNT-2: tt.make_tensor_descriptor
|
37 | 40 | // PIPELINE-COUNT-2: ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 512 : i32}
|
38 | 41 | // PIPELINE-COUNT-2: tt.experimental_tensormap_create
|
| 42 | + // PIPELINE-NOT: tt.experimental_tensormap_create |
| 43 | + // PIPELINE-COUNT-2: async_tma_copy_global_to_local |
| 44 | + // PIPELINE-NOT: async_tma_copy_global_to_local |
39 | 45 | // CHECK-NOT: partition2
|
40 | 46 | scf.for %k = %c0_i32 to %k_tiles step %c1_i32 iter_args(%acc = %zero, %flag = %true, %a_desc = %a_desc_undef, %b_desc = %b_desc_undef) -> (tensor<128x128xf32, #acc_layout>, i1, !tt.tensordesc<tensor<128x64xf16, #shared>>, !tt.tensordesc<tensor<64x128xf16, #shared>>) : i32 {
|
41 | 47 | %do_prologue = "prologue_cond"(%k) : (i32) -> i1
|
@@ -108,6 +114,53 @@ tt.func @matmul_tma_acc_with_conditional_def_and_use(
|
108 | 114 | tt.return
|
109 | 115 | }
|
110 | 116 |
|
| 117 | +// CHECK-LABEL: @matmul_tma_and_regular_load |
| 118 | +tt.func @matmul_tma_and_regular_load( |
| 119 | + %a_desc: !tt.tensordesc<tensor<1x64xf16, #shared>>, |
| 120 | + %b_ptr_init: tensor<64x128x!tt.ptr<f16>, #b_layout> {tt.divisibility = 16 : i32, tt.contiguity = 64 : i32} |
| 121 | +) { |
| 122 | + %c0_i32 = arith.constant 0 : i32 |
| 123 | + %c1_i32 = arith.constant 1 : i32 |
| 124 | + %true = arith.constant true |
| 125 | + %false = arith.constant false |
| 126 | + %zero = arith.constant dense<0.0> : tensor<128x128xf32, #acc_layout> |
| 127 | + %k_tiles = arith.constant 32 : i32 |
| 128 | + // CHECK-LABEL: ttg.warp_specialize |
| 129 | + // CHECK-LABEL: default |
| 130 | + // CHECK-LABEL: partition0 |
| 131 | + // CHECK-SAME: num_warps(4) |
| 132 | + // PIPELINE-COUNT-3: async_copy_global_to_local |
| 133 | + // PIPELINE-NOT: async_copy_global_to_local |
| 134 | + // CHECK-LABEL: partition1 |
| 135 | + // CHECK-SAME: num_warps(4) |
| 136 | + // CHECK: [[INDICES:%.*]] = tt.splat %{{.*}} : i32 -> tensor<128xi32, |
| 137 | + // CHECK: ttng.async_tma_gather %{{.*}}[[[INDICES]], |
| 138 | + // CHECK-NOT: partition2 |
| 139 | + scf.for %k = %c0_i32 to %k_tiles step %c1_i32 iter_args(%acc = %zero, %flag = %true, %b_ptr = %b_ptr_init) -> (tensor<128x128xf32, #acc_layout>, i1, tensor<64x128x!tt.ptr<f16>, #b_layout>) : i32 { |
| 140 | + %off_m, %offs_n, %off_k = "get_offsets"(%k) : (i32) -> (i32, tensor<64x128xi32, #b_layout>, i32) |
| 141 | + %indices = tt.splat %off_m : i32 -> tensor<128xi32, #indices_layout> |
| 142 | + |
| 143 | + %a = tt.descriptor_gather %a_desc[%indices, %off_k] : (!tt.tensordesc<tensor<1x64xf16, #shared>>, tensor<128xi32, #indices_layout>, i32) -> tensor<128x64xf16, #oper_layout> |
| 144 | + |
| 145 | + %b_ptrs = tt.addptr %b_ptr, %offs_n {tt.divisibility = dense<16> : tensor<64x128xi32>, tt.contiguity = dense<64> : tensor<64x128xi32>, tt.constancy = dense<1> : tensor<64x128xi32>} : tensor<64x128x!tt.ptr<f16>, #b_layout>, tensor<64x128xi32, #b_layout> |
| 146 | + %b = tt.load %b_ptrs : tensor<64x128x!tt.ptr<f16>, #b_layout> |
| 147 | + |
| 148 | + %a_shared = ttg.local_alloc %a : (tensor<128x64xf16, #oper_layout>) -> !ttg.memdesc<128x64xf16, #shared, #smem> |
| 149 | + %b_shared = ttg.local_alloc %b : (tensor<64x128xf16, #b_layout>) -> !ttg.memdesc<64x128xf16, #shared, #smem> |
| 150 | + %c_tmem, %c_tok = ttng.tmem_alloc %acc : (tensor<128x128xf32, #acc_layout>) -> (!ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) |
| 151 | + %mma_tok = ttng.tc_gen5_mma %a_shared, %b_shared, %c_tmem[%c_tok], %flag, %true : !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable> |
| 152 | + %c, %load_tok = ttng.tmem_load %c_tmem[%mma_tok] : !ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #acc_layout> |
| 153 | + |
| 154 | + %do_epilogue = arith.cmpi eq, %k, %c0_i32 : i32 |
| 155 | + %use_acc = arith.select %do_epilogue, %false, %true : i1 |
| 156 | + scf.if %do_epilogue { |
| 157 | + "acc_user"(%c) : (tensor<128x128xf32, #acc_layout>) -> () |
| 158 | + } |
| 159 | + scf.yield %c, %use_acc, %b_ptrs : tensor<128x128xf32, #acc_layout>, i1, tensor<64x128x!tt.ptr<f16>, #b_layout> |
| 160 | + } {tt.warp_specialize, tt.disallow_acc_multi_buffer, tt.num_stages = 2 : i32} |
| 161 | + tt.return |
| 162 | +} |
| 163 | + |
111 | 164 | }
|
112 | 165 |
|
113 | 166 | // -----
|
|
0 commit comments