Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 57 additions & 38 deletions mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,9 @@ class ScalableMaskedIOp<string mnemonic, string op_description,
"$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)";
}

def SdotOp : ArmSVE_Op<"sdot",
[Pure,
AllTypesMatch<["src1", "src2"]>,
AllTypesMatch<["acc", "dst"]>,
]> {
def SdotOp : ArmSVE_Op<"sdot", [Pure,
AllTypesMatch<["src1", "src2"]>,
AllTypesMatch<["acc", "dst"]>]> {
let summary = "Vector-vector dot product and accumulate op";
let description = [{
SDOT: Signed integer addition of dot product.
Expand All @@ -178,11 +176,9 @@ def SdotOp : ArmSVE_Op<"sdot",
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}

def SmmlaOp : ArmSVE_Op<"smmla",
[Pure,
AllTypesMatch<["src1", "src2"]>,
AllTypesMatch<["acc", "dst"]>,
]> {
def SmmlaOp : ArmSVE_Op<"smmla", [Pure,
AllTypesMatch<["src1", "src2"]>,
AllTypesMatch<["acc", "dst"]>]> {
let summary = "Matrix-matrix multiply and accumulate op";
let description = [{
SMMLA: Signed integer matrix multiply-accumulate.
Expand Down Expand Up @@ -210,11 +206,9 @@ def SmmlaOp : ArmSVE_Op<"smmla",
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}

def UdotOp : ArmSVE_Op<"udot",
[Pure,
AllTypesMatch<["src1", "src2"]>,
AllTypesMatch<["acc", "dst"]>,
]> {
def UdotOp : ArmSVE_Op<"udot", [Pure,
AllTypesMatch<["src1", "src2"]>,
AllTypesMatch<["acc", "dst"]>]> {
let summary = "Vector-vector dot product and accumulate op";
let description = [{
UDOT: Unsigned integer addition of dot product.
Expand All @@ -241,11 +235,9 @@ def UdotOp : ArmSVE_Op<"udot",
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}

def UmmlaOp : ArmSVE_Op<"ummla",
[Pure,
AllTypesMatch<["src1", "src2"]>,
AllTypesMatch<["acc", "dst"]>,
]> {
def UmmlaOp : ArmSVE_Op<"ummla", [Pure,
AllTypesMatch<["src1", "src2"]>,
AllTypesMatch<["acc", "dst"]>]> {
let summary = "Matrix-matrix multiply and accumulate op";
let description = [{
UMMLA: Unsigned integer matrix multiply-accumulate.
Expand Down Expand Up @@ -273,14 +265,42 @@ def UmmlaOp : ArmSVE_Op<"ummla",
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}

def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure,
AllTypesMatch<["src1", "src2"]>,
AllTypesMatch<["acc", "dst"]>]> {
let summary = "Matrix-matrix multiply and accumulate op";
let description = [{
USMMLA: Unsigned by signed integer matrix multiply-accumulate.

The unsigned by signed integer matrix multiply-accumulate operation
multiplies the 2×8 matrix of unsigned 8-bit integer values held
the first source vector by the 8×2 matrix of signed 8-bit integer
values in the second source vector. The resulting 2×2 widened 32-bit
integer matrix product is then added to the 32-bit integer matrix
accumulator.

Source:
https://developer.arm.com/documentation/100987/0000
}];
// Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
let arguments = (ins
ScalableVectorOfLengthAndType<[4], [I32]>:$acc,
ScalableVectorOfLengthAndType<[16], [I8]>:$src1,
ScalableVectorOfLengthAndType<[16], [I8]>:$src2
);
let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst);
let assemblyFormat =
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}

class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
"expected corresponding svbool type widened to [16]xi1",
lhsArg, rhsArg,
"VectorType(VectorType::Builder(::llvm::cast<VectorType>($_self)).setDim(::llvm::cast<VectorType>($_self).getRank() - 1, 16))">;

def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool",
[Pure, SvboolTypeConstraint<"result", "source">]>
{
[Pure,
SvboolTypeConstraint<"result", "source">]> {
let summary = "Convert a svbool type to a SVE predicate type";
let description = [{
Converts svbool types (`vector<[16]xi1>` or vectors of that type, e.g.
Expand Down Expand Up @@ -313,8 +333,8 @@ def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool",
}

def ConvertToSvboolOp : ArmSVE_Op<"convert_to_svbool",
[Pure, SvboolTypeConstraint<"source", "result">]>
{
[Pure,
SvboolTypeConstraint<"source", "result">]> {
let summary = "Convert a SVE predicate type to a svbool type";
let description = [{
Converts SVE predicate types (or vectors of predicate types, e.g.
Expand Down Expand Up @@ -356,10 +376,9 @@ def ZipInputVectorType : AnyTypeOf<[
Scalable1DVectorOfLength<16, [I8]>],
"an SVE vector with element size <= 64-bit">;

def ZipX2Op : ArmSVE_Op<"zip.x2", [
Pure,
AllTypesMatch<["sourceV1", "sourceV2", "resultV1", "resultV2"]>]
> {
def ZipX2Op : ArmSVE_Op<"zip.x2", [Pure,
AllTypesMatch<["sourceV1", "sourceV2",
"resultV1", "resultV2"]>]> {
let summary = "Multi-vector two-way zip op";

let description = [{
Expand Down Expand Up @@ -400,12 +419,11 @@ def ZipX2Op : ArmSVE_Op<"zip.x2", [
}];
}

def ZipX4Op : ArmSVE_Op<"zip.x4", [
Pure,
AllTypesMatch<[
"sourceV1", "sourceV2", "sourceV3", "sourceV4",
"resultV1", "resultV2", "resultV3", "resultV4"]>]
> {
def ZipX4Op
: ArmSVE_Op<"zip.x4",
[Pure,
AllTypesMatch<["sourceV1", "sourceV2", "sourceV3", "sourceV4",
"resultV1", "resultV2", "resultV3", "resultV4"]>]> {
let summary = "Multi-vector four-way zip op";

let description = [{
Expand Down Expand Up @@ -463,10 +481,7 @@ def ZipX4Op : ArmSVE_Op<"zip.x4", [
}];
}

def PselOp : ArmSVE_Op<"psel", [
Pure,
AllTypesMatch<["p1", "result"]>,
]> {
def PselOp : ArmSVE_Op<"psel", [Pure, AllTypesMatch<["p1", "result"]>]> {
let summary = "Predicate select";

let description = [{
Expand Down Expand Up @@ -571,6 +586,10 @@ def SmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"smmla">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;

def UsmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"usmmla">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;

def SdotIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"sdot">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
using UsmmlaOpLowering = OneToOneConvertToLLVMPattern<UsmmlaOp, UsmmlaIntrOp>;
using DupQLaneLowering =
OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
using ScalableMaskedAddIOpLowering =
Expand Down Expand Up @@ -206,6 +207,7 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
SmmlaOpLowering,
UdotOpLowering,
UmmlaOpLowering,
UsmmlaOpLowering,
ZipX2OpLowering,
ZipX4OpLowering,
SdotOpLowering>(converter);
Expand Down Expand Up @@ -234,6 +236,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
SmmlaIntrOp,
UdotIntrOp,
UmmlaIntrOp,
UsmmlaIntrOp,
WhileLTIntrOp,
ZipX2IntrOp,
ZipX4IntrOp,
Expand All @@ -254,6 +257,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
SmmlaOp,
UdotOp,
UmmlaOp,
UsmmlaOp,
ZipX2Op,
ZipX4Op,
SdotOp>();
Expand Down
12 changes: 12 additions & 0 deletions mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,18 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>,

// -----

func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
%b: vector<[16]xi8>,
%c: vector<[4]xi32>)
-> vector<[4]xi32> {
// CHECK: arm_sve.intr.usmmla
%0 = arm_sve.usmmla %c, %a, %b :
vector<[16]xi8> to vector<[4]xi32>
return %0 : vector<[4]xi32>
}

// -----

func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>,
%b: vector<[4]xi32>,
%c: vector<[4]xi32>,
Expand Down
11 changes: 11 additions & 0 deletions mlir/test/Dialect/ArmSVE/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>,

// -----

func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
%b: vector<[16]xi8>,
%c: vector<[4]xi32>) -> vector<[4]xi32> {
// CHECK: arm_sve.usmmla {{.*}}: vector<[16]xi8> to vector<[4]xi3
%0 = arm_sve.usmmla %c, %a, %b :
vector<[16]xi8> to vector<[4]xi32>
return %0 : vector<[4]xi32>
}

// -----

func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>,
%b: vector<[4]xi32>,
%c: vector<[4]xi32>,
Expand Down
12 changes: 12 additions & 0 deletions mlir/test/Target/LLVMIR/arm-sve.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,18 @@ llvm.func @arm_sve_ummla(%arg0: vector<[16]xi8>,
llvm.return %0 : vector<[4]xi32>
}

// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_usmmla
llvm.func @arm_sve_usmmla(%arg0: vector<[16]xi8>,
%arg1: vector<[16]xi8>,
%arg2: vector<[4]xi32>)
-> vector<[4]xi32> {
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.usmmla.nxv4i32(<vscale x 4
%0 = "arm_sve.intr.usmmla"(%arg2, %arg0, %arg1) :
(vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>)
-> vector<[4]xi32>
llvm.return %0 : vector<[4]xi32>
}

// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi
llvm.func @arm_sve_arithi(%arg0: vector<[4]xi32>,
%arg1: vector<[4]xi32>,
Expand Down