Skip to content

Commit 075fdaa

Browse files
authored
[μKernels]: AMX - optimal register allocation (#1076)
Changes/Relaxation in `AMX` pass to support optimal register allocation.
1 parent 232e854 commit 075fdaa

File tree

3 files changed

+71
-4
lines changed

3 files changed

+71
-4
lines changed

lib/TPP/Transforms/VectorContractToAMX.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ static SmallVector<Value> createTileMuls(OpBuilder &builder, Location loc,
388388
SmallVector<Value> results;
389389
int numIterArgs = 0;
390390
for (unsigned i = 0; i < aLoadTiles.size(); i++) {
391-
for (unsigned j = 0; j < aLoadTiles.size(); j++) {
391+
for (unsigned j = 0; j < bLoadTiles.size(); j++) {
392392
auto amx =
393393
resType.getElementType().isFloat()
394394
? builder.create<amx::TileMulFOp>(loc, resType, aLoadTiles[i],
@@ -515,10 +515,10 @@ struct VectorContractToAMXPattern
515515
auto accType = cast<ShapedType>(accDefiningOp.getType());
516516
int64_t M = accType.getDimSize(0);
517517
int64_t N = accType.getDimSize(1);
518-
// M and N must be equal and divisible by 16.
519-
if (M != N || M % 16 != 0 || N % 16 != 0)
518+
// M and N must be divisible by 16.
519+
if (M % 16 != 0 || N % 16 != 0)
520520
return rewriter.notifyMatchFailure(
521-
op, "Output matrix dimensions must be equal and divisible by 16");
521+
op, "Output matrix dimensions must be divisible by 16");
522522

523523
auto accSubview = accDefiningOp.getBase();
524524
Location loc = op.getLoc();
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// RUN: tpp-run -e optimal_blocking --entry-point-result=void -print --splat-to-random --init-type normal -seed 123 %s > %t.1
2+
// RUN: tpp-opt %s --tile-brgemm-linalg="registerBlocking=48,16,32" --loop-invariant-code-motion --vectorization-pass --hoist-vector-transfer --vector-contract-to-amx | tpp-run -e optimal_blocking --entry-point-result=void -print --splat-to-random --init-type normal -seed 123 > %t.2
3+
// RUN: fpcmp -r 0.001 %t.1 %t.2
4+
5+
func.func @optimal_blocking(%arg0: memref<1x48x16x2xbf16>, %arg1: memref<1x16x16x2xbf16>, %arg2: memref<48x16xf32>) -> memref<48x16xf32> {
6+
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>, affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : memref<1x48x16x2xbf16>, memref<1x16x16x2xbf16>) outs(%arg2 : memref<48x16xf32>) {
7+
^bb0(%in: bf16, %in_1: bf16, %out: f32):
8+
%a = arith.extf %in : bf16 to f32
9+
%b = arith.extf %in_1 : bf16 to f32
10+
%1 = arith.mulf %a, %b : f32
11+
%2 = arith.addf %out, %1 : f32
12+
linalg.yield %2 : f32
13+
}
14+
return %arg2 : memref<48x16xf32>
15+
}
16+
17+
18+
// RUN: tpp-run -e optimal_blocking_1x3 --entry-point-result=void -print --splat-to-random --init-type normal -seed 123 %s > %t.1
19+
// RUN: tpp-opt %s --tile-brgemm-linalg="registerBlocking=16,48,32" --loop-invariant-code-motion --vectorization-pass --hoist-vector-transfer --vector-contract-to-amx | tpp-run -e optimal_blocking_1x3 --entry-point-result=void -print --splat-to-random --init-type normal -seed 123 > %t.2
20+
// RUN: fpcmp -r 0.001 %t.1 %t.2
21+
22+
func.func @optimal_blocking_1x3(%arg0: memref<1x16x16x2xbf16>, %arg1: memref<1x16x48x2xbf16>, %arg2: memref<16x48xf32>) -> memref<16x48xf32> {
23+
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>, affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : memref<1x16x16x2xbf16>, memref<1x16x48x2xbf16>) outs(%arg2 : memref<16x48xf32>) {
24+
^bb0(%in: bf16, %in_1: bf16, %out: f32):
25+
%a = arith.extf %in : bf16 to f32
26+
%b = arith.extf %in_1 : bf16 to f32
27+
%1 = arith.mulf %a, %b : f32
28+
%2 = arith.addf %out, %1 : f32
29+
linalg.yield %2 : f32
30+
}
31+
return %arg2 : memref<16x48xf32>
32+
}

test/Passes/pass-vector-contract-to-amx.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,41 @@ func.func @entry(%arg0: memref<8x32x32x32xbf16>, %arg1: memref<2x32x16x32x2xbf16
204204

205205
// -----
206206

207+
func.func @optimal_register_blocking_3x1(%arg0: memref<1x48x16x2xbf16>, %arg1: memref<1x16x16x2xbf16>, %arg2: memref<48x16xf32>) -> memref<48x16xf32> {
208+
%0 = ub.poison : f32
209+
%1 = ub.poison : bf16
210+
%c0 = arith.constant 0 : index
211+
%c48 = arith.constant 48 : index
212+
%c16 = arith.constant 16 : index
213+
%c1 = arith.constant 1 : index
214+
scf.for %arg3 = %c0 to %c48 step %c48 {
215+
scf.for %arg4 = %c0 to %c16 step %c16 {
216+
%subview = memref.subview %arg2[%arg3, %arg4] [48, 16] [1, 1] : memref<48x16xf32> to memref<48x16xf32, strided<[16, 1], offset: ?>>
217+
%2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} : memref<48x16xf32, strided<[16, 1], offset: ?>>, vector<48x16xf32>
218+
%3 = scf.for %arg5 = %c0 to %c1 step %c1 iter_args(%arg6 = %2) -> (vector<48x16xf32>) {
219+
%4 = scf.for %arg7 = %c0 to %c16 step %c16 iter_args(%arg8 = %arg6) -> (vector<48x16xf32>) {
220+
%subview_0 = memref.subview %arg0[%arg5, %arg3, %arg7, 0] [1, 48, 16, 2] [1, 1, 1, 1] : memref<1x48x16x2xbf16> to memref<1x48x16x2xbf16, strided<[1536, 32, 2, 1], offset: ?>>
221+
%subview_1 = memref.subview %arg1[%arg5, %arg7, %arg4, 0] [1, 16, 16, 2] [1, 1, 1, 1] : memref<1x16x16x2xbf16> to memref<1x16x16x2xbf16, strided<[512, 32, 2, 1], offset: ?>>
222+
%5 = vector.transfer_read %subview_0[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x48x16x2xbf16, strided<[1536, 32, 2, 1], offset: ?>>, vector<1x48x16x2xbf16>
223+
%6 = vector.transfer_read %subview_1[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x16x16x2xbf16, strided<[512, 32, 2, 1], offset: ?>>, vector<1x16x16x2xbf16>
224+
%7 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>, affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %5, %6, %arg8 : vector<1x48x16x2xbf16>, vector<1x16x16x2xbf16> into vector<48x16xf32>
225+
scf.yield %7 : vector<48x16xf32>
226+
}
227+
scf.yield %4 : vector<48x16xf32>
228+
}
229+
vector.transfer_write %3, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<48x16xf32>, memref<48x16xf32, strided<[16, 1], offset: ?>>
230+
}
231+
}
232+
return %arg2 : memref<48x16xf32>
233+
}
234+
235+
// CHECK-LABEL: func.func @optimal_register_blocking_3x1
236+
// CHECK-COUNT-3: amx.tile_load
237+
// CHECK-COUNT-3: amx.tile_mulf
238+
// CHECK-COUNT-3: amx.tile_store
239+
240+
// -----
241+
207242
// This tests shows the lowering of a mixed precision vector.contract
208243
// (i8 x i8 -> i32) to AMX dialect.
209244
func.func @entry(%arg0: memref<4x16x64x64xi8>, %arg1: memref<16x16x16x64x4xi8>, %arg2: memref<4x16x64x64xi32>) {

0 commit comments

Comments
 (0)