2929#define AMX
3030
3131include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
32+ include "mlir/Dialect/AMX/AMXInterfaces.td"
3233include "mlir/Interfaces/SideEffectInterfaces.td"
3334include "mlir/IR/AttrTypeBase.td"
3435include "mlir/IR/BuiltinTypes.td"
@@ -47,8 +48,6 @@ def AMX_Dialect : Dialect {
4748
4849 This `AMX` dialect provides a bridge between MLIR concepts such as
4950 vectors and memrefs and the lower level LLVM IR support of AMX.
50- The dialect is split into user-facing AMX ops (AMX_Op) and
51- backend-facing intrinsic ops (AMX_IntrOp).
5251
5352 Note that since configuration changes (implicit at dialect level) are
5453 costly, it is highly recommended to use the AMX dialect on same-shaped
@@ -135,21 +134,17 @@ def AMXTileI8 : AMXTileOf<[I8]>;
135134class AMX_Op<string mnemonic, list<Trait> traits = []> :
136135 Op<AMX_Dialect, mnemonic, traits> {}
137136
138- // The "internal" intrinsics are meant for compiler usage.
139- class AMX_IntrOp<string mnemonic, int numResults, list<Trait> traits = []> :
140- LLVM_IntrOpBase<AMX_Dialect, mnemonic,
141- "x86_" # !subst(".", "_", mnemonic) # "_internal",
142- [], [], traits, numResults>;
143-
144137//===----------------------------------------------------------------------===//
145- // AMX Op definitions (user facing).
138+ // AMX Op definitions
146139//===----------------------------------------------------------------------===//
147140
148141//
149142// Tile reset.
150143//
151144
152- def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
145+ def TileZeroOp : AMX_Op<"tile_zero", [Pure,
146+ AMXIntrinsicOpInterface
147+ ]> {
153148 let summary = "tile zero operation";
154149 let description = [{
155150 Zeroes the destination tile, with the shape defined by the 2-dim
@@ -167,6 +162,14 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
167162 TileType getTileType() {
168163 return ::llvm::cast<TileType>(getRes().getType());
169164 }
165+
166+ std::string getIntrinsicName() {
167+ return "llvm.x86.tilezero.internal";
168+ }
169+ SmallVector<Value> getIntrinsicOperands(
170+ ::mlir::ArrayRef<Value> operands,
171+ const ::mlir::LLVMTypeConverter &typeConverter,
172+ ::mlir::RewriterBase &rewriter);
170173 }];
171174 let assemblyFormat = "attr-dict `:` qualified(type($res))";
172175 let hasVerifier = 1;
@@ -176,7 +179,9 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
176179// Tile memory operations.
177180//
178181
179- def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
182+ def TileLoadOp : AMX_Op<"tile_load", [Pure,
183+ AMXIntrinsicOpInterface
184+ ]> {
180185 let summary = "tile load operation";
181186 let description = [{
182187 Loads a tile from memory defined by a base and indices, with the
@@ -200,13 +205,23 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
200205 TileType getTileType() {
201206 return ::llvm::cast<TileType>(getRes().getType());
202207 }
208+
209+ std::string getIntrinsicName() {
210+ return "llvm.x86.tileloadd64.internal";
211+ }
212+ SmallVector<Value> getIntrinsicOperands(
213+ ::mlir::ArrayRef<Value> operands,
214+ const ::mlir::LLVMTypeConverter &typeConverter,
215+ ::mlir::RewriterBase &rewriter);
203216 }];
204217 let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
205218 "type($base) `into` qualified(type($res))";
206219 let hasVerifier = 1;
207220}
208221
209- def TileStoreOp : AMX_Op<"tile_store"> {
222+ def TileStoreOp : AMX_Op<"tile_store", [
223+ AMXIntrinsicOpInterface
224+ ]> {
210225 let summary = "tile store operation";
211226 let description = [{
212227 Stores a tile to memory defined by a base and indices, with the
@@ -230,6 +245,14 @@ def TileStoreOp : AMX_Op<"tile_store"> {
230245 TileType getTileType() {
231246 return ::llvm::cast<TileType>(getVal().getType());
232247 }
248+
249+ std::string getIntrinsicName() {
250+ return "llvm.x86.tilestored64.internal";
251+ }
252+ SmallVector<Value> getIntrinsicOperands(
253+ ::mlir::ArrayRef<Value> operands,
254+ const ::mlir::LLVMTypeConverter &typeConverter,
255+ ::mlir::RewriterBase &rewriter);
233256 }];
234257 let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
235258 "type($base) `,` qualified(type($val))";
@@ -240,8 +263,10 @@ def TileStoreOp : AMX_Op<"tile_store"> {
240263// Tile arithmetic operations.
241264//
242265
243- def TileMulFOp : AMX_Op<"tile_mulf", [
244- Pure, AllTypesMatch<["acc", "res"]>]> {
266+ def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
267+ AMXIntrinsicOpInterface,
268+ AllTypesMatch<["acc", "res"]>
269+ ]> {
245270 let summary = "tile multiplication operation (floating-point)";
246271 let description = [{
247272 Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
@@ -270,15 +295,30 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
270295 TileType getTileType() {
271296 return ::llvm::cast<TileType>(getRes().getType());
272297 }
298+
299+ std::string getIntrinsicName() {
300+ std::string intr = "llvm.x86.tdp";
301+ auto elementType =
302+ getLhsTileType().getElementType();
303+ intr += elementType.isF16() ? "fp16" : "bf16";
304+ intr += "ps.internal";
305+ return intr;
306+ }
307+ SmallVector<Value> getIntrinsicOperands(
308+ ::mlir::ArrayRef<Value> operands,
309+ const ::mlir::LLVMTypeConverter &typeConverter,
310+ ::mlir::RewriterBase &rewriter);
273311 }];
274312 let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
275313 "qualified(type($lhs)) `,` qualified(type($rhs))"
276314 " `,` qualified(type($acc)) ";
277315 let hasVerifier = 1;
278316}
279317
280- def TileMulIOp : AMX_Op<"tile_muli", [
281- Pure, AllTypesMatch<["acc", "res"]>]> {
318+ def TileMulIOp : AMX_Op<"tile_muli", [Pure,
319+ AMXIntrinsicOpInterface,
320+ AllTypesMatch<["acc", "res"]>
321+ ]> {
282322 let summary = "tile multiplication operation (integer)";
283323 let description = [{
284324 Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
@@ -313,77 +353,22 @@ def TileMulIOp : AMX_Op<"tile_muli", [
313353 TileType getTileType() {
314354 return ::llvm::cast<TileType>(getRes().getType());
315355 }
356+
357+ std::string getIntrinsicName() {
358+ std::string intr = "llvm.x86.tdpb";
359+ intr += getIsZextLhs() ? "u" : "s";
360+ intr += getIsZextRhs() ? "u" : "s";
361+ intr += "d.internal";
362+ return intr;
363+ }
364+ SmallVector<Value> getIntrinsicOperands(
365+ ::mlir::ArrayRef<Value> operands,
366+ const ::mlir::LLVMTypeConverter &typeConverter,
367+ ::mlir::RewriterBase &rewriter);
316368 }];
317369 let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
318370 "qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc)) ";
319371 let hasVerifier = 1;
320372}
321373
322- //===----------------------------------------------------------------------===//
323- // AMX IntrOp definitions (LLVM compiler facing).
324- //===----------------------------------------------------------------------===//
325-
326- //
327- // Tile reset. Parameters define the tile size.
328- //
329-
330- def LLVM_x86_amx_tilezero : AMX_IntrOp<"tilezero", 1>,
331- Arguments<(ins AnyInteger, AnyInteger)>;
332-
333- //
334- // Tile memory operations. Parameters define the tile size,
335- // base address, and stride between consecutive rows for the
336- // memory operation.
337- //
338-
339- def LLVM_x86_amx_tileloadd64 : AMX_IntrOp<"tileloadd64", 1>,
340- Arguments<(ins AnyInteger,
341- AnyInteger, LLVM_AnyPointer, AnyInteger)>;
342-
343- def LLVM_x86_amx_tilestored64 : AMX_IntrOp<"tilestored64", 0>,
344- Arguments<(ins AnyInteger,
345- AnyInteger, LLVM_AnyPointer, AnyInteger, LLVM_Type)>;
346-
347- //
348- // Tile multiplication operations (series of dot products). Parameters
349- // define the tile sizes and source and destination tiles for the
350- // operation. Note that the prefix "tdp" stands for tile dot product.
351- //
352-
353- // Dot product of bf16 tiles into f32 tile.
354- def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>,
355- Arguments<(ins AnyInteger,
356- AnyInteger,
357- AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
358-
359- // Dot product of f16 tiles into f32 tile.
360- def LLVM_x86_amx_tdpfp16ps : AMX_IntrOp<"tdpfp16ps", 1>,
361- Arguments<(ins AnyInteger,
362- AnyInteger,
363- AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
364-
365- // Dot product of i8 tiles into i32 tile (with sign/sign extension).
366- def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>,
367- Arguments<(ins AnyInteger,
368- AnyInteger,
369- AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
370-
371- // Dot product of i8 tiles into i32 tile (with sign/zero extension).
372- def LLVM_x86_amx_tdpbsud : AMX_IntrOp<"tdpbsud", 1>,
373- Arguments<(ins AnyInteger,
374- AnyInteger,
375- AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
376-
377- // Dot product of i8 tiles into i32 tile (with zero/sign extension).
378- def LLVM_x86_amx_tdpbusd : AMX_IntrOp<"tdpbusd", 1>,
379- Arguments<(ins AnyInteger,
380- AnyInteger,
381- AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
382-
383- // Dot product of i8 tiles into i32 tile (with zero/zero extension).
384- def LLVM_x86_amx_tdpbuud : AMX_IntrOp<"tdpbuud", 1>,
385- Arguments<(ins AnyInteger,
386- AnyInteger,
387- AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
388-
389374#endif // AMX
0 commit comments