@@ -3937,6 +3937,23 @@ static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
39373937 return success ();
39383938}
39393939
3940+ template <typename TransferOp>
3941+ static LogicalResult foldTransferFullMask (TransferOp op) {
3942+ auto mask = op.getMask ();
3943+ if (!mask)
3944+ return failure ();
3945+
3946+ auto constantMask = mask.template getDefiningOp <vector::ConstantMaskOp>();
3947+ if (!constantMask)
3948+ return failure ();
3949+
3950+ if (!constantMask.isFullMask ())
3951+ return failure ();
3952+
3953+ op.getMaskMutable ().clear ();
3954+ return success ();
3955+ }
3956+
39403957// / ```
39413958// / %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
39423959// / : vector<1x4xf32>, tensor<4x4xf32>
@@ -3969,6 +3986,8 @@ OpFoldResult TransferReadOp::fold(FoldAdaptor) {
39693986 // / transfer_read(memrefcast) -> transfer_read
39703987 if (succeeded (foldTransferInBoundsAttribute (*this )))
39713988 return getResult ();
3989+ if (succeeded (foldTransferFullMask (*this )))
3990+ return getResult ();
39723991 if (succeeded (memref::foldMemRefCast (*this )))
39733992 return getResult ();
39743993 if (succeeded (tensor::foldTensorCast (*this )))
@@ -4334,6 +4353,8 @@ LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
43344353 return success ();
43354354 if (succeeded (foldTransferInBoundsAttribute (*this )))
43364355 return success ();
4356+ if (succeeded (foldTransferFullMask (*this )))
4357+ return success ();
43374358 return memref::foldMemRefCast (*this );
43384359}
43394360
@@ -5601,6 +5622,22 @@ LogicalResult ConstantMaskOp::verify() {
56015622 return success ();
56025623}
56035624
5625+ bool ConstantMaskOp::isFullMask () {
5626+ auto resultType = getVectorType ();
5627+ // Check the corner case of 0-D vectors first.
5628+ if (resultType.getRank () == 0 ) {
5629+ assert (getMaskDimSizes ().size () == 1 && " invalid sizes for zero rank mask" );
5630+ return llvm::cast<IntegerAttr>(getMaskDimSizes ()[0 ]).getInt () == 1 ;
5631+ }
5632+ for (const auto [resultSize, intAttr] :
5633+ llvm::zip_equal (resultType.getShape (), getMaskDimSizes ())) {
5634+ int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt ();
5635+ if (maskDimSize < resultSize)
5636+ return false ;
5637+ }
5638+ return true ;
5639+ }
5640+
56045641// ===----------------------------------------------------------------------===//
56055642// CreateMaskOp
56065643// ===----------------------------------------------------------------------===//
0 commit comments