@@ -108,7 +108,6 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
108108 mlir::Value mask = sum.getMask ();
109109 mlir::Value dim = sum.getDim ();
110110 int64_t dimVal = fir::getIntIfConstant (dim).value_or (0 );
111- assert (dimVal > 0 && " DIM must be present and a positive constant" );
112111 mlir::Value resultShape, dimExtent;
113112 std::tie (resultShape, dimExtent) =
114113 genResultShape (loc, builder, array, dimVal);
@@ -235,6 +234,9 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
235234 mlir::Value inShape = hlfir::genShape (loc, builder, array);
236235 llvm::SmallVector<mlir::Value> inExtents =
237236 hlfir::getExplicitExtentsFromShape (inShape, builder);
237+ assert (dimVal > 0 && dimVal <= static_cast <int64_t >(inExtents.size ()) &&
238+ " DIM must be present and a positive constant not exceeding "
239+ " the array's rank" );
238240 if (inShape.getUses ().empty ())
239241 inShape.getDefiningOp ()->erase ();
240242
@@ -348,12 +350,20 @@ class SimplifyHLFIRIntrinsics
348350 // would avoid creating a temporary for the elemental array expression.
349351 target.addDynamicallyLegalOp <hlfir::SumOp>([](hlfir::SumOp sum) {
350352 if (mlir::Value dim = sum.getDim ()) {
351- if (fir::getIntIfConstant (dim)) {
353+ if (auto dimVal = fir::getIntIfConstant (dim)) {
352354 if (!fir::isa_trivial (sum.getType ())) {
353355 // Ignore the case SUM(a, DIM=X), where 'a' is a 1D array.
354356 // It is only legal when X is 1, and it should probably be
355357 // canonicalized into SUM(a).
356- return false ;
358+ fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
359+ hlfir::getFortranElementOrSequenceType (
360+ sum.getArray ().getType ()));
361+ if (*dimVal > 0 && *dimVal <= arrayTy.getDimension ()) {
362+ // Ignore SUMs with illegal DIM values.
363+ // They may appear in dead code,
364+ // and they do not have to be converted.
365+ return false ;
366+ }
357367 }
358368 }
359369 }
0 commit comments