diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp index 7dbbf015182f3..a7540fcc386fa 100644 --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -225,13 +225,13 @@ class AffineParallelLowering : public OpRewritePattern { // initialization of the result values. Attribute reduction = std::get<0>(pair); Type resultType = std::get<1>(pair); - std::optional reductionOp = - arith::symbolizeAtomicRMWKind( - static_cast(cast(reduction).getInt())); - assert(reductionOp && "Reduction operation cannot be of None Type"); - arith::AtomicRMWKind reductionOpValue = *reductionOp; - identityVals.push_back( - arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc)); + arith::AtomicRMWKind reductionOpValue = *(arith::symbolizeAtomicRMWKind( + static_cast(cast(reduction).getInt()))); + auto reductionOp = + arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc); + if (!reductionOp) + return failure(); + identityVals.push_back(reductionOp); } parOp = rewriter.create( loc, lowerBoundTuple, upperBoundTuple, steps, identityVals, diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 1002719f0b89f..b3bb0f03995d5 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -2505,6 +2505,8 @@ Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType, bool useOnlyFiniteValue) { auto attr = getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue); + if (!attr) + return nullptr; return builder.create(loc, attr); } @@ -2525,10 +2527,6 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder, return builder.create(loc, lhs, rhs); case AtomicRMWKind::minimumf: return builder.create(loc, lhs, rhs); - case AtomicRMWKind::maxnumf: - return builder.create(loc, lhs, rhs); - case AtomicRMWKind::minnumf: - return builder.create(loc, lhs, rhs); case AtomicRMWKind::maxs: return builder.create(loc, lhs, rhs); case AtomicRMWKind::mins: