1414#ifndef ARMSME_OPS
1515#define ARMSME_OPS
1616
17+ include "mlir/IR/EnumAttr.td"
1718include "mlir/IR/OpBase.td"
1819include "mlir/Interfaces/SideEffectInterfaces.td"
1920include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
@@ -36,6 +37,7 @@ def ArmSME_Dialect : Dialect {
3637 https://developer.arm.com/documentation/ddi0602/2023-03/SME-Instructions
3738 }];
3839 let dependentDialects = ["scf::SCFDialect", "vector::VectorDialect"];
40+ let useDefaultAttributePrinterParser = 1;
3941}
4042
4143//===----------------------------------------------------------------------===//
@@ -83,6 +85,24 @@ def TileElementWidthMatchesTileID : TypesMatchWith<
8385 "::llvm::cast<VectorType>($_self).getElementType())"
8486 ".getWidth())">;
8587
88+ //===----------------------------------------------------------------------===//
89+ // ArmSME attr definitions
90+ //===----------------------------------------------------------------------===//
91+
92+ def TileSliceLayout : I32EnumAttr<"TileSliceLayout", "Layout of a tile slice", [
93+ I32EnumAttrCase<"Horizontal", 0, "horizontal">,
94+ I32EnumAttrCase<"Vertical", 1, "vertical">,
95+ ]> {
96+ let cppNamespace = "::mlir::arm_sme";
97+ let genSpecializedAttr = 0;
98+ }
99+
100+ /// An attribute that specifies the layout of a tile slice in a tile.
101+ def ArmSME_TileSliceLayoutAttr : EnumAttr<ArmSME_Dialect, TileSliceLayout,
102+ "layout"> {
103+ let assemblyFormat = "`<` $value `>`";
104+ }
105+
86106//===----------------------------------------------------------------------===//
87107// ArmSME op definitions
88108//===----------------------------------------------------------------------===//
@@ -240,28 +260,33 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
240260 let description = [{
241261 Loads a 2D SME "virtual tile" from memory defined by a base and indices,
242262 with the shape defined by the 2D scalable vector type of the result tile.
243- The slice of memory must be contiguous. The memref must be either rank 1 or
244- rank 2 with dynamic dimensions, since the operation is scalable, and the
245- element type must be a scalar that matches the element type of the result.
263+ An optional tile slice layout attribute specifies whether the slices of the
264+ tile being loaded are horizontal (default) or vertical. The slice of memory
265+ must be contiguous. The memref must be either rank 1 or rank 2 with dynamic
266+ dimensions, since the operation is scalable, and the element type must be a
267+ scalar that matches the element type of the result.
246268
247- Example 1: Load an 8-bit element ZA tile from memory (ZA0.B).
269+ Example 1: Load an 8-bit element ZA tile with horizontal layout (default) from memory (ZA0.B).
248270 ```mlir
249271 %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
250272 ```
251273
252- Example 2: Load a FP 32-bit element ZA tile from memory.
274+ Example 2: Load a FP 32-bit element ZA tile with vertical layout from memory.
253275 ```mlir
254- %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
276+ %tile = arm_sme.tile_load %base[%c0, %c0], <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
255277 ```
256278
257- Example 3: Load a 128-bit element ZA tile from memory.
279+ Example 3: Load a 128-bit element ZA tile with horizontal layout (default) from memory.
258280 ```mlir
259- %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
281+ %tile = arm_sme.tile_load %base[%c0, %c0], <horizontal> : memref<?x?xi128>, vector<[1]x[1]xi128>
260282 ```
261283 }];
262284 let arguments = (ins
263- Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
264- Variadic<Index>:$indices);
285+ Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
286+ Variadic<Index>:$indices,
287+ DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
288+ "::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
289+ );
265290 let results = (outs SMETile:$result);
266291
267292 let extraClassDeclaration = [{
@@ -274,37 +299,42 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
274299 }];
275300
276301 let assemblyFormat =
277- "$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)";
302+ "$base `[` $indices `]` (`,` $layout^)? attr-dict "
303+ "`:` type($base) `,` type($result)";
278304}
279305
280306def TileStoreOp : ArmSME_Op<"tile_store"> {
281307 let summary = "Tile store operation";
282308 let description = [{
283309 Stores a 2D SME "virtual tile" to memory defined by a base and indices,
284310 with the shape defined by the 2D scalable vector type of the tile being
285- stored. The slice of memory must be contiguous. The memref must be either
286- rank 1 or rank 2 with dynamic dimensions, since the operation is scalable,
287- and the element type must be a scalar that matches the element type of the
288- result.
311+ stored. An optional tile slice layout attribute specifies whether the
312+ slices of the tile being stored are horizontal (default) or vertical. The
313+ slice of memory must be contiguous. The memref must be either rank 1 or
314+ rank 2 with dynamic dimensions, since the operation is scalable, and the
315+ element type must be a scalar that matches the element type of the result.
289316
290- Example 1: Store an 8-bit element ZA tile to memory (ZA0.B).
317+ Example 1: Store an 8-bit element ZA tile with horizontal (default) layout to memory (ZA0.B).
291318 ```mlir
292319 arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
293320 ```
294321
295- Example 2: Store a FP 32-bit element ZA tile to memory.
322+ Example 2: Store a FP 32-bit element ZA tile with vertical layout to memory.
296323 ```mlir
297- arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
324+ arm_sme.tile_store %tile, %base[%c0, %c0], <vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
298325 ```
299326
300- Example 3: Store a 128-bit element ZA tile to memory.
327+ Example 3: Store a 128-bit element ZA tile with horizontal (default) layout to memory.
301328 ```mlir
302- arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
329+ arm_sme.tile_store %tile, %base[%c0, %c0], <horizontal> : vector<[1]x[1]xi128>, memref<?x?xi128>
303330 ```
304331 }];
305332 let arguments = (ins SMETile:$valueToStore,
306- Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
307- Variadic<Index>:$indices);
333+ Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
334+ Variadic<Index>:$indices,
335+ DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
336+ "::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
337+ );
308338 let extraClassDeclaration = [{
309339 MemRefType getMemRefType() {
310340 return ::llvm::cast<MemRefType>(getBase().getType());
@@ -314,8 +344,9 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
314344 }
315345 }];
316346
317- let assemblyFormat = "$valueToStore `,` $base `[` $indices `]` attr-dict "
318- "`:` type($base) `,` type($valueToStore)";
347+ let assemblyFormat =
348+ "$valueToStore `,` $base `[` $indices `]` (`,` $layout^)? attr-dict "
349+ "`:` type($base) `,` type($valueToStore)";
319350}
320351
321352def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
@@ -326,31 +357,36 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
326357 Loads a 1D tile slice from memory into a 2D SME "virtual tile". The tile
327358 slice is defined by the dimension of the 2D scalable vector type pointed by
328359 the index. A tile slice index describes where in the input tile the tile
329- slice is loaded to. The updated tile is returned as the result.
360+ slice is loaded to. An optional tile slice layout attribute specifies
361+ whether the tile slice being loaded at the given index is horizontal
362+ (default) or vertical. The updated tile is returned as the result.
330363
331364 The slice of memory read is defined by a base and indices and must be
332365 contiguous. The memref must be either rank 1 or rank 2, have dynamic
333366 dimensions since the operation is scalable, and the element type must be a
334367 scalar that matches the element type of the result.
335368
336- Example 1: Load a vector<[16]xi8> tile slice from memory into tile at given index.
369+ Example 1: Load a vector<[16]xi8> tile slice from memory into tile horizontally (default) at given index.
337370 ```mlir
338371 %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
339372 ```
340373
341- Example 2: Load a vector<[4]xf32> tile slice from memory into tile at given index.
374+ Example 2: Load a vector<[4]xf32> tile slice from memory into tile vertically at given index.
342375 ```mlir
343- %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
376+ %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
344377 ```
345378
346- Example 3: Load a vector<[1]xi128> tile slice from memory into tile at given index.
379+ Example 3: Load a vector<[1]xi128> tile slice from memory into tile vertically at given index.
347380 ```mlir
348- %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
381+ %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
349382 ```
350383 }];
351384 let arguments = (ins
352- Arg<AnyMemRef, "the reference to load from">:$base,
353- SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index);
385+ Arg<AnyMemRef, "the reference to load from">:$base,
386+ SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index,
387+ DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
388+ "::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
389+ );
354390 let results = (outs SMETile:$result);
355391
356392 let extraClassDeclaration = [{
@@ -363,7 +399,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
363399 }];
364400
365401 let assemblyFormat = [{
366- $base `[` $indices `]` `,` $tile `,` $tile_slice_index
402+ $base `[` $indices `]` `,` $tile `,` $tile_slice_index (`,` $layout^)?
367403 attr-dict `:` type($base) `,` type($result)
368404 }];
369405}
@@ -374,31 +410,36 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
374410 Stores a 1D tile slice from a 2D SME "virtual tile" into memory. The tile
375411 slice is defined by the dimension of the 2D scalable vector type pointed by
376412 the index. A tile slice index describes where in the input tile the tile
377- slice is stored from.
413+ slice is stored from. An optional tile slice layout attribute specifies
414+ whether the tile slice being stored from the given index is horizontal
415+ (default) or vertical.
378416
379417 The slice of memory written is defined by a base and indices and must be
380418 contiguous. The memref must be either rank 1 or rank 2, have dynamic
381419 dimensions since the operation is scalable, and the element type must be a
382420 scalar that matches the element type of the input tile.
383421
384- Example 1: Store vector<[16]xi8> tile slice from tile at given index to memory.
422+ Example 1: Store vector<[16]xi8> horizontal (default) tile slice from tile at given index to memory.
385423 ```mlir
386424 arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
387425 ```
388426
389- Example 2: Store vector<[4]xf32> tile slice from tile at given index to memory.
427+ Example 2: Store vector<[4]xf32> vertical tile slice from tile at given index to memory.
390428 ```mlir
391- arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
429+ arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0], <vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
392430 ```
393431
394- Example 3: Store a vector<[1]xi128> tile slice from tile at given index to memory.
432+ Example 3: Store a vector<[1]xi128> vertical tile slice from tile at given index to memory.
395433 ```mlir
396- arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
434+ arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0], <vertical> : vector<[1]x[1]xi128>, memref<?x?xi128>
397435 ```
398436 }];
399437 let arguments = (ins SMETile:$tile, Index:$tile_slice_index,
400- Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
401- Variadic<Index>:$indices);
438+ Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
439+ Variadic<Index>:$indices,
440+ DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
441+ "::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
442+ );
402443 let extraClassDeclaration = [{
403444 MemRefType getMemRefType() {
404445 return ::llvm::cast<MemRefType>(getBase().getType());
@@ -409,7 +450,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
409450 }];
410451
411452 let assemblyFormat = [{
412- $tile `,` $tile_slice_index `,` $base `[` $indices `]`
453+ $tile `,` $tile_slice_index `,` $base `[` $indices `]` (`,` $layout^)?
413454 attr-dict `:` type($base) `,` type($tile)
414455 }];
415456}
0 commit comments