Skip to content

Commit c2c1a22

Browse files
committed
Replace select with mul
1 parent 528f913 commit c2c1a22

File tree

2 files changed

+31
-55
lines changed

2 files changed

+31
-55
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -446,17 +446,16 @@ struct LinearizeVectorSplat final
446446
};
447447

448448
/// This pattern converts the CreateMaskOp to work on a linearized vector.
449-
// The pattern currently supports only 2D masks with a unit outer dimension.
449+
/// It currently supports only 2D masks with a unit outer dimension.
450450
/// Following,
451451
/// vector.create_mask %arg0, %arg1 : vector<1x4xi1>
452452
/// is converted to:
453453
/// %zero = arith.constant 0 : index
454-
/// %cmpi = arith.cmpi sle, %arg0, %zero : index
455-
/// %splat = vector.splat %cmpi : vector<4xi1>
456-
/// %cst = arith.constant dense<false> : vector<4xi1>
457-
/// %mask = vector.create_mask %arg1 : vector<4xi1>
458-
/// %out = arith.select %splat, %cst, %mask : vector<4xi1>
459-
/// %out_1d = vector.shape_cast %out : vector<4xi1> to vector<1x4xi1>
454+
/// %cmpi = arith.cmpi sgt, %arg0, %zero : index
455+
/// %index = arith.index_cast %cmpi : i1 to index
456+
/// %mul = arith.muli %index, %arg1 : index
457+
/// %mask = vector.create_mask %mul : vector<4xi1>
458+
/// %out_1d = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1>
460459
struct LinearizeVectorCreateMask final
461460
: OpConversionPattern<vector::CreateMaskOp> {
462461
using OpConversionPattern::OpConversionPattern;
@@ -483,25 +482,23 @@ struct LinearizeVectorCreateMask final
483482
if (!dstTy)
484483
return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
485484

486-
// Compare the first operand with 0. If it's less than or equal to 0,
487-
// create a zero mask, else strip the first operand and create a mask
488-
// using the second operand.
485+
// Compare the first operand with 0. If it is greater than 0, the
486+
// corresponding mask element is set to true, otherwise false.
487+
// The result of the comparison is then multiplied with
488+
// the second operand of create_mask to get the 1D mask.
489489
auto firstOperand = adaptor.getOperands().front();
490490
auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
491-
auto isZeroOrNegative = rewriter.create<mlir::arith::CmpIOp>(
492-
loc, mlir::arith::CmpIPredicate::sle, firstOperand, zero);
493-
auto isZeroOrNegativeSplat =
494-
rewriter.create<mlir::vector::SplatOp>(loc, dstTy, isZeroOrNegative);
495-
496-
// Use a select operation to choose between the masks.
497-
auto zeroMask = rewriter.create<mlir::arith::ConstantOp>(
498-
loc, dstTy, rewriter.getZeroAttr(dstTy));
499-
auto newMask = rewriter.create<mlir::vector::CreateMaskOp>(
500-
loc, dstTy, adaptor.getOperands().back());
501-
auto result = rewriter.create<mlir::arith::SelectOp>(
502-
loc, isZeroOrNegativeSplat, zeroMask, newMask);
491+
auto isNonZero = rewriter.create<mlir::arith::CmpIOp>(
492+
loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero);
493+
auto isNonZeroIndex = rewriter.create<mlir::arith::IndexCastOp>(
494+
loc, rewriter.getIndexType(), isNonZero);
495+
auto secondOperand = adaptor.getOperands().back();
496+
auto maskSize = rewriter.create<mlir::arith::MulIOp>(
497+
loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand);
503498

504-
rewriter.replaceOp(createMaskOp, result.getResult());
499+
auto newMask = rewriter.create<mlir::vector::CreateMaskOp>(
500+
loc, dstTy, maskSize.getResult());
501+
rewriter.replaceOp(createMaskOp, newMask);
505502
return success();
506503
}
507504
};

mlir/test/Dialect/Vector/linearize.mlir

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -347,39 +347,18 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
347347
}
348348

349349
// -----
350-
// ALL-LABEL: linearize_create_mask
351-
func.func @linearize_create_mask() -> vector<1x16xi1> {
350+
351+
// CHECK-LABEL: linearize_create_mask
352+
// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<1x16xi1>
353+
func.func @linearize_create_mask(%arg0 : index, %arg1 : index) -> vector<1x16xi1> {
354+
352355
// CHECK: %[[C0:.*]] = arith.constant 0 : index
353-
// CHECK: %[[C10:.*]] = arith.constant 10 : index
354-
// CHECK: %[[C0_0:.*]] = arith.constant 0 : index
355-
// CHECK: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index
356-
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<16xi1>
357-
// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<16xi1>
358-
// CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[C10]] : vector<16xi1>
359-
// CHECK: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<16xi1>, vector<16xi1>
360-
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<16xi1> to vector<1x16xi1>
356+
// CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[ARG0]], %[[C0]] : index
357+
// CHECK: %[[INDEXCAST:.*]] = arith.index_cast %[[CMP]] : i1 to index
358+
// CHECK: %[[MULI:.*]] = arith.muli %[[INDEXCAST]], %[[ARG1]] : index
359+
// CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[MULI]] : vector<16xi1>
360+
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[MASK_1D]] : vector<16xi1> to vector<1x16xi1>
361361
// CHECK: return %[[CAST]] : vector<1x16xi1>
362-
%c0 = arith.constant 0 : index
363-
%c10 = arith.constant 10 : index
364-
%0 = vector.create_mask %c0, %c10 : vector<1x16xi1>
362+
%0 = vector.create_mask %arg0, %arg1 : vector<1x16xi1>
365363
return %0 : vector<1x16xi1>
366364
}
367-
368-
// -----
369-
// ALL-LABEL: linearize_scalable_create_mask
370-
func.func @linearize_scalable_create_mask() -> vector<1x[16]xi1> {
371-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
372-
// CHECK: %[[C10:.*]] = arith.constant 10 : index
373-
// CHECK: %[[C0_0:.*]] = arith.constant 0 : index
374-
// CHECK: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index
375-
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<[16]xi1>
376-
// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<[16]xi1>
377-
// CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[C10]] : vector<[16]xi1>
378-
// CHECK: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<[16]xi1>, vector<[16]xi1>
379-
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<[16]xi1> to vector<1x[16]xi1>
380-
// CHECK: return %[[CAST]] : vector<1x[16]xi1>
381-
%c0 = arith.constant 0 : index
382-
%c10 = arith.constant 10 : index
383-
%0 = vector.create_mask %c0, %c10 : vector<1x[16]xi1>
384-
return %0 : vector<1x[16]xi1>
385-
}

0 commit comments

Comments
 (0)