3030
3131include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
3232include "mlir/Interfaces/SideEffectInterfaces.td"
33+ include "mlir/IR/AttrTypeBase.td"
34+ include "mlir/IR/BuiltinTypes.td"
3335
3436//===----------------------------------------------------------------------===//
3537// AMX dialect definition.
@@ -55,8 +57,69 @@ def AMX_Dialect : Dialect {
5557 For details, see the Intel documentation:
5658 https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
5759 }];
60+ let useDefaultTypePrinterParser = 1;
5861}
5962
63+ //===----------------------------------------------------------------------===//
64+ // AMX Tile definition.
65+ //===----------------------------------------------------------------------===//
66+
67+ class AMX_Type<string typeName, string typeMnemonic, list<Trait> traits = []>
68+ : TypeDef<AMX_Dialect, typeName, traits> {
69+ let mnemonic = typeMnemonic;
70+ }
71+
72+ def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
73+ let cppFunctionName = "isValidTileTypeElementType";
74+ }
75+
76+ def AMX_TileType : AMX_Type<"Tile", "amx.tile", [ShapedTypeInterface, ValueSemantics]> {
77+ let summary = "AMX 2D tile to be used by AMX opertaions.";
78+
79+ let description = [{
80+ This type is used to represent values in AMX tile registers. All AMX operations
81+ work on AMX tiles and these tiles cannot be used in other operations directly.
82+ LLVM IR type for AMX tile is a primitive type, but in MLIR we provide shape and
83+ element type for IR verification and lowering to LLVMIR dialect.
84+ }];
85+
86+ let parameters = (ins
87+ ArrayRefParameter<"int64_t">:$shape,
88+ AMX_TileTypeElementType:$elementType
89+ );
90+
91+ let builders = [
92+ TypeBuilderWithInferredContext<(ins
93+ "ArrayRef<int64_t>":$shape, "Type":$elementType), [{
94+ return $_get(elementType.getContext(), shape, elementType);
95+ }]>
96+ ];
97+
98+ let extraClassDeclaration = [{
99+ /// Returns if this type is ranked (always true).
100+ bool hasRank() const { return true; }
101+
102+ /// Clone this tile type with the given shape and element type. If the
103+ /// provided shape is `std::nullopt`, the current shape of the type is used.
104+ TileType cloneWith(std::optional<ArrayRef<int64_t>> shape,
105+ Type elementType) const {
106+ return get(shape.value_or(getShape()), elementType);
107+ }
108+ }];
109+
110+ let hasCustomAssemblyFormat = 1;
111+ let skipDefaultBuilders = 1;
112+ }
113+
114+ def IsAMXTilePred : CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">;
115+
116+ def IsAMX2DTilePred : And<[IsAMXTilePred,
117+ CPred<[{::llvm::cast<::mlir::amx::TileType>($_self).getRank() == 2}]>]>;
118+
119+ class AMX2DTileOf<list<Type> allowedTypes> :
120+ ShapedContainerType<allowedTypes, IsAMX2DTilePred, "tile",
121+ "::mlir::amx::TileType">;
122+
60123//===----------------------------------------------------------------------===//
61124// AMX Op and IntrOp definitions.
62125//===----------------------------------------------------------------------===//
@@ -88,14 +151,14 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
88151 Example:
89152
90153 ```mlir
91- %0 = amx.tile_zero : vector <16x16xbf16>
154+ %0 = amx.tile_zero : <16x16xbf16>
92155 ```
93156 }];
94157 let results = (outs
95- VectorOfRankAndType<[2], [F32 , BF16, I32, I8]>:$res);
158+ AMX2DTileOf<[F32, F16 , BF16, I32, I8]>:$res);
96159 let extraClassDeclaration = [{
97- VectorType getVectorType () {
98- return ::llvm::cast<VectorType >(getRes().getType());
160+ TileType getTileType () {
161+ return ::llvm::cast<TileType >(getRes().getType());
99162 }
100163 }];
101164 let assemblyFormat = "attr-dict `:` type($res)";
@@ -117,19 +180,19 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
117180 Example:
118181
119182 ```mlir
120- %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into vector <16x64xi8>
183+ %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into <16x64xi8>
121184 ```
122185 }];
123186 let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
124187 Variadic<Index>:$indices);
125188 let results = (outs
126- VectorOfRankAndType<[2], [F32 , BF16, I32, I8]>:$res);
189+ AMX2DTileOf<[F32, F16 , BF16, I32, I8]>:$res);
127190 let extraClassDeclaration = [{
128191 MemRefType getMemRefType() {
129192 return ::llvm::cast<MemRefType>(getBase().getType());
130193 }
131- VectorType getVectorType () {
132- return ::llvm::cast<VectorType >(getRes().getType());
194+ TileType getTileType () {
195+ return ::llvm::cast<TileType >(getRes().getType());
133196 }
134197 }];
135198 let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
@@ -148,18 +211,18 @@ def TileStoreOp : AMX_Op<"tile_store"> {
148211 Example:
149212
150213 ```mlir
151- amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, vector <16x64xi8>
214+ amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, <16x64xi8>
152215 ```
153216 }];
154217 let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
155218 Variadic<Index>:$indices,
156- VectorOfRankAndType<[2], [F32 , BF16, I32, I8]>:$val);
219+ AMX2DTileOf<[F32, F16 , BF16, I32, I8]>:$val);
157220 let extraClassDeclaration = [{
158221 MemRefType getMemRefType() {
159222 return ::llvm::cast<MemRefType>(getBase().getType());
160223 }
161- VectorType getVectorType () {
162- return ::llvm::cast<VectorType >(getVal().getType());
224+ TileType getTileType () {
225+ return ::llvm::cast<TileType >(getVal().getType());
163226 }
164227 }];
165228 let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
@@ -183,23 +246,22 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
183246 Example:
184247
185248 ```mlir
186- %0 = amx.tile_mulf %a, %b, %c
187- : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
249+ %0 = amx.tile_mulf %a, %b, %c : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
188250 ```
189251 }];
190- let arguments = (ins VectorOfRankAndType<[2], [F32 , BF16]>:$lhs,
191- VectorOfRankAndType<[2], [F32 , BF16]>:$rhs,
192- VectorOfRankAndType<[2], [ F32, BF16 ]>:$acc);
193- let results = (outs VectorOfRankAndType<[2], [ F32, BF16 ]>:$res);
252+ let arguments = (ins AMX2DTileOf<[F16 , BF16]>:$lhs,
253+ AMX2DTileOf<[F16 , BF16]>:$rhs,
254+ AMX2DTileOf<[ F32]>:$acc);
255+ let results = (outs AMX2DTileOf<[ F32]>:$res);
194256 let extraClassDeclaration = [{
195- VectorType getLhsVectorType () {
196- return ::llvm::cast<VectorType >(getLhs().getType());
257+ TileType getLhsTileType () {
258+ return ::llvm::cast<TileType >(getLhs().getType());
197259 }
198- VectorType getRhsVectorType () {
199- return ::llvm::cast<VectorType >(getRhs().getType());
260+ TileType getRhsTileType () {
261+ return ::llvm::cast<TileType >(getRhs().getType());
200262 }
201- VectorType getVectorType () {
202- return ::llvm::cast<VectorType >(getRes().getType());
263+ TileType getTileType () {
264+ return ::llvm::cast<TileType >(getRes().getType());
203265 }
204266 }];
205267 let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
@@ -222,26 +284,25 @@ def TileMulIOp : AMX_Op<"tile_muli", [
222284 Example:
223285
224286 ```mlir
225- %0 = amx.tile_muli %a zext, %b zext, %c
226- : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
287+ %0 = amx.tile_muli %a zext, %b zext, %c : <16x64xi8>, <16x64xi8>, <16x16xi32>
227288 ```
228289 }];
229- let arguments = (ins VectorOfRankAndType<[2], [I32, I8]>:$lhs,
230- VectorOfRankAndType<[2], [I32, I8]>:$rhs,
231- VectorOfRankAndType<[2], [ I32, I8 ]>:$acc,
290+ let arguments = (ins AMX2DTileOf<[ I8]>:$lhs,
291+ AMX2DTileOf<[ I8]>:$rhs,
292+ AMX2DTileOf<[ I32]>:$acc,
232293 UnitAttr:$isZextLhs,
233294 UnitAttr:$isZextRhs
234295 );
235- let results = (outs VectorOfRankAndType<[2], [ I32, I8 ]>:$res);
296+ let results = (outs AMX2DTileOf<[ I32]>:$res);
236297 let extraClassDeclaration = [{
237- VectorType getLhsVectorType () {
238- return ::llvm::cast<VectorType >(getLhs().getType());
298+ TileType getLhsTileType () {
299+ return ::llvm::cast<TileType >(getLhs().getType());
239300 }
240- VectorType getRhsVectorType () {
241- return ::llvm::cast<VectorType >(getRhs().getType());
301+ TileType getRhsTileType () {
302+ return ::llvm::cast<TileType >(getRhs().getType());
242303 }
243- VectorType getVectorType () {
244- return ::llvm::cast<VectorType >(getRes().getType());
304+ TileType getTileType () {
305+ return ::llvm::cast<TileType >(getRes().getType());
245306 }
246307 }];
247308 let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
@@ -286,6 +347,12 @@ def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>,
286347 AnyInteger,
287348 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
288349
350+ // Dot product of f16 tiles into f32 tile.
351+ def LLVM_x86_amx_tdpfp16ps : AMX_IntrOp<"tdpfp16ps", 1>,
352+ Arguments<(ins AnyInteger,
353+ AnyInteger,
354+ AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
355+
289356// Dot product of i8 tiles into i32 tile (with sign/sign extension).
290357def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>,
291358 Arguments<(ins AnyInteger,
0 commit comments