Skip to content

Commit 3a83e2d

Browse files
committed
Add linearization pattern for vector.create_mask
1 parent 4efcc52 commit 3a83e2d

File tree

3 files changed

+97
-4
lines changed

3 files changed

+97
-4
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,64 @@ struct LinearizeVectorSplat final
445445
}
446446
};
447447

448+
/// This pattern converts the CreateMaskOp to work on a
449+
/// linearized vector. The pattern currently
450+
/// supports only 2D masks with a unit outer dimension.
451+
/// Following,
452+
/// vector.create_mask %dims : vector<1x4xi1>
453+
/// is converted to:
454+
/// %out_1d = vector.create_mask %dims : vector<4xi1>
455+
/// %out_nd = vector.shape_cast %out_1d : vector<4xi1> to vector<1x4xi1>
456+
struct LinearizeVectorCreateMask final
457+
: OpConversionPattern<vector::CreateMaskOp> {
458+
using OpConversionPattern::OpConversionPattern;
459+
460+
LinearizeVectorCreateMask(const TypeConverter &typeConverter,
461+
MLIRContext *context, PatternBenefit benefit = 1)
462+
: OpConversionPattern(typeConverter, context, benefit) {}
463+
464+
LogicalResult
465+
matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
466+
ConversionPatternRewriter &rewriter) const override {
467+
auto srcTy = createMaskOp.getType();
468+
auto srcShape = srcTy.getShape();
469+
if (srcShape.size() != 2)
470+
return rewriter.notifyMatchFailure(createMaskOp,
471+
"only 2D mask is supported.");
472+
473+
if (srcShape[0] != 1)
474+
return rewriter.notifyMatchFailure(
475+
createMaskOp, "only unit outer dimension is supported.");
476+
477+
auto dstTy = getTypeConverter()->convertType(srcTy);
478+
if (!dstTy)
479+
return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
480+
481+
// Compare the first operand with 0. If it's less than or equal to 0,
482+
// create a zero mask, else strip the first operand and create a mask
483+
// using the second operand.
484+
auto firstOperand = adaptor.getOperands().front();
485+
auto zero =
486+
rewriter.create<mlir::arith::ConstantIndexOp>(createMaskOp.getLoc(), 0);
487+
auto isZeroOrNegative = rewriter.create<mlir::arith::CmpIOp>(
488+
createMaskOp.getLoc(), mlir::arith::CmpIPredicate::sle, firstOperand,
489+
zero);
490+
auto isZeroOrNegativeSplat = rewriter.create<mlir::vector::SplatOp>(
491+
createMaskOp.getLoc(), dstTy, isZeroOrNegative);
492+
493+
// Use a select operation to choose between the masks.
494+
auto zeroMask = rewriter.create<mlir::arith::ConstantOp>(
495+
createMaskOp.getLoc(), dstTy, rewriter.getZeroAttr(dstTy));
496+
auto newMask = rewriter.create<mlir::vector::CreateMaskOp>(
497+
createMaskOp.getLoc(), dstTy, adaptor.getOperands().back());
498+
auto result = rewriter.create<mlir::arith::SelectOp>(
499+
createMaskOp.getLoc(), isZeroOrNegativeSplat, zeroMask, newMask);
500+
501+
rewriter.replaceOp(createMaskOp, result.getResult());
502+
return success();
503+
}
504+
};
505+
448506
} // namespace
449507

450508
/// Return true if the operation `op` does not support scalable vectors and
@@ -530,9 +588,10 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
530588
void mlir::vector::populateVectorLinearizeBasePatterns(
531589
const TypeConverter &typeConverter, const ConversionTarget &target,
532590
RewritePatternSet &patterns) {
533-
patterns.add<LinearizeConstantLike, LinearizeVectorizable,
534-
LinearizeVectorBitCast, LinearizeVectorSplat>(
535-
typeConverter, patterns.getContext());
591+
patterns
592+
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
593+
LinearizeVectorSplat, LinearizeVectorCreateMask>(
594+
typeConverter, patterns.getContext());
536595
}
537596

538597
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(

mlir/test/Dialect/Vector/linearize.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,3 +447,36 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
447447
%0 = vector.splat %arg0 : vector<4x[2]xi32>
448448
return %0 : vector<4x[2]xi32>
449449
}
450+
451+
// -----
452+
// ALL-LABEL: test_create_mask
453+
func.func @test_create_mask() -> vector<1x16xi1> {
454+
// DEFAULT: %[[C0:.*]] = arith.constant 0 : index
455+
// BW-128: %[[C0:.*]] = arith.constant 0 : index
456+
// DEFAULT: %[[C20:.*]] = arith.constant 20 : index
457+
// BW-128: %[[C20:.*]] = arith.constant 20 : index
458+
// DEFAULT: %[[C0_0:.*]] = arith.constant 0 : index
459+
// BW-128: %[[C0_0:.*]] = arith.constant 0 : index
460+
// DEFAULT: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index
461+
// BW-128: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index
462+
// DEFAULT: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<16xi1>
463+
// BW-128: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<16xi1>
464+
// DEFAULT: %[[CST:.*]] = arith.constant dense<false> : vector<16xi1>
465+
// BW-128: %[[CST:.*]] = arith.constant dense<false> : vector<16xi1>
466+
// DEFAULT: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<16xi1>
467+
// BW-128: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<16xi1>
468+
// DEFAULT: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<16xi1>, vector<16xi1>
469+
// BW-128: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<16xi1>
470+
// DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<16xi1> to vector<1x16xi1>
471+
// BW-128: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<16xi1> to vector<1x16xi1>
472+
// DEFAULT: return %[[CAST]] : vector<1x16xi1>
473+
// BW-128: return %[[CAST]] : vector<1x16xi1>
474+
475+
// BW-0: %[[C0:.*]] = arith.constant 0 : index
476+
// BW-0: %[[C20:.*]] = arith.constant 20 : index
477+
// BW-0: %[[MASK:.*]] = vector.create_mask %[[C0]], %[[C20]] : vector<1x16xi1>
478+
%c0 = arith.constant 0 : index
479+
%c20 = arith.constant 20 : index
480+
%0 = vector.create_mask %c0, %c20 : vector<1x16xi1>
481+
return %0 : vector<1x16xi1>
482+
}

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -973,7 +973,8 @@ struct TestVectorLinearize final
973973
return "Linearizes ND vectors for N >= 2 into 1D vectors";
974974
}
975975
void getDependentDialects(DialectRegistry &registry) const override {
976-
registry.insert<vector::VectorDialect>();
976+
registry.insert<vector::VectorDialect, memref::MemRefDialect,
977+
arith::ArithDialect>();
977978
}
978979

979980
void runOnOperation() override {

0 commit comments

Comments
 (0)