Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2605,7 +2605,9 @@ def Vector_ConstantMaskOp :
}

def Vector_CreateMaskOp :
Vector_Op<"create_mask", [Pure]>,
Vector_Op<"create_mask", [Pure,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>
]>,
Arguments<(ins Variadic<Index>:$operands)>,
Results<(outs VectorOfAnyRankOf<[I1]>)> {
let summary = "creates a vector mask";
Expand Down
96 changes: 94 additions & 2 deletions mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,97 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
vector::UnrollVectorOptions options;
};

/// This pattern unrolls `vector.create_mask` operations into smaller mask
/// operations based on the target unroll shape. Each unrolled slice computes
/// its local mask size in each dimension (d) as:
/// min(max(originalMaskSize[d] - offset[d], 0), unrolledDimSize[d]).
/// Example:
/// Given a create_mask operation:
/// %0 = vector.create_mask %c6, %c10 : vector<8x16xi1> // mask first 6x10
/// elements
///
/// and a target unroll shape of <4x8>, the pattern produces:
///
/// %false = arith.constant dense<false> : vector<8x16xi1>
///
/// Slice [0,0]:
/// mask size = min(max(6-0, 0), 4) x min(max(10-0, 0), 8) = 4x8
/// %mask00 = vector.create_mask %c4, %c8 : vector<4x8xi1>
/// %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1]
/// : vector<4x8xi1> into vector<8x16xi1>
/// Slice [0,8]:
/// mask size = min(max(6-0, 0), 4) x min(max(10-8, 0), 8) = 4x2
/// %mask01 = vector.create_mask %c4, %c2 : vector<4x8xi1>
/// %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1]
/// : vector<4x8xi1> into vector<8x16xi1>
/// Slice [4,0]:
/// mask size = min(max(6-4, 0), 4) x min(max(10-0, 0), 8) = 2x8
/// %mask10 = vector.create_mask %c2, %c8 : vector<4x8xi1>
/// %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1]
/// : vector<4x8xi1> into vector<8x16xi1>
/// Slice [4,8]:
/// mask size = min(max(6-4, 0), 4) x min(max(10-8, 0), 8) = 2x2
/// %mask11 = vector.create_mask %c2, %c2 : vector<4x8xi1>
/// %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1]
/// : vector<4x8xi1> into vector<8x16xi1>
struct UnrollCreateMaskPattern : public OpRewritePattern<vector::CreateMaskOp> {
UnrollCreateMaskPattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::CreateMaskOp>(context, benefit),
options(options) {}

LogicalResult matchAndRewrite(vector::CreateMaskOp createMaskOp,
PatternRewriter &rewriter) const override {
auto targetShape = getTargetShape(options, createMaskOp);
if (!targetShape)
return failure();

VectorType resultType = createMaskOp.getVectorType();
SmallVector<int64_t> originalSize = *createMaskOp.getShapeForUnroll();
Location loc = createMaskOp.getLoc();

Value result = arith::ConstantOp::create(rewriter, loc, resultType,
rewriter.getZeroAttr(resultType));
VectorType targetVectorType =
VectorType::get(*targetShape, rewriter.getI1Type());
SmallVector<int64_t> strides(targetShape->size(), 1);

// In each dimension (d), each unrolled vector computes its mask size as:
// min(max(originalMaskOperands[d] - offset[d], 0), unrolledDimSize[d]).
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(originalSize, *targetShape)) {
SmallVector<Value> unrolledOperands;

for (auto [i, originalMaskOperand] :
llvm::enumerate(createMaskOp.getOperands())) {
Value offsetVal =
arith::ConstantIndexOp::create(rewriter, loc, offsets[i]);
Value adjustedMaskSize = rewriter.createOrFold<arith::SubIOp>(
loc, originalMaskOperand, offsetVal);
Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
Value unrolledDimSize =
arith::ConstantIndexOp::create(rewriter, loc, (*targetShape)[i]);
Value nonNegative =
rewriter.createOrFold<arith::MaxSIOp>(loc, adjustedMaskSize, zero);
Value unrolledOperand = rewriter.createOrFold<arith::MinSIOp>(
loc, nonNegative, unrolledDimSize);
unrolledOperands.push_back(unrolledOperand);
}

auto unrolledMask = rewriter.createOrFold<vector::CreateMaskOp>(
loc, targetVectorType, unrolledOperands);
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, unrolledMask, result, offsets, strides);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd need to use createOrFold for many if not all of the previous operations. For the example that you provided in the documentation, I'd expect everything to fold into a single constant mask op. We should make sure that such a folding happens in the tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing that out..changed it

}
rewriter.replaceOp(createMaskOp, result);
return success();
}

private:
vector::UnrollVectorOptions options;
};

/// Checks whether extractShape is a contiguous slice of shape.
/// For extractShape to be contiguous in shape:
/// 1) All but the leading dimension of extractShape and shape must match
Expand Down Expand Up @@ -1202,8 +1293,9 @@ void mlir::vector::populateVectorUnrollPatterns(
UnrollReductionPattern, UnrollMultiReductionPattern,
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern>(
patterns.getContext(), options, benefit);
UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern,
UnrollCreateMaskPattern>(patterns.getContext(), options,
benefit);
}

void mlir::vector::populateVectorToElementsUnrollPatterns(
Expand Down
55 changes: 55 additions & 0 deletions mlir/test/Dialect/Vector/vector-unroll-options.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,61 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3
// CHECK-NOT: arith.addf
// CHECK: return

func.func @vector_create_mask(%size1: index, %size2: index) -> vector<16x16xi1> {
%0 = vector.create_mask %size1, %size2 : vector<16x16xi1>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please, add a test with (full or partial?) constant indices to make sure that folding happens

return %0 : vector<16x16xi1>
}

// CHECK-LABEL: func @vector_create_mask
// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<16x16xi1>
// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<16x16xi1>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C8:.*]] = arith.constant 8 : index
// CHECK: %[[MAX0:.*]] = arith.maxsi %[[ARG0]], %[[C0]] : index
// CHECK: %[[MIN0:.*]] = arith.minsi %[[MAX0]], %[[C8]] : index
// CHECK: %[[MAX1:.*]] = arith.maxsi %[[ARG1]], %[[C0]] : index
// CHECK: %[[MIN1:.*]] = arith.minsi %[[MAX1]], %[[C8]] : index
// CHECK: %[[MASK00:.*]] = vector.create_mask %[[MIN0]], %[[MIN1]] : vector<8x8xi1>
// CHECK: %[[INS00:.*]] = vector.insert_strided_slice %[[MASK00]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
// CHECK: %[[MAX0_2:.*]] = arith.maxsi %[[ARG0]], %[[C0]] : index
// CHECK: %[[MIN0_2:.*]] = arith.minsi %[[MAX0_2]], %[[C8]] : index
// CHECK: %[[SUB1:.*]] = arith.subi %[[ARG1]], %[[C8]] : index
// CHECK: %[[MAX1_2:.*]] = arith.maxsi %[[SUB1]], %[[C0]] : index
// CHECK: %[[MIN1_2:.*]] = arith.minsi %[[MAX1_2]], %[[C8]] : index
// CHECK: %[[MASK01:.*]] = vector.create_mask %[[MIN0_2]], %[[MIN1_2]] : vector<8x8xi1>
// CHECK: %[[INS01:.*]] = vector.insert_strided_slice %[[MASK01]], %[[INS00]] {offsets = [0, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
// CHECK: %[[SUB0:.*]] = arith.subi %[[ARG0]], %[[C8]] : index
// CHECK: %[[MAX0_3:.*]] = arith.maxsi %[[SUB0]], %[[C0]] : index
// CHECK: %[[MIN0_3:.*]] = arith.minsi %[[MAX0_3]], %[[C8]] : index
// CHECK: %[[MAX1_3:.*]] = arith.maxsi %[[ARG1]], %[[C0]] : index
// CHECK: %[[MIN1_3:.*]] = arith.minsi %[[MAX1_3]], %[[C8]] : index
// CHECK: %[[MASK10:.*]] = vector.create_mask %[[MIN0_3]], %[[MIN1_3]] : vector<8x8xi1>
// CHECK: %[[INS10:.*]] = vector.insert_strided_slice %[[MASK10]], %[[INS01]] {offsets = [8, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
// CHECK: %[[SUB0_2:.*]] = arith.subi %[[ARG0]], %[[C8]] : index
// CHECK: %[[MAX0_4:.*]] = arith.maxsi %[[SUB0_2]], %[[C0]] : index
// CHECK: %[[MIN0_4:.*]] = arith.minsi %[[MAX0_4]], %[[C8]] : index
// CHECK: %[[SUB1_2:.*]] = arith.subi %[[ARG1]], %[[C8]] : index
// CHECK: %[[MAX1_4:.*]] = arith.maxsi %[[SUB1_2]], %[[C0]] : index
// CHECK: %[[MIN1_4:.*]] = arith.minsi %[[MAX1_4]], %[[C8]] : index
// CHECK: %[[MASK11:.*]] = vector.create_mask %[[MIN0_4]], %[[MIN1_4]] : vector<8x8xi1>
// CHECK: %[[INS11:.*]] = vector.insert_strided_slice %[[MASK11]], %[[INS10]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
// CHECK: return %[[INS11]] : vector<16x16xi1>

func.func @vector_create_mask_constant_dim_sizes() -> vector<16x16xi1> {
%cst16 = arith.constant 16 : index
%0 = vector.create_mask %cst16, %cst16 : vector<16x16xi1>
return %0 : vector<16x16xi1>
}

// CHECK-LABEL: func @vector_create_mask_constant_dim_sizes() -> vector<16x16xi1> {
// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<16x16xi1>
// CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<8x8xi1>
// CHECK: %[[S0:.*]] = vector.insert_strided_slice %[[CST_0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
// CHECK: %[[S1:.*]] = vector.insert_strided_slice %[[CST_0]], %[[S0]] {offsets = [0, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
// CHECK: %[[S2:.*]] = vector.insert_strided_slice %[[CST_0]], %[[S1]] {offsets = [8, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
// CHECK: %[[S3:.*]] = vector.insert_strided_slice %[[CST_0]], %[[S2]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
// CHECK: return %[[S3]] : vector<16x16xi1>


func.func @shape_cast_1D(%v: vector<16xf32>) -> vector<2x2x4xf32> {
%0 = vector.shape_cast %v : vector<16xf32> to vector<2x2x4xf32>
Expand Down
6 changes: 6 additions & 0 deletions mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,12 @@ struct TestVectorUnrollingPatterns
.setFilterConstraint([](Operation *op) {
return success(isa<vector::StepOp>(op));
}));
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{8, 8})
.setFilterConstraint([](Operation *op) {
return success(isa<vector::CreateMaskOp>(op));
}));
populateVectorUnrollPatterns(
patterns,
UnrollVectorOptions()
Expand Down