Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 48 additions & 59 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1263,23 +1263,20 @@ def Vector_ExtractStridedSliceOp :

// TODO: Tighten semantics so that masks and inbounds can't be used
// simultaneously within the same transfer op.
def Vector_TransferReadOp :
Vector_Op<"transfer_read", [
DeclareOpInterfaceMethods<VectorTransferOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
AttrSizedOperandSegments,
DestinationStyleOpInterface
]>,
Arguments<(ins AnyShaped:$source,
Variadic<Index>:$indices,
AffineMapAttr:$permutation_map,
AnyType:$padding,
Optional<VectorOfNonZeroRankOf<[I1]>>:$mask,
BoolArrayAttr:$in_bounds)>,
Results<(outs AnyVectorOfAnyRank:$vector)> {
def Vector_TransferReadOp
: Vector_Op<"transfer_read",
[DeclareOpInterfaceMethods<VectorTransferOpInterface>,
DeclareOpInterfaceMethods<
VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
AttrSizedOperandSegments, DestinationStyleOpInterface]>,
Arguments<(ins AnyShaped:$base, Variadic<Index>:$indices,
AffineMapAttr:$permutation_map, AnyType:$padding,
Optional<VectorOfNonZeroRankOf<[I1]>>:$mask,
BoolArrayAttr:$in_bounds)>,
Results<(outs AnyVectorOfAnyRank:$vector)> {

let summary = "Reads a supervector from memory into an SSA vector value.";

Expand Down Expand Up @@ -1468,30 +1465,25 @@ def Vector_TransferReadOp :
}];

let builders = [
/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
OpBuilder<(ins "VectorType":$vectorType,
"Value":$source,
"ValueRange":$indices,
"AffineMapAttr":$permutationMapAttr,
"ArrayAttr":$inBoundsAttr)>,
/// 2. Builder that sets padding to zero and an empty mask (variant without attrs).
OpBuilder<(ins "VectorType":$vectorType,
"Value":$source,
"ValueRange":$indices,
"AffineMap":$permutationMap,
CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
OpBuilder<(ins "VectorType":$vectorType,
"Value":$source,
"ValueRange":$indices,
"Value":$padding,
CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
/// 4. Builder that sets padding to zero and permutation map to
/// 'getMinorIdentityMap'.
OpBuilder<(ins "VectorType":$vectorType,
"Value":$source,
"ValueRange":$indices,
CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
/// 1. Builder that sets padding to zero and an empty mask (variant with
/// attrs).
OpBuilder<(ins "VectorType":$vectorType, "Value":$base,
"ValueRange":$indices, "AffineMapAttr":$permutationMapAttr,
"ArrayAttr":$inBoundsAttr)>,
/// 2. Builder that sets padding to zero and an empty mask (variant
/// without attrs).
OpBuilder<(ins "VectorType":$vectorType, "Value":$base,
"ValueRange":$indices, "AffineMap":$permutationMap,
CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
OpBuilder<(ins "VectorType":$vectorType, "Value":$base,
"ValueRange":$indices, "Value":$padding,
CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
/// 4. Builder that sets padding to zero and permutation map to
/// 'getMinorIdentityMap'.
OpBuilder<(ins "VectorType":$vectorType, "Value":$base,
"ValueRange":$indices,
CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
];

let extraClassDeclaration = [{
Expand All @@ -1511,23 +1503,20 @@ def Vector_TransferReadOp :

// TODO: Tighten semantics so that masks and inbounds can't be used
// simultaneously within the same transfer op.
def Vector_TransferWriteOp :
Vector_Op<"transfer_write", [
DeclareOpInterfaceMethods<VectorTransferOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
AttrSizedOperandSegments,
DestinationStyleOpInterface
]>,
Arguments<(ins AnyVectorOfAnyRank:$valueToStore,
AnyShaped:$source,
Variadic<Index>:$indices,
AffineMapAttr:$permutation_map,
Optional<VectorOfNonZeroRankOf<[I1]>>:$mask,
BoolArrayAttr:$in_bounds)>,
Results<(outs Optional<AnyRankedTensor>:$result)> {
def Vector_TransferWriteOp
: Vector_Op<"transfer_write",
[DeclareOpInterfaceMethods<VectorTransferOpInterface>,
DeclareOpInterfaceMethods<
VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
AttrSizedOperandSegments, DestinationStyleOpInterface]>,
Arguments<(ins AnyVectorOfAnyRank:$valueToStore, AnyShaped:$base,
Variadic<Index>:$indices, AffineMapAttr:$permutation_map,
Optional<VectorOfNonZeroRankOf<[I1]>>:$mask,
BoolArrayAttr:$in_bounds)>,
Results<(outs Optional<AnyRankedTensor>:$result)> {

let summary = "The vector.transfer_write op writes a supervector to memory.";

Expand Down Expand Up @@ -1663,7 +1652,7 @@ def Vector_TransferWriteOp :
/// ops of other dialects.
Value getValue() { return getVector(); }

MutableOperandRange getDpsInitsMutable() { return getSourceMutable(); }
MutableOperandRange getDpsInitsMutable() { return getBaseMutable(); }
}];

let hasFolder = 1;
Expand Down
104 changes: 46 additions & 58 deletions mlir/include/mlir/Interfaces/VectorInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -75,75 +75,66 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
}];
let cppNamespace = "::mlir";

let methods = [
InterfaceMethod<
/*desc=*/[{
let methods = [InterfaceMethod<
/*desc=*/[{
Return the `in_bounds` attribute name.
}],
/*retTy=*/"::mlir::StringRef",
/*methodName=*/"getInBoundsAttrName",
/*args=*/(ins)
>,
InterfaceMethod<
/*desc=*/[{
/*retTy=*/"::mlir::StringRef",
/*methodName=*/"getInBoundsAttrName",
/*args=*/(ins)>,
InterfaceMethod<
/*desc=*/[{
Return the `permutation_map` attribute name.
}],
/*retTy=*/"::mlir::StringRef",
/*methodName=*/"getPermutationMapAttrName",
/*args=*/(ins)
>,
InterfaceMethod<
/*desc=*/[{
/*retTy=*/"::mlir::StringRef",
/*methodName=*/"getPermutationMapAttrName",
/*args=*/(ins)>,
InterfaceMethod<
/*desc=*/[{
Return the optional in_bounds attribute that specifies for each vector
dimension whether it is in-bounds or not. (Broadcast dimensions are
always in-bounds).
}],
/*retTy=*/"::mlir::ArrayAttr",
/*methodName=*/"getInBounds",
/*args=*/(ins)
>,
InterfaceMethod<
/*desc=*/[{
/*retTy=*/"::mlir::ArrayAttr",
/*methodName=*/"getInBounds",
/*args=*/(ins)>,
InterfaceMethod<
/*desc=*/[{
Return the memref or ranked tensor operand that this operation operates
on. In case of a "read" operation, that's the source from which the
operation reads. In case of a "write" operation, that's the destination
into which the operation writes.
TODO: Change name of operand, which is not accurate for xfer_write.
}],
/*retTy=*/"::mlir::Value",
/*methodName=*/"getSource",
/*args=*/(ins)
>,
InterfaceMethod<
/*desc=*/[{
/*retTy=*/"::mlir::Value",
/*methodName=*/"getBase",
/*args=*/(ins)>,
InterfaceMethod<
/*desc=*/[{
Return the vector that this operation operates on. In case of a "read",
that's the vector OpResult. In case of a "write", that's the vector
operand value that is written by the op.
}],
/*retTy=*/"::mlir::Value",
/*methodName=*/"getVector",
/*args=*/(ins)
>,
InterfaceMethod<
/*desc=*/[{
/*retTy=*/"::mlir::Value",
/*methodName=*/"getVector",
/*args=*/(ins)>,
InterfaceMethod<
/*desc=*/[{
Return the type of the vector that this operation operates on.
}],
/*retTy=*/"::mlir::VectorType",
/*methodName=*/"getVectorType",
/*args=*/(ins)
>,
InterfaceMethod<
/*desc=*/[{
/*retTy=*/"::mlir::VectorType",
/*methodName=*/"getVectorType",
/*args=*/(ins)>,
InterfaceMethod<
/*desc=*/[{
Return the indices that specify the starting offsets into the source
operand. The starting offsets are guaranteed to be in-bounds.
}],
/*retTy=*/"::mlir::OperandRange",
/*methodName=*/"getIndices",
/*args=*/(ins)
>,
/*retTy=*/"::mlir::OperandRange",
/*methodName=*/"getIndices",
/*args=*/(ins)>,

InterfaceMethod<
/*desc=*/[{
InterfaceMethod<
/*desc=*/[{
Return the permutation map that describes the mapping of vector
dimensions to source dimensions, as well as broadcast dimensions.

Expand All @@ -162,20 +153,17 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
: memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
```
}],
/*retTy=*/"::mlir::AffineMap",
/*methodName=*/"getPermutationMap",
/*args=*/(ins)
>,
InterfaceMethod<
/*desc=*/[{
/*retTy=*/"::mlir::AffineMap",
/*methodName=*/"getPermutationMap",
/*args=*/(ins)>,
InterfaceMethod<
/*desc=*/[{
Return the mask operand if the op has a mask. Otherwise, return an
empty value.
}],
/*retTy=*/"Value",
/*methodName=*/"getMask",
/*args=*/(ins)
>
];
/*retTy=*/"Value",
/*methodName=*/"getMask",
/*args=*/(ins)>];

let extraSharedClassDeclaration = [{
/// Return a vector of all in_bounds values as booleans (one per vector
Expand Down Expand Up @@ -203,7 +191,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {

/// Return the shaped type of the "source" operand value.
::mlir::ShapedType getShapedType() {
return ::llvm::cast<::mlir::ShapedType>($_op.getSource().getType());
return ::llvm::cast<::mlir::ShapedType>($_op.getBase().getType());
}

/// Return the number of dimensions that participate in the permutation map.
Expand Down
12 changes: 6 additions & 6 deletions mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ struct TransferReadToArmSMELowering
return rewriter.notifyMatchFailure(transferReadOp,
"not a valid vector type for SME");

if (!llvm::isa<MemRefType>(transferReadOp.getSource().getType()))
if (!llvm::isa<MemRefType>(transferReadOp.getBase().getType()))
return rewriter.notifyMatchFailure(transferReadOp, "not a memref source");

// Out-of-bounds dims are not supported.
Expand All @@ -84,7 +84,7 @@ struct TransferReadToArmSMELowering
auto mask = transferReadOp.getMask();
auto padding = mask ? transferReadOp.getPadding() : nullptr;
rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
transferReadOp, vectorType, transferReadOp.getSource(),
transferReadOp, vectorType, transferReadOp.getBase(),
transferReadOp.getIndices(), padding, mask, layout);

return success();
Expand Down Expand Up @@ -128,7 +128,7 @@ struct TransferWriteToArmSMELowering
if (!arm_sme::isValidSMETileVectorType(vType))
return failure();

if (!llvm::isa<MemRefType>(writeOp.getSource().getType()))
if (!llvm::isa<MemRefType>(writeOp.getBase().getType()))
return failure();

// Out-of-bounds dims are not supported.
Expand All @@ -149,7 +149,7 @@ struct TransferWriteToArmSMELowering
: arm_sme::TileSliceLayout::Horizontal;

rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(),
writeOp, writeOp.getVector(), writeOp.getBase(), writeOp.getIndices(),
writeOp.getMask(), layout);
return success();
}
Expand Down Expand Up @@ -686,7 +686,7 @@ struct FoldTransferWriteOfExtractTileSlice

LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
PatternRewriter &rewriter) const final {
if (!isa<MemRefType>(writeOp.getSource().getType()))
if (!isa<MemRefType>(writeOp.getBase().getType()))
return rewriter.notifyMatchFailure(writeOp, "destination not a memref");

if (writeOp.hasOutOfBoundsDim())
Expand All @@ -713,7 +713,7 @@ struct FoldTransferWriteOfExtractTileSlice

rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
writeOp, extractTileSlice.getTile(),
extractTileSlice.getTileSliceIndex(), mask, writeOp.getSource(),
extractTileSlice.getTileSliceIndex(), mask, writeOp.getBase(),
writeOp.getIndices(), extractTileSlice.getLayout());
return success();
}
Expand Down
Loading