@@ -39,10 +39,10 @@ def ArmSMETileType : I32EnumAttr<"ArmSMETileType", "Arm SME tile type",
3939
4040def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
4141 let description = [{
42- An interface for operations that use or allocate Arm SME tiles. These
43- operations need to be assigned a tile ID, an i32 attribute, which specifies
44- which virtual tile within the ZA storage to use. The number of tiles
45- available depends on the type of the tile. This is summarized below:
42+ An interface for operations that use Arm SME tiles. These operations need to
43+ be assigned a tile ID, an i32 attribute, which specifies which virtual tile
44+ within the ZA storage to use. The number of tiles available depends on the
45+ type of the tile. This is summarized below:
4646
4747 | Tile Vector Types | Possible Tile IDs |
4848 |-------------------------------------------------------------------------|---------------------|
@@ -51,10 +51,6 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
5151 | `vector<[4]x[4]xi32>` or `vector<[4]x[4]xf32>` | 0 to 3 (inclusive) |
5252 | `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>` | 0 to 7 (inclusive) |
5353 | `vector<[1]x[1]xi128>` | 0 to 15 (inclusive) |
54-
55- Operations that allocate a new tile (such as arm_sme.get_tile), are used as
56- the roots for tile allocation, with all operations that (transitively)
57- depend on a root being assigned the same tile ID.
5854 }];
5955 let methods = [
6056 InterfaceMethod<
@@ -84,20 +80,6 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
8480 return op->getAttrOfType<mlir::IntegerAttr>("tile_id");
8581 }]
8682 >,
87- InterfaceMethod<
88- [{
89- The type of tile this operation allocates. Returns none (std::nullopt)
90- if this operation does not allocate a tile.
91- }],
92- /*returnType=*/"std::optional<::mlir::arm_sme::ArmSMETileType>",
93- /*methodName=*/"getAllocatedTileType",
94- /*arguments=*/(ins),
95- /*methodBody=*/[{}],
96- /*defaultImpl=*/ [{
97- // This operation does not allocate a tile.
98- return std::nullopt;
99- }]
100- >,
10183 InterfaceMethod<
10284 "Returns the VectorType of the tile used by this operation.",
10385 /*returnType=*/"VectorType",
@@ -106,30 +88,13 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
10688 ];
10789
10890 let extraSharedClassDeclaration = [{
109- // A helper to create a new operation and propagate this operations tile ID.
110- template<typename T, typename... Args>
111- T createOpAndForwardTileId(::mlir::RewriterBase& rewriter, ::mlir::Location loc, Args &&...args) {
112- auto op = rewriter.create<T>(loc, std::forward<Args>(args)...);
113- if (auto tileOp = ::llvm::dyn_cast<ArmSMETileOpInterface>(op.getOperation()))
114- tileOp.setTileId($_op.getTileId());
115- return op;
116- }
117-
118- // A helper to replace this operation and forward its tile ID (if present).
119- template<typename T, typename... Args>
120- T replaceWithAndForwardTileId(::mlir::RewriterBase& rewriter, Args &&...args) {
121- auto newOp = createOpAndForwardTileId<T>(rewriter, $_op.getLoc(), std::forward<Args>(args)...);
122- rewriter.replaceOp($_op, newOp);
123- return newOp;
124- }
125-
12691 bool isInMemoryTile() {
12792 auto tileId = getTileId();
12893 return tileId && tileId.getInt() >= kInMemoryTileIdBase;
12994 }
13095 }];
13196
132- let verify = [{ return ::mlir::arm_sme::verifyOperationHasValidTileId ($_op); }];
97+ let verify = [{ return detail::verifyArmSMETileOpInterface ($_op); }];
13398}
13499
135100//===----------------------------------------------------------------------===//
@@ -255,30 +220,30 @@ def ArmSME_TypeSizeAttr : EnumAttr<ArmSME_Dialect, TypeSize,
255220class ArmSME_Op<string mnemonic, list<Trait> traits = []> :
256221 Op<ArmSME_Dialect, mnemonic, traits> {}
257222
258- def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
259- let summary = "Returns a SME virtual tile";
223+ def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface, Pure ]> {
224+ let summary = "Creates an undefined value of SME virtual tile type ";
260225 let description = [{
261- Allocates a new SME "virtual tile" within a function. The contents of the
262- tile returned from this operation are undefined.
226+ Creates a new SME "virtual tile" value within a function. The contents of
227+ the tile returned from this operation are undefined.
263228
264229 Example 1:
265230
266231 ```mlir
267- // Allocate an 8-bit element "virtual tile"
232+ // Create an 8-bit element "virtual tile" value:
268233 %za0_b = arm_sme.get_tile: vector<[16]x[16]xi8>
269234 ```
270235
271236 Example 2:
272237
273238 ```mlir
274- // Allocate two 16-bit element "virtual tiles"
239+ // Create two 16-bit element "virtual tiles" values:
275240 %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
276241 %za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
277242 ```
278243
279244 Example 3:
280245 ```mlir
281- // Allocate an 128-bit element "virtual tile"
246+ // Create an 128-bit element "virtual tile" value:
282247 %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
283248 ```
284249 }];
@@ -290,37 +255,15 @@ def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
290255 VectorType getTileType() {
291256 return ::llvm::cast<VectorType>(getTile().getType());
292257 }
293-
294- std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
295- return arm_sme::getSMETileType(getTileType());
296- }
297- }];
298- }
299-
300- def MaterializeSSATileOp : ArmSME_Op<"materialize_ssa_tile", [Pure]> {
301- let summary = "SME tile placeholder";
302- let description = [{
303- A placeholder to preserve dataflow while lowering to SME intrinsics (which
304- do not take or return SME virtual tile values). This operation is intended
305- to be DCE'd once all ArmSME operations have been lowered.
306-
307- This operation is not intended to be used outside of the ArmSME -> LLVM
308- conversion.
309258 }];
310- let results = (outs SMETile:$tile);
311- let assemblyFormat = "attr-dict `:` type($tile)";
312259}
313260
314- //
315- // Tile reset.
316- //
317-
318- def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
319- let summary = "Initialize the two-dimensional ZA array with 0s";
261+ def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface, Pure]> {
262+ let summary = "Creates a zero-initialized value of SME virtual tile type";
320263 let results = (outs SMETile:$res);
321264 let description = [{
322- Initialise ZA with 0. This operation is convenient wrapper for the SME
323- `zero` intrinsic and instruction .
265+ Creates a new SME "virtual tile" value within a function. The contents of
266+ the tile returned from this operation are zero-initialized .
324267
325268 Example 1: Zero an 8-bit element ZA tile.
326269
@@ -338,16 +281,39 @@ def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
338281 VectorType getVectorType() {
339282 return ::llvm::cast<VectorType>(getRes().getType());
340283 }
341- std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
342- return arm_sme::getSMETileType(getVectorType());
343- }
344284 VectorType getTileType() {
345285 return getVectorType();
346286 }
347287 }];
348288 let assemblyFormat = "attr-dict `:` type($res)";
349289}
350290
291+ def CopyTileOp : ArmSME_Op<"copy_tile", [
292+ Pure,
293+ ArmSMETileOpInterface,
294+ AllTypesMatch<["tile", "result"]>
295+ ]> {
296+ let summary = "Copies an SME tile value";
297+ let arguments = (ins SMETile:$tile);
298+ let results = (outs SMETile:$result);
299+ let description = [{
300+ Copies an SME "virtual tile" value to a new SSA value. This operation is
301+ primarily intended to be used to normalize the IR prior to tile allocation.
302+
303+ Example:
304+
305+ ```mlir
306+ %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
307+ ```
308+ }];
309+ let extraClassDeclaration = [{
310+ VectorType getTileType() {
311+ return ::llvm::cast<VectorType>(getResult().getType());
312+ }
313+ }];
314+ let assemblyFormat = "$tile attr-dict `:` type($result)";
315+ }
316+
351317def TileLoadOp : ArmSME_Op<"tile_load", [
352318 ArmSMETileOpInterface,
353319 AttrSizedOperandSegments,
@@ -417,9 +383,6 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
417383 VectorType getVectorType() {
418384 return ::llvm::cast<VectorType>(getResult().getType());
419385 }
420- std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
421- return arm_sme::getSMETileType(getVectorType());
422- }
423386 VectorType getTileType() {
424387 return getVectorType();
425388 }
@@ -545,7 +508,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
545508 ```
546509 }];
547510 let arguments = (ins
548- Arg<AnyMemRef, "the reference to load from">:$base, SVEPredicate:$mask,
511+ Arg<AnyMemRef, "the reference to load from", [MemRead] >:$base, SVEPredicate:$mask,
549512 SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index,
550513 ArmSME_TileSliceLayoutAttr:$layout
551514 );
@@ -630,7 +593,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
630593}
631594
632595def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
633- ArmSMETileOpInterface,
596+ ArmSMETileOpInterface, Pure,
634597 AllTypesMatch<["tile", "result"]>,
635598 TypesMatchWith<
636599 "type of 'vector' matches type of 'tile' slice",
@@ -679,7 +642,7 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
679642}
680643
681644def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [
682- ArmSMETileOpInterface,
645+ ArmSMETileOpInterface, Pure,
683646 TypesMatchWith<
684647 "type of 'result' matches type of 'tile' slice",
685648 "tile", "result",
@@ -736,6 +699,7 @@ class OuterProductResultTileTypeConstraint<string operand> :
736699
737700def OuterProductOp :
738701 ArmSME_Op<"outerproduct", [
702+ Pure,
739703 ArmSMETileOpInterface,
740704 AttrSizedOperandSegments,
741705 AllTypesMatch<["lhs", "rhs"]>,
@@ -802,12 +766,6 @@ let arguments = (ins
802766 VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
803767 VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
804768 VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
805- std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
806- // The outerproduct op allocates a new tile if no accumulator is passed.
807- if (!getAcc())
808- return arm_sme::getSMETileType(getResultType());
809- return std::nullopt;
810- }
811769 VectorType getTileType() {
812770 return getResultType();
813771 }
@@ -819,6 +777,7 @@ class OuterProductWideningBase<string mnemonic,
819777 list<Type> allowedResultVectorTypes,
820778 int numOuterProducts> :
821779 ArmSME_Op<mnemonic, [
780+ Pure,
822781 ArmSMETileOpInterface,
823782 AttrSizedOperandSegments,
824783 AllTypesMatch<["lhs", "rhs"]>,
@@ -857,12 +816,6 @@ class OuterProductWideningBase<string mnemonic,
857816 VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
858817 VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
859818 VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
860- std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
861- // The outerproduct op allocates a new tile if no accumulator is passed.
862- if (!getAcc())
863- return arm_sme::getSMETileType(getResultType());
864- return std::nullopt;
865- }
866819 VectorType getTileType() {
867820 return getResultType();
868821 }
0 commit comments