diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp index 0c34c8221aeda..ace63a970db93 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp @@ -108,7 +108,6 @@ class SumAsElementalConversion : public mlir::OpRewritePattern { mlir::Value mask = sum.getMask(); mlir::Value dim = sum.getDim(); int64_t dimVal = fir::getIntIfConstant(dim).value_or(0); - assert(dimVal > 0 && "DIM must be present and a positive constant"); mlir::Value resultShape, dimExtent; std::tie(resultShape, dimExtent) = genResultShape(loc, builder, array, dimVal); @@ -235,6 +234,9 @@ class SumAsElementalConversion : public mlir::OpRewritePattern { mlir::Value inShape = hlfir::genShape(loc, builder, array); llvm::SmallVector inExtents = hlfir::getExplicitExtentsFromShape(inShape, builder); + assert(dimVal > 0 && dimVal <= static_cast(inExtents.size()) && + "DIM must be present and a positive constant not exceeding " + "the array's rank"); if (inShape.getUses().empty()) inShape.getDefiningOp()->erase(); @@ -348,12 +350,20 @@ class SimplifyHLFIRIntrinsics // would avoid creating a temporary for the elemental array expression. target.addDynamicallyLegalOp([](hlfir::SumOp sum) { if (mlir::Value dim = sum.getDim()) { - if (fir::getIntIfConstant(dim)) { + if (auto dimVal = fir::getIntIfConstant(dim)) { if (!fir::isa_trivial(sum.getType())) { // Ignore the case SUM(a, DIM=X), where 'a' is a 1D array. // It is only legal when X is 1, and it should probably be // canonicalized into SUM(a). - return false; + fir::SequenceType arrayTy = mlir::cast( + hlfir::getFortranElementOrSequenceType( + sum.getArray().getType())); + if (*dimVal > 0 && *dimVal <= arrayTy.getDimension()) { + // Ignore SUMs with illegal DIM values. + // They may appear in dead code, + // and they do not have to be converted. + return false; + } } } } diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir index 703b6673154f3..313e54d5d0c4a 100644 --- a/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir +++ b/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir @@ -411,3 +411,21 @@ func.func @sum_non_const_dim(%arg0: !fir.box>, %dim: i32) { // CHECK: %[[VAL_2:.*]] = hlfir.sum %[[VAL_0]] dim %[[VAL_1]] : (!fir.box>, i32) -> i32 // CHECK: return // CHECK: } + +// negative: invalid dim==0 +func.func @sum_invalid_dim0(%arg0: !hlfir.expr<2x3xi32>) { + %cst = arith.constant 0 : i32 + %res = hlfir.sum %arg0 dim %cst : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32> + return +} +// CHECK-LABEL: func.func @sum_invalid_dim0( +// CHECK: hlfir.sum %{{.*}} dim %{{.*}} : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32> + +// negative: invalid dim>rank +func.func @sum_invalid_dim_big(%arg0: !hlfir.expr<2x3xi32>) { + %cst = arith.constant 3 : i32 + %res = hlfir.sum %arg0 dim %cst : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32> + return +} +// CHECK-LABEL: func.func @sum_invalid_dim_big( +// CHECK: hlfir.sum %{{.*}} dim %{{.*}} : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32>