From 0dc71bede28b26bcfcc4deb29e3eb8bc38e63975 Mon Sep 17 00:00:00 2001 From: swapnilghanshyala Date: Thu, 26 Oct 2023 13:49:54 +0530 Subject: [PATCH] [MLIR][Affine] Fixed crash with invalid reduction op (Issue #64073) Updated AffineParallelLowering to check if reduction op value being returned is valid else return failure in matchAndRewrite Updated ArithOps.cpp getIdentityValue method to return nullptr if op is not a valid reduction op Code cleanup in ArithOps.cpp getReductionOp method; removed cases maxnumf and minnumf as not valid reduction ops Reporting Issue is #64073 --- .../AffineToStandard/AffineToStandard.cpp | 14 +++++++------- mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 6 ++---- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp index 7dbbf015182f3..a7540fcc386fa 100644 --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -225,13 +225,13 @@ class AffineParallelLowering : public OpRewritePattern { // initialization of the result values. Attribute reduction = std::get<0>(pair); Type resultType = std::get<1>(pair); - std::optional reductionOp = - arith::symbolizeAtomicRMWKind( - static_cast(cast(reduction).getInt())); - assert(reductionOp && "Reduction operation cannot be of None Type"); - arith::AtomicRMWKind reductionOpValue = *reductionOp; - identityVals.push_back( - arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc)); + arith::AtomicRMWKind reductionOpValue = *(arith::symbolizeAtomicRMWKind( + static_cast(cast(reduction).getInt()))); + auto reductionOp = + arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc); + if (!reductionOp) + return failure(); + identityVals.push_back(reductionOp); } parOp = rewriter.create( loc, lowerBoundTuple, upperBoundTuple, steps, identityVals, diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 1002719f0b89f..b3bb0f03995d5 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -2505,6 +2505,8 @@ Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType, bool useOnlyFiniteValue) { auto attr = getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue); + if (!attr) + return nullptr; return builder.create(loc, attr); } @@ -2525,10 +2527,6 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder, return builder.create(loc, lhs, rhs); case AtomicRMWKind::minimumf: return builder.create(loc, lhs, rhs); - case AtomicRMWKind::maxnumf: - return builder.create(loc, lhs, rhs); - case AtomicRMWKind::minnumf: - return builder.create(loc, lhs, rhs); case AtomicRMWKind::maxs: return builder.create(loc, lhs, rhs); case AtomicRMWKind::mins: