Skip to content

Commit 7c26b4c

Browse files
committed
[mlir][test] Add e2e test for linalg.mmt4d + SVE
Adds an end-to-end test for computing matrix-multiplication using linalg.mmt4d, combined with "scalable" tiling and "scalable" vectorisation. This is similar to an existing example that does not use "scalable" sizes: * test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir
1 parent 77596b7 commit 7c26b4c

File tree

1 file changed

+398
-0
lines changed

1 file changed

+398
-0
lines changed
Lines changed: 398 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,398 @@
1+
// DEFINE: %{compile} = mlir-opt %s \
2+
// DEFINE: -transform-interpreter -test-transform-dialect-erase-schedule \
3+
// DEFINE: -cse -canonicalize -test-lower-to-llvm
4+
// DEFINE: %{entry_point} = main
5+
// DEFINE: %{run} = mlir-runner -e %{entry_point} -entry-point-result=void \
6+
// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
7+
8+
// RUN: %{compile} | %{run} | FileCheck %s
9+
10+
//===----------------------------------------------------------------------===//
11+
/// HIGH-LEVEL OVERVIEW
12+
///
13+
/// End-to-end test for computing matrix-multiplication using linalg.mmt4d. In
14+
/// particular, demonstrates how the following MLIR sequence (implemented in
15+
/// @matmul_via_mmt4d):
16+
///
17+
/// A_pack = linalg.pack A
18+
/// B_pack = linalg.pack B
19+
/// C_pack = linalg.pack C
20+
/// out_pack = linalg.mmt4d(A_pack, B_pack, C_pack)
21+
///
22+
/// is equivalent to:
23+
///
24+
/// linalg.matmul(A, B, C)
25+
///
26+
/// (implemented in @matmul_via_matmul).
27+
///
28+
/// NOTES ON IMPLEMENTATION
29+
/// 1. The MMT4D example uses _scalable_ tile sizes for data tiling.
30+
/// * The matrix-multiplication dimension that's scalable: N.
31+
///
32+
/// 2. The lowering of linalg.mmt4d leverages scalable vectorisation.
33+
/// * The matrix-multiplication dimension that's scalable: N (to match data
34+
/// tiling configuration).
35+
///
36+
/// 3. Neither `linalg.pack` nor `linalg.unpack` are vectorised ATM.
37+
///
38+
/// 4. The MMT4D and Pack/Unpack Ops are kept in seperate functions to isolate
39+
/// the corresponding lowering and lowering configs.
40+
/// * TODO: Ideally, we should consider fusion opportunities by moving these
41+
/// Ops into one function.
42+
//===----------------------------------------------------------------------===//
43+
44+
//===----------------------------------------------------------------------===//
45+
// @main
46+
//
47+
// The main entry point that computes matrix multiplication via linalg.mmt4d
48+
// and linalg.matmul. Note, the output should be independent of the underlying
49+
// Linalg Op used, as well as SVE vector length.
50+
//===----------------------------------------------------------------------===//
51+
func.func @main() {
52+
// Allocate and initialise the inputs
53+
%A_empty = tensor.empty() : tensor<7x16xi32>
54+
%B_empty = tensor.empty() : tensor<16x13xi32>
55+
56+
%c3 = arith.constant 3 : i32
57+
%c4 = arith.constant 4 : i32
58+
%A = linalg.fill ins(%c3 : i32) outs(%A_empty : tensor<7x16xi32>) -> tensor<7x16xi32>
59+
%B = linalg.fill ins(%c4 : i32) outs(%B_empty : tensor<16x13xi32>) -> tensor<16x13xi32>
60+
%C = arith.constant dense<[
61+
[ 1, 8, 15, 22, 29, 36, 43, 50, 57, 64, 71, 78, 85],
62+
[ 2, 9, 16, 23, 30, 37, 44, 51, 58, 65, 72, 79, 86],
63+
[ 3, 10, 17, 24, 31, 38, 45, 52, 59, 66, 73, 80, 87],
64+
[ 4, 11, 18, 25, 32, 39, 46, 53, 60, 67, 74, 81, 88],
65+
[ 5, 12, 19, 26, 33, 40, 47, 54, 61, 68, 75, 82, 89],
66+
[ 6, 13, 20, 27, 34, 41, 48, 55, 62, 69, 76, 83, 90],
67+
[ 7, 14, 21, 28, 35, 42, 49, 56, 63, 70, 77, 84, 91]
68+
]> : tensor<7x13xi32>
69+
70+
// VARIANT: Matrix multiplication via linalg.mmt4d
71+
// CHECK: Unranked Memref
72+
// CHECK: [193, 200, 207, 214, 221, 228, 235, 242, 249, 256, 263, 270, 277]
73+
// CHECK: [194, 201, 208, 215, 222, 229, 236, 243, 250, 257, 264, 271, 278]
74+
// CHECK: [195, 202, 209, 216, 223, 230, 237, 244, 251, 258, 265, 272, 279]
75+
// CHECK: [196, 203, 210, 217, 224, 231, 238, 245, 252, 259, 266, 273, 280]
76+
// CHECK: [197, 204, 211, 218, 225, 232, 239, 246, 253, 260, 267, 274, 281]
77+
// CHECK: [198, 205, 212, 219, 226, 233, 240, 247, 254, 261, 268, 275, 282]
78+
// CHECK: [199, 206, 213, 220, 227, 234, 241, 248, 255, 262, 269, 276, 283]
79+
%C_mmt4d = func.call @matmul_via_mmt4d(%A, %B, %C) : (tensor<7x16xi32>, tensor<16x13xi32>, tensor<7x13xi32>) -> tensor<7x13xi32>
80+
%C_mmt4d_cast = tensor.cast %C_mmt4d : tensor<7x13xi32> to tensor<*xi32>
81+
vector.print str "--------------------------\n"
82+
vector.print str "RESULT FROM linalg.mmt4d:\n"
83+
vector.print str "--------------------------\n"
84+
call @printMemrefI32(%C_mmt4d_cast) : (tensor<*xi32>) -> ()
85+
86+
// VARIANT: Matrix multiplication via linalg.matmul
87+
// CHECK: Unranked Memref
88+
// CHECK: [193, 200, 207, 214, 221, 228, 235, 242, 249, 256, 263, 270, 277]
89+
// CHECK: [194, 201, 208, 215, 222, 229, 236, 243, 250, 257, 264, 271, 278]
90+
// CHECK: [195, 202, 209, 216, 223, 230, 237, 244, 251, 258, 265, 272, 279]
91+
// CHECK: [196, 203, 210, 217, 224, 231, 238, 245, 252, 259, 266, 273, 280]
92+
// CHECK: [197, 204, 211, 218, 225, 232, 239, 246, 253, 260, 267, 274, 281]
93+
// CHECK: [198, 205, 212, 219, 226, 233, 240, 247, 254, 261, 268, 275, 282]
94+
// CHECK: [199, 206, 213, 220, 227, 234, 241, 248, 255, 262, 269, 276, 283]
95+
%C_matmul = func.call @matmul(%A, %B, %C) : (tensor<7x16xi32>, tensor<16x13xi32>, tensor<7x13xi32>) -> tensor<7x13xi32>
96+
%C_matmul_cast = tensor.cast %C_matmul : tensor<7x13xi32> to tensor<*xi32>
97+
vector.print str "\n--------------------------\n"
98+
vector.print str "RESULT FROM linalg.matmul:\n"
99+
vector.print str "--------------------------\n"
100+
call @printMemrefI32(%C_matmul_cast) : (tensor<*xi32>) -> ()
101+
102+
return
103+
}
104+
105+
//===----------------------------------------------------------------------===//
106+
// @matmul_via_matmul
107+
//
108+
// Implements matrix-multiplication via linalg.matmul
109+
//===----------------------------------------------------------------------===//
110+
func.func private @matmul(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tensor<7x13xi32>) -> tensor<7x13xi32> {
111+
%C_matmul = linalg.matmul ins(%A, %B: tensor<7x16xi32>, tensor<16x13xi32>)
112+
outs(%C: tensor<7x13xi32>) -> tensor<7x13xi32>
113+
114+
return %C_matmul : tensor<7x13xi32>
115+
}
116+
117+
//===----------------------------------------------------------------------===//
118+
// @matmul_via_mmt4d
119+
//
120+
// Implements matrix-multiplication via linalg.mmt4d
121+
//===----------------------------------------------------------------------===//
122+
func.func private @pack_lhs(%A: tensor<7x16xi32>) -> tensor<1x16x8x1xi32> {
123+
%pad = arith.constant 0 : i32
124+
125+
%A_pack_empty = tensor.empty() : tensor<1x16x8x1xi32>
126+
%A_pack = linalg.pack %A
127+
padding_value(%pad : i32)
128+
inner_dims_pos = [0, 1]
129+
inner_tiles = [8, 1]
130+
into %A_pack_empty : tensor<7x16xi32> -> tensor<1x16x8x1xi32>
131+
132+
return %A_pack : tensor<1x16x8x1xi32>
133+
}
134+
135+
//===----------------------------------------------------------------------===//
136+
// @pack_rhs
137+
//
138+
// Implements packing for the B matrix (RHS) in matrix multiplication. The
139+
// inner tile size is "scalable": 8 * vscale.
140+
//===----------------------------------------------------------------------===//
141+
func.func private @pack_rhs(%B: tensor<16x13xi32>) -> tensor<?x16x?x1xi32> {
142+
%pad = arith.constant 0 : i32
143+
144+
// Compute the outer tile size.
145+
%vs = vector.vscale
146+
%c8 = arith.constant 8 : index
147+
%vs_c8 = arith.muli %vs, %c8 : index
148+
%c13 = arith.constant 13 : index
149+
%outer_tile_size = arith.ceildivui %c13, %vs_c8 : index
150+
151+
%B_pack_empty = tensor.empty(%outer_tile_size, %vs_c8) : tensor<?x16x?x1xi32>
152+
%B_pack = linalg.pack %B
153+
padding_value(%pad : i32)
154+
outer_dims_perm = [1, 0]
155+
inner_dims_pos = [1, 0]
156+
inner_tiles = [%vs_c8, 1]
157+
into %B_pack_empty : tensor<16x13xi32> -> tensor<?x16x?x1xi32>
158+
159+
return %B_pack : tensor<?x16x?x1xi32>
160+
}
161+
162+
//===----------------------------------------------------------------------===//
163+
// @pack_acc
164+
//
165+
// Implements packing for the C matrix (accumulator) in matrix multiplication.
166+
// The inner tile size is "scalable": 8 * vscale
167+
//===----------------------------------------------------------------------===//
168+
func.func private @pack_acc(%C: tensor<7x13xi32>) -> tensor<1x?x8x?xi32> {
169+
%pad = arith.constant 0 : i32
170+
171+
// Compute the outer tile size.
172+
%c13 = arith.constant 13 : index
173+
%vs = vector.vscale
174+
%c8 = arith.constant 8 : index
175+
%vs_c8 = arith.muli %vs, %c8 : index
176+
%outer_tile_size = arith.ceildivui %c13, %vs_c8 : index
177+
178+
%C_pack_empty = tensor.empty(%outer_tile_size, %vs_c8) : tensor<1x?x8x?xi32>
179+
%C_pack = linalg.pack %C
180+
padding_value(%pad : i32)
181+
outer_dims_perm = [0, 1]
182+
inner_dims_pos = [0, 1]
183+
inner_tiles = [8, %vs_c8] into %C_pack_empty : tensor<7x13xi32> -> tensor<1x?x8x?xi32>
184+
185+
return %C_pack : tensor<1x?x8x?xi32>
186+
}
187+
188+
//===----------------------------------------------------------------------===//
189+
// @unpack_acc
190+
//
191+
// Implements unpacking for the C matrix (accumulator) in matrix
192+
// multiplication. The inner tile size is "scalable": 8 * vscale
193+
//===----------------------------------------------------------------------===//
194+
func.func private @unpack_acc(%C_packed: tensor<1x?x8x?xi32>) -> tensor<7x13xi32> {
195+
%vs = vector.vscale
196+
%c8 = arith.constant 8 : index
197+
%vs_c8 = arith.muli %vs, %c8 : index
198+
199+
%C_out_empty = tensor.empty() : tensor<7x13xi32>
200+
%C_out_unpack = linalg.unpack %C_packed
201+
outer_dims_perm = [0, 1]
202+
inner_dims_pos = [0, 1]
203+
inner_tiles = [8, %vs_c8]
204+
into %C_out_empty : tensor<1x?x8x?xi32> -> tensor<7x13xi32>
205+
206+
return %C_out_unpack: tensor<7x13xi32>
207+
}
208+
209+
//===----------------------------------------------------------------------===//
210+
// Helper methods for printing
211+
//===----------------------------------------------------------------------===//
212+
func.func private @print_pack_A(%A_pack : tensor<1x16x8x1xi32>) -> () {
213+
%A_pack_cast = tensor.cast %A_pack : tensor<1x16x8x1xi32> to tensor<*xi32>
214+
call @printMemrefI32(%A_pack_cast) : (tensor<*xi32>) -> ()
215+
216+
return
217+
}
218+
219+
func.func private @print_pack_B(%B_pack : tensor<?x16x?x1xi32>) -> () {
220+
%B_pack_cast = tensor.cast %B_pack : tensor<?x16x?x1xi32> to tensor<*xi32>
221+
call @printMemrefI32(%B_pack_cast) : (tensor<*xi32>) -> ()
222+
223+
return
224+
}
225+
226+
func.func private @print_pack_C(%C_pack : tensor<1x?x8x?xi32>) -> () {
227+
%C_pack_cast = tensor.cast %C_pack : tensor<1x?x8x?xi32> to tensor<*xi32>
228+
call @printMemrefI32(%C_pack_cast) : (tensor<*xi32>) -> ()
229+
230+
return
231+
}
232+
233+
//===----------------------------------------------------------------------===//
234+
// @matmul_via_mmt4d
235+
//
236+
// Implements matrix-multiplication via linalg.mmt4d
237+
//===----------------------------------------------------------------------===//
238+
func.func private @matmul_via_mmt4d(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tensor<7x13xi32>) -> tensor<7x13xi32> {
239+
// Pack input matrices
240+
%A_pack = func.call @pack_lhs(%A): (tensor<7x16xi32>) -> tensor<1x16x8x1xi32>
241+
%B_pack = func.call @pack_rhs(%B): (tensor<16x13xi32>) -> tensor<?x16x?x1xi32>
242+
%C_pack = func.call @pack_acc(%C): (tensor<7x13xi32>) -> tensor<1x?x8x?xi32>
243+
244+
// Print the packed matrices (this is the only _visible_ part that changes
245+
// when adjusting the SVE vector size).
246+
func.call @print_pack_A(%A_pack) : (tensor<1x16x8x1xi32>) -> ()
247+
func.call @print_pack_B(%B_pack) : (tensor<?x16x?x1xi32>) -> ()
248+
func.call @print_pack_C(%C_pack) : (tensor<1x?x8x?xi32>) -> ()
249+
250+
// MMT4D
251+
%mmt4d = linalg.mmt4d ins(%A_pack, %B_pack : tensor<1x16x8x1xi32>, tensor<?x16x?x1xi32>) outs(%C_pack : tensor<1x?x8x?xi32>) -> tensor<1x?x8x?xi32>
252+
253+
// Unpack the output
254+
%C_out_unpack = func.call @unpack_acc(%mmt4d) : (tensor<1x?x8x?xi32>) -> tensor<7x13xi32>
255+
256+
return %C_out_unpack : tensor<7x13xi32>
257+
}
258+
259+
//===----------------------------------------------------------------------===//
260+
// TD Sequence
261+
//===----------------------------------------------------------------------===//
262+
module @transforms attributes { transform.with_named_sequence } {
263+
transform.named_sequence @__transform_main(%module: !transform.any_op {transform.consumed}) {
264+
//==========================================================================
265+
// HANDLE MMT4D
266+
//==========================================================================
267+
%mmt4d = transform.collect_matching @match_mmt4d in %module : (!transform.any_op) -> (!transform.any_op)
268+
%mmt4d_func = transform.get_parent_op %mmt4d {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func">
269+
270+
// Step 1: Tile
271+
// Tile parallel dims (note, the N dim is scalable!)
272+
%tiled_mmt4d_parallel, %_:4 = transform.structured.tile_using_for %mmt4d tile_sizes [1, 1, 0, 8, [8], 0]
273+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
274+
// Tile reduction dims
275+
%tiled_mmt4d, %_1:2 = transform.structured.tile_using_for %tiled_mmt4d_parallel tile_sizes [0, 0, 1, 0, 0, 1]
276+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
277+
278+
// Step 2: Vectorize linalg.mmt4d (note, the N dim is scalable!)
279+
transform.structured.vectorize %tiled_mmt4d
280+
vector_sizes [1, 1, 1, 8, [8], 1] {assume_dynamic_dims_match_vec_sizes} : !transform.any_op
281+
282+
// Step 3: Simplify
283+
// vector.multi_reduction --> vector.contract
284+
// Generates a 6-dim vector.contract with the dim matching the original MMT4D Op
285+
// and with the following split into parallel and reduction dims:
286+
// * parallel, parallel, reduction, parallel, parallel, reduction
287+
transform.apply_patterns to %mmt4d_func {
288+
transform.apply_patterns.vector.reduction_to_contract
289+
// Reduce the rank of xfer ops. This transforms vector.contract to be
290+
// more matmul-like and to enable the lowering to outer product Ops.
291+
transform.apply_patterns.vector.transfer_permutation_patterns
292+
} : !transform.op<"func.func">
293+
294+
// Hoisting and LICM - not strictly required
295+
%mmt4d_func_h = transform.structured.hoist_redundant_vector_transfers %mmt4d_func
296+
: (!transform.op<"func.func">) -> !transform.op<"func.func">
297+
%all_loops = transform.structured.match interface{LoopLikeInterface} in %mmt4d_func_h
298+
: (!transform.op<"func.func">) -> !transform.any_op
299+
transform.apply_licm to %all_loops : !transform.any_op
300+
transform.loop.hoist_loop_invariant_subsets %all_loops : !transform.any_op
301+
302+
// Simplification
303+
transform.apply_patterns to %mmt4d_func_h {
304+
transform.apply_patterns.vector.reduction_to_contract
305+
transform.apply_patterns.vector.cast_away_vector_leading_one_dim
306+
transform.apply_patterns.canonicalization
307+
} : !transform.op<"func.func">
308+
309+
//==========================================================================
310+
// HANDLE PACK + UNPACK
311+
//==========================================================================
312+
%pack = transform.structured.match ops{["linalg.pack"]} in %module : (!transform.any_op) -> !transform.any_op
313+
%unpack = transform.structured.match ops{["linalg.unpack"]} in %module : (!transform.any_op) -> !transform.any_op
314+
315+
// 1.1 Tile the linalg.pack Op so that we can decompose it into e.g. tensor.pad
316+
// and other lower-level Ops (see step 2.1)
317+
%tiled_pack_op_p, %loops_pack:2 = transform.structured.tile_using_for %pack tile_sizes [1, 1]
318+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
319+
320+
// 1.2 Tile the linalg.unpack Op so that we can decompose it into e.g. tensor.pad
321+
// and other lower-level Ops (see step 2)
322+
%tiled_unpack_op_p, %loops_unpack:2 = transform.structured.tile_using_for %unpack tile_sizes [8, 1]
323+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
324+
325+
// 2.1. Decompose tiled PackOp into lower-level Ops + simplify
326+
%func_op_pack = transform.get_parent_op %tiled_pack_op_p {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func">
327+
transform.apply_patterns to %func_op_pack {
328+
transform.apply_patterns.linalg.decompose_pack_unpack
329+
transform.apply_patterns.linalg.decompose_pad
330+
} : !transform.op<"func.func">
331+
332+
transform.apply_patterns to %func_op_pack {
333+
transform.apply_patterns.tensor.fold_tensor_subset_ops
334+
transform.apply_patterns.canonicalization
335+
} : !transform.op<"func.func">
336+
337+
// 2.2. Decompose tiled UnpackOp into lower-level Ops + simplify
338+
%func_op_unpack = transform.get_parent_op %tiled_unpack_op_p {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func">
339+
transform.apply_patterns to %func_op_unpack {
340+
transform.apply_patterns.linalg.decompose_pack_unpack
341+
} : !transform.op<"func.func">
342+
343+
transform.apply_patterns to %func_op_unpack {
344+
transform.apply_patterns.tensor.fold_tensor_subset_ops
345+
transform.apply_patterns.canonicalization
346+
} : !transform.op<"func.func">
347+
348+
//==========================================================================
349+
// BUFFERIZATION
350+
//==========================================================================
351+
%bufferize = transform.bufferization.one_shot_bufferize %module
352+
{bufferize_function_boundaries=true} : (!transform.any_op) -> !transform.any_op
353+
354+
//==========================================================================
355+
// SIMPLIFY THE CONTRACT Op
356+
//==========================================================================
357+
%contract = transform.collect_matching @match_contract in %bufferize : (!transform.any_op) -> (!transform.any_op)
358+
%contract_func = transform.get_parent_op %contract {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func">
359+
360+
// Drop trailing unit dims (the correspondong pattern works only
361+
// post-bufferization)
362+
transform.apply_patterns to %contract_func {
363+
transform.apply_patterns.tensor.fold_tensor_subset_ops
364+
transform.apply_patterns.vector.drop_inner_most_unit_dims_from_xfer_ops
365+
transform.apply_patterns.canonicalization
366+
} : !transform.op<"func.func">
367+
368+
//==========================================================================
369+
// LOWER CONTRACT TO FMA
370+
//==========================================================================
371+
transform.apply_patterns to %contract_func {
372+
transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
373+
transform.apply_patterns.vector.lower_outerproduct
374+
} : !transform.op<"func.func">
375+
376+
transform.yield
377+
}
378+
379+
//==========================================================================
380+
// TD MATCHERS (helper hooks)
381+
//==========================================================================
382+
transform.named_sequence @match_mmt4d(
383+
%entry: !transform.any_op {transform.readonly}) -> !transform.any_op {
384+
transform.match.operation_name %entry ["linalg.mmt4d"] : !transform.any_op
385+
transform.yield %entry : !transform.any_op
386+
}
387+
388+
transform.named_sequence @match_contract(
389+
%entry: !transform.any_op {transform.readonly}) -> !transform.any_op {
390+
transform.match.operation_name %entry ["vector.contract"] : !transform.any_op
391+
transform.yield %entry : !transform.any_op
392+
}
393+
}
394+
395+
//===----------------------------------------------------------------------===//
396+
// Function signatures
397+
//===----------------------------------------------------------------------===//
398+
func.func private @printMemrefI32(%ptr : tensor<*xi32>)

0 commit comments

Comments
 (0)