@@ -43,20 +43,17 @@ using namespace mlir::vector;
4343struct VectorShape {
4444 ArrayRef<int64_t > sizes;
4545 ArrayRef<bool > scalableFlags;
46-
47- bool empty () const { return sizes.empty (); }
4846};
4947
50- // Returns vector shape if the type is a vector. Returns an empty shape if it is
51- // not a vector.
52- static VectorShape vectorShape (Type type) {
53- auto vectorType = dyn_cast<VectorType>(type);
54- return vectorType
55- ? VectorShape{vectorType.getShape (), vectorType.getScalableDims ()}
56- : VectorShape{};
48+ // Returns vector shape if the type is a vector, otherwise return nullopt.
49+ static std::optional<VectorShape> vectorShape (Type type) {
50+ if (auto vectorType = dyn_cast<VectorType>(type)) {
51+ return VectorShape{vectorType.getShape (), vectorType.getScalableDims ()};
52+ }
53+ return std::nullopt ;
5754}
5855
59- static VectorShape vectorShape (Value value) {
56+ static std::optional< VectorShape> vectorShape (Value value) {
6057 return vectorShape (value.getType ());
6158}
6259
@@ -65,19 +62,18 @@ static VectorShape vectorShape(Value value) {
6562// ----------------------------------------------------------------------------//
6663
6764// Broadcasts scalar type into vector type (iff shape is non-scalar).
68- static Type broadcast (Type type, VectorShape shape) {
65+ static Type broadcast (Type type, std::optional< VectorShape> shape) {
6966 assert (!isa<VectorType>(type) && " must be scalar type" );
70- return !shape.empty ()
71- ? VectorType::get (shape.sizes , type, shape.scalableFlags )
72- : type;
67+ return shape ? VectorType::get (shape->sizes , type, shape->scalableFlags )
68+ : type;
7369}
7470
7571// Broadcasts scalar value into vector (iff shape is non-scalar).
7672static Value broadcast (ImplicitLocOpBuilder &builder, Value value,
77- VectorShape shape) {
73+ std::optional< VectorShape> shape) {
7874 assert (!isa<VectorType>(value.getType ()) && " must be scalar value" );
7975 auto type = broadcast (value.getType (), shape);
80- return ! shape. empty () ? builder.create <BroadcastOp>(type, value) : value;
76+ return shape ? builder.create <BroadcastOp>(type, value) : value;
8177}
8278
8379// ----------------------------------------------------------------------------//
@@ -227,7 +223,7 @@ static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
227223static std::pair<Value, Value> frexp (ImplicitLocOpBuilder &builder, Value arg,
228224 bool isPositive = false ) {
229225 assert (getElementTypeOrSelf (arg).isF32 () && " arg must be f32 type" );
230- VectorShape shape = vectorShape (arg);
226+ std::optional< VectorShape> shape = vectorShape (arg);
231227
232228 auto bcast = [&](Value value) -> Value {
233229 return broadcast (builder, value, shape);
@@ -267,7 +263,7 @@ static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
267263// Computes exp2 for an i32 argument.
268264static Value exp2I32 (ImplicitLocOpBuilder &builder, Value arg) {
269265 assert (getElementTypeOrSelf (arg).isInteger (32 ) && " arg must be i32 type" );
270- VectorShape shape = vectorShape (arg);
266+ std::optional< VectorShape> shape = vectorShape (arg);
271267
272268 auto bcast = [&](Value value) -> Value {
273269 return broadcast (builder, value, shape);
@@ -293,7 +289,7 @@ Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
293289 Type elementType = getElementTypeOrSelf (x);
294290 assert ((elementType.isF32 () || elementType.isF16 ()) &&
295291 " x must be f32 or f16 type" );
296- VectorShape shape = vectorShape (x);
292+ std::optional< VectorShape> shape = vectorShape (x);
297293
298294 if (coeffs.empty ())
299295 return broadcast (builder, floatCst (builder, 0 .0f , elementType), shape);
@@ -391,7 +387,7 @@ AtanApproximation::matchAndRewrite(math::AtanOp op,
391387 if (!getElementTypeOrSelf (operand).isF32 ())
392388 return rewriter.notifyMatchFailure (op, " unsupported operand type" );
393389
394- VectorShape shape = vectorShape (op.getOperand ());
390+ std::optional< VectorShape> shape = vectorShape (op.getOperand ());
395391
396392 ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
397393 Value abs = builder.create <math::AbsFOp>(operand);
@@ -490,7 +486,7 @@ Atan2Approximation::matchAndRewrite(math::Atan2Op op,
490486 return rewriter.notifyMatchFailure (op, " unsupported operand type" );
491487
492488 ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
493- VectorShape shape = vectorShape (op.getResult ());
489+ std::optional< VectorShape> shape = vectorShape (op.getResult ());
494490
495491 // Compute atan in the valid range.
496492 auto div = builder.create <arith::DivFOp>(y, x);
@@ -556,7 +552,7 @@ TanhApproximation::matchAndRewrite(math::TanhOp op,
556552 if (!getElementTypeOrSelf (op.getOperand ()).isF32 ())
557553 return rewriter.notifyMatchFailure (op, " unsupported operand type" );
558554
559- VectorShape shape = vectorShape (op.getOperand ());
555+ std::optional< VectorShape> shape = vectorShape (op.getOperand ());
560556
561557 ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
562558 auto bcast = [&](Value value) -> Value {
@@ -644,7 +640,7 @@ LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
644640 if (!getElementTypeOrSelf (op.getOperand ()).isF32 ())
645641 return rewriter.notifyMatchFailure (op, " unsupported operand type" );
646642
647- VectorShape shape = vectorShape (op.getOperand ());
643+ std::optional< VectorShape> shape = vectorShape (op.getOperand ());
648644
649645 ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
650646 auto bcast = [&](Value value) -> Value {
@@ -791,7 +787,7 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op,
791787 if (!getElementTypeOrSelf (op.getOperand ()).isF32 ())
792788 return rewriter.notifyMatchFailure (op, " unsupported operand type" );
793789
794- VectorShape shape = vectorShape (op.getOperand ());
790+ std::optional< VectorShape> shape = vectorShape (op.getOperand ());
795791
796792 ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
797793 auto bcast = [&](Value value) -> Value {
@@ -846,7 +842,7 @@ AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op,
846842 if (!(elementType.isF32 () || elementType.isF16 ()))
847843 return rewriter.notifyMatchFailure (op,
848844 " only f32 and f16 type is supported." );
849- VectorShape shape = vectorShape (operand);
845+ std::optional< VectorShape> shape = vectorShape (operand);
850846
851847 ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
852848 auto bcast = [&](Value value) -> Value {
@@ -941,7 +937,7 @@ AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op,
941937 if (!(elementType.isF32 () || elementType.isF16 ()))
942938 return rewriter.notifyMatchFailure (op,
943939 " only f32 and f16 type is supported." );
944- VectorShape shape = vectorShape (operand);
940+ std::optional< VectorShape> shape = vectorShape (operand);
945941
946942 ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
947943 auto bcast = [&](Value value) -> Value {
@@ -1019,7 +1015,7 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
10191015 if (!(elementType.isF32 () || elementType.isF16 ()))
10201016 return rewriter.notifyMatchFailure (op,
10211017 " only f32 and f16 type is supported." );
1022- VectorShape shape = vectorShape (operand);
1018+ std::optional< VectorShape> shape = vectorShape (operand);
10231019
10241020 ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
10251021 auto bcast = [&](Value value) -> Value {
@@ -1128,8 +1124,9 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
11281124
11291125namespace {
11301126
1131- Value clampWithNormals (ImplicitLocOpBuilder &builder, const VectorShape shape,
1132- Value value, float lowerBound, float upperBound) {
1127+ Value clampWithNormals (ImplicitLocOpBuilder &builder,
1128+ const std::optional<VectorShape> shape, Value value,
1129+ float lowerBound, float upperBound) {
11331130 assert (!std::isnan (lowerBound));
11341131 assert (!std::isnan (upperBound));
11351132
@@ -1320,7 +1317,7 @@ ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
13201317 if (!getElementTypeOrSelf (op.getOperand ()).isF32 ())
13211318 return rewriter.notifyMatchFailure (op, " unsupported operand type" );
13221319
1323- VectorShape shape = vectorShape (op.getOperand ());
1320+ std::optional< VectorShape> shape = vectorShape (op.getOperand ());
13241321
13251322 ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
13261323 auto bcast = [&](Value value) -> Value {
@@ -1390,7 +1387,7 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
13901387 if (!getElementTypeOrSelf (op.getOperand ()).isF32 ())
13911388 return rewriter.notifyMatchFailure (op, " unsupported operand type" );
13921389
1393- VectorShape shape = vectorShape (op.getOperand ());
1390+ std::optional< VectorShape> shape = vectorShape (op.getOperand ());
13941391
13951392 ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
13961393 auto bcast = [&](Value value) -> Value {
@@ -1517,7 +1514,7 @@ CbrtApproximation::matchAndRewrite(math::CbrtOp op,
15171514 return rewriter.notifyMatchFailure (op, " unsupported operand type" );
15181515
15191516 ImplicitLocOpBuilder b (op->getLoc (), rewriter);
1520- VectorShape shape = vectorShape (operand);
1517+ std::optional< VectorShape> shape = vectorShape (operand);
15211518
15221519 Type floatTy = getElementTypeOrSelf (operand.getType ());
15231520 Type intTy = b.getIntegerType (floatTy.getIntOrFloatBitWidth ());
@@ -1606,10 +1603,10 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
16061603 if (!getElementTypeOrSelf (op.getOperand ()).isF32 ())
16071604 return rewriter.notifyMatchFailure (op, " unsupported operand type" );
16081605
1609- VectorShape shape = vectorShape (op.getOperand ());
1606+ std::optional< VectorShape> shape = vectorShape (op.getOperand ());
16101607
16111608 // Only support already-vectorized rsqrt's.
1612- if (shape.empty () || shape. sizes .back () % 8 != 0 )
1609+ if (! shape || shape-> sizes .empty () || shape-> sizes .back () % 8 != 0 )
16131610 return rewriter.notifyMatchFailure (op, " unsupported operand type" );
16141611
16151612 ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
0 commit comments