Skip to content

Commit 448811d

Browse files
authored
[mlir][amx] Add write side effect to AMX tile creation ops (#155403)
Adds `MemWrite` side effect to `amx.tile_zero` and `amx.tile_load` ops. Memory write models hardware populating AMX tiles with specified values through tile zero and load ops. Making the side effect explicit allows to use multiple op instances as a compilation hint to use different AMX tile registers. This can prevent less efficient lowering through tile store-load copies compared to directly populating tiles with values. To illustrate the trade off: Without explicit side effects, `CSE` optimizes two `amx.tile_zero` into a single op which lowers to a copy for the second tile: ``` tilezero %tmm0 tilestored %tmm0, -2032(%rbp,%rbx) # 1024-byte Folded Spill tileloadd -2032(%rbp,%rbx), %tmm1 # 1024-byte Folded Reload ``` By keeping the two `amx.tile_zero` ops and, thus, lowering to two separate intrinsic invocations, the two tile registers are zeroed out directly without the additional round trip through memory: ``` tilezero %tmm0 tilezero %tmm1 ``` The same principle applies to `amx.tile_load` ops.
1 parent 9155f51 commit 448811d

File tree

2 files changed

+42
-5
lines changed

2 files changed

+42
-5
lines changed

mlir/include/mlir/Dialect/AMX/AMX.td

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,14 +142,17 @@ class AMX_Op<string mnemonic, list<Trait> traits = []> :
142142
// Tile reset.
143143
//
144144

145-
def TileZeroOp : AMX_Op<"tile_zero", [Pure,
146-
AMXIntrinsicOpInterface
145+
def TileZeroOp : AMX_Op<"tile_zero", [
146+
AMXIntrinsicOpInterface,
147+
MemoryEffects<[MemWrite]>
147148
]> {
148149
let summary = "tile zero operation";
149150
let description = [{
150151
Zeroes the destination tile, with the shape defined by the 2-dim
151152
vector type of the result. This is eventually lowered into the
152153
"tilezero" instruction with the corresponding tile configuration.
154+
With memory-effects, each "tilezero" operation serves as a compilation
155+
hint to use a separate tile register.
153156

154157
Example:
155158

@@ -179,15 +182,17 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure,
179182
// Tile memory operations.
180183
//
181184

182-
def TileLoadOp : AMX_Op<"tile_load", [Pure,
183-
AMXIntrinsicOpInterface
185+
def TileLoadOp : AMX_Op<"tile_load", [
186+
AMXIntrinsicOpInterface,
187+
MemoryEffects<[MemWrite]>
184188
]> {
185189
let summary = "tile load operation";
186190
let description = [{
187191
Loads a tile from memory defined by a base and indices, with the
188192
shape defined by the 2-dim vector type of the result. This is
189193
eventually lowered into the "tileloadd" instruction with the
190-
corresponding tile configuration.
194+
corresponding tile configuration. With memory-effects, each "tileload"
195+
operation serves as a compilation hint to use a separate tile register.
191196

192197
Example:
193198

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// RUN: mlir-opt %s -cse -convert-vector-to-llvm="enable-amx" | FileCheck %s
2+
3+
// With inclusion of memory side-effects, it is expected CSE not to fold multiple
4+
// "tileload" and "tilezero".
5+
// CHECK-LABEL: do_not_fold_tiles(
6+
// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
7+
// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
8+
// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
9+
// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
10+
// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
11+
// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
12+
func.func @do_not_fold_tiles(%arg0: memref<2x32x32xbf16>, %arg1: memref<2x16x32xbf16>) -> memref<16x32xf32> {
13+
%c1 = arith.constant 1 : index
14+
%c0 = arith.constant 0 : index
15+
%c2 = arith.constant 2 : index
16+
%c16 = arith.constant 16 : index
17+
%alloca = memref.alloca() : memref<16x32xf32>
18+
%0 = amx.tile_zero : !amx.tile<16x16xf32>
19+
%1 = amx.tile_zero : !amx.tile<16x16xf32>
20+
%2:2 = scf.for %arg2 = %c0 to %c2 step %c1 iter_args(%arg3 = %0, %arg4 = %1) -> (!amx.tile<16x16xf32>, !amx.tile<16x16xf32>) {
21+
%3 = amx.tile_load %arg0[%arg2, %c0, %c0] : memref<2x32x32xbf16> into !amx.tile<16x32xbf16>
22+
%4 = amx.tile_load %arg0[%arg2, %c16, %c0] : memref<2x32x32xbf16> into !amx.tile<16x32xbf16>
23+
%5 = amx.tile_load %arg1[%arg2, %c0, %c0] : memref<2x16x32xbf16> into !amx.tile<16x32xbf16>
24+
%6 = amx.tile_load %arg1[%arg2, %c0, %c0] : memref<2x16x32xbf16> into !amx.tile<16x32xbf16>
25+
%7 = amx.tile_mulf %3, %5, %arg3 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
26+
%8 = amx.tile_mulf %4, %6, %arg4 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
27+
scf.yield %7, %8 : !amx.tile<16x16xf32>, !amx.tile<16x16xf32>
28+
}
29+
amx.tile_store %alloca[%c0, %c0], %2#0 : memref<16x32xf32>, !amx.tile<16x16xf32>
30+
amx.tile_store %alloca[%c0, %c16], %2#1 : memref<16x32xf32>, !amx.tile<16x16xf32>
31+
return %alloca : memref<16x32xf32>
32+
}

0 commit comments

Comments
 (0)