3030
3131include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
3232include "mlir/Interfaces/SideEffectInterfaces.td"
33+ include "mlir/IR/AttrTypeBase.td"
34+ include "mlir/IR/BuiltinTypes.td"
3335
3436//===----------------------------------------------------------------------===//
3537// AMX dialect definition.
@@ -55,8 +57,77 @@ def AMX_Dialect : Dialect {
5557 For details, see the Intel documentation:
5658 https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
5759 }];
60+ let useDefaultTypePrinterParser = 1;
5861}
5962
63+ //===----------------------------------------------------------------------===//
64+ // AMX Tile definition.
65+ //===----------------------------------------------------------------------===//
66+
67+ class AMX_Type<string typeName, string typeMnemonic, list<Trait> traits = []>
68+ : TypeDef<AMX_Dialect, typeName, traits> {
69+ let mnemonic = typeMnemonic;
70+ }
71+
72+ def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
73+ let cppFunctionName = "isValidTileTypeElementType";
74+ }
75+
76+ def AMX_TileType : AMX_Type<"Tile", "tile", [ShapedTypeInterface, ValueSemantics]> {
77+ let summary = "AMX 2D tile to be used by AMX opertaions.";
78+
79+ let description = [{
80+ This type is used to represent values in AMX tile registers. All AMX operations
81+ work on AMX tiles and these tiles cannot be used in other operations directly.
82+ LLVM IR type for AMX tile is a primitive type, but in MLIR we provide shape and
83+ element type for IR verification and lowering to LLVMIR dialect.
84+ }];
85+
86+ let parameters = (ins
87+ ArrayRefParameter<"int64_t">:$shape,
88+ AMX_TileTypeElementType:$elementType
89+ );
90+
91+ let builders = [
92+ TypeBuilderWithInferredContext<(ins
93+ "ArrayRef<int64_t>":$shape, "Type":$elementType), [{
94+ return $_get(elementType.getContext(), shape, elementType);
95+ }]>
96+ ];
97+
98+ let extraClassDeclaration = [{
99+ /// Returns if this type is ranked (always true).
100+ bool hasRank() const { return true; }
101+
102+ /// Clone this tile type with the given shape and element type. If the
103+ /// provided shape is `std::nullopt`, the current shape of the type is used.
104+ TileType cloneWith(std::optional<ArrayRef<int64_t>> shape,
105+ Type elementType) const {
106+ return get(shape.value_or(getShape()), elementType);
107+ }
108+ }];
109+
110+ let hasCustomAssemblyFormat = 1;
111+ let skipDefaultBuilders = 1;
112+ }
113+
114+ def IsAMXTilePred : And<[CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">,
115+ CPred<[{::llvm::cast<::mlir::amx::TileType>($_self).getRank() == 2}]>]>;
116+
117+ class AMXTileOf<list<Type> allowedTypes> :
118+ ShapedContainerType<allowedTypes, IsAMXTilePred, "tile",
119+ "::mlir::amx::TileType">;
120+
121+ def AnyAMXTile : AMXTileOf<[F32, F16, BF16, I32, I8]>;
122+
123+ def AMXTileF32 : AMXTileOf<[F32]>;
124+
125+ def AMXTileF16OrBF16 : AMXTileOf<[F16, BF16]>;
126+
127+ def AMXTileI32 : AMXTileOf<[I32]>;
128+
129+ def AMXTileI8 : AMXTileOf<[I8]>;
130+
60131//===----------------------------------------------------------------------===//
61132// AMX Op and IntrOp definitions.
62133//===----------------------------------------------------------------------===//
@@ -88,17 +159,16 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
88159 Example:
89160
90161 ```mlir
91- %0 = amx.tile_zero : vector <16x16xbf16>
162+ %0 = amx.tile_zero : !amx.tile <16x16xbf16>
92163 ```
93164 }];
94- let results = (outs
95- VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
165+ let results = (outs AnyAMXTile:$res);
96166 let extraClassDeclaration = [{
97- VectorType getVectorType () {
98- return ::llvm::cast<VectorType >(getRes().getType());
167+ TileType getTileType () {
168+ return ::llvm::cast<TileType >(getRes().getType());
99169 }
100170 }];
101- let assemblyFormat = "attr-dict `:` type($res)";
171+ let assemblyFormat = "attr-dict `:` qualified( type($res) )";
102172 let hasVerifier = 1;
103173}
104174
@@ -117,23 +187,22 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
117187 Example:
118188
119189 ```mlir
120- %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into vector <16x64xi8>
190+ %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into !amx.tile <16x64xi8>
121191 ```
122192 }];
123193 let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
124194 Variadic<Index>:$indices);
125- let results = (outs
126- VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
195+ let results = (outs AnyAMXTile:$res);
127196 let extraClassDeclaration = [{
128197 MemRefType getMemRefType() {
129198 return ::llvm::cast<MemRefType>(getBase().getType());
130199 }
131- VectorType getVectorType () {
132- return ::llvm::cast<VectorType >(getRes().getType());
200+ TileType getTileType () {
201+ return ::llvm::cast<TileType >(getRes().getType());
133202 }
134203 }];
135204 let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
136- "type($base) `into` type($res)";
205+ "type($base) `into` qualified( type($res) )";
137206 let hasVerifier = 1;
138207}
139208
@@ -148,22 +217,22 @@ def TileStoreOp : AMX_Op<"tile_store"> {
148217 Example:
149218
150219 ```mlir
151- amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, vector <16x64xi8>
220+ amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, !amx.tile <16x64xi8>
152221 ```
153222 }];
154223 let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
155224 Variadic<Index>:$indices,
156- VectorOfRankAndType<[2], [F32, BF16, I32, I8]> :$val);
225+ AnyAMXTile :$val);
157226 let extraClassDeclaration = [{
158227 MemRefType getMemRefType() {
159228 return ::llvm::cast<MemRefType>(getBase().getType());
160229 }
161- VectorType getVectorType () {
162- return ::llvm::cast<VectorType >(getVal().getType());
230+ TileType getTileType () {
231+ return ::llvm::cast<TileType >(getVal().getType());
163232 }
164233 }];
165234 let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
166- "type($base) `,` type($val)";
235+ "type($base) `,` qualified( type($val) )";
167236 let hasVerifier = 1;
168237}
169238
@@ -184,26 +253,27 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
184253
185254 ```mlir
186255 %0 = amx.tile_mulf %a, %b, %c
187- : vector <16x32xbf16>, vector <16x32xbf16>, vector <16x16xf32>
256+ : !amx.tile <16x32xbf16>, !amx.tile <16x32xbf16>, !amx.tile <16x16xf32>
188257 ```
189258 }];
190- let arguments = (ins VectorOfRankAndType<[2], [F32, BF16]> :$lhs,
191- VectorOfRankAndType<[2], [F32, BF16]> :$rhs,
192- VectorOfRankAndType<[2], [F32, BF16]> :$acc);
193- let results = (outs VectorOfRankAndType<[2], [F32, BF16]> :$res);
259+ let arguments = (ins AMXTileF16OrBF16 :$lhs,
260+ AMXTileF16OrBF16 :$rhs,
261+ AMXTileF32 :$acc);
262+ let results = (outs AMXTileF32 :$res);
194263 let extraClassDeclaration = [{
195- VectorType getLhsVectorType () {
196- return ::llvm::cast<VectorType >(getLhs().getType());
264+ TileType getLhsTileType () {
265+ return ::llvm::cast<TileType >(getLhs().getType());
197266 }
198- VectorType getRhsVectorType () {
199- return ::llvm::cast<VectorType >(getRhs().getType());
267+ TileType getRhsTileType () {
268+ return ::llvm::cast<TileType >(getRhs().getType());
200269 }
201- VectorType getVectorType () {
202- return ::llvm::cast<VectorType >(getRes().getType());
270+ TileType getTileType () {
271+ return ::llvm::cast<TileType >(getRes().getType());
203272 }
204273 }];
205274 let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
206- "type($lhs) `,` type($rhs) `,` type($acc) ";
275+ "qualified(type($lhs)) `,` qualified(type($rhs))"
276+ " `,` qualified(type($acc)) ";
207277 let hasVerifier = 1;
208278}
209279
@@ -223,29 +293,29 @@ def TileMulIOp : AMX_Op<"tile_muli", [
223293
224294 ```mlir
225295 %0 = amx.tile_muli %a zext, %b zext, %c
226- : vector <16x64xi8>, vector <16x64xi8>, vector <16x16xi32>
296+ : !amx.tile <16x64xi8>, !amx.tile <16x64xi8>, !amx.tile <16x16xi32>
227297 ```
228298 }];
229- let arguments = (ins VectorOfRankAndType<[2], [I32, I8]> :$lhs,
230- VectorOfRankAndType<[2], [I32, I8]> :$rhs,
231- VectorOfRankAndType<[2], [I32, I8]> :$acc,
299+ let arguments = (ins AMXTileI8 :$lhs,
300+ AMXTileI8 :$rhs,
301+ AMXTileI32 :$acc,
232302 UnitAttr:$isZextLhs,
233303 UnitAttr:$isZextRhs
234304 );
235- let results = (outs VectorOfRankAndType<[2], [I32, I8]> :$res);
305+ let results = (outs AMXTileI32 :$res);
236306 let extraClassDeclaration = [{
237- VectorType getLhsVectorType () {
238- return ::llvm::cast<VectorType >(getLhs().getType());
307+ TileType getLhsTileType () {
308+ return ::llvm::cast<TileType >(getLhs().getType());
239309 }
240- VectorType getRhsVectorType () {
241- return ::llvm::cast<VectorType >(getRhs().getType());
310+ TileType getRhsTileType () {
311+ return ::llvm::cast<TileType >(getRhs().getType());
242312 }
243- VectorType getVectorType () {
244- return ::llvm::cast<VectorType >(getRes().getType());
313+ TileType getTileType () {
314+ return ::llvm::cast<TileType >(getRes().getType());
245315 }
246316 }];
247317 let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
248- "type($lhs) `,` type($rhs) `,` type($acc) ";
318+ "qualified( type($lhs)) `,` qualified( type($rhs)) `,` qualified( type($acc) ) ";
249319 let hasVerifier = 1;
250320}
251321
@@ -286,6 +356,12 @@ def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>,
286356 AnyInteger,
287357 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
288358
359+ // Dot product of f16 tiles into f32 tile.
360+ def LLVM_x86_amx_tdpfp16ps : AMX_IntrOp<"tdpfp16ps", 1>,
361+ Arguments<(ins AnyInteger,
362+ AnyInteger,
363+ AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
364+
289365// Dot product of i8 tiles into i32 tile (with sign/sign extension).
290366def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>,
291367 Arguments<(ins AnyInteger,
0 commit comments