Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
mlir::Value mask = sum.getMask();
mlir::Value dim = sum.getDim();
int64_t dimVal = fir::getIntIfConstant(dim).value_or(0);
assert(dimVal > 0 && "DIM must be present and a positive constant");
mlir::Value resultShape, dimExtent;
std::tie(resultShape, dimExtent) =
genResultShape(loc, builder, array, dimVal);
Expand Down Expand Up @@ -235,6 +234,9 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
mlir::Value inShape = hlfir::genShape(loc, builder, array);
llvm::SmallVector<mlir::Value> inExtents =
hlfir::getExplicitExtentsFromShape(inShape, builder);
assert(dimVal > 0 && dimVal <= static_cast<int64_t>(inExtents.size()) &&
"DIM must be present and a positive constant not exceeding "
"the array's rank");
if (inShape.getUses().empty())
inShape.getDefiningOp()->erase();

Expand Down Expand Up @@ -348,12 +350,20 @@ class SimplifyHLFIRIntrinsics
// would avoid creating a temporary for the elemental array expression.
target.addDynamicallyLegalOp<hlfir::SumOp>([](hlfir::SumOp sum) {
if (mlir::Value dim = sum.getDim()) {
if (fir::getIntIfConstant(dim)) {
if (auto dimVal = fir::getIntIfConstant(dim)) {
if (!fir::isa_trivial(sum.getType())) {
// Ignore the case SUM(a, DIM=X), where 'a' is a 1D array.
// It is only legal when X is 1, and it should probably be
// canonicalized into SUM(a).
return false;
fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
hlfir::getFortranElementOrSequenceType(
sum.getArray().getType()));
if (*dimVal > 0 && *dimVal <= arrayTy.getDimension()) {
// Ignore SUMs with illegal DIM values.
// They may appear in dead code,
// and they do not have to be converted.
return false;
}
}
}
}
Expand Down
18 changes: 18 additions & 0 deletions flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
Original file line number Diff line number Diff line change
Expand Up @@ -411,3 +411,21 @@ func.func @sum_non_const_dim(%arg0: !fir.box<!fir.array<3xi32>>, %dim: i32) {
// CHECK: %[[VAL_2:.*]] = hlfir.sum %[[VAL_0]] dim %[[VAL_1]] : (!fir.box<!fir.array<3xi32>>, i32) -> i32
// CHECK: return
// CHECK: }

// negative: invalid dim==0
func.func @sum_invalid_dim0(%arg0: !hlfir.expr<2x3xi32>) {
%cst = arith.constant 0 : i32
%res = hlfir.sum %arg0 dim %cst : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32>
return
}
// CHECK-LABEL: func.func @sum_invalid_dim0(
// CHECK: hlfir.sum %{{.*}} dim %{{.*}} : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32>

// negative: invalid dim>rank
func.func @sum_invalid_dim_big(%arg0: !hlfir.expr<2x3xi32>) {
%cst = arith.constant 3 : i32
%res = hlfir.sum %arg0 dim %cst : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32>
return
}
// CHECK-LABEL: func.func @sum_invalid_dim_big(
// CHECK: hlfir.sum %{{.*}} dim %{{.*}} : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32>
Loading