diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 10557658d5d7d..b496ee0114910 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1431,6 +1431,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> { "bool", /*default=*/"false", "Enables the use of ArmSVE dialect while lowering the vector " "dialect.">, + Option<"armI8MM", "enable-arm-i8mm", + "bool", /*default=*/"false", + "Enables the use of Arm FEAT_I8MM instructions while lowering " + "the vector dialect.">, Option<"x86Vector", "enable-x86vector", "bool", /*default=*/"false", "Enables the use of X86Vector dialect while lowering the vector " diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h index 8665c8224cc45..232e2be29e574 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h @@ -20,6 +20,9 @@ class RewritePatternSet; void populateArmSVELegalizeForLLVMExportPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns); +void populateLowerContractionToSVEI8MMPatternPatterns( + RewritePatternSet &patterns); + /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM /// intrinsics. void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target); diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt index 330474a718e30..8e2620029c354 100644 --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass MLIRVectorToLLVM MLIRArmNeonDialect + MLIRArmNeonTransforms MLIRArmSVEDialect MLIRArmSVETransforms MLIRAMXDialect diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 0ee6dce9ee94b..293e01a5bf4d4 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmNeon/Transforms.h" #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorStepLoweringPatterns(patterns); populateVectorRankReducingFMAPattern(patterns); populateVectorGatherLoweringPatterns(patterns); + if (armI8MM) { + if (armNeon) + arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns); + if (armSVE) + populateLowerContractionToSVEI8MMPatternPatterns(patterns); + } (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp index 2a1271dfd6bdf..5ce3d2b28aeb3 100644 --- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp +++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp @@ -56,6 +56,10 @@ class LowerContractionToSMMLAPattern // Avoid 0-D vectors and 1-D rhs: if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2) return failure(); + // This codegen does not work for scalable vectors. Return failure so this + // pattern is not accidentally chosen over patterns that lower to ArmSVE. + if (lhsType.isScalable() || rhsType.isScalable()) + return failure(); auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0); auto dimN = rhsType.getDimSize(0); auto dimK = rhsType.getDimSize(1); @@ -238,5 +242,5 @@ class LowerContractionToSMMLAPattern void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns( RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); - patterns.add(context, /*benefit=*/1); + patterns.add(context, /*benefit=*/2); } diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt index a70c489a51fea..65f98b44b1b69 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRArmSVETransforms LegalizeForLLVMExport.cpp LegalizeVectorStorage.cpp + LowerContractionToSVEI8MMPattern.cpp DEPENDS MLIRArmSVEConversionsIncGen diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp new file mode 100644 index 0000000000000..b1233c5c06eb4 --- /dev/null +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp @@ -0,0 +1,365 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements lowering patterns from vector.contract to operations +// that map to instructions from the SVE FEAT_I8MM extension. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; + +namespace { +// Get the operand of a `vector.contract`. This function is intended to abstract +// away from the particular way a value is extended before feeding it into the +// `vector.contract` - via zero-extend or an explicit or implicit sign-extend +// (for implicit sign-extension see `vector.contract` documentation). +// +// The template parameter `Op` indicates the extension operation (explicit or +// implicit) for which we are checking. +// +// Return success only for extensions from `i8` to `i32`. +template +std::optional getExtOperand(Value v, Type i8Ty, Type i32Ty) { + + static_assert(llvm::is_one_of::value, + "Must be instantiated with either sign- or zero- extension op"); + + // If the operand is not defined by an explicit extend operation of the + // accepted operation type allow for an implicit sign-extension. + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) { + if constexpr (std::is_same::value) { + auto vTy = cast(v.getType()); + if (vTy.getElementType() != i8Ty) + return {}; + return v; + } + return {}; + } + + // If the operand is defined by an explicit extend operation of the accepted + // operation type, check it's extended from `i8` to `i32`. + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy || inTy.getElementType() != i8Ty) + return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!outTy || outTy.getElementType() != i32Ty) + return {}; + + return inOp; +} + +// Designate the operation (resp. instruction) used to do sub-tile matrix +// multiplications. +enum class MMLA { + Signed, // smmla + Unsigned, // ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped +}; + +// Create the matrix mulitply and accumulate operation according to `op`. +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { + switch (op) { + case MMLA::Signed: + return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Unsigned: + return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Mixed: + return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::MixedSwapped: + // The accumulator comes transposed and the result will be transposed + // later, so all we have to do here is swap the operands. + return rewriter.create(loc, accType, acc, rhs, lhs); + } +} + +/// Lower a contraction operation that performs a matrix multiplication +/// of two 8-bit integer matrix tiles with logical dimensions and <8x[N]> +/// for the left-hand side and the right-hand side, respectively, +/// yielding a 32-bit integer result. +/// +/// The operands' shapes are such that the operands can be evenly split into +/// sub-tiles with dimensions as expected by the targeted FEAT_I8MM +/// instructions. The intent is that M and N are chosen (by higher level +/// transforms) in such a way as to maximise register usage. The main use case +/// we envision as of now is MMT4D, thus the RHS operand is expected +/// pre-transposed. +/// +/// The matrix multiplication is performed by unrolling the usual tiled matrix +/// multiplication algorithm using sub-tiles with dimensions <2x8> for the LHS, +/// <8x[2]> for the RHS, and <2x[2]> for the result and the input accumulator. +/// +/// One way to illustrate the operation is as follows: +/// +/// RHS<8x[N]>: <8x[2]> <8x[2]> ... <8x[2]> +/// +----------------------------- +/// LHS: <2x8> | <2x[2]> <2x[2]> ... <2x[2]> +/// <2x8> | <2x[2]> <2x[2]> ... <2x[2]> +/// ... | ... ... ... ... +/// <2x8> | <2x[2]> <2x[2]> ... <2x[2]> +/// +/// The RHS operand is unpacked into N/2 values, each representing a sequence of +/// VSCALE number of sub-tiles with dimensions <8x2>. +/// The LHS operand is initially unpacked into M/2 values, each representing a +/// sub-tile with dimensions <2x8>, and then each such sub-tile is replicated +/// VSCALE times. +/// Multiplying thus replicated LHS sub-tile by the corresponding RHS sub-tile +/// correctly computes an entire result sub-tile. +class LowerContractionToSVEI8MMPattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + mlir::VectorType lhsType = op.getLhsType(); + mlir::VectorType rhsType = op.getRhsType(); + + // Check the rank the types so we can safely examine their dimensions. + if (lhsType.getRank() != 2 || rhsType.getRank() != 2) + return rewriter.notifyMatchFailure(op, "non-matching operand shape"); + + auto M = lhsType.getDimSize(0); + auto N = rhsType.getDimSize(0); + auto K = rhsType.getDimSize(1); + + // Check the operands have the expected shape: + // * for LHS: fixed vector MxK + // * for RHS: scalable vector [N]xK + // * K == 8 + // * M and N even and at least 2 + if (lhsType.isScalable() || !rhsType.getScalableDims()[0] || + rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != K || K != 8 || + M < 2 || M % 2 != 0 || N < 2 || N % 2 != 0 || + !rhsType.getScalableDims()[0]) + return rewriter.notifyMatchFailure(op, "non-matching operand shape"); + + // Check permutation maps. For now only accept + // lhs: (d0, d1, d2) -> (d0, d2) + // rhs: (d0, d1, d2) -> (d1, d2) + // acc: (d0, d1, d2) -> (d0, d1) + // This corresponds to matrix multiplication with transposed RHS. + if (op.getIndexingMapsArray()[0] != + AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u}, + op.getContext()) || + op.getIndexingMapsArray()[1] != + AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u}, + op.getContext()) || + op.getIndexingMapsArray()[2] != + AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u}, + op.getContext())) + return rewriter.notifyMatchFailure(op, "non-matching permutation maps"); + + // Check iterator types for matrix multiplication. + auto itTypes = op.getIteratorTypesArray(); + if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel || + itTypes[1] != vector::IteratorType::parallel || + itTypes[2] != vector::IteratorType::reduction) + return rewriter.notifyMatchFailure( + op, "iterator types do not correspond to matrix multiplication"); + + // Check the combining kind is addition. + if (op.getKind() != vector::CombiningKind::ADD) + return rewriter.notifyMatchFailure(op, + "combining kind is not an addition"); + + // Check the output is a vector of i32 elements. + auto outTy = dyn_cast(op.getResultType()); + if (!outTy || outTy.getElementType() != rewriter.getI32Type()) + return rewriter.notifyMatchFailure(op, + "output type is not a vector of i32"); + + // Check inputs are sign-/zero- extensions from i8 to i32. Get the values + // before the extension. All four signed/unsigned combinations for input + // operands are supported, but they are lowered to different operations. + // Determine which is the appropriate operation to lower to. + MMLA mmlaOp = MMLA::Signed; + auto maybeLhs = getExtOperand( + op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type()); + if (!maybeLhs) { + mmlaOp = MMLA::Unsigned; + maybeLhs = getExtOperand( + op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type()); + } + if (!maybeLhs) + return rewriter.notifyMatchFailure( + op, "LHS is not a sign- or zero- extended i8"); + + auto maybeRhs = getExtOperand( + op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type()); + if (maybeRhs) { + if (mmlaOp == MMLA::Unsigned) + mmlaOp = MMLA::Mixed; + } else { + if (mmlaOp == MMLA::Signed) + mmlaOp = MMLA::MixedSwapped; + maybeRhs = getExtOperand( + op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type()); + } + if (!maybeRhs) + return rewriter.notifyMatchFailure( + op, "RHS is not a sign- or zero- extended i8"); + + // One-dimensional vector types for arm_sve.*mmla + auto nxv16i8 = VectorType::get(/*shape=*/16, rewriter.getI8Type(), + /*scalableDims=*/{true}); + auto nxv4i32 = VectorType::get(/*shape=*/4, rewriter.getI32Type(), + /*scalableDims=*/{true}); + + // Extract LHS sub-tiles with logicall shape <2x8>. + SmallVector lhsTile; + for (int64_t i = 0; i < M; i += 2) { + // Extract two consecutive rows of the LHS tile. + auto r0 = rewriter.create(loc, *maybeLhs, + ArrayRef{i}); + auto r1 = rewriter.create(loc, *maybeLhs, + ArrayRef{i + 1}); + // Concatenate to obtain a 16 x i8 flattened sub-tile. + auto t = rewriter.create( + loc, r0, r1, + llvm::ArrayRef{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15}); + // Turn it into a scalable vector. + auto s = rewriter.create( + loc, t, rewriter.create(loc, nxv16i8), 0); + // Replicate the sub-tile VSCALE times to fill the entire vector. + auto r = rewriter.create(loc, s, 0); + lhsTile.push_back(r); + } + + // "Flatten" the RHS tile from <[N]x8> to <[8*N]>. + auto rhs = rewriter.create( + maybeRhs->getLoc(), + VectorType::get(/*shape=*/8 * N, rewriter.getI8Type(), + /*scalableDims=*/{true}), + *maybeRhs); + + // Extract the RHS sub-tiles with logical shape <8x[2]>. + SmallVector rhsTile; + for (int64_t j = 0; j < N; j += 2) + rhsTile.push_back( + rewriter.create(loc, nxv16i8, rhs, j * 8)); + + // Handy types for packing/unpacking of the accumulator tile. + auto accRowTy = VectorType::get(/*shape=*/N, rewriter.getI32Type(), + /*scalableDims=*/{true}); + auto accRowX2Ty = VectorType::get(/*shape=*/2 * N, rewriter.getI32Type(), + /*scalableDims=*/{true}); + auto accRow64Ty = VectorType::get(/*shape=*/N / 2, rewriter.getI64Type(), + /*scalableDims=*/{true}); + auto accRowX264Ty = VectorType::get(/*shape=*/N, rewriter.getI64Type(), + /*scalableDims=*/{true}); + + // Extract and pack the ACC sub-tiles. + SmallVector accTile; + for (int64_t i = 0; i < M; i += 2) { + // Extract two consecutive rows of the accumulator tile. + auto r0 = rewriter.create(loc, op.getAcc(), + ArrayRef{i}); + auto r1 = rewriter.create(loc, op.getAcc(), + ArrayRef{i + 1}); + Value accTileVec; + if (mmlaOp == MMLA::MixedSwapped) { + // We need to swap the positions of the LHS and RHS (since we don't have + // a signed * unsigned operation), but then each individual 2x2 tile of + // the acumulator and (later) the result need to be transposed. + accTileVec = rewriter.create(loc, r0, r1); + } else { + // Bitcast them to 64-bit elements, so subsequent + // interleave/deinterleave work on pairs of 32-bit numbers. + auto r0I64 = rewriter.create(loc, accRow64Ty, r0); + auto r1I64 = rewriter.create(loc, accRow64Ty, r1); + + // Interleave the rows, effectively flattening each 2x2 tile into 4 + // consecutive elements. + auto intrI64 = rewriter.create(loc, r0I64, r1I64); + + // Bitcast back to 32-bit elements. + accTileVec = + rewriter.create(loc, accRowX2Ty, intrI64); + } + // Extract ACC sub-tiles. + for (int64_t j = 0; j < N; j += 2) + accTile.push_back(rewriter.create( + loc, nxv4i32, accTileVec, j * 2)); + } + + // Emit sub-tile matrix multiplications. + SmallVector outTile; + for (int64_t i = 0; i < M / 2; ++i) + for (int64_t j = 0; j < N / 2; ++j) { + Value mmla = createMMLA(rewriter, mmlaOp, loc, nxv4i32, + accTile[i * N / 2 + j], lhsTile[i], rhsTile[j]); + outTile.push_back(mmla); + } + + // Unpack the OUT sub-tiles and insert into the result. + Value result = rewriter.create(loc, op.getResultType()); + for (int64_t i = 0; i < M / 2; ++i) { + // Collect a number of sub-tiles in a row. + Value row = rewriter.create(loc, accRowX2Ty); + for (int64_t j = 0; j < N / 2; ++j) + row = rewriter.create( + loc, outTile[i * N / 2 + j], row, j * 4); + + // Unpack the row to obtain two rows of the output. If we have the out + // sub-tiles transposed we obtain two consecutive output rows by + // separating even and odd elements, i.e. a simple deinterleave. + // Otherwise, the interleave is by pairs. + Value out0, out1; + if (mmlaOp == MMLA::MixedSwapped) { + auto tmp = rewriter.create(loc, row); + out0 = tmp.getRes1(); + out1 = tmp.getRes2(); + } else { + // Deinterleave by pairs. + auto row64 = rewriter.create(loc, accRowX264Ty, row); + auto deintr64 = rewriter.create(loc, row64); + + // Bitcast back into 32-bit elements and insert into the result. + out0 = rewriter.create(loc, accRowTy, + deintr64.getRes1()); + out1 = rewriter.create(loc, accRowTy, + deintr64.getRes2()); + } + result = rewriter.create(loc, out0, result, i * 2); + result = rewriter.create(loc, out1, result, i * 2 + 1); + } + + rewriter.replaceOp(op, result); + return success(); + } +}; + +} // namespace + +void mlir::populateLowerContractionToSVEI8MMPatternPatterns( + RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.add(context, /*benefit=*/2); +} diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir new file mode 100644 index 0000000000000..af0cb37e2d249 --- /dev/null +++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir @@ -0,0 +1,181 @@ +// RUN: mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' | FileCheck %s + +#attrs = { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind +} + +// CHECK-LABEL: @test_vector_contract_to_smmla + +// Extract LHS rows 0 and 1, concatenate, turn into scalable vector +// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T4]][1] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8> +// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T0:[0-9+]]][0] : vector<16xi8> into vector<[16]xi8> + +// Replicate across the entire length of the scalabale vector +// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> + +// Same for LHS rows 2 and 4 +// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T4]][2] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T4]][3] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8> +// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T0]][0] : vector<16xi8> into vector<[16]xi8> +// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> + +// Extract sub-tiles from the RHS +// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8> +// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8> +// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8> + +// Extract accumulator rows 0 and 1 and pack (into "registers") +// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T3:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extractvalue %[[T3]][1] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T21:[0-9]+]] = llvm.bitcast %[[T19]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T22:[0-9]+]] = llvm.bitcast %[[T20]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T23:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T21]], %[[T22]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64> +// CHECK-NEXT: %[[T24:[0-9]+]] = llvm.bitcast %[[T23]] : vector<[4]xi64> to vector<[8]xi32> +// CHECK-NEXT: %[[T25:[0-9]+]] = llvm.intr.vector.extract %[[T24]][0] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[T26:[0-9]+]] = llvm.intr.vector.extract %[[T24]][4] : vector<[4]xi32> from vector<[8]xi32> + +// Same for accumulator rows 2 and 3. +// CHECK-NEXT: %[[T27:[0-9]+]] = llvm.extractvalue %[[T3]][2] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T28:[0-9]+]] = llvm.extractvalue %[[T3]][3] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T29:[0-9]+]] = llvm.bitcast %[[T27]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T30:[0-9]+]] = llvm.bitcast %[[T28]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T31:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T29]], %[[T30]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64> +// CHECK-NEXT: %[[T32:[0-9]+]] = llvm.bitcast %[[T31]] : vector<[4]xi64> to vector<[8]xi32> +// CHECK-NEXT: %[[T33:[0-9]+]] = llvm.intr.vector.extract %[[T32]][0] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[T34:[0-9]+]] = llvm.intr.vector.extract %[[T32]][4] : vector<[4]xi32> from vector<[8]xi32> + +// Do the sub-tile matrix multiplications +// CHECK-NEXT: %[[T35:[0-9]+]] = "arm_sve.intr.smmla"(%[[T25]], %[[T10]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> +// CHECK-NEXT: %[[T36:[0-9]+]] = "arm_sve.intr.smmla"(%[[T26]], %[[T10]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> +// CHECK-NEXT: %[[T37:[0-9]+]] = "arm_sve.intr.smmla"(%[[T33]], %[[T15]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> +// CHECK-NEXT: %[[T38:[0-9]+]] = "arm_sve.intr.smmla"(%[[T34]], %[[T15]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> + +// Unpack (from "registers") and insert in the output result rows 0 and 1 +// CHECK-NEXT: %[[T39:[0-9]+]] = llvm.intr.vector.insert %[[T35]], %[[T2:[0-9]+]][0] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T40:[0-9]+]] = llvm.intr.vector.insert %[[T36]], %[[T39]][4] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T41:[0-9]+]] = llvm.bitcast %[[T40]] : vector<[8]xi32> to vector<[4]xi64> +// CHECK-NEXT: %[[T42:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T41]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T43:[0-9]+]] = llvm.extractvalue %[[T42]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T44:[0-9]+]] = llvm.extractvalue %[[T42]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T45:[0-9]+]] = llvm.bitcast %[[T43]] : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[T46:[0-9]+]] = llvm.bitcast %[[T44]] : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[T47:[0-9]+]] = llvm.insertvalue %[[T45]], %[[T5:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T48:[0-9]+]] = llvm.insertvalue %[[T46]], %[[T47]][1] : !llvm.array<4 x vector<[4]xi32>> + +// Same for result rows 2 and 3 +// CHECK-NEXT: %[[T49:[0-9]+]] = llvm.intr.vector.insert %[[T37]], %[[T2]][0] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T50:[0-9]+]] = llvm.intr.vector.insert %[[T38]], %[[T49]][4] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T51:[0-9]+]] = llvm.bitcast %[[T50]] : vector<[8]xi32> to vector<[4]xi64> +// CHECK-NEXT: %[[T52:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T51]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T53:[0-9]+]] = llvm.extractvalue %[[T52]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T54:[0-9]+]] = llvm.extractvalue %[[T52]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T55:[0-9]+]] = llvm.bitcast %[[T53]] : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[T56:[0-9]+]] = llvm.bitcast %[[T54]] : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[T57:[0-9]+]] = llvm.insertvalue %[[T55]], %[[T48]][2] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T58:[0-9]+]] = llvm.insertvalue %[[T56]], %[[T57]][3] : !llvm.array<4 x vector<[4]xi32>> + +func.func @test_vector_contract_to_smmla(%lhs: vector<4x8xi8>, + %rhs: vector<[4]x8xi8>, + %acc: vector<4x[4]xi32>) -> vector<4x[4]xi32> { + + %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32> + %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32> + %2 = vector.contract #attrs %0, %1, %acc + : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32> + + return %2 : vector<4x[4]xi32> +} + +// CHECK-LABEL: @test_vector_contract_to_smmla_implicit_sext + +// Extract LHS rows 0 and 1, concatenate, turn into scalable vector +// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T4]][1] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8> +// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T0:[0-9+]]][0] : vector<16xi8> into vector<[16]xi8> + +// Replicate across the entire length of the scalabale vector +// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> + +// Same for LHS rows 2 and 4 +// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T4]][2] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T4]][3] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8> +// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T0]][0] : vector<16xi8> into vector<[16]xi8> +// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> + +// Extract sub-tiles from the RHS +// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8> +// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8> +// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8> + +// Extract accumulator rows 0 and 1 and pack (into "registers") +// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T3:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extractvalue %[[T3]][1] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T21:[0-9]+]] = llvm.bitcast %[[T19]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T22:[0-9]+]] = llvm.bitcast %[[T20]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T23:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T21]], %[[T22]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64> +// CHECK-NEXT: %[[T24:[0-9]+]] = llvm.bitcast %[[T23]] : vector<[4]xi64> to vector<[8]xi32> +// CHECK-NEXT: %[[T25:[0-9]+]] = llvm.intr.vector.extract %[[T24]][0] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[T26:[0-9]+]] = llvm.intr.vector.extract %[[T24]][4] : vector<[4]xi32> from vector<[8]xi32> + +// Same for accumulator rows 2 and 3. +// CHECK-NEXT: %[[T27:[0-9]+]] = llvm.extractvalue %[[T3]][2] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T28:[0-9]+]] = llvm.extractvalue %[[T3]][3] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T29:[0-9]+]] = llvm.bitcast %[[T27]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T30:[0-9]+]] = llvm.bitcast %[[T28]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T31:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T29]], %[[T30]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64> +// CHECK-NEXT: %[[T32:[0-9]+]] = llvm.bitcast %[[T31]] : vector<[4]xi64> to vector<[8]xi32> +// CHECK-NEXT: %[[T33:[0-9]+]] = llvm.intr.vector.extract %[[T32]][0] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[T34:[0-9]+]] = llvm.intr.vector.extract %[[T32]][4] : vector<[4]xi32> from vector<[8]xi32> + +// Do the sub-tile matrix multiplications +// CHECK-NEXT: %[[T35:[0-9]+]] = "arm_sve.intr.smmla"(%[[T25]], %[[T10]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> +// CHECK-NEXT: %[[T36:[0-9]+]] = "arm_sve.intr.smmla"(%[[T26]], %[[T10]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> +// CHECK-NEXT: %[[T37:[0-9]+]] = "arm_sve.intr.smmla"(%[[T33]], %[[T15]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> +// CHECK-NEXT: %[[T38:[0-9]+]] = "arm_sve.intr.smmla"(%[[T34]], %[[T15]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> + +// Unpack (from "registers") and insert in the output result rows 0 and 1 +// CHECK-NEXT: %[[T39:[0-9]+]] = llvm.intr.vector.insert %[[T35]], %[[T2:[0-9]+]][0] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T40:[0-9]+]] = llvm.intr.vector.insert %[[T36]], %[[T39]][4] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T41:[0-9]+]] = llvm.bitcast %[[T40]] : vector<[8]xi32> to vector<[4]xi64> +// CHECK-NEXT: %[[T42:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T41]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T43:[0-9]+]] = llvm.extractvalue %[[T42]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T44:[0-9]+]] = llvm.extractvalue %[[T42]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T45:[0-9]+]] = llvm.bitcast %[[T43]] : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[T46:[0-9]+]] = llvm.bitcast %[[T44]] : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[T47:[0-9]+]] = llvm.insertvalue %[[T45]], %[[T5:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T48:[0-9]+]] = llvm.insertvalue %[[T46]], %[[T47]][1] : !llvm.array<4 x vector<[4]xi32>> + +// Same for result rows 2 and 3 +// CHECK-NEXT: %[[T49:[0-9]+]] = llvm.intr.vector.insert %[[T37]], %[[T2]][0] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T50:[0-9]+]] = llvm.intr.vector.insert %[[T38]], %[[T49]][4] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T51:[0-9]+]] = llvm.bitcast %[[T50]] : vector<[8]xi32> to vector<[4]xi64> +// CHECK-NEXT: %[[T52:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T51]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T53:[0-9]+]] = llvm.extractvalue %[[T52]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T54:[0-9]+]] = llvm.extractvalue %[[T52]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T55:[0-9]+]] = llvm.bitcast %[[T53]] : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[T56:[0-9]+]] = llvm.bitcast %[[T54]] : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[T57:[0-9]+]] = llvm.insertvalue %[[T55]], %[[T48]][2] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T58:[0-9]+]] = llvm.insertvalue %[[T56]], %[[T57]][3] : !llvm.array<4 x vector<[4]xi32>> + +// Test a variant where the sign-extension of the operands is +// implicit. The output is identical to the one of the previous test. +func.func @test_vector_contract_to_smmla_implicit_sext(%lhs: vector<4x8xi8>, + %rhs: vector<[4]x8xi8>, + %acc: vector<4x[4]xi32>) -> vector<4x[4]xi32> { + + %0 = vector.contract #attrs %lhs, %rhs, %acc + : vector<4x8xi8>, vector<[4]x8xi8> into vector<4x[4]xi32> + + return %0 : vector<4x[4]xi32> +} diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir new file mode 100644 index 0000000000000..b6285d068b0f8 --- /dev/null +++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir @@ -0,0 +1,85 @@ +// RUN: mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' --split-input-file | FileCheck %s + +#packed_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> +] + +// CHECK-LABEL: @test_vector_contract_to_usmmla_rev + +// Extract LHS rows 0 and 1, concatenate, turn into scalable vector +// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T1:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T1]][1] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8> +// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T5:[0-9]+]][0] : vector<16xi8> into vector<[16]xi8> + +// Replicate across the entire length of the scalabale vector +// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> + +// Same for LHS rows 2 and 4 +// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T1]][2] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T1]][3] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8> +// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T5]][0] : vector<16xi8> into vector<[16]xi8> +// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> + + +// Extract sub-tiles from the RHS +// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8> +// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8> +// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8> + +// Extract accumulator rows 0 and 1 and pack (into "registers") +// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T0:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extractvalue %[[T0]][1] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T21:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T19]], %[[T20]]) : (vector<[4]xi32>, vector<[4]xi32>) -> vector<[8]xi32> +// CHECK-NEXT: %[[T22:[0-9]+]] = llvm.intr.vector.extract %[[T21]][0] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[T23:[0-9]+]] = llvm.intr.vector.extract %[[T21]][4] : vector<[4]xi32> from vector<[8]xi32> + +// Same for accumulator rows 2 and 3. +// CHECK-NEXT: %[[T24:[0-9]+]] = llvm.extractvalue %[[T0]][2] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T25:[0-9]+]] = llvm.extractvalue %[[T0]][3] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T26:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T24]], %[[T25]]) : (vector<[4]xi32>, vector<[4]xi32>) -> vector<[8]xi32> +// CHECK-NEXT: %[[T27:[0-9]+]] = llvm.intr.vector.extract %[[T26]][0] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[T28:[0-9]+]] = llvm.intr.vector.extract %[[T26]][4] : vector<[4]xi32> from vector<[8]xi32> + +// Do the sub-tile matrix multiplications +// CHECK-NEXT: %[[T29:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T22]], %[[T17]], %[[T10]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> +// CHECK-NEXT: %[[T30:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T23]], %[[T18]], %[[T10]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> +// CHECK-NEXT: %[[T31:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T27]], %[[T17]], %[[T15]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> +// CHECK-NEXT: %[[T32:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T28]], %[[T18]], %[[T15]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> + +// Unpack (from "registers") and insert in the output result rows 0 and 1 +// CHECK-NEXT: %[[T33:[0-9]+]] = llvm.intr.vector.insert %[[T29]], %[[T2:[0-9]+]][0] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T34:[0-9]+]] = llvm.intr.vector.insert %[[T30]], %[[T33]][4] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T35:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T34]]) : (vector<[8]xi32>) -> !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)> +// CHECK-NEXT: %[[T36:[0-9]+]] = llvm.extractvalue %[[T35]][0] : !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)> +// CHECK-NEXT: %[[T37:[0-9]+]] = llvm.extractvalue %[[T35]][1] : !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)> +// CHECK-NEXT: %[[T38:[0-9]+]] = llvm.insertvalue %[[T36]], %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T39:[0-9]+]] = llvm.insertvalue %[[T37]], %[[T38]][1] : !llvm.array<4 x vector<[4]xi32>> + +// Same for result rows 2 and 3 +// CHECK-NEXT: %[[T40:[0-9]+]] = llvm.intr.vector.insert %[[T31]], %[[T2]][0] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T41:[0-9]+]] = llvm.intr.vector.insert %[[T32]], %[[T40]][4] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T42:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T41]]) : (vector<[8]xi32>) -> !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)> +// CHECK-NEXT: %[[T43:[0-9]+]] = llvm.extractvalue %[[T42]][0] : !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)> +// CHECK-NEXT: %[[T44:[0-9]+]] = llvm.extractvalue %[[T42]][1] : !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)> +// CHECK-NEXT: %[[T45:[0-9]+]] = llvm.insertvalue %[[T43]], %[[T39]][2] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T46:[0-9]+]] = llvm.insertvalue %[[T44]], %[[T45]][3] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T47:[0-9]+]] = builtin.unrealized_conversion_cast %[[T46]] : !llvm.array<4 x vector<[4]xi32>> to vector<4x[4]xi32> + +func.func @test_vector_contract_to_usmmla_rev( + %lhs: vector<4x8xi8>, + %rhs: vector<[4]x8xi8>, + %acc: vector<4x[4]xi32>) -> vector<4x[4]xi32> { + + %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32> + %1 = arith.extui %rhs : vector<[4]x8xi8> to vector<[4]x8xi32> + %2 = vector.contract {indexing_maps = #packed_maps, + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %0, %1, %acc + : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32> + + return %2 : vector<4x[4]xi32> +} diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir new file mode 100644 index 0000000000000..cde57842295f7 --- /dev/null +++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir @@ -0,0 +1,94 @@ +// RUN: mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' --split-input-file | FileCheck %s + +#packed_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> +] + +// CHECK-LABEL: @test_vector_contract_to_ummla + +// Extract LHS rows 0 and 1, concatenate, turn into scalable vector +// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T4]][1] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8> +// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T0:[0-9+]]][0] : vector<16xi8> into vector<[16]xi8> + +// Replicate across the entire length of the scalabale vector +// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> + +// Same for LHS rows 2 and 4 +// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T4]][2] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T4]][3] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8> +// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T0]][0] : vector<16xi8> into vector<[16]xi8> +// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> + +// Extract sub-tiles from the RHS +// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8> +// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8> +// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8> + +// Extract accumulator rows 0 and 1 and pack (into "registers") +// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T3:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extractvalue %[[T3]][1] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T21:[0-9]+]] = llvm.bitcast %[[T19]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T22:[0-9]+]] = llvm.bitcast %[[T20]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T23:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T21]], %[[T22]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64> +// CHECK-NEXT: %[[T24:[0-9]+]] = llvm.bitcast %[[T23]] : vector<[4]xi64> to vector<[8]xi32> +// CHECK-NEXT: %[[T25:[0-9]+]] = llvm.intr.vector.extract %[[T24]][0] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[T26:[0-9]+]] = llvm.intr.vector.extract %[[T24]][4] : vector<[4]xi32> from vector<[8]xi32> + +// Same for accumulator rows 2 and 3. +// CHECK-NEXT: %[[T27:[0-9]+]] = llvm.extractvalue %[[T3]][2] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T28:[0-9]+]] = llvm.extractvalue %[[T3]][3] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T29:[0-9]+]] = llvm.bitcast %[[T27]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T30:[0-9]+]] = llvm.bitcast %[[T28]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T31:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T29]], %[[T30]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64> +// CHECK-NEXT: %[[T32:[0-9]+]] = llvm.bitcast %[[T31]] : vector<[4]xi64> to vector<[8]xi32> +// CHECK-NEXT: %[[T33:[0-9]+]] = llvm.intr.vector.extract %[[T32]][0] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[T34:[0-9]+]] = llvm.intr.vector.extract %[[T32]][4] : vector<[4]xi32> from vector<[8]xi32> + +// Do the sub-tile matrix multiplications +// CHECK-NEXT: %[[T35:[0-9]+]] = "arm_sve.intr.ummla"(%[[T25]], %[[T10]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> +// CHECK-NEXT: %[[T36:[0-9]+]] = "arm_sve.intr.ummla"(%[[T26]], %[[T10]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> +// CHECK-NEXT: %[[T37:[0-9]+]] = "arm_sve.intr.ummla"(%[[T33]], %[[T15]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> +// CHECK-NEXT: %[[T38:[0-9]+]] = "arm_sve.intr.ummla"(%[[T34]], %[[T15]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> + +// Unpack (from "registers") and insert in the output result rows 0 and 1 +// CHECK-NEXT: %[[T39:[0-9]+]] = llvm.intr.vector.insert %[[T35]], %[[T2:[0-9]+]][0] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T40:[0-9]+]] = llvm.intr.vector.insert %[[T36]], %[[T39]][4] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T41:[0-9]+]] = llvm.bitcast %[[T40]] : vector<[8]xi32> to vector<[4]xi64> +// CHECK-NEXT: %[[T42:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T41]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T43:[0-9]+]] = llvm.extractvalue %[[T42]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T44:[0-9]+]] = llvm.extractvalue %[[T42]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T45:[0-9]+]] = llvm.bitcast %[[T43]] : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[T46:[0-9]+]] = llvm.bitcast %[[T44]] : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[T47:[0-9]+]] = llvm.insertvalue %[[T45]], %[[T5:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T48:[0-9]+]] = llvm.insertvalue %[[T46]], %[[T47]][1] : !llvm.array<4 x vector<[4]xi32>> + +// Same for result rows 2 and 3 +// CHECK-NEXT: %[[T49:[0-9]+]] = llvm.intr.vector.insert %[[T37]], %[[T2]][0] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T50:[0-9]+]] = llvm.intr.vector.insert %[[T38]], %[[T49]][4] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T51:[0-9]+]] = llvm.bitcast %[[T50]] : vector<[8]xi32> to vector<[4]xi64> +// CHECK-NEXT: %[[T52:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T51]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T53:[0-9]+]] = llvm.extractvalue %[[T52]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T54:[0-9]+]] = llvm.extractvalue %[[T52]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T55:[0-9]+]] = llvm.bitcast %[[T53]] : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[T56:[0-9]+]] = llvm.bitcast %[[T54]] : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[T57:[0-9]+]] = llvm.insertvalue %[[T55]], %[[T48]][2] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T58:[0-9]+]] = llvm.insertvalue %[[T56]], %[[T57]][3] : !llvm.array<4 x vector<[4]xi32>> + +func.func @test_vector_contract_to_ummla(%lhs: vector<4x8xi8>, + %rhs: vector<[4]x8xi8>, + %acc: vector<4x[4]xi32>) -> vector<4x[4]xi32> { + + %0 = arith.extui %lhs : vector<4x8xi8> to vector<4x8xi32> + %1 = arith.extui %rhs : vector<[4]x8xi8> to vector<[4]x8xi32> + %2 = vector.contract {indexing_maps = #packed_maps, + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %0, %1, %acc + : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32> + + return %2 : vector<4x[4]xi32> +} diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir new file mode 100644 index 0000000000000..d0eef9fb9769c --- /dev/null +++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir @@ -0,0 +1,95 @@ +// RUN: mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' --split-input-file | FileCheck %s + +#packed_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> +] + +// CHECK-LABEL: @test_vector_contract_to_usmmla + +// Extract LHS rows 0 and 1, concatenate, turn into scalable vector +// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T4]][1] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8> +// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T0:[0-9+]]][0] : vector<16xi8> into vector<[16]xi8> + +// Replicate across the entire length of the scalabale vector +// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> + +// Same for LHS rows 2 and 4 +// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T4]][2] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T4]][3] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8> +// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T0]][0] : vector<16xi8> into vector<[16]xi8> +// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> + +// Extract sub-tiles from the RHS +// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8> +// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8> +// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8> + +// Extract accumulator rows 0 and 1 and pack (into "registers") +// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T3:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extractvalue %[[T3]][1] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T21:[0-9]+]] = llvm.bitcast %[[T19]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T22:[0-9]+]] = llvm.bitcast %[[T20]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T23:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T21]], %[[T22]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64> +// CHECK-NEXT: %[[T24:[0-9]+]] = llvm.bitcast %[[T23]] : vector<[4]xi64> to vector<[8]xi32> +// CHECK-NEXT: %[[T25:[0-9]+]] = llvm.intr.vector.extract %[[T24]][0] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[T26:[0-9]+]] = llvm.intr.vector.extract %[[T24]][4] : vector<[4]xi32> from vector<[8]xi32> + +// Same for accumulator rows 2 and 3. +// CHECK-NEXT: %[[T27:[0-9]+]] = llvm.extractvalue %[[T3]][2] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T28:[0-9]+]] = llvm.extractvalue %[[T3]][3] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T29:[0-9]+]] = llvm.bitcast %[[T27]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T30:[0-9]+]] = llvm.bitcast %[[T28]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T31:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T29]], %[[T30]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64> +// CHECK-NEXT: %[[T32:[0-9]+]] = llvm.bitcast %[[T31]] : vector<[4]xi64> to vector<[8]xi32> +// CHECK-NEXT: %[[T33:[0-9]+]] = llvm.intr.vector.extract %[[T32]][0] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[T34:[0-9]+]] = llvm.intr.vector.extract %[[T32]][4] : vector<[4]xi32> from vector<[8]xi32> + +// Do the sub-tile matrix multiplications +// CHECK-NEXT: %[[T35:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T25]], %[[T10]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> +// CHECK-NEXT: %[[T36:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T26]], %[[T10]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> +// CHECK-NEXT: %[[T37:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T33]], %[[T15]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> +// CHECK-NEXT: %[[T38:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T34]], %[[T15]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> + +// Unpack (from "registers") and insert in the output result rows 0 and 1 +// CHECK-NEXT: %[[T39:[0-9]+]] = llvm.intr.vector.insert %[[T35]], %[[T2:[0-9]+]][0] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T40:[0-9]+]] = llvm.intr.vector.insert %[[T36]], %[[T39]][4] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T41:[0-9]+]] = llvm.bitcast %[[T40]] : vector<[8]xi32> to vector<[4]xi64> +// CHECK-NEXT: %[[T42:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T41]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T43:[0-9]+]] = llvm.extractvalue %[[T42]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T44:[0-9]+]] = llvm.extractvalue %[[T42]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T45:[0-9]+]] = llvm.bitcast %[[T43]] : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[T46:[0-9]+]] = llvm.bitcast %[[T44]] : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[T47:[0-9]+]] = llvm.insertvalue %[[T45]], %[[T5:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T48:[0-9]+]] = llvm.insertvalue %[[T46]], %[[T47]][1] : !llvm.array<4 x vector<[4]xi32>> + +// Same for result rows 2 and 3 +// CHECK-NEXT: %[[T49:[0-9]+]] = llvm.intr.vector.insert %[[T37]], %[[T2]][0] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T50:[0-9]+]] = llvm.intr.vector.insert %[[T38]], %[[T49]][4] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T51:[0-9]+]] = llvm.bitcast %[[T50]] : vector<[8]xi32> to vector<[4]xi64> +// CHECK-NEXT: %[[T52:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T51]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T53:[0-9]+]] = llvm.extractvalue %[[T52]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T54:[0-9]+]] = llvm.extractvalue %[[T52]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T55:[0-9]+]] = llvm.bitcast %[[T53]] : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[T56:[0-9]+]] = llvm.bitcast %[[T54]] : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[T57:[0-9]+]] = llvm.insertvalue %[[T55]], %[[T48]][2] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T58:[0-9]+]] = llvm.insertvalue %[[T56]], %[[T57]][3] : !llvm.array<4 x vector<[4]xi32>> + +func.func @test_vector_contract_to_usmmla( + %lhs: vector<4x8xi8>, + %rhs: vector<[4]x8xi8>, + %acc: vector<4x[4]xi32>) -> vector<4x[4]xi32> { + + %0 = arith.extui %lhs : vector<4x8xi8> to vector<4x8xi32> + %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32> + %2 = vector.contract {indexing_maps = #packed_maps, + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %0, %1, %acc + : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32> + + return %2 : vector<4x[4]xi32> +}