diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 43ebcaa03a470..d8ed46c2820fe 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2605,7 +2605,9 @@ def Vector_ConstantMaskOp : } def Vector_CreateMaskOp : - Vector_Op<"create_mask", [Pure]>, + Vector_Op<"create_mask", [Pure, + DeclareOpInterfaceMethods + ]>, Arguments<(ins Variadic:$operands)>, Results<(outs VectorOfAnyRankOf<[I1]>)> { let summary = "creates a vector mask"; diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index b60f80534bfb6..462bd8c3dc4a6 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -1003,6 +1003,97 @@ struct UnrollFromElements : OpRewritePattern { vector::UnrollVectorOptions options; }; +/// This pattern unrolls `vector.create_mask` operations into smaller mask +/// operations based on the target unroll shape. Each unrolled slice computes +/// its local mask size in each dimension (d) as: +/// min(max(originalMaskSize[d] - offset[d], 0), unrolledDimSize[d]). +/// Example: +/// Given a create_mask operation: +/// %0 = vector.create_mask %c6, %c10 : vector<8x16xi1> // mask first 6x10 +/// elements +/// +/// and a target unroll shape of <4x8>, the pattern produces: +/// +/// %false = arith.constant dense : vector<8x16xi1> +/// +/// Slice [0,0]: +/// mask size = min(max(6-0, 0), 4) x min(max(10-0, 0), 8) = 4x8 +/// %mask00 = vector.create_mask %c4, %c8 : vector<4x8xi1> +/// %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1] +/// : vector<4x8xi1> into vector<8x16xi1> +/// Slice [0,8]: +/// mask size = min(max(6-0, 0), 4) x min(max(10-8, 0), 8) = 4x2 +/// %mask01 = vector.create_mask %c4, %c2 : vector<4x8xi1> +/// %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1] +/// : vector<4x8xi1> into vector<8x16xi1> +/// Slice [4,0]: +/// mask size = min(max(6-4, 0), 4) x min(max(10-0, 0), 8) = 2x8 +/// %mask10 = vector.create_mask %c2, %c8 : vector<4x8xi1> +/// %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1] +/// : vector<4x8xi1> into vector<8x16xi1> +/// Slice [4,8]: +/// mask size = min(max(6-4, 0), 4) x min(max(10-8, 0), 8) = 2x2 +/// %mask11 = vector.create_mask %c2, %c2 : vector<4x8xi1> +/// %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1] +/// : vector<4x8xi1> into vector<8x16xi1> +struct UnrollCreateMaskPattern : public OpRewritePattern { + UnrollCreateMaskPattern(MLIRContext *context, + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + options(options) {} + + LogicalResult matchAndRewrite(vector::CreateMaskOp createMaskOp, + PatternRewriter &rewriter) const override { + auto targetShape = getTargetShape(options, createMaskOp); + if (!targetShape) + return failure(); + + VectorType resultType = createMaskOp.getVectorType(); + SmallVector originalSize = *createMaskOp.getShapeForUnroll(); + Location loc = createMaskOp.getLoc(); + + Value result = arith::ConstantOp::create(rewriter, loc, resultType, + rewriter.getZeroAttr(resultType)); + VectorType targetVectorType = + VectorType::get(*targetShape, rewriter.getI1Type()); + SmallVector strides(targetShape->size(), 1); + + // In each dimension (d), each unrolled vector computes its mask size as: + // min(max(originalMaskOperands[d] - offset[d], 0), unrolledDimSize[d]). + for (SmallVector offsets : + StaticTileOffsetRange(originalSize, *targetShape)) { + SmallVector unrolledOperands; + + for (auto [i, originalMaskOperand] : + llvm::enumerate(createMaskOp.getOperands())) { + Value offsetVal = + arith::ConstantIndexOp::create(rewriter, loc, offsets[i]); + Value adjustedMaskSize = rewriter.createOrFold( + loc, originalMaskOperand, offsetVal); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value unrolledDimSize = + arith::ConstantIndexOp::create(rewriter, loc, (*targetShape)[i]); + Value nonNegative = + rewriter.createOrFold(loc, adjustedMaskSize, zero); + Value unrolledOperand = rewriter.createOrFold( + loc, nonNegative, unrolledDimSize); + unrolledOperands.push_back(unrolledOperand); + } + + auto unrolledMask = rewriter.createOrFold( + loc, targetVectorType, unrolledOperands); + result = rewriter.createOrFold( + loc, unrolledMask, result, offsets, strides); + } + rewriter.replaceOp(createMaskOp, result); + return success(); + } + +private: + vector::UnrollVectorOptions options; +}; + /// Checks whether extractShape is a contiguous slice of shape. /// For extractShape to be contiguous in shape: /// 1) All but the leading dimension of extractShape and shape must match @@ -1202,8 +1293,9 @@ void mlir::vector::populateVectorUnrollPatterns( UnrollReductionPattern, UnrollMultiReductionPattern, UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern, UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements, - UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern>( - patterns.getContext(), options, benefit); + UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern, + UnrollCreateMaskPattern>(patterns.getContext(), options, + benefit); } void mlir::vector::populateVectorToElementsUnrollPatterns( diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir index dec32e1c61a9b..805e66f133c59 100644 --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -497,6 +497,61 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3 // CHECK-NOT: arith.addf // CHECK: return +func.func @vector_create_mask(%size1: index, %size2: index) -> vector<16x16xi1> { + %0 = vector.create_mask %size1, %size2 : vector<16x16xi1> + return %0 : vector<16x16xi1> +} + +// CHECK-LABEL: func @vector_create_mask +// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<16x16xi1> +// CHECK: %[[CST:.*]] = arith.constant dense : vector<16x16xi1> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C8:.*]] = arith.constant 8 : index +// CHECK: %[[MAX0:.*]] = arith.maxsi %[[ARG0]], %[[C0]] : index +// CHECK: %[[MIN0:.*]] = arith.minsi %[[MAX0]], %[[C8]] : index +// CHECK: %[[MAX1:.*]] = arith.maxsi %[[ARG1]], %[[C0]] : index +// CHECK: %[[MIN1:.*]] = arith.minsi %[[MAX1]], %[[C8]] : index +// CHECK: %[[MASK00:.*]] = vector.create_mask %[[MIN0]], %[[MIN1]] : vector<8x8xi1> +// CHECK: %[[INS00:.*]] = vector.insert_strided_slice %[[MASK00]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: %[[MAX0_2:.*]] = arith.maxsi %[[ARG0]], %[[C0]] : index +// CHECK: %[[MIN0_2:.*]] = arith.minsi %[[MAX0_2]], %[[C8]] : index +// CHECK: %[[SUB1:.*]] = arith.subi %[[ARG1]], %[[C8]] : index +// CHECK: %[[MAX1_2:.*]] = arith.maxsi %[[SUB1]], %[[C0]] : index +// CHECK: %[[MIN1_2:.*]] = arith.minsi %[[MAX1_2]], %[[C8]] : index +// CHECK: %[[MASK01:.*]] = vector.create_mask %[[MIN0_2]], %[[MIN1_2]] : vector<8x8xi1> +// CHECK: %[[INS01:.*]] = vector.insert_strided_slice %[[MASK01]], %[[INS00]] {offsets = [0, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: %[[SUB0:.*]] = arith.subi %[[ARG0]], %[[C8]] : index +// CHECK: %[[MAX0_3:.*]] = arith.maxsi %[[SUB0]], %[[C0]] : index +// CHECK: %[[MIN0_3:.*]] = arith.minsi %[[MAX0_3]], %[[C8]] : index +// CHECK: %[[MAX1_3:.*]] = arith.maxsi %[[ARG1]], %[[C0]] : index +// CHECK: %[[MIN1_3:.*]] = arith.minsi %[[MAX1_3]], %[[C8]] : index +// CHECK: %[[MASK10:.*]] = vector.create_mask %[[MIN0_3]], %[[MIN1_3]] : vector<8x8xi1> +// CHECK: %[[INS10:.*]] = vector.insert_strided_slice %[[MASK10]], %[[INS01]] {offsets = [8, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: %[[SUB0_2:.*]] = arith.subi %[[ARG0]], %[[C8]] : index +// CHECK: %[[MAX0_4:.*]] = arith.maxsi %[[SUB0_2]], %[[C0]] : index +// CHECK: %[[MIN0_4:.*]] = arith.minsi %[[MAX0_4]], %[[C8]] : index +// CHECK: %[[SUB1_2:.*]] = arith.subi %[[ARG1]], %[[C8]] : index +// CHECK: %[[MAX1_4:.*]] = arith.maxsi %[[SUB1_2]], %[[C0]] : index +// CHECK: %[[MIN1_4:.*]] = arith.minsi %[[MAX1_4]], %[[C8]] : index +// CHECK: %[[MASK11:.*]] = vector.create_mask %[[MIN0_4]], %[[MIN1_4]] : vector<8x8xi1> +// CHECK: %[[INS11:.*]] = vector.insert_strided_slice %[[MASK11]], %[[INS10]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: return %[[INS11]] : vector<16x16xi1> + +func.func @vector_create_mask_constant_dim_sizes() -> vector<16x16xi1> { + %cst16 = arith.constant 16 : index + %0 = vector.create_mask %cst16, %cst16 : vector<16x16xi1> + return %0 : vector<16x16xi1> +} + +// CHECK-LABEL: func @vector_create_mask_constant_dim_sizes() -> vector<16x16xi1> { +// CHECK: %[[CST:.*]] = arith.constant dense : vector<16x16xi1> +// CHECK: %[[CST_0:.*]] = arith.constant dense : vector<8x8xi1> +// CHECK: %[[S0:.*]] = vector.insert_strided_slice %[[CST_0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: %[[S1:.*]] = vector.insert_strided_slice %[[CST_0]], %[[S0]] {offsets = [0, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: %[[S2:.*]] = vector.insert_strided_slice %[[CST_0]], %[[S1]] {offsets = [8, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: %[[S3:.*]] = vector.insert_strided_slice %[[CST_0]], %[[S2]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: return %[[S3]] : vector<16x16xi1> + func.func @shape_cast_1D(%v: vector<16xf32>) -> vector<2x2x4xf32> { %0 = vector.shape_cast %v : vector<16xf32> to vector<2x2x4xf32> diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index e8ea0cc02d7f6..f834d0cdd42bd 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -178,6 +178,12 @@ struct TestVectorUnrollingPatterns .setFilterConstraint([](Operation *op) { return success(isa(op)); })); + populateVectorUnrollPatterns( + patterns, UnrollVectorOptions() + .setNativeShape(ArrayRef{8, 8}) + .setFilterConstraint([](Operation *op) { + return success(isa(op)); + })); populateVectorUnrollPatterns( patterns, UnrollVectorOptions()