@@ -446,17 +446,16 @@ struct LinearizeVectorSplat final
446446};
447447
448448// / This pattern converts the CreateMaskOp to work on a linearized vector.
449- // The pattern currently supports only 2D masks with a unit outer dimension.
449+ // / It currently supports only 2D masks with a unit outer dimension.
450450// / Following,
451451// / vector.create_mask %arg0, %arg1 : vector<1x4xi1>
452452// / is converted to:
453453// / %zero = arith.constant 0 : index
454- // / %cmpi = arith.cmpi sle, %arg0, %zero : index
455- // / %splat = vector.splat %cmpi : vector<4xi1>
456- // / %cst = arith.constant dense<false> : vector<4xi1>
457- // / %mask = vector.create_mask %arg1 : vector<4xi1>
458- // / %out = arith.select %splat, %cst, %mask : vector<4xi1>
459- // / %out_1d = vector.shape_cast %out : vector<4xi1> to vector<1x4xi1>
454+ // / %cmpi = arith.cmpi sgt, %arg0, %zero : index
455+ // / %index = arith.index_cast %cmpi : i1 to index
456+ // / %mul = arith.muli %index, %arg1 : index
457+ // / %mask = vector.create_mask %mul : vector<4xi1>
458+ // / %out_1d = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1>
460459struct LinearizeVectorCreateMask final
461460 : OpConversionPattern<vector::CreateMaskOp> {
462461 using OpConversionPattern::OpConversionPattern;
@@ -483,25 +482,23 @@ struct LinearizeVectorCreateMask final
483482 if (!dstTy)
484483 return rewriter.notifyMatchFailure (createMaskOp, " cannot convert type." );
485484
486- // Compare the first operand with 0. If it's less than or equal to 0,
487- // create a zero mask, else strip the first operand and create a mask
488- // using the second operand.
485+ // Compare the first operand with 0. If it is greater than 0, the
486+ // corresponding mask element is set to true, otherwise false.
487+ // The result of the comparison is then multiplied with
488+ // the second operand of create_mask to get the 1D mask.
489489 auto firstOperand = adaptor.getOperands ().front ();
490490 auto zero = rewriter.create <mlir::arith::ConstantIndexOp>(loc, 0 );
491- auto isZeroOrNegative = rewriter.create <mlir::arith::CmpIOp>(
492- loc, mlir::arith::CmpIPredicate::sle, firstOperand, zero);
493- auto isZeroOrNegativeSplat =
494- rewriter.create <mlir::vector::SplatOp>(loc, dstTy, isZeroOrNegative);
495-
496- // Use a select operation to choose between the masks.
497- auto zeroMask = rewriter.create <mlir::arith::ConstantOp>(
498- loc, dstTy, rewriter.getZeroAttr (dstTy));
499- auto newMask = rewriter.create <mlir::vector::CreateMaskOp>(
500- loc, dstTy, adaptor.getOperands ().back ());
501- auto result = rewriter.create <mlir::arith::SelectOp>(
502- loc, isZeroOrNegativeSplat, zeroMask, newMask);
491+ auto isNonZero = rewriter.create <mlir::arith::CmpIOp>(
492+ loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero);
493+ auto isNonZeroIndex = rewriter.create <mlir::arith::IndexCastOp>(
494+ loc, rewriter.getIndexType (), isNonZero);
495+ auto secondOperand = adaptor.getOperands ().back ();
496+ auto maskSize = rewriter.create <mlir::arith::MulIOp>(
497+ loc, rewriter.getIndexType (), isNonZeroIndex, secondOperand);
503498
504- rewriter.replaceOp (createMaskOp, result.getResult ());
499+ auto newMask = rewriter.create <mlir::vector::CreateMaskOp>(
500+ loc, dstTy, maskSize.getResult ());
501+ rewriter.replaceOp (createMaskOp, newMask);
505502 return success ();
506503 }
507504};
0 commit comments