Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def ArmNeon_Dialect : Dialect {
//===----------------------------------------------------------------------===//

class NeonVectorOfLength<int length, Type elementType> : ShapedContainerType<
[elementType], And<[IsVectorOfShape<[length]>, IsFixedVectorTypePred]>,
[elementType], And<[IsVectorOfShape<[length]>, IsFixedVectorOfAnyRankTypePred]>,
"a vector with length " # length,
"::mlir::VectorType">;

Expand Down
54 changes: 27 additions & 27 deletions mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,11 @@ class ScalableMaskedFOp<string mnemonic, string op_description,
op_description # [{ on active lanes. Inactive lanes will keep the value of
the first operand.}];
let arguments = (ins
ScalableVectorOf<[I1]>:$mask,
ScalableVectorOf<[AnyFloat]>:$src1,
ScalableVectorOf<[AnyFloat]>:$src2
ScalableVectorOfAnyRank<[I1]>:$mask,
ScalableVectorOfAnyRank<[AnyFloat]>:$src1,
ScalableVectorOfAnyRank<[AnyFloat]>:$src2
);
let results = (outs ScalableVectorOf<[AnyFloat]>:$res);
let results = (outs ScalableVectorOfAnyRank<[AnyFloat]>:$res);
let assemblyFormat =
"$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)";
}
Expand All @@ -123,11 +123,11 @@ class ScalableMaskedIOp<string mnemonic, string op_description,
op_description # [{ on active lanes. Inactive lanes will keep the value of
the first operand.}];
let arguments = (ins
ScalableVectorOf<[I1]>:$mask,
ScalableVectorOf<[I8, I16, I32, I64]>:$src1,
ScalableVectorOf<[I8, I16, I32, I64]>:$src2
ScalableVectorOfAnyRank<[I1]>:$mask,
ScalableVectorOfAnyRank<[I8, I16, I32, I64]>:$src1,
ScalableVectorOfAnyRank<[I8, I16, I32, I64]>:$src2
);
let results = (outs ScalableVectorOf<[I8, I16, I32, I64]>:$res);
let results = (outs ScalableVectorOfAnyRank<[I8, I16, I32, I64]>:$res);
let assemblyFormat =
"$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)";
}
Expand Down Expand Up @@ -511,55 +511,55 @@ def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">;

def UmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"ummla">,
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;

def SmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"smmla">,
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;

def SdotIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"sdot">,
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;

def UdotIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"udot">,
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;

def ScalableMaskedAddIIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"add">,
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;

def ScalableMaskedAddFIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"fadd">,
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;

def ScalableMaskedMulIIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"mul">,
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;

def ScalableMaskedMulFIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"fmul">,
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;

def ScalableMaskedSubIIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"sub">,
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;

def ScalableMaskedSubFIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"fsub">,
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;

def ScalableMaskedSDivIIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"sdiv">,
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;

def ScalableMaskedUDivIIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"udiv">,
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;

def ScalableMaskedDivFIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"fdiv">,
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;

def ConvertFromSvboolIntrOp :
ArmSVE_IntrOp<"convert.from.svbool",
Expand All @@ -581,19 +581,19 @@ def ZipX2IntrOp : ArmSVE_IntrOp<"zip.x2",
/*overloadedOperands=*/[0],
/*overloadedResults=*/[],
/*numResults=*/2>,
Arguments<(ins Arg<AnyScalableVector, "v1">:$v1,
Arg<AnyScalableVector, "v2">:$v2)>;
Arguments<(ins Arg<AnyScalableVectorOfAnyRank, "v1">:$v1,
Arg<AnyScalableVectorOfAnyRank, "v2">:$v2)>;

// Note: This multi-vector intrinsic requires SME2.
def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4",
/*traits=*/[],
/*overloadedOperands=*/[0],
/*overloadedResults=*/[],
/*numResults=*/4>,
Arguments<(ins Arg<AnyScalableVector, "v1">:$v1,
Arg<AnyScalableVector, "v2">:$v2,
Arg<AnyScalableVector, "v3">:$v3,
Arg<AnyScalableVector, "v3">:$v4)>;
Arguments<(ins Arg<AnyScalableVectorOfAnyRank, "v1">:$v1,
Arg<AnyScalableVectorOfAnyRank, "v2">:$v2,
Arg<AnyScalableVectorOfAnyRank, "v3">:$v3,
Arg<AnyScalableVectorOfAnyRank, "v3">:$v4)>;

// Note: This intrinsic requires SME or SVE2.1.
def PselIntrOp : ArmSVE_IntrOp<"psel",
Expand Down
22 changes: 12 additions & 10 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -417,16 +417,18 @@ def Vector_BroadcastOp :
let hasVerifier = 1;
}

def Vector_ShuffleOp :
Vector_Op<"shuffle", [Pure,
PredOpTrait<"first operand v1 and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
PredOpTrait<"second operand v2 and result have same element type",
TCresVTEtIsSameAsOpBase<0, 1>>,
InferTypeOpAdaptor]>,
Arguments<(ins AnyFixedVector:$v1, AnyFixedVector:$v2,
DenseI64ArrayAttr:$mask)>,
Results<(outs AnyVector:$vector)> {
def Vector_ShuffleOp
: Vector_Op<
"shuffle",
[Pure,
PredOpTrait<"first operand v1 and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
PredOpTrait<"second operand v2 and result have same element type",
TCresVTEtIsSameAsOpBase<0, 1>>,
InferTypeOpAdaptor]>,
Arguments<(ins AnyFixedVectorOfAnyRank:$v1, AnyFixedVectorOfAnyRank:$v2,
DenseI64ArrayAttr:$mask)>,
Results<(outs AnyVector:$vector)> {
let summary = "shuffle operation";
let description = [{
The shuffle operation constructs a permutation (or duplication) of elements
Expand Down
35 changes: 18 additions & 17 deletions mlir/include/mlir/IR/CommonTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def IsVectorTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
def IsVectorOfAnyRankTypePred : CPred<"::llvm::isa<::mlir::VectorType>($_self)">;

// Whether a type is a fixed-length VectorType.
def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
!::llvm::cast<VectorType>($_self).isScalable()}]>;

// Whether a type is a scalable VectorType.
Expand Down Expand Up @@ -438,11 +438,11 @@ class VectorOfAnyRankOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsVectorOfAnyRankTypePred, "vector",
"::mlir::VectorType">;

class FixedVectorOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsFixedVectorTypePred,
class FixedVectorOfAnyRank<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsFixedVectorOfAnyRankTypePred,
"fixed-length vector", "::mlir::VectorType">;

class ScalableVectorOf<list<Type> allowedTypes> :
class ScalableVectorOfAnyRank<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsVectorTypeWithAnyDimScalablePred,
"scalable vector", "::mlir::VectorType">;

Expand All @@ -467,7 +467,7 @@ class IsVectorOfRankPred<list<int> allowedRanks> :
// Whether the number of elements of a fixed-length vector is from the given
// `allowedRanks` list
class IsFixedVectorOfRankPred<list<int> allowedRanks> :
And<[IsFixedVectorTypePred,
And<[IsFixedVectorOfAnyRankTypePred,
Or<!foreach(allowedlength, allowedRanks,
CPred<[{::llvm::cast<::mlir::VectorType>($_self).getRank()
== }]
Expand Down Expand Up @@ -509,8 +509,8 @@ class VectorOfRankAndType<list<int> allowedRanks,
// the type is from the given `allowedTypes` list
class FixedVectorOfRankAndType<list<int> allowedRanks,
list<Type> allowedTypes> : AllOfType<
[FixedVectorOf<allowedTypes>, VectorOfRank<allowedRanks>],
FixedVectorOf<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
[FixedVectorOfAnyRank<allowedTypes>, VectorOfRank<allowedRanks>],
FixedVectorOfAnyRank<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
"::mlir::VectorType">;

// Whether the number of elements of a vector is from the given
Expand All @@ -525,7 +525,7 @@ class IsVectorOfLengthPred<list<int> allowedLengths> :
// Whether the number of elements of a fixed-length vector is from the given
// `allowedLengths` list
class IsFixedVectorOfLengthPred<list<int> allowedLengths> :
And<[IsFixedVectorTypePred,
And<[IsFixedVectorOfAnyRankTypePred,
Or<!foreach(allowedlength, allowedLengths,
CPred<[{::llvm::cast<::mlir::VectorType>($_self).getNumElements()
== }]
Expand Down Expand Up @@ -612,17 +612,17 @@ class VectorOfLengthAndType<list<int> allowedLengths,
// `allowedLengths` list and the type is from the given `allowedTypes` list
class FixedVectorOfLengthAndType<list<int> allowedLengths,
list<Type> allowedTypes> : AllOfType<
[FixedVectorOf<allowedTypes>, FixedVectorOfLength<allowedLengths>],
FixedVectorOf<allowedTypes>.summary #
[FixedVectorOfAnyRank<allowedTypes>, FixedVectorOfLength<allowedLengths>],
FixedVectorOfAnyRank<allowedTypes>.summary #
FixedVectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;

// Any scalable vector where the number of elements is from the given
// `allowedLengths` list and the type is from the given `allowedTypes` list
class ScalableVectorOfLengthAndType<list<int> allowedLengths,
list<Type> allowedTypes> : AllOfType<
[ScalableVectorOf<allowedTypes>, ScalableVectorOfLength<allowedLengths>],
ScalableVectorOf<allowedTypes>.summary #
[ScalableVectorOfAnyRank<allowedTypes>, ScalableVectorOfLength<allowedLengths>],
ScalableVectorOfAnyRank<allowedTypes>.summary #
ScalableVectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;

Expand All @@ -632,10 +632,10 @@ class ScalableVectorOfLengthAndType<list<int> allowedLengths,
class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
list<int> allowedLengths,
list<Type> allowedTypes> : AllOfType<
[ScalableVectorOfRank<allowedRanks>, ScalableVectorOf<allowedTypes>,
[ScalableVectorOfRank<allowedRanks>, ScalableVectorOfAnyRank<allowedTypes>,
ScalableVectorOfLength<allowedLengths>],
ScalableVectorOfRank<allowedRanks>.summary #
ScalableVectorOf<allowedTypes>.summary #
ScalableVectorOfAnyRank<allowedTypes>.summary #
ScalableVectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;

Expand All @@ -657,13 +657,14 @@ class VectorWithTrailingDimScalableOfSizeAndType<list<int> allowedTrailingSizes,
ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>.summary,
"::mlir::VectorType">;

// Unlike the following definitions, this one excludes 0-D vectors
def AnyVector : VectorOf<[AnyType]>;
// Temporary vector type clone that allows gradual transition to 0-D vectors.

def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;

def AnyFixedVector : FixedVectorOf<[AnyType]>;
def AnyFixedVectorOfAnyRank : FixedVectorOfAnyRank<[AnyType]>;

def AnyScalableVector : ScalableVectorOf<[AnyType]>;
def AnyScalableVectorOfAnyRank : ScalableVectorOfAnyRank<[AnyType]>;

// Shaped types.

Expand Down
Loading