@@ -1003,6 +1003,97 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
10031003 vector::UnrollVectorOptions options;
10041004};
10051005
1006+ // / This pattern unrolls `vector.create_mask` operations into smaller mask
1007+ // / operations based on the target unroll shape. Each unrolled slice computes
1008+ // / its local mask size in each dimension (d) as:
1009+ // / min(max(originalMaskSize[d] - offset[d], 0), unrolledDimSize[d]).
1010+ // / Example:
1011+ // / Given a create_mask operation:
1012+ // / %0 = vector.create_mask %c6, %c10 : vector<8x16xi1> // mask first 6x10
1013+ // / elements
1014+ // /
1015+ // / and a target unroll shape of <4x8>, the pattern produces:
1016+ // /
1017+ // / %false = arith.constant dense<false> : vector<8x16xi1>
1018+ // /
1019+ // / Slice [0,0]:
1020+ // / mask size = min(max(6-0, 0), 4) x min(max(10-0, 0), 8) = 4x8
1021+ // / %mask00 = vector.create_mask %c4, %c8 : vector<4x8xi1>
1022+ // / %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1]
1023+ // / : vector<4x8xi1> into vector<8x16xi1>
1024+ // / Slice [0,8]:
1025+ // / mask size = min(max(6-0, 0), 4) x min(max(10-8, 0), 8) = 4x2
1026+ // / %mask01 = vector.create_mask %c4, %c2 : vector<4x8xi1>
1027+ // / %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1]
1028+ // / : vector<4x8xi1> into vector<8x16xi1>
1029+ // / Slice [4,0]:
1030+ // / mask size = min(max(6-4, 0), 4) x min(max(10-0, 0), 8) = 2x8
1031+ // / %mask10 = vector.create_mask %c2, %c8 : vector<4x8xi1>
1032+ // / %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1]
1033+ // / : vector<4x8xi1> into vector<8x16xi1>
1034+ // / Slice [4,8]:
1035+ // / mask size = min(max(6-4, 0), 4) x min(max(10-8, 0), 8) = 2x2
1036+ // / %mask11 = vector.create_mask %c2, %c2 : vector<4x8xi1>
1037+ // / %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1]
1038+ // / : vector<4x8xi1> into vector<8x16xi1>
1039+ struct UnrollCreateMaskPattern : public OpRewritePattern <vector::CreateMaskOp> {
1040+ UnrollCreateMaskPattern (MLIRContext *context,
1041+ const vector::UnrollVectorOptions &options,
1042+ PatternBenefit benefit = 1 )
1043+ : OpRewritePattern<vector::CreateMaskOp>(context, benefit),
1044+ options (options) {}
1045+
1046+ LogicalResult matchAndRewrite (vector::CreateMaskOp createMaskOp,
1047+ PatternRewriter &rewriter) const override {
1048+ auto targetShape = getTargetShape (options, createMaskOp);
1049+ if (!targetShape)
1050+ return failure ();
1051+
1052+ VectorType resultType = createMaskOp.getVectorType ();
1053+ SmallVector<int64_t > originalSize = *createMaskOp.getShapeForUnroll ();
1054+ Location loc = createMaskOp.getLoc ();
1055+
1056+ Value result = arith::ConstantOp::create (rewriter, loc, resultType,
1057+ rewriter.getZeroAttr (resultType));
1058+ VectorType targetVectorType =
1059+ VectorType::get (*targetShape, rewriter.getI1Type ());
1060+ SmallVector<int64_t > strides (targetShape->size (), 1 );
1061+
1062+ // In each dimension (d), each unrolled vector computes its mask size as:
1063+ // min(max(originalMaskOperands[d] - offset[d], 0), unrolledDimSize[d]).
1064+ for (SmallVector<int64_t > offsets :
1065+ StaticTileOffsetRange (originalSize, *targetShape)) {
1066+ SmallVector<Value> unrolledOperands;
1067+
1068+ for (auto [i, originalMaskOperand] :
1069+ llvm::enumerate (createMaskOp.getOperands ())) {
1070+ Value offsetVal =
1071+ arith::ConstantIndexOp::create (rewriter, loc, offsets[i]);
1072+ Value adjustedMaskSize = rewriter.createOrFold <arith::SubIOp>(
1073+ loc, originalMaskOperand, offsetVal);
1074+ Value zero = arith::ConstantIndexOp::create (rewriter, loc, 0 );
1075+ Value unrolledDimSize =
1076+ arith::ConstantIndexOp::create (rewriter, loc, (*targetShape)[i]);
1077+ Value nonNegative =
1078+ rewriter.createOrFold <arith::MaxSIOp>(loc, adjustedMaskSize, zero);
1079+ Value unrolledOperand = rewriter.createOrFold <arith::MinSIOp>(
1080+ loc, nonNegative, unrolledDimSize);
1081+ unrolledOperands.push_back (unrolledOperand);
1082+ }
1083+
1084+ auto unrolledMask = rewriter.createOrFold <vector::CreateMaskOp>(
1085+ loc, targetVectorType, unrolledOperands);
1086+ result = rewriter.createOrFold <vector::InsertStridedSliceOp>(
1087+ loc, unrolledMask, result, offsets, strides);
1088+ }
1089+ rewriter.replaceOp (createMaskOp, result);
1090+ return success ();
1091+ }
1092+
1093+ private:
1094+ vector::UnrollVectorOptions options;
1095+ };
1096+
10061097// / Checks whether extractShape is a contiguous slice of shape.
10071098// / For extractShape to be contiguous in shape:
10081099// / 1) All but the leading dimension of extractShape and shape must match
@@ -1202,8 +1293,9 @@ void mlir::vector::populateVectorUnrollPatterns(
12021293 UnrollReductionPattern, UnrollMultiReductionPattern,
12031294 UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
12041295 UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
1205- UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern>(
1206- patterns.getContext (), options, benefit);
1296+ UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern,
1297+ UnrollCreateMaskPattern>(patterns.getContext (), options,
1298+ benefit);
12071299}
12081300
12091301void mlir::vector::populateVectorToElementsUnrollPatterns (
0 commit comments