Skip to content

Commit b34f15d

Browse files
authored
[mlir][ArmSME] Add arm_sme.move_tile_slice_to_vector op (llvm#67652)
This adds a simple higher-level op for the tile slice to vector intrinsics (and updates the existing vector.print lowering to use it). This op will be used a few more times to implement vector.insert/extract lowerings in later patches.
1 parent 6ce7461 commit b34f15d

File tree

7 files changed

+267
-43
lines changed

7 files changed

+267
-43
lines changed

mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,43 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
501501
}];
502502
}
503503

504+
def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
505+
TypesMatchWith<
506+
"type of 'result' matches type of 'tile' slice",
507+
"tile", "result",
508+
"VectorType(VectorType::Builder(::llvm::cast<mlir::VectorType>($_self)).dropDim(0))">,
509+
]> {
510+
let summary = "Move slice of a 2-D tile to a 1-D scalable vector";
511+
let description = [{
512+
The tile slice to vector operation extracts a 1-D scalable slice from a 2-D
513+
scalable tile at the given index. A tile slice is a 1-D vector of
514+
horizontally or vertically contiguous elements within a ZA tile. Horizontal
515+
tile slices are currently assumed when lowering to intrinsics.
516+
517+
Example 1: Extract `vector<[16]xi8>` from tile at the given index.
518+
```mlir
519+
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[16]xi8> from vector<[16]x[16]xi8>
520+
```
521+
522+
Example 2: Extract `vector<[2]xf64>` from tile at the given index.
523+
```mlir
524+
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
525+
```
526+
}];
527+
528+
let arguments = (ins SMETile:$tile, Index:$tile_slice_index);
529+
let results = (outs SVEVector:$result);
530+
531+
let extraClassDeclaration = [{
532+
VectorType getSliceType() { return getResult().getType(); }
533+
}];
534+
535+
let assemblyFormat = [{
536+
$tile `[` $tile_slice_index `]` attr-dict
537+
`:` type($result) `from` type($tile)
538+
}];
539+
}
540+
504541
//===----------------------------------------------------------------------===//
505542
// ArmSME Intrinsic op definitions
506543
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,8 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
191191
};
192192

193193
/// Lowers `vector.print` of a tile into a loop over the rows of the tile,
194-
/// extracting them via a MOVA, then printing with a 1D `vector.print`.
194+
/// extracting them via `arm_sme.move_tile_slice_to_vector`, then printing with
195+
/// a 1D `vector.print`.
195196
///
196197
/// BEFORE:
197198
/// ```mlir
@@ -202,16 +203,11 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
202203
/// %c0 = arith.constant 0 : index
203204
/// %c1 = arith.constant 1 : index
204205
/// %c4 = arith.constant 4 : index
205-
/// %ptrue = arith.constant dense<true> : vector<[4]xi1>
206-
/// %tile_id = arm_sme.cast_vector_to_tile %tile : vector<[4]x[4]xf32> to i32
207206
/// %vscale = vector.vscale
208207
/// %svl_s = arith.muli %c4, %vscale : index
209-
/// %cst = arith.constant dense<0.000000e+00> : vector<[4]xf32>
210208
/// scf.for %i = %c0 to %svl_s step %c1 {
211-
/// %slice_idx = arith.index_cast %i : index to i32
212-
/// %tile_slice = "arm_sme.intr.read.horiz"
213-
/// (%cst, %ptrue, %tile_id, %slice_idx)
214-
/// : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
209+
/// %tile_slice = arm_sme.move_tile_slice_to_vector %tile[%i]
210+
/// : vector<[4]xf32> from vector<[4]x[4]xf32>
215211
/// vector.print %tile_slice : vector<[4]xf32>
216212
/// }
217213
/// ```
@@ -229,23 +225,6 @@ struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
229225

230226
auto loc = printOp.getLoc();
231227

232-
// Create an 'all true' predicate for each tile row.
233-
auto predicateType =
234-
VectorType::get(vectorType.getDimSize(1), rewriter.getI1Type(), true);
235-
auto allTruePredicate = rewriter.create<arith::ConstantOp>(
236-
loc, DenseElementsAttr::get(predicateType, true));
237-
238-
// Cast tile to i32 tile ID.
239-
auto tileId =
240-
rewriter.create<arm_sme::CastVectorToTile>(loc, printOp.getSource());
241-
Value tileIdI32 = castTileIDToI32(tileId, loc, rewriter);
242-
243-
// Zero destination/fallback for tile slice extraction.
244-
auto rowType = VectorType::get(vectorType.getDimSize(1),
245-
vectorType.getElementType(), true);
246-
auto zeroVector = rewriter.create<arith::ConstantOp>(
247-
loc, rowType, rewriter.getZeroAttr(rowType));
248-
249228
// Create a loop over the rows of the tile.
250229
auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
251230
auto minTileRows =
@@ -259,10 +238,8 @@ struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
259238
rewriter.setInsertionPointToStart(forOp.getBody());
260239
// Extract the current row from the tile.
261240
Value rowIndex = forOp.getInductionVar();
262-
auto rowIndexI32 = rewriter.create<arith::IndexCastOp>(
263-
loc, rewriter.getI32Type(), rowIndex);
264-
auto tileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>(
265-
loc, rowType, zeroVector, allTruePredicate, tileIdI32, rowIndexI32);
241+
auto tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
242+
loc, printOp.getSource(), rowIndex);
266243
// Print the row with a 1D vector.print.
267244
rewriter.create<vector::PrintOp>(loc, tileSlice,
268245
printOp.getPunctuation());

mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,48 @@ struct MoveVectorToTileSliceToArmSMELowering
402402
}
403403
};
404404

405+
/// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics. Only horizontal
406+
/// tile slices are currently supported.
407+
struct MoveTileSliceToVectorArmSMELowering
408+
: public ConvertOpToLLVMPattern<arm_sme::MoveTileSliceToVectorOp> {
409+
using ConvertOpToLLVMPattern<
410+
arm_sme::MoveTileSliceToVectorOp>::ConvertOpToLLVMPattern;
411+
412+
LogicalResult
413+
matchAndRewrite(arm_sme::MoveTileSliceToVectorOp moveTileSliceToVector,
414+
OpAdaptor,
415+
ConversionPatternRewriter &rewriter) const override {
416+
auto loc = moveTileSliceToVector.getLoc();
417+
auto sliceType = moveTileSliceToVector.getSliceType();
418+
auto tile = moveTileSliceToVector.getTile();
419+
auto sliceIndex = moveTileSliceToVector.getTileSliceIndex();
420+
421+
// Cast tile to i32 tile ID.
422+
auto tileId = rewriter.create<arm_sme::CastVectorToTile>(loc, tile);
423+
Value tileIdI32 = castTileIDToI32(tileId, loc, rewriter);
424+
425+
// Create an 'all true' predicate for the tile slice.
426+
auto predicateType = sliceType.cloneWith({}, rewriter.getI1Type());
427+
auto allTruePredicate = rewriter.create<arith::ConstantOp>(
428+
loc, DenseElementsAttr::get(predicateType, true));
429+
430+
// Zero destination/fallback for tile slice extraction.
431+
auto zeroVector = rewriter.create<arith::ConstantOp>(
432+
loc, sliceType, rewriter.getZeroAttr(sliceType));
433+
434+
// Cast tile slice from index to i32 for intrinsic.
435+
auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
436+
loc, rewriter.getI32Type(), sliceIndex);
437+
438+
// Create 'arm_sme.intr.read.horiz' to extract the tile slice.
439+
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
440+
moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
441+
tileIdI32, sliceIndexI32);
442+
443+
return success();
444+
}
445+
};
446+
405447
/// Lower `vector.outerproduct` to SME MOPA intrinsics.
406448
///
407449
/// Example:
@@ -525,9 +567,9 @@ void mlir::configureArmSMELegalizeForExportTarget(
525567
arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert,
526568
arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
527569
arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
528-
arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_write_horiz,
529-
arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_za_enable,
530-
arm_sme::aarch64_sme_za_disable>();
570+
arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
571+
arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_mopa,
572+
arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable>();
531573
target.addLegalOp<GetTileID>();
532574
target.addIllegalOp<vector::OuterProductOp>();
533575

@@ -561,6 +603,7 @@ void mlir::populateArmSMELegalizeForLLVMExportPatterns(
561603
patterns.add<EnableZAPattern, DisableZAPattern>(patterns.getContext());
562604
patterns
563605
.add<ZeroOpConversion, StoreTileSliceToArmSMELowering,
564-
LoadTileSliceToArmSMELowering, MoveVectorToTileSliceToArmSMELowering,
606+
LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
607+
MoveVectorToTileSliceToArmSMELowering,
565608
VectorOuterProductToArmSMELowering>(converter);
566609
}

mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,11 @@ func.func @arm_sme_tile_print(%tile: vector<[4]x[4]xf32>)
6666
}
6767
// CHECK-LABEL: func.func @arm_sme_tile_print(
6868
// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xf32>) {
69-
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
70-
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
71-
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
72-
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
73-
// CHECK-DAG: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
74-
// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xf32> to i32
75-
// CHECK-DAG: %[[ZERO_VECTOR:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
76-
// CHECK-DAG: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
69+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
70+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
71+
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
72+
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
73+
// CHECK-DAG: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
7774
// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
78-
// CHECK-NEXT: %[[TILE_SLICE_INDEX_I32:.*]] = arith.index_cast %[[TILE_SLICE_INDEX]] : index to i32
79-
// CHECK-NEXT: %[[TILE_SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VECTOR]], %[[PTRUE]], %[[TILE_ID]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
75+
// CHECK-NEXT: %[[TILE_SLICE:.*]] = arm_sme.move_tile_slice_to_vector %[[TILE]][%[[TILE_SLICE_INDEX]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
8076
// CHECK-NEXT: vector.print %[[TILE_SLICE]] : vector<[4]xf32>

mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,3 +399,89 @@ func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_s
399399
arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
400400
return
401401
}
402+
403+
404+
//===----------------------------------------------------------------------===//
405+
// arm_sme.move_tile_slice_to_vector
406+
//===----------------------------------------------------------------------===//
407+
408+
// -----
409+
410+
// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i8
411+
// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
412+
func.func @arm_sme_move_tile_slice_to_vector_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]xi8> {
413+
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[16]xi8> from vector<[16]x[16]xi8>
414+
return %slice : vector<[16]xi8>
415+
}
416+
417+
// -----
418+
419+
// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i16
420+
// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
421+
func.func @arm_sme_move_tile_slice_to_vector_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index) -> vector<[8]xi16> {
422+
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[8]xi16> from vector<[8]x[8]xi16>
423+
return %slice : vector<[8]xi16>
424+
}
425+
426+
// -----
427+
428+
// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i32
429+
// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
430+
func.func @arm_sme_move_tile_slice_to_vector_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index) -> vector<[4]xi32> {
431+
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[4]xi32> from vector<[4]x[4]xi32>
432+
return %slice : vector<[4]xi32>
433+
}
434+
435+
// -----
436+
437+
// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i64
438+
// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
439+
func.func @arm_sme_move_tile_slice_to_vector_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index) -> vector<[2]xi64> {
440+
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xi64> from vector<[2]x[2]xi64>
441+
return %slice : vector<[2]xi64>
442+
}
443+
444+
// -----
445+
446+
// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i128
447+
// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
448+
func.func @arm_sme_move_tile_slice_to_vector_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index) -> vector<[1]xi128> {
449+
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[1]xi128> from vector<[1]x[1]xi128>
450+
return %slice : vector<[1]xi128>
451+
}
452+
453+
// -----
454+
455+
// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_f16
456+
// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
457+
func.func @arm_sme_move_tile_slice_to_vector_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index) -> vector<[8]xf16> {
458+
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[8]xf16> from vector<[8]x[8]xf16>
459+
return %slice : vector<[8]xf16>
460+
}
461+
462+
// -----
463+
464+
// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_bf16
465+
// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
466+
func.func @arm_sme_move_tile_slice_to_vector_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) -> vector<[8]xbf16> {
467+
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[8]xbf16> from vector<[8]x[8]xbf16>
468+
return %slice : vector<[8]xbf16>
469+
}
470+
471+
// -----
472+
473+
// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_f32
474+
// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
475+
func.func @arm_sme_move_tile_slice_to_vector_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]xf32> {
476+
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32>
477+
return %slice : vector<[4]xf32>
478+
}
479+
480+
// -----
481+
482+
// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_f64
483+
// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
484+
func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index) -> vector<[2]xf64> {
485+
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
486+
return %slice : vector<[2]xf64>
487+
}

mlir/test/Dialect/ArmSME/invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,11 @@ func.func @arm_sme_move_vector_to_tile_slice_f32__bad_vector_type(%vector : vect
8989
%0 = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[8]xf32> into vector<[4]x[4]xf32>
9090
return %0 : vector<[4]x[4]xf32>
9191
}
92+
93+
// -----
94+
95+
func.func @arm_sme_move_tile_slice_to_vector__bad_result_type(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[2]xf64> {
96+
// expected-error@+1 {{op failed to verify that type of 'result' matches type of 'tile' slice}}
97+
%0 = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[4]x[4]xf32>
98+
return %0 : vector<[2]xf64>
99+
}

mlir/test/Dialect/ArmSME/roundtrip.mlir

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,3 +1058,80 @@ func.func @arm_sme_move_vector_to_tile_slice_f64(%vector : vector<[2]xf64>, %til
10581058
arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[2]xf64> into vector<[2]x[2]xf64>
10591059
return
10601060
}
1061+
1062+
1063+
//===----------------------------------------------------------------------===//
1064+
// arm_sme.move_tile_slice_to_vector
1065+
//===----------------------------------------------------------------------===//
1066+
1067+
// -----
1068+
1069+
func.func @arm_sme_move_tile_slice_to_vector_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]xi8> {
1070+
// CHECK: arm_sme.move_tile_slice_to_vector {{.*}} : vector<[16]xi8> from vector<[16]x[16]xi8>
1071+
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[16]xi8> from vector<[16]x[16]xi8>
1072+
return %slice : vector<[16]xi8>
1073+
}
1074+
1075+
// -----
1076+
1077+
func.func @arm_sme_move_tile_slice_to_vector_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index) -> vector<[8]xi16> {
1078+
// CHECK: arm_sme.move_tile_slice_to_vector {{.*}} : vector<[8]xi16> from vector<[8]x[8]xi16>
1079+
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[8]xi16> from vector<[8]x[8]xi16>
1080+
return %slice : vector<[8]xi16>
1081+
}
1082+
1083+
// -----
1084+
1085+
func.func @arm_sme_move_tile_slice_to_vector_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index) -> vector<[4]xi32> {
1086+
// CHECK: arm_sme.move_tile_slice_to_vector {{.*}} : vector<[4]xi32> from vector<[4]x[4]xi32>
1087+
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[4]xi32> from vector<[4]x[4]xi32>
1088+
return %slice : vector<[4]xi32>
1089+
}
1090+
1091+
// -----
1092+
1093+
func.func @arm_sme_move_tile_slice_to_vector_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index) -> vector<[2]xi64> {
1094+
// CHECK: arm_sme.move_tile_slice_to_vector {{.*}} : vector<[2]xi64> from vector<[2]x[2]xi64>
1095+
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xi64> from vector<[2]x[2]xi64>
1096+
return %slice : vector<[2]xi64>
1097+
}
1098+
1099+
// -----
1100+
1101+
func.func @arm_sme_move_tile_slice_to_vector_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index) -> vector<[1]xi128> {
1102+
// CHECK: arm_sme.move_tile_slice_to_vector {{.*}} : vector<[1]xi128> from vector<[1]x[1]xi128>
1103+
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[1]xi128> from vector<[1]x[1]xi128>
1104+
return %slice : vector<[1]xi128>
1105+
}
1106+
1107+
// -----
1108+
1109+
func.func @arm_sme_move_tile_slice_to_vector_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index) -> vector<[8]xf16> {
1110+
// CHECK: arm_sme.move_tile_slice_to_vector {{.*}} : vector<[8]xf16> from vector<[8]x[8]xf16>
1111+
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[8]xf16> from vector<[8]x[8]xf16>
1112+
return %slice : vector<[8]xf16>
1113+
}
1114+
1115+
// -----
1116+
1117+
func.func @arm_sme_move_tile_slice_to_vector_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) -> vector<[8]xbf16> {
1118+
// CHECK: arm_sme.move_tile_slice_to_vector {{.*}} : vector<[8]xbf16> from vector<[8]x[8]xbf16>
1119+
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[8]xbf16> from vector<[8]x[8]xbf16>
1120+
return %slice : vector<[8]xbf16>
1121+
}
1122+
1123+
// -----
1124+
1125+
func.func @arm_sme_move_tile_slice_to_vector_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]xf32> {
1126+
// CHECK: arm_sme.move_tile_slice_to_vector {{.*}} : vector<[4]xf32> from vector<[4]x[4]xf32>
1127+
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32>
1128+
return %slice : vector<[4]xf32>
1129+
}
1130+
1131+
// -----
1132+
1133+
func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index) -> vector<[2]xf64> {
1134+
// CHECK: arm_sme.move_tile_slice_to_vector {{.*}} : vector<[2]xf64> from vector<[2]x[2]xf64>
1135+
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
1136+
return %slice : vector<[2]xf64>
1137+
}

0 commit comments

Comments
 (0)