@@ -5532,41 +5532,34 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
55325532 setResultRanges (getResult (), argRanges.front ());
55335533}
55345534
5535- // / Returns true if each element of 'a' is equal to the product of a contiguous
5536- // / sequence of the elements of 'b'. Returns false otherwise.
5535+ // / Returns true if each element of 'a' is either 1 or equal to the product of a
5536+ // / contiguous sequence of the elements of 'b'. Returns false otherwise.
5537+ // /
5538+ // / This function assumes that the product of elements in a and b are the same.
55375539static bool isExpandingShapeCast (ArrayRef<int64_t > a, ArrayRef<int64_t > b) {
5538- unsigned rankA = a.size ();
5539- unsigned rankB = b.size ();
5540- if (rankA > rankB) {
5541- return false ;
5542- }
5543-
5544- auto isOne = [](int64_t v) { return v == 1 ; };
5545-
5546- // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
5547- // casted to a 0-d vector.
5548- if (rankA == 0 && llvm::all_of (b, isOne))
5549- return true ;
55505540
5541+ unsigned rankA = a.size ();
55515542 unsigned i = 0 ;
55525543 unsigned j = 0 ;
5553- while (i < rankA && j < rankB) {
5544+ while (i < rankA) {
5545+ if (a[i] == 1 ) {
5546+ ++i;
5547+ continue ;
5548+ }
5549+
55545550 int64_t dimA = a[i];
55555551 int64_t dimB = 1 ;
5556- while (dimB < dimA && j < rankB)
5552+
5553+ while (dimB < dimA) {
55575554 dimB *= b[j++];
5558- if (dimA != dimB)
5559- break ;
5560- ++i;
5555+ }
55615556
5562- // Handle the case when trailing dimensions are of size 1.
5563- // Include them into the contiguous sequence.
5564- if (i < rankA && llvm::all_of (a.slice (i), isOne))
5565- i = rankA;
5566- if (j < rankB && llvm::all_of (b.slice (j), isOne))
5567- j = rankB;
5557+ if (dimA != dimB) {
5558+ return false ;
5559+ }
5560+ ++i;
55685561 }
5569- return i == rankA && j == rankB ;
5562+ return true ;
55705563}
55715564
55725565static bool isValidShapeCast (ArrayRef<int64_t > a, ArrayRef<int64_t > b) {
@@ -5582,7 +5575,8 @@ static LogicalResult verifyVectorShapeCast(Operation *op,
55825575 ArrayRef<int64_t > inShape = sourceVectorType.getShape ();
55835576 ArrayRef<int64_t > outShape = resultVectorType.getShape ();
55845577
5585- // Check that product of source dim sizes matches product of result dim sizes.
5578+ // Check that product of source dim sizes matches product of result dim
5579+ // sizes.
55865580 int64_t nInElms = std::accumulate (inShape.begin (), inShape.end (), 1LL ,
55875581 std::multiplies<int64_t >{});
55885582 int64_t nOutElms = std::accumulate (outShape.begin (), outShape.end (), 1LL ,
@@ -5702,8 +5696,8 @@ static VectorType trimTrailingOneDims(VectorType oldType) {
57025696// /
57035697// / Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
57045698// / dimension. If the input vector comes from `vector.create_mask` for which
5705- // / the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
5706- // / to fold shape_cast into create_mask.
5699+ // / the corresponding mask input value is 1 (e.g. `%c1` below), then it is
5700+ // / safe to fold shape_cast into create_mask.
57075701// /
57085702// / BEFORE:
57095703// / %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
@@ -5970,8 +5964,8 @@ LogicalResult TypeCastOp::verify() {
59705964 auto resultType = getResultMemRefType ();
59715965 if (getElementTypeOrSelf (getElementTypeOrSelf (sourceType)) !=
59725966 getElementTypeOrSelf (getElementTypeOrSelf (resultType)))
5973- return emitOpError (
5974- " expects result and operand with same underlying scalar type: " )
5967+ return emitOpError (" expects result and operand with same underlying "
5968+ " scalar type: " )
59755969 << resultType;
59765970 if (extractShape (sourceType) != extractShape (resultType))
59775971 return emitOpError (
@@ -6009,7 +6003,8 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
60096003 return attr.reshape (getResultVectorType ());
60106004
60116005 // Eliminate identity transpose ops. This happens when the dimensions of the
6012- // input vector remain in their original order after the transpose operation.
6006+ // input vector remain in their original order after the transpose
6007+ // operation.
60136008 ArrayRef<int64_t > perm = getPermutation ();
60146009
60156010 // Check if the permutation of the dimensions contains sequential values:
@@ -6068,7 +6063,8 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
60686063 return result;
60696064 };
60706065
6071- // Return if the input of 'transposeOp' is not defined by another transpose.
6066+ // Return if the input of 'transposeOp' is not defined by another
6067+ // transpose.
60726068 vector::TransposeOp parentTransposeOp =
60736069 transposeOp.getVector ().getDefiningOp <vector::TransposeOp>();
60746070 if (!parentTransposeOp)
@@ -6212,8 +6208,9 @@ LogicalResult ConstantMaskOp::verify() {
62126208 return emitOpError (
62136209 " only supports 'none set' or 'all set' scalable dimensions" );
62146210 }
6215- // Verify that if one mask dim size is zero, they all should be zero (because
6216- // the mask region is a conjunction of each mask dimension interval).
6211+ // Verify that if one mask dim size is zero, they all should be zero
6212+ // (because the mask region is a conjunction of each mask dimension
6213+ // interval).
62176214 bool anyZeros = llvm::is_contained (maskDimSizes, 0 );
62186215 bool allZeros = llvm::all_of (maskDimSizes, [](int64_t s) { return s == 0 ; });
62196216 if (anyZeros && !allZeros)
@@ -6251,7 +6248,8 @@ void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
62516248
62526249LogicalResult CreateMaskOp::verify () {
62536250 auto vectorType = llvm::cast<VectorType>(getResult ().getType ());
6254- // Verify that an operand was specified for each result vector each dimension.
6251+ // Verify that an operand was specified for each result vector each
6252+ // dimension.
62556253 if (vectorType.getRank () == 0 ) {
62566254 if (getNumOperands () != 1 )
62576255 return emitOpError (
@@ -6458,8 +6456,8 @@ void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
64586456void MaskOp::ensureTerminator (Region ®ion, Builder &builder, Location loc) {
64596457 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
64606458 MaskOp>::ensureTerminator (region, builder, loc);
6461- // Keep the default yield terminator if the number of masked operations is not
6462- // the expected. This case will trigger a verification failure.
6459+ // Keep the default yield terminator if the number of masked operations is
6460+ // not the expected. This case will trigger a verification failure.
64636461 Block &block = region.front ();
64646462 if (block.getOperations ().size () != 2 )
64656463 return ;
@@ -6563,9 +6561,9 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor,
65636561 return success ();
65646562}
65656563
6566- // Elides empty vector.mask operations with or without return values. Propagates
6567- // the yielded values by the vector.yield terminator, if any, or erases the op,
6568- // otherwise.
6564+ // Elides empty vector.mask operations with or without return values.
6565+ // Propagates the yielded values by the vector.yield terminator, if any, or
6566+ // erases the op, otherwise.
65696567class ElideEmptyMaskOp : public OpRewritePattern <MaskOp> {
65706568 using OpRewritePattern::OpRewritePattern;
65716569
@@ -6668,7 +6666,8 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
66686666 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
66696667 return {};
66706668
6671- // SplatElementsAttr::get treats single value for second arg as being a splat.
6669+ // SplatElementsAttr::get treats single value for second arg as being a
6670+ // splat.
66726671 return SplatElementsAttr::get (getType (), {constOperand});
66736672}
66746673
@@ -6790,12 +6789,12 @@ Operation *mlir::vector::maskOperation(OpBuilder &builder,
67906789}
67916790
67926791// / Creates a vector select operation that picks values from `newValue` or
6793- // / `passthru` for each result vector lane based on `mask`. This utility is used
6794- // / to propagate the pass-thru value of vector.mask or for cases where only the
6795- // / pass-thru value propagation is needed. VP intrinsics do not support
6796- // / pass-thru values and every mask-out lane is set to poison. LLVM backends are
6797- // / usually able to match op + select patterns and fold them into a native
6798- // / target instructions.
6792+ // / `passthru` for each result vector lane based on `mask`. This utility is
6793+ // / used to propagate the pass-thru value of vector.mask or for cases where
6794+ // / only the pass-thru value propagation is needed. VP intrinsics do not
6795+ // / support pass-thru values and every mask-out lane is set to poison. LLVM
6796+ // / backends are usually able to match op + select patterns and fold them into
6797+ // / a native target instructions.
67996798Value mlir::vector::selectPassthru (OpBuilder &builder, Value mask,
68006799 Value newValue, Value passthru) {
68016800 if (!mask)
0 commit comments