Skip to content

Commit 0dc71be

Browse files
[MLIR][Affine] Fixed crash with invalid reduction op (Issue llvm#64073)
Updated AffineParallelLowering to check if reduction op value being returned is valid else return failure in matchAndRewrite Updated ArithOps.cpp getIdentityValue method to return nullptr if op is not a valid reduction op Code cleanup in ArithOps.cpp getReductionOp method; removed cases maxnumf and minnumf as not valid reduction ops Reporting Issue is llvm#64073
1 parent a8913f8 commit 0dc71be

File tree

2 files changed

+9
-11
lines changed

2 files changed

+9
-11
lines changed

mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -225,13 +225,13 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
225225
// initialization of the result values.
226226
Attribute reduction = std::get<0>(pair);
227227
Type resultType = std::get<1>(pair);
228-
std::optional<arith::AtomicRMWKind> reductionOp =
229-
arith::symbolizeAtomicRMWKind(
230-
static_cast<uint64_t>(cast<IntegerAttr>(reduction).getInt()));
231-
assert(reductionOp && "Reduction operation cannot be of None Type");
232-
arith::AtomicRMWKind reductionOpValue = *reductionOp;
233-
identityVals.push_back(
234-
arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc));
228+
arith::AtomicRMWKind reductionOpValue = *(arith::symbolizeAtomicRMWKind(
229+
static_cast<uint64_t>(cast<IntegerAttr>(reduction).getInt())));
230+
auto reductionOp =
231+
arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc);
232+
if (!reductionOp)
233+
return failure();
234+
identityVals.push_back(reductionOp);
235235
}
236236
parOp = rewriter.create<scf::ParallelOp>(
237237
loc, lowerBoundTuple, upperBoundTuple, steps, identityVals,

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2505,6 +2505,8 @@ Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
25052505
bool useOnlyFiniteValue) {
25062506
auto attr =
25072507
getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue);
2508+
if (!attr)
2509+
return nullptr;
25082510
return builder.create<arith::ConstantOp>(loc, attr);
25092511
}
25102512

@@ -2525,10 +2527,6 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
25252527
return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
25262528
case AtomicRMWKind::minimumf:
25272529
return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
2528-
case AtomicRMWKind::maxnumf:
2529-
return builder.create<arith::MaxNumFOp>(loc, lhs, rhs);
2530-
case AtomicRMWKind::minnumf:
2531-
return builder.create<arith::MinNumFOp>(loc, lhs, rhs);
25322530
case AtomicRMWKind::maxs:
25332531
return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
25342532
case AtomicRMWKind::mins:

0 commit comments

Comments
 (0)