|
| 1 | +// DEFINE: %{compile} = mlir-opt %s \ |
| 2 | +// DEFINE: -transform-interpreter -test-transform-dialect-erase-schedule \ |
| 3 | +// DEFINE: -cse -canonicalize -test-lower-to-llvm |
| 4 | +// DEFINE: %{entry_point} = main |
| 5 | +// DEFINE: %{run} = mlir-runner -e %{entry_point} -entry-point-result=void \ |
| 6 | +// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils |
| 7 | + |
| 8 | +// RUN: %{compile} | %{run} | FileCheck %s |
| 9 | + |
| 10 | +//===----------------------------------------------------------------------===// |
| 11 | +/// HIGH-LEVEL OVERVIEW |
| 12 | +/// |
| 13 | +/// End-to-end test for computing matrix-multiplication using linalg.mmt4d. In |
| 14 | +/// particular, demonstrates how the following MLIR sequence (implemented in |
| 15 | +/// @matmul_via_mmt4d): |
| 16 | +/// |
| 17 | +/// A_pack = linalg.pack A |
| 18 | +/// B_pack = linalg.pack B |
| 19 | +/// C_pack = linalg.pack C |
| 20 | +/// out_pack = linalg.mmt4d(A_pack, B_pack, C_pack) |
| 21 | +/// |
| 22 | +/// is equivalent to: |
| 23 | +/// |
| 24 | +/// linalg.matmul(A, B, C) |
| 25 | +/// |
| 26 | +/// (implemented in @matmul_via_matmul). |
| 27 | +/// |
| 28 | +/// NOTES ON IMPLEMENTATION |
| 29 | +/// 1. The MMT4D example uses _scalable_ tile sizes for data tiling. |
| 30 | +/// * The matrix-multiplication dimension that's scalable: N. |
| 31 | +/// |
| 32 | +/// 2. The lowering of linalg.mmt4d leverages scalable vectorisation. |
| 33 | +/// * The matrix-multiplication dimension that's scalable: N (to match data |
| 34 | +/// tiling configuration). |
| 35 | +/// |
| 36 | +/// 3. Neither `linalg.pack` nor `linalg.unpack` are vectorised ATM. |
| 37 | +/// |
| 38 | +/// 4. The MMT4D and Pack/Unpack Ops are kept in seperate functions to isolate |
| 39 | +/// the corresponding lowering and lowering configs. |
| 40 | +/// * TODO: Ideally, we should consider fusion opportunities by moving these |
| 41 | +/// Ops into one function. |
| 42 | +//===----------------------------------------------------------------------===// |
| 43 | + |
| 44 | +//===----------------------------------------------------------------------===// |
| 45 | +// @main |
| 46 | +// |
| 47 | +// The main entry point that computes matrix multiplication via linalg.mmt4d |
| 48 | +// and linalg.matmul. Note, the output should be independent of the underlying |
| 49 | +// Linalg Op used, as well as SVE vector length. |
| 50 | +//===----------------------------------------------------------------------===// |
| 51 | +func.func @main() { |
| 52 | + // Allocate and initialise the inputs |
| 53 | + %A_empty = tensor.empty() : tensor<7x16xi32> |
| 54 | + %B_empty = tensor.empty() : tensor<16x13xi32> |
| 55 | + |
| 56 | + %c3 = arith.constant 3 : i32 |
| 57 | + %c4 = arith.constant 4 : i32 |
| 58 | + %A = linalg.fill ins(%c3 : i32) outs(%A_empty : tensor<7x16xi32>) -> tensor<7x16xi32> |
| 59 | + %B = linalg.fill ins(%c4 : i32) outs(%B_empty : tensor<16x13xi32>) -> tensor<16x13xi32> |
| 60 | + %C = arith.constant dense<[ |
| 61 | + [ 1, 8, 15, 22, 29, 36, 43, 50, 57, 64, 71, 78, 85], |
| 62 | + [ 2, 9, 16, 23, 30, 37, 44, 51, 58, 65, 72, 79, 86], |
| 63 | + [ 3, 10, 17, 24, 31, 38, 45, 52, 59, 66, 73, 80, 87], |
| 64 | + [ 4, 11, 18, 25, 32, 39, 46, 53, 60, 67, 74, 81, 88], |
| 65 | + [ 5, 12, 19, 26, 33, 40, 47, 54, 61, 68, 75, 82, 89], |
| 66 | + [ 6, 13, 20, 27, 34, 41, 48, 55, 62, 69, 76, 83, 90], |
| 67 | + [ 7, 14, 21, 28, 35, 42, 49, 56, 63, 70, 77, 84, 91] |
| 68 | + ]> : tensor<7x13xi32> |
| 69 | + |
| 70 | + // VARIANT: Matrix multiplication via linalg.mmt4d |
| 71 | + // CHECK: Unranked Memref |
| 72 | + // CHECK: [193, 200, 207, 214, 221, 228, 235, 242, 249, 256, 263, 270, 277] |
| 73 | + // CHECK: [194, 201, 208, 215, 222, 229, 236, 243, 250, 257, 264, 271, 278] |
| 74 | + // CHECK: [195, 202, 209, 216, 223, 230, 237, 244, 251, 258, 265, 272, 279] |
| 75 | + // CHECK: [196, 203, 210, 217, 224, 231, 238, 245, 252, 259, 266, 273, 280] |
| 76 | + // CHECK: [197, 204, 211, 218, 225, 232, 239, 246, 253, 260, 267, 274, 281] |
| 77 | + // CHECK: [198, 205, 212, 219, 226, 233, 240, 247, 254, 261, 268, 275, 282] |
| 78 | + // CHECK: [199, 206, 213, 220, 227, 234, 241, 248, 255, 262, 269, 276, 283] |
| 79 | + %C_mmt4d = func.call @matmul_via_mmt4d(%A, %B, %C) : (tensor<7x16xi32>, tensor<16x13xi32>, tensor<7x13xi32>) -> tensor<7x13xi32> |
| 80 | + %C_mmt4d_cast = tensor.cast %C_mmt4d : tensor<7x13xi32> to tensor<*xi32> |
| 81 | + vector.print str "--------------------------\n" |
| 82 | + vector.print str "RESULT FROM linalg.mmt4d:\n" |
| 83 | + vector.print str "--------------------------\n" |
| 84 | + call @printMemrefI32(%C_mmt4d_cast) : (tensor<*xi32>) -> () |
| 85 | + |
| 86 | + // VARIANT: Matrix multiplication via linalg.matmul |
| 87 | + // CHECK: Unranked Memref |
| 88 | + // CHECK: [193, 200, 207, 214, 221, 228, 235, 242, 249, 256, 263, 270, 277] |
| 89 | + // CHECK: [194, 201, 208, 215, 222, 229, 236, 243, 250, 257, 264, 271, 278] |
| 90 | + // CHECK: [195, 202, 209, 216, 223, 230, 237, 244, 251, 258, 265, 272, 279] |
| 91 | + // CHECK: [196, 203, 210, 217, 224, 231, 238, 245, 252, 259, 266, 273, 280] |
| 92 | + // CHECK: [197, 204, 211, 218, 225, 232, 239, 246, 253, 260, 267, 274, 281] |
| 93 | + // CHECK: [198, 205, 212, 219, 226, 233, 240, 247, 254, 261, 268, 275, 282] |
| 94 | + // CHECK: [199, 206, 213, 220, 227, 234, 241, 248, 255, 262, 269, 276, 283] |
| 95 | + %C_matmul = func.call @matmul(%A, %B, %C) : (tensor<7x16xi32>, tensor<16x13xi32>, tensor<7x13xi32>) -> tensor<7x13xi32> |
| 96 | + %C_matmul_cast = tensor.cast %C_matmul : tensor<7x13xi32> to tensor<*xi32> |
| 97 | + vector.print str "\n--------------------------\n" |
| 98 | + vector.print str "RESULT FROM linalg.matmul:\n" |
| 99 | + vector.print str "--------------------------\n" |
| 100 | + call @printMemrefI32(%C_matmul_cast) : (tensor<*xi32>) -> () |
| 101 | + |
| 102 | + return |
| 103 | +} |
| 104 | + |
| 105 | +//===----------------------------------------------------------------------===// |
| 106 | +// @matmul_via_matmul |
| 107 | +// |
| 108 | +// Implements matrix-multiplication via linalg.matmul |
| 109 | +//===----------------------------------------------------------------------===// |
| 110 | +func.func private @matmul(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tensor<7x13xi32>) -> tensor<7x13xi32> { |
| 111 | + %C_matmul = linalg.matmul ins(%A, %B: tensor<7x16xi32>, tensor<16x13xi32>) |
| 112 | + outs(%C: tensor<7x13xi32>) -> tensor<7x13xi32> |
| 113 | + |
| 114 | + return %C_matmul : tensor<7x13xi32> |
| 115 | +} |
| 116 | + |
| 117 | +//===----------------------------------------------------------------------===// |
| 118 | +// @matmul_via_mmt4d |
| 119 | +// |
| 120 | +// Implements matrix-multiplication via linalg.mmt4d |
| 121 | +//===----------------------------------------------------------------------===// |
| 122 | +func.func private @pack_lhs(%A: tensor<7x16xi32>) -> tensor<1x16x8x1xi32> { |
| 123 | + %pad = arith.constant 0 : i32 |
| 124 | + |
| 125 | + %A_pack_empty = tensor.empty() : tensor<1x16x8x1xi32> |
| 126 | + %A_pack = linalg.pack %A |
| 127 | + padding_value(%pad : i32) |
| 128 | + inner_dims_pos = [0, 1] |
| 129 | + inner_tiles = [8, 1] |
| 130 | + into %A_pack_empty : tensor<7x16xi32> -> tensor<1x16x8x1xi32> |
| 131 | + |
| 132 | + return %A_pack : tensor<1x16x8x1xi32> |
| 133 | +} |
| 134 | + |
| 135 | +//===----------------------------------------------------------------------===// |
| 136 | +// @pack_rhs |
| 137 | +// |
| 138 | +// Implements packing for the B matrix (RHS) in matrix multiplication. The |
| 139 | +// inner tile size is "scalable": 8 * vscale. |
| 140 | +//===----------------------------------------------------------------------===// |
| 141 | +func.func private @pack_rhs(%B: tensor<16x13xi32>) -> tensor<?x16x?x1xi32> { |
| 142 | + %pad = arith.constant 0 : i32 |
| 143 | + |
| 144 | + // Compute the outer tile size. |
| 145 | + %vs = vector.vscale |
| 146 | + %c8 = arith.constant 8 : index |
| 147 | + %vs_c8 = arith.muli %vs, %c8 : index |
| 148 | + %c13 = arith.constant 13 : index |
| 149 | + %outer_tile_size = arith.ceildivui %c13, %vs_c8 : index |
| 150 | + |
| 151 | + %B_pack_empty = tensor.empty(%outer_tile_size, %vs_c8) : tensor<?x16x?x1xi32> |
| 152 | + %B_pack = linalg.pack %B |
| 153 | + padding_value(%pad : i32) |
| 154 | + outer_dims_perm = [1, 0] |
| 155 | + inner_dims_pos = [1, 0] |
| 156 | + inner_tiles = [%vs_c8, 1] |
| 157 | + into %B_pack_empty : tensor<16x13xi32> -> tensor<?x16x?x1xi32> |
| 158 | + |
| 159 | + return %B_pack : tensor<?x16x?x1xi32> |
| 160 | +} |
| 161 | + |
| 162 | +//===----------------------------------------------------------------------===// |
| 163 | +// @pack_acc |
| 164 | +// |
| 165 | +// Implements packing for the C matrix (accumulator) in matrix multiplication. |
| 166 | +// The inner tile size is "scalable": 8 * vscale |
| 167 | +//===----------------------------------------------------------------------===// |
| 168 | +func.func private @pack_acc(%C: tensor<7x13xi32>) -> tensor<1x?x8x?xi32> { |
| 169 | + %pad = arith.constant 0 : i32 |
| 170 | + |
| 171 | + // Compute the outer tile size. |
| 172 | + %c13 = arith.constant 13 : index |
| 173 | + %vs = vector.vscale |
| 174 | + %c8 = arith.constant 8 : index |
| 175 | + %vs_c8 = arith.muli %vs, %c8 : index |
| 176 | + %outer_tile_size = arith.ceildivui %c13, %vs_c8 : index |
| 177 | + |
| 178 | + %C_pack_empty = tensor.empty(%outer_tile_size, %vs_c8) : tensor<1x?x8x?xi32> |
| 179 | + %C_pack = linalg.pack %C |
| 180 | + padding_value(%pad : i32) |
| 181 | + outer_dims_perm = [0, 1] |
| 182 | + inner_dims_pos = [0, 1] |
| 183 | + inner_tiles = [8, %vs_c8] into %C_pack_empty : tensor<7x13xi32> -> tensor<1x?x8x?xi32> |
| 184 | + |
| 185 | + return %C_pack : tensor<1x?x8x?xi32> |
| 186 | +} |
| 187 | + |
| 188 | +//===----------------------------------------------------------------------===// |
| 189 | +// @unpack_acc |
| 190 | +// |
| 191 | +// Implements unpacking for the C matrix (accumulator) in matrix |
| 192 | +// multiplication. The inner tile size is "scalable": 8 * vscale |
| 193 | +//===----------------------------------------------------------------------===// |
| 194 | +func.func private @unpack_acc(%C_packed: tensor<1x?x8x?xi32>) -> tensor<7x13xi32> { |
| 195 | + %vs = vector.vscale |
| 196 | + %c8 = arith.constant 8 : index |
| 197 | + %vs_c8 = arith.muli %vs, %c8 : index |
| 198 | + |
| 199 | + %C_out_empty = tensor.empty() : tensor<7x13xi32> |
| 200 | + %C_out_unpack = linalg.unpack %C_packed |
| 201 | + outer_dims_perm = [0, 1] |
| 202 | + inner_dims_pos = [0, 1] |
| 203 | + inner_tiles = [8, %vs_c8] |
| 204 | + into %C_out_empty : tensor<1x?x8x?xi32> -> tensor<7x13xi32> |
| 205 | + |
| 206 | + return %C_out_unpack: tensor<7x13xi32> |
| 207 | +} |
| 208 | + |
| 209 | +//===----------------------------------------------------------------------===// |
| 210 | +// Helper methods for printing |
| 211 | +//===----------------------------------------------------------------------===// |
| 212 | +func.func private @print_pack_A(%A_pack : tensor<1x16x8x1xi32>) -> () { |
| 213 | + %A_pack_cast = tensor.cast %A_pack : tensor<1x16x8x1xi32> to tensor<*xi32> |
| 214 | + call @printMemrefI32(%A_pack_cast) : (tensor<*xi32>) -> () |
| 215 | + |
| 216 | + return |
| 217 | +} |
| 218 | + |
| 219 | +func.func private @print_pack_B(%B_pack : tensor<?x16x?x1xi32>) -> () { |
| 220 | + %B_pack_cast = tensor.cast %B_pack : tensor<?x16x?x1xi32> to tensor<*xi32> |
| 221 | + call @printMemrefI32(%B_pack_cast) : (tensor<*xi32>) -> () |
| 222 | + |
| 223 | + return |
| 224 | +} |
| 225 | + |
| 226 | +func.func private @print_pack_C(%C_pack : tensor<1x?x8x?xi32>) -> () { |
| 227 | + %C_pack_cast = tensor.cast %C_pack : tensor<1x?x8x?xi32> to tensor<*xi32> |
| 228 | + call @printMemrefI32(%C_pack_cast) : (tensor<*xi32>) -> () |
| 229 | + |
| 230 | + return |
| 231 | +} |
| 232 | + |
| 233 | +//===----------------------------------------------------------------------===// |
| 234 | +// @matmul_via_mmt4d |
| 235 | +// |
| 236 | +// Implements matrix-multiplication via linalg.mmt4d |
| 237 | +//===----------------------------------------------------------------------===// |
| 238 | +func.func private @matmul_via_mmt4d(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tensor<7x13xi32>) -> tensor<7x13xi32> { |
| 239 | + // Pack input matrices |
| 240 | + %A_pack = func.call @pack_lhs(%A): (tensor<7x16xi32>) -> tensor<1x16x8x1xi32> |
| 241 | + %B_pack = func.call @pack_rhs(%B): (tensor<16x13xi32>) -> tensor<?x16x?x1xi32> |
| 242 | + %C_pack = func.call @pack_acc(%C): (tensor<7x13xi32>) -> tensor<1x?x8x?xi32> |
| 243 | + |
| 244 | + // Print the packed matrices (this is the only _visible_ part that changes |
| 245 | + // when adjusting the SVE vector size). |
| 246 | + func.call @print_pack_A(%A_pack) : (tensor<1x16x8x1xi32>) -> () |
| 247 | + func.call @print_pack_B(%B_pack) : (tensor<?x16x?x1xi32>) -> () |
| 248 | + func.call @print_pack_C(%C_pack) : (tensor<1x?x8x?xi32>) -> () |
| 249 | + |
| 250 | + // MMT4D |
| 251 | + %mmt4d = linalg.mmt4d ins(%A_pack, %B_pack : tensor<1x16x8x1xi32>, tensor<?x16x?x1xi32>) outs(%C_pack : tensor<1x?x8x?xi32>) -> tensor<1x?x8x?xi32> |
| 252 | + |
| 253 | + // Unpack the output |
| 254 | + %C_out_unpack = func.call @unpack_acc(%mmt4d) : (tensor<1x?x8x?xi32>) -> tensor<7x13xi32> |
| 255 | + |
| 256 | + return %C_out_unpack : tensor<7x13xi32> |
| 257 | +} |
| 258 | + |
| 259 | +//===----------------------------------------------------------------------===// |
| 260 | +// TD Sequence |
| 261 | +//===----------------------------------------------------------------------===// |
| 262 | +module @transforms attributes { transform.with_named_sequence } { |
| 263 | + transform.named_sequence @__transform_main(%module: !transform.any_op {transform.consumed}) { |
| 264 | + //========================================================================== |
| 265 | + // HANDLE MMT4D |
| 266 | + //========================================================================== |
| 267 | + %mmt4d = transform.collect_matching @match_mmt4d in %module : (!transform.any_op) -> (!transform.any_op) |
| 268 | + %mmt4d_func = transform.get_parent_op %mmt4d {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func"> |
| 269 | + |
| 270 | + // Step 1: Tile |
| 271 | + // Tile parallel dims (note, the N dim is scalable!) |
| 272 | + %tiled_mmt4d_parallel, %_:4 = transform.structured.tile_using_for %mmt4d tile_sizes [1, 1, 0, 8, [8], 0] |
| 273 | + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) |
| 274 | + // Tile reduction dims |
| 275 | + %tiled_mmt4d, %_1:2 = transform.structured.tile_using_for %tiled_mmt4d_parallel tile_sizes [0, 0, 1, 0, 0, 1] |
| 276 | + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) |
| 277 | + |
| 278 | + // Step 2: Vectorize linalg.mmt4d (note, the N dim is scalable!) |
| 279 | + transform.structured.vectorize %tiled_mmt4d |
| 280 | + vector_sizes [1, 1, 1, 8, [8], 1] {assume_dynamic_dims_match_vec_sizes} : !transform.any_op |
| 281 | + |
| 282 | + // Step 3: Simplify |
| 283 | + // vector.multi_reduction --> vector.contract |
| 284 | + // Generates a 6-dim vector.contract with the dim matching the original MMT4D Op |
| 285 | + // and with the following split into parallel and reduction dims: |
| 286 | + // * parallel, parallel, reduction, parallel, parallel, reduction |
| 287 | + transform.apply_patterns to %mmt4d_func { |
| 288 | + transform.apply_patterns.vector.reduction_to_contract |
| 289 | + // Reduce the rank of xfer ops. This transforms vector.contract to be |
| 290 | + // more matmul-like and to enable the lowering to outer product Ops. |
| 291 | + transform.apply_patterns.vector.transfer_permutation_patterns |
| 292 | + } : !transform.op<"func.func"> |
| 293 | + |
| 294 | + // Hoisting and LICM - not strictly required |
| 295 | + %mmt4d_func_h = transform.structured.hoist_redundant_vector_transfers %mmt4d_func |
| 296 | + : (!transform.op<"func.func">) -> !transform.op<"func.func"> |
| 297 | + %all_loops = transform.structured.match interface{LoopLikeInterface} in %mmt4d_func_h |
| 298 | + : (!transform.op<"func.func">) -> !transform.any_op |
| 299 | + transform.apply_licm to %all_loops : !transform.any_op |
| 300 | + transform.loop.hoist_loop_invariant_subsets %all_loops : !transform.any_op |
| 301 | + |
| 302 | + // Simplification |
| 303 | + transform.apply_patterns to %mmt4d_func_h { |
| 304 | + transform.apply_patterns.vector.reduction_to_contract |
| 305 | + transform.apply_patterns.vector.cast_away_vector_leading_one_dim |
| 306 | + transform.apply_patterns.canonicalization |
| 307 | + } : !transform.op<"func.func"> |
| 308 | + |
| 309 | + //========================================================================== |
| 310 | + // HANDLE PACK + UNPACK |
| 311 | + //========================================================================== |
| 312 | + %pack = transform.structured.match ops{["linalg.pack"]} in %module : (!transform.any_op) -> !transform.any_op |
| 313 | + %unpack = transform.structured.match ops{["linalg.unpack"]} in %module : (!transform.any_op) -> !transform.any_op |
| 314 | + |
| 315 | + // 1.1 Tile the linalg.pack Op so that we can decompose it into e.g. tensor.pad |
| 316 | + // and other lower-level Ops (see step 2.1) |
| 317 | + %tiled_pack_op_p, %loops_pack:2 = transform.structured.tile_using_for %pack tile_sizes [1, 1] |
| 318 | + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) |
| 319 | + |
| 320 | + // 1.2 Tile the linalg.unpack Op so that we can decompose it into e.g. tensor.pad |
| 321 | + // and other lower-level Ops (see step 2) |
| 322 | + %tiled_unpack_op_p, %loops_unpack:2 = transform.structured.tile_using_for %unpack tile_sizes [8, 1] |
| 323 | + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) |
| 324 | + |
| 325 | + // 2.1. Decompose tiled PackOp into lower-level Ops + simplify |
| 326 | + %func_op_pack = transform.get_parent_op %tiled_pack_op_p {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func"> |
| 327 | + transform.apply_patterns to %func_op_pack { |
| 328 | + transform.apply_patterns.linalg.decompose_pack_unpack |
| 329 | + transform.apply_patterns.linalg.decompose_pad |
| 330 | + } : !transform.op<"func.func"> |
| 331 | + |
| 332 | + transform.apply_patterns to %func_op_pack { |
| 333 | + transform.apply_patterns.tensor.fold_tensor_subset_ops |
| 334 | + transform.apply_patterns.canonicalization |
| 335 | + } : !transform.op<"func.func"> |
| 336 | + |
| 337 | + // 2.2. Decompose tiled UnpackOp into lower-level Ops + simplify |
| 338 | + %func_op_unpack = transform.get_parent_op %tiled_unpack_op_p {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func"> |
| 339 | + transform.apply_patterns to %func_op_unpack { |
| 340 | + transform.apply_patterns.linalg.decompose_pack_unpack |
| 341 | + } : !transform.op<"func.func"> |
| 342 | + |
| 343 | + transform.apply_patterns to %func_op_unpack { |
| 344 | + transform.apply_patterns.tensor.fold_tensor_subset_ops |
| 345 | + transform.apply_patterns.canonicalization |
| 346 | + } : !transform.op<"func.func"> |
| 347 | + |
| 348 | + //========================================================================== |
| 349 | + // BUFFERIZATION |
| 350 | + //========================================================================== |
| 351 | + %bufferize = transform.bufferization.one_shot_bufferize %module |
| 352 | + {bufferize_function_boundaries=true} : (!transform.any_op) -> !transform.any_op |
| 353 | + |
| 354 | + //========================================================================== |
| 355 | + // SIMPLIFY THE CONTRACT Op |
| 356 | + //========================================================================== |
| 357 | + %contract = transform.collect_matching @match_contract in %bufferize : (!transform.any_op) -> (!transform.any_op) |
| 358 | + %contract_func = transform.get_parent_op %contract {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func"> |
| 359 | + |
| 360 | + // Drop trailing unit dims (the correspondong pattern works only |
| 361 | + // post-bufferization) |
| 362 | + transform.apply_patterns to %contract_func { |
| 363 | + transform.apply_patterns.tensor.fold_tensor_subset_ops |
| 364 | + transform.apply_patterns.vector.drop_inner_most_unit_dims_from_xfer_ops |
| 365 | + transform.apply_patterns.canonicalization |
| 366 | + } : !transform.op<"func.func"> |
| 367 | + |
| 368 | + //========================================================================== |
| 369 | + // LOWER CONTRACT TO FMA |
| 370 | + //========================================================================== |
| 371 | + transform.apply_patterns to %contract_func { |
| 372 | + transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" |
| 373 | + transform.apply_patterns.vector.lower_outerproduct |
| 374 | + } : !transform.op<"func.func"> |
| 375 | + |
| 376 | + transform.yield |
| 377 | + } |
| 378 | + |
| 379 | + //========================================================================== |
| 380 | + // TD MATCHERS (helper hooks) |
| 381 | + //========================================================================== |
| 382 | + transform.named_sequence @match_mmt4d( |
| 383 | + %entry: !transform.any_op {transform.readonly}) -> !transform.any_op { |
| 384 | + transform.match.operation_name %entry ["linalg.mmt4d"] : !transform.any_op |
| 385 | + transform.yield %entry : !transform.any_op |
| 386 | + } |
| 387 | + |
| 388 | + transform.named_sequence @match_contract( |
| 389 | + %entry: !transform.any_op {transform.readonly}) -> !transform.any_op { |
| 390 | + transform.match.operation_name %entry ["vector.contract"] : !transform.any_op |
| 391 | + transform.yield %entry : !transform.any_op |
| 392 | + } |
| 393 | +} |
| 394 | + |
| 395 | +//===----------------------------------------------------------------------===// |
| 396 | +// Function signatures |
| 397 | +//===----------------------------------------------------------------------===// |
| 398 | +func.func private @printMemrefI32(%ptr : tensor<*xi32>) |
0 commit comments