From 196d6de06e1c9b18db5b5bb06fa9ad05d068ac54 Mon Sep 17 00:00:00 2001 From: Slava Zakharin Date: Tue, 3 Dec 2024 14:21:11 -0800 Subject: [PATCH 1/3] [flang] Expand SUM(DIM=CONSTANT) into an hlfir.elemental. An array SUM with the specified constant DIM argument may be expanded into hlfir.elemental with a reduction loop inside it processing all elements of the specified dimension. The expansion allows further optimization of the cases like `A=SUM(B+1,DIM=1)` in the optimized bufferization pass (given that it can prove there are no read/write conflicts). --- .../Transforms/SimplifyHLFIRIntrinsics.cpp | 204 ++++++++++ .../HLFIR/simplify-hlfir-intrinsics-sum.fir | 361 ++++++++++++++++++ 2 files changed, 565 insertions(+) create mode 100644 flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp index 60b06437e6a98..35dc881e880df 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp @@ -10,6 +10,7 @@ // into the calling function. //===----------------------------------------------------------------------===// +#include "flang/Optimizer/Builder/Complex.h" #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Builder/HLFIRTools.h" #include "flang/Optimizer/Dialect/FIRDialect.h" @@ -90,6 +91,190 @@ class TransposeAsElementalConversion } }; +// Expand the SUM(DIM=CONSTANT) operation into . +class SumAsElementalConversion : public mlir::OpRewritePattern { +public: + using mlir::OpRewritePattern::OpRewritePattern; + + llvm::LogicalResult + matchAndRewrite(hlfir::SumOp sum, + mlir::PatternRewriter &rewriter) const override { + mlir::Location loc = sum.getLoc(); + fir::FirOpBuilder builder{rewriter, sum.getOperation()}; + hlfir::ExprType expr = mlir::dyn_cast(sum.getType()); + assert(expr && "expected an expression type for the result of hlfir.sum"); + mlir::Type elementType = expr.getElementType(); + hlfir::Entity array = hlfir::Entity{sum.getArray()}; + mlir::Value mask = sum.getMask(); + mlir::Value dim = sum.getDim(); + int64_t dimVal = fir::getIntIfConstant(dim).value_or(0); + assert(dimVal > 0 && "DIM must be present and a positive constant"); + mlir::Value resultShape, dimExtent; + std::tie(resultShape, dimExtent) = + genResultShape(loc, builder, array, dimVal); + + auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder, + mlir::ValueRange inputIndices) -> hlfir::Entity { + // Loop over all indices in the DIM dimension, and reduce all values. + // We do not need to create the reduction loop always: if we can + // slice the input array given the inputIndices, then we can + // just apply a new SUM operation (total reduction) to the slice. + // For the time being, generate the explicit loop because the slicing + // requires generating an elemental operation for the input array + // (and the mask, if present). + // TODO: produce the slices and new SUM after adding a pattern + // for expanding total reduction SUM case. + mlir::Type indexType = builder.getIndexType(); + auto one = builder.createIntegerConstant(loc, indexType, 1); + auto ub = builder.createConvert(loc, indexType, dimExtent); + + // Initial value for the reduction. + mlir::Value initValue = genInitValue(loc, builder, elementType); + + // The reduction loop may be unordered if FastMathFlags::reassoc + // transformations are allowed. The integer reduction is always + // unordered. + bool isUnordered = mlir::isa(elementType) || + static_cast(sum.getFastmath() & + mlir::arith::FastMathFlags::reassoc); + + // If the mask is present and is a scalar, then we'd better load its value + // outside of the reduction loop making the loop unswitching easier. + // Maybe it is worth hoisting it from the elemental operation as well. + if (mask) { + hlfir::Entity maskValue{mask}; + if (maskValue.isScalar()) + mask = hlfir::loadTrivialScalar(loc, builder, maskValue); + } + + // NOTE: the outer elemental operation may be lowered into + // omp.workshare.loop_wrapper/omp.loop_nest later, so the reduction + // loop may appear disjoint from the workshare loop nest. + // Moreover, the inner loop is not strictly nested (due to the reduction + // starting value initialization), and the above omp dialect operations + // cannot produce results. + // It is unclear what we should do about it yet. + auto doLoop = builder.create( + loc, one, ub, one, isUnordered, /*finalCountValue=*/false, + mlir::ValueRange{initValue}); + + // Address the input array using the reduction loop's IV + // for the DIM dimension. + mlir::Value iv = doLoop.getInductionVar(); + llvm::SmallVector indices{inputIndices}; + indices.insert(indices.begin() + dimVal - 1, iv); + + mlir::OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(doLoop.getBody()); + mlir::Value reductionValue = doLoop.getRegionIterArgs()[0]; + fir::IfOp ifOp; + if (mask) { + // Make the reduction value update conditional on the value + // of the mask. + hlfir::Entity maskValue{mask}; + if (!maskValue.isScalar()) { + // If the mask is an array, use the elemental and the loop indices + // to address the proper mask element. + maskValue = hlfir::getElementAt(loc, builder, maskValue, indices); + maskValue = hlfir::loadTrivialScalar(loc, builder, maskValue); + } + mlir::Value isUnmasked = + builder.create(loc, builder.getI1Type(), maskValue); + ifOp = builder.create(loc, elementType, isUnmasked, + /*withElseRegion=*/true); + // In the 'else' block return the current reduction value. + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + builder.create(loc, reductionValue); + + // In the 'then' block do the actual addition. + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + } + + hlfir::Entity element = hlfir::getElementAt(loc, builder, array, indices); + hlfir::Entity elementValue = + hlfir::loadTrivialScalar(loc, builder, element); + // NOTE: we can use "Kahan summation" same way as the runtime + // (e.g. when fast-math is not allowed), but let's start with + // the simple version. + reductionValue = genScalarAdd(loc, builder, reductionValue, elementValue); + builder.create(loc, reductionValue); + + if (ifOp) { + builder.setInsertionPointAfter(ifOp); + builder.create(loc, ifOp.getResult(0)); + } + + return hlfir::Entity{doLoop.getResult(0)}; + }; + hlfir::ElementalOp elementalOp = hlfir::genElementalOp( + loc, builder, elementType, resultShape, {}, genKernel, + /*isUnordered=*/true, /*polymorphicMold=*/nullptr, + sum.getResult().getType()); + + // it wouldn't be safe to replace block arguments with a different + // hlfir.expr type. Types can differ due to differing amounts of shape + // information + assert(elementalOp.getResult().getType() == sum.getResult().getType()); + + rewriter.replaceOp(sum, elementalOp); + return mlir::success(); + } + +private: + // Return fir.shape specifying the shape of the result + // of a SUM reduction with DIM=dimVal. The second return value + // is the extent of the DIM dimension. + static std::tuple + genResultShape(mlir::Location loc, fir::FirOpBuilder &builder, + hlfir::Entity array, int64_t dimVal) { + mlir::Value inShape = hlfir::genShape(loc, builder, array); + llvm::SmallVector inExtents = + hlfir::getExplicitExtentsFromShape(inShape, builder); + if (inShape.getUses().empty()) + inShape.getDefiningOp()->erase(); + + mlir::Value dimExtent = inExtents[dimVal - 1]; + inExtents.erase(inExtents.begin() + dimVal - 1); + return {builder.create(loc, inExtents), dimExtent}; + } + + // Generate the initial value for a SUM reduction with the given + // data type. + static mlir::Value genInitValue(mlir::Location loc, + fir::FirOpBuilder &builder, + mlir::Type elementType) { + if (auto ty = mlir::dyn_cast(elementType)) { + const llvm::fltSemantics &sem = ty.getFloatSemantics(); + return builder.createRealConstant(loc, elementType, + llvm::APFloat::getZero(sem)); + } else if (auto ty = mlir::dyn_cast(elementType)) { + mlir::Value initValue = genInitValue(loc, builder, ty.getElementType()); + return fir::factory::Complex{builder, loc}.createComplex(ty, initValue, + initValue); + } else if (mlir::isa(elementType)) { + return builder.createIntegerConstant(loc, elementType, 0); + } + + llvm_unreachable("unsupported SUM reduction type"); + } + + // Generate scalar addition of the two values (of the same data type). + static mlir::Value genScalarAdd(mlir::Location loc, + fir::FirOpBuilder &builder, + mlir::Value value1, mlir::Value value2) { + mlir::Type ty = value1.getType(); + assert(ty == value2.getType() && "reduction values' types do not match"); + if (mlir::isa(ty)) + return builder.create(loc, value1, value2); + else if (mlir::isa(ty)) + return builder.create(loc, value1, value2); + else if (mlir::isa(ty)) + return builder.create(loc, value1, value2); + + llvm_unreachable("unsupported SUM reduction type"); + } +}; + class SimplifyHLFIRIntrinsics : public hlfir::impl::SimplifyHLFIRIntrinsicsBase { public: @@ -97,6 +282,7 @@ class SimplifyHLFIRIntrinsics mlir::MLIRContext *context = &getContext(); mlir::RewritePatternSet patterns(context); patterns.insert(context); + patterns.insert(context); mlir::ConversionTarget target(*context); // don't transform transpose of polymorphic arrays (not currently supported // by hlfir.elemental) @@ -105,6 +291,24 @@ class SimplifyHLFIRIntrinsics return mlir::cast(transpose.getType()) .isPolymorphic(); }); + // Handle only SUM(DIM=CONSTANT) case for now. + // It may be beneficial to expand the non-DIM case as well. + // E.g. when the input array is an elemental array expression, + // expanding the SUM into a total reduction loop nest + // would avoid creating a temporary for the elemental array expression. + target.addDynamicallyLegalOp([](hlfir::SumOp sum) { + if (mlir::Value dim = sum.getDim()) { + if (fir::getIntIfConstant(dim)) { + if (!fir::isa_trivial(sum.getType())) { + // Ignore the case SUM(a, DIM=X), where 'a' is a 1D array. + // It is only legal when X is 1, and it should probably be + // canonicalized into SUM(a). + return false; + } + } + } + return true; + }); target.markUnknownOpDynamicallyLegal( [](mlir::Operation *) { return true; }); if (mlir::failed(mlir::applyFullConversion(getOperation(), target, diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir new file mode 100644 index 0000000000000..05a4dfde6344e --- /dev/null +++ b/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir @@ -0,0 +1,361 @@ +// RUN: fir-opt --simplify-hlfir-intrinsics %s | FileCheck %s + +// box with known extents +func.func @sum_box_known_extents(%arg0: !fir.box>) { + %cst = arith.constant 2 : i32 + %res = hlfir.sum %arg0 dim %cst : (!fir.box>, i32) -> !hlfir.expr<2xi32> + return +} +// CHECK-LABEL: func.func @sum_box_known_extents( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.box>) { +// CHECK: %[[VAL_1:.*]] = arith.constant 2 : i32 +// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 3 : index +// CHECK: %[[VAL_4:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1> +// CHECK: %[[VAL_5:.*]] = hlfir.elemental %[[VAL_4]] unordered : (!fir.shape<1>) -> !hlfir.expr<2xi32> { +// CHECK: ^bb0(%[[VAL_6:.*]]: index): +// CHECK: %[[VAL_7:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_8:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_9:.*]] = fir.do_loop %[[VAL_10:.*]] = %[[VAL_7]] to %[[VAL_3]] step %[[VAL_7]] unordered iter_args(%[[VAL_11:.*]] = %[[VAL_8]]) -> (i32) { +// CHECK: %[[VAL_12:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_13:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_12]] : (!fir.box>, index) -> (index, index, index) +// CHECK: %[[VAL_14:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_15:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_14]] : (!fir.box>, index) -> (index, index, index) +// CHECK: %[[VAL_16:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_17:.*]] = arith.subi %[[VAL_13]]#0, %[[VAL_16]] : index +// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_6]], %[[VAL_17]] : index +// CHECK: %[[VAL_19:.*]] = arith.subi %[[VAL_15]]#0, %[[VAL_16]] : index +// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_10]], %[[VAL_19]] : index +// CHECK: %[[VAL_21:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_18]], %[[VAL_20]]) : (!fir.box>, index, index) -> !fir.ref +// CHECK: %[[VAL_22:.*]] = fir.load %[[VAL_21]] : !fir.ref +// CHECK: %[[VAL_23:.*]] = arith.addi %[[VAL_11]], %[[VAL_22]] : i32 +// CHECK: fir.result %[[VAL_23]] : i32 +// CHECK: } +// CHECK: hlfir.yield_element %[[VAL_9]] : i32 +// CHECK: } +// CHECK: return +// CHECK: } + +// expr with known extents +func.func @sum_expr_known_extents(%arg0: !hlfir.expr<2x3xi32>) { + %cst = arith.constant 1 : i32 + %res = hlfir.sum %arg0 dim %cst : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32> + return +} +// CHECK-LABEL: func.func @sum_expr_known_extents( +// CHECK-SAME: %[[VAL_0:.*]]: !hlfir.expr<2x3xi32>) { +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 3 : index +// CHECK: %[[VAL_4:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1> +// CHECK: %[[VAL_5:.*]] = hlfir.elemental %[[VAL_4]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xi32> { +// CHECK: ^bb0(%[[VAL_6:.*]]: index): +// CHECK: %[[VAL_7:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_8:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_9:.*]] = fir.do_loop %[[VAL_10:.*]] = %[[VAL_7]] to %[[VAL_2]] step %[[VAL_7]] unordered iter_args(%[[VAL_11:.*]] = %[[VAL_8]]) -> (i32) { +// CHECK: %[[VAL_12:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_10]], %[[VAL_6]] : (!hlfir.expr<2x3xi32>, index, index) -> i32 +// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : i32 +// CHECK: fir.result %[[VAL_13]] : i32 +// CHECK: } +// CHECK: hlfir.yield_element %[[VAL_9]] : i32 +// CHECK: } +// CHECK: return +// CHECK: } + +// box with unknown extent +func.func @sum_box_unknown_extent1(%arg0: !fir.box>>) { + %cst = arith.constant 1 : i32 + %res = hlfir.sum %arg0 dim %cst : (!fir.box>>, i32) -> !hlfir.expr<3xcomplex> + return +} +// CHECK-LABEL: func.func @sum_box_unknown_extent1( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.box>>) { +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_3:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_2]] : (!fir.box>>, index) -> (index, index, index) +// CHECK: %[[VAL_4:.*]] = arith.constant 3 : index +// CHECK: %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1> +// CHECK: %[[VAL_6:.*]] = hlfir.elemental %[[VAL_5]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xcomplex> { +// CHECK: ^bb0(%[[VAL_7:.*]]: index): +// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK: %[[VAL_10:.*]] = fir.undefined complex +// CHECK: %[[VAL_11:.*]] = fir.insert_value %[[VAL_10]], %[[VAL_9]], [0 : index] : (complex, f64) -> complex +// CHECK: %[[VAL_12:.*]] = fir.insert_value %[[VAL_11]], %[[VAL_9]], [1 : index] : (complex, f64) -> complex +// CHECK: %[[VAL_13:.*]] = fir.do_loop %[[VAL_14:.*]] = %[[VAL_8]] to %[[VAL_3]]#1 step %[[VAL_8]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (complex) { +// CHECK: %[[VAL_16:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_17:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_16]] : (!fir.box>>, index) -> (index, index, index) +// CHECK: %[[VAL_18:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_19:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_18]] : (!fir.box>>, index) -> (index, index, index) +// CHECK: %[[VAL_20:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_21:.*]] = arith.subi %[[VAL_17]]#0, %[[VAL_20]] : index +// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_14]], %[[VAL_21]] : index +// CHECK: %[[VAL_23:.*]] = arith.subi %[[VAL_19]]#0, %[[VAL_20]] : index +// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_7]], %[[VAL_23]] : index +// CHECK: %[[VAL_25:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_22]], %[[VAL_24]]) : (!fir.box>>, index, index) -> !fir.ref> +// CHECK: %[[VAL_26:.*]] = fir.load %[[VAL_25]] : !fir.ref> +// CHECK: %[[VAL_27:.*]] = fir.addc %[[VAL_15]], %[[VAL_26]] : complex +// CHECK: fir.result %[[VAL_27]] : complex +// CHECK: } +// CHECK: hlfir.yield_element %[[VAL_13]] : complex +// CHECK: } +// CHECK: return +// CHECK: } + +func.func @sum_box_unknown_extent2(%arg0: !fir.box>>) { + %cst = arith.constant 2 : i32 + %res = hlfir.sum %arg0 dim %cst : (!fir.box>>, i32) -> !hlfir.expr> + return +} +// CHECK-LABEL: func.func @sum_box_unknown_extent2( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.box>>) { +// CHECK: %[[VAL_1:.*]] = arith.constant 2 : i32 +// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_3:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_2]] : (!fir.box>>, index) -> (index, index, index) +// CHECK: %[[VAL_4:.*]] = arith.constant 3 : index +// CHECK: %[[VAL_5:.*]] = fir.shape %[[VAL_3]]#1 : (index) -> !fir.shape<1> +// CHECK: %[[VAL_6:.*]] = hlfir.elemental %[[VAL_5]] unordered : (!fir.shape<1>) -> !hlfir.expr> { +// CHECK: ^bb0(%[[VAL_7:.*]]: index): +// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK: %[[VAL_10:.*]] = fir.undefined complex +// CHECK: %[[VAL_11:.*]] = fir.insert_value %[[VAL_10]], %[[VAL_9]], [0 : index] : (complex, f64) -> complex +// CHECK: %[[VAL_12:.*]] = fir.insert_value %[[VAL_11]], %[[VAL_9]], [1 : index] : (complex, f64) -> complex +// CHECK: %[[VAL_13:.*]] = fir.do_loop %[[VAL_14:.*]] = %[[VAL_8]] to %[[VAL_4]] step %[[VAL_8]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (complex) { +// CHECK: %[[VAL_16:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_17:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_16]] : (!fir.box>>, index) -> (index, index, index) +// CHECK: %[[VAL_18:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_19:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_18]] : (!fir.box>>, index) -> (index, index, index) +// CHECK: %[[VAL_20:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_21:.*]] = arith.subi %[[VAL_17]]#0, %[[VAL_20]] : index +// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_7]], %[[VAL_21]] : index +// CHECK: %[[VAL_23:.*]] = arith.subi %[[VAL_19]]#0, %[[VAL_20]] : index +// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_14]], %[[VAL_23]] : index +// CHECK: %[[VAL_25:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_22]], %[[VAL_24]]) : (!fir.box>>, index, index) -> !fir.ref> +// CHECK: %[[VAL_26:.*]] = fir.load %[[VAL_25]] : !fir.ref> +// CHECK: %[[VAL_27:.*]] = fir.addc %[[VAL_15]], %[[VAL_26]] : complex +// CHECK: fir.result %[[VAL_27]] : complex +// CHECK: } +// CHECK: hlfir.yield_element %[[VAL_13]] : complex +// CHECK: } +// CHECK: return +// CHECK: } + +// expr with unknown extent +func.func @sum_expr_unkwnonw_extent1(%arg0: !hlfir.expr) { + %cst = arith.constant 1 : i32 + %res = hlfir.sum %arg0 dim %cst : (!hlfir.expr, i32) -> !hlfir.expr<3xf32> + return +} +// CHECK-LABEL: func.func @sum_expr_unkwnonw_extent1( +// CHECK-SAME: %[[VAL_0:.*]]: !hlfir.expr) { +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_2:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr) -> !fir.shape<2> +// CHECK: %[[VAL_3:.*]] = hlfir.get_extent %[[VAL_2]] {dim = 0 : index} : (!fir.shape<2>) -> index +// CHECK: %[[VAL_4:.*]] = arith.constant 3 : index +// CHECK: %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1> +// CHECK: %[[VAL_6:.*]] = hlfir.elemental %[[VAL_5]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xf32> { +// CHECK: ^bb0(%[[VAL_7:.*]]: index): +// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_10:.*]] = fir.do_loop %[[VAL_11:.*]] = %[[VAL_8]] to %[[VAL_3]] step %[[VAL_8]] iter_args(%[[VAL_12:.*]] = %[[VAL_9]]) -> (f32) { +// CHECK: %[[VAL_13:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_11]], %[[VAL_7]] : (!hlfir.expr, index, index) -> f32 +// CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32 +// CHECK: fir.result %[[VAL_14]] : f32 +// CHECK: } +// CHECK: hlfir.yield_element %[[VAL_10]] : f32 +// CHECK: } +// CHECK: return +// CHECK: } + +func.func @sum_expr_unkwnonw_extent2(%arg0: !hlfir.expr) { + %cst = arith.constant 2 : i32 + %res = hlfir.sum %arg0 dim %cst : (!hlfir.expr, i32) -> !hlfir.expr + return +} +// CHECK-LABEL: func.func @sum_expr_unkwnonw_extent2( +// CHECK-SAME: %[[VAL_0:.*]]: !hlfir.expr) { +// CHECK: %[[VAL_1:.*]] = arith.constant 2 : i32 +// CHECK: %[[VAL_2:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr) -> !fir.shape<2> +// CHECK: %[[VAL_3:.*]] = hlfir.get_extent %[[VAL_2]] {dim = 0 : index} : (!fir.shape<2>) -> index +// CHECK: %[[VAL_4:.*]] = arith.constant 3 : index +// CHECK: %[[VAL_5:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1> +// CHECK: %[[VAL_6:.*]] = hlfir.elemental %[[VAL_5]] unordered : (!fir.shape<1>) -> !hlfir.expr { +// CHECK: ^bb0(%[[VAL_7:.*]]: index): +// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_10:.*]] = fir.do_loop %[[VAL_11:.*]] = %[[VAL_8]] to %[[VAL_4]] step %[[VAL_8]] iter_args(%[[VAL_12:.*]] = %[[VAL_9]]) -> (f32) { +// CHECK: %[[VAL_13:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_7]], %[[VAL_11]] : (!hlfir.expr, index, index) -> f32 +// CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32 +// CHECK: fir.result %[[VAL_14]] : f32 +// CHECK: } +// CHECK: hlfir.yield_element %[[VAL_10]] : f32 +// CHECK: } +// CHECK: return +// CHECK: } + +// scalar mask +func.func @sum_scalar_mask(%arg0: !hlfir.expr, %mask: !fir.ref>) { + %cst = arith.constant 1 : i32 + %res = hlfir.sum %arg0 dim %cst mask %mask : (!hlfir.expr, i32, !fir.ref>) -> !hlfir.expr<3xf32> + return +} +// CHECK-LABEL: func.func @sum_scalar_mask( +// CHECK-SAME: %[[VAL_0:.*]]: !hlfir.expr, +// CHECK-SAME: %[[VAL_1:.*]]: !fir.ref>) { +// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_3:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr) -> !fir.shape<2> +// CHECK: %[[VAL_4:.*]] = hlfir.get_extent %[[VAL_3]] {dim = 0 : index} : (!fir.shape<2>) -> index +// CHECK: %[[VAL_5:.*]] = arith.constant 3 : index +// CHECK: %[[VAL_6:.*]] = fir.shape %[[VAL_5]] : (index) -> !fir.shape<1> +// CHECK: %[[VAL_7:.*]] = hlfir.elemental %[[VAL_6]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xf32> { +// CHECK: ^bb0(%[[VAL_8:.*]]: index): +// CHECK: %[[VAL_9:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_11:.*]] = fir.load %[[VAL_1]] : !fir.ref> +// CHECK: %[[VAL_12:.*]] = fir.do_loop %[[VAL_13:.*]] = %[[VAL_9]] to %[[VAL_4]] step %[[VAL_9]] iter_args(%[[VAL_14:.*]] = %[[VAL_10]]) -> (f32) { +// CHECK: %[[VAL_15:.*]] = fir.convert %[[VAL_11]] : (!fir.logical<1>) -> i1 +// CHECK: %[[VAL_16:.*]] = fir.if %[[VAL_15]] -> (f32) { +// CHECK: %[[VAL_17:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_13]], %[[VAL_8]] : (!hlfir.expr, index, index) -> f32 +// CHECK: %[[VAL_18:.*]] = arith.addf %[[VAL_14]], %[[VAL_17]] : f32 +// CHECK: fir.result %[[VAL_18]] : f32 +// CHECK: } else { +// CHECK: fir.result %[[VAL_14]] : f32 +// CHECK: } +// CHECK: fir.result %[[VAL_16]] : f32 +// CHECK: } +// CHECK: hlfir.yield_element %[[VAL_12]] : f32 +// CHECK: } +// CHECK: return +// CHECK: } + +// array mask +func.func @sum_array_mask(%arg0: !hlfir.expr, %mask: !fir.box>>) { + %cst = arith.constant 2 : i32 + %res = hlfir.sum %arg0 dim %cst mask %mask : (!hlfir.expr, i32, !fir.box>>) -> !hlfir.expr + return +} +// CHECK-LABEL: func.func @sum_array_mask( +// CHECK-SAME: %[[VAL_0:.*]]: !hlfir.expr, +// CHECK-SAME: %[[VAL_1:.*]]: !fir.box>>) { +// CHECK: %[[VAL_2:.*]] = arith.constant 2 : i32 +// CHECK: %[[VAL_3:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr) -> !fir.shape<2> +// CHECK: %[[VAL_4:.*]] = hlfir.get_extent %[[VAL_3]] {dim = 0 : index} : (!fir.shape<2>) -> index +// CHECK: %[[VAL_5:.*]] = arith.constant 3 : index +// CHECK: %[[VAL_6:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1> +// CHECK: %[[VAL_7:.*]] = hlfir.elemental %[[VAL_6]] unordered : (!fir.shape<1>) -> !hlfir.expr { +// CHECK: ^bb0(%[[VAL_8:.*]]: index): +// CHECK: %[[VAL_9:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_11:.*]] = fir.do_loop %[[VAL_12:.*]] = %[[VAL_9]] to %[[VAL_5]] step %[[VAL_9]] iter_args(%[[VAL_13:.*]] = %[[VAL_10]]) -> (f32) { +// CHECK: %[[VAL_14:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_15:.*]]:3 = fir.box_dims %[[VAL_1]], %[[VAL_14]] : (!fir.box>>, index) -> (index, index, index) +// CHECK: %[[VAL_16:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_17:.*]]:3 = fir.box_dims %[[VAL_1]], %[[VAL_16]] : (!fir.box>>, index) -> (index, index, index) +// CHECK: %[[VAL_18:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_19:.*]] = arith.subi %[[VAL_15]]#0, %[[VAL_18]] : index +// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_8]], %[[VAL_19]] : index +// CHECK: %[[VAL_21:.*]] = arith.subi %[[VAL_17]]#0, %[[VAL_18]] : index +// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_12]], %[[VAL_21]] : index +// CHECK: %[[VAL_23:.*]] = hlfir.designate %[[VAL_1]] (%[[VAL_20]], %[[VAL_22]]) : (!fir.box>>, index, index) -> !fir.ref> +// CHECK: %[[VAL_24:.*]] = fir.load %[[VAL_23]] : !fir.ref> +// CHECK: %[[VAL_25:.*]] = fir.convert %[[VAL_24]] : (!fir.logical<1>) -> i1 +// CHECK: %[[VAL_26:.*]] = fir.if %[[VAL_25]] -> (f32) { +// CHECK: %[[VAL_27:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_8]], %[[VAL_12]] : (!hlfir.expr, index, index) -> f32 +// CHECK: %[[VAL_28:.*]] = arith.addf %[[VAL_13]], %[[VAL_27]] : f32 +// CHECK: fir.result %[[VAL_28]] : f32 +// CHECK: } else { +// CHECK: fir.result %[[VAL_13]] : f32 +// CHECK: } +// CHECK: fir.result %[[VAL_26]] : f32 +// CHECK: } +// CHECK: hlfir.yield_element %[[VAL_11]] : f32 +// CHECK: } +// CHECK: return +// CHECK: } + +// array expr mask +func.func @sum_array_expr_mask(%arg0: !hlfir.expr, %mask: !hlfir.expr>) { + %cst = arith.constant 2 : i32 + %res = hlfir.sum %arg0 dim %cst mask %mask : (!hlfir.expr, i32, !hlfir.expr>) -> !hlfir.expr + return +} +// CHECK-LABEL: func.func @sum_array_expr_mask( +// CHECK-SAME: %[[VAL_0:.*]]: !hlfir.expr, +// CHECK-SAME: %[[VAL_1:.*]]: !hlfir.expr>) { +// CHECK: %[[VAL_2:.*]] = arith.constant 2 : i32 +// CHECK: %[[VAL_3:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr) -> !fir.shape<2> +// CHECK: %[[VAL_4:.*]] = hlfir.get_extent %[[VAL_3]] {dim = 0 : index} : (!fir.shape<2>) -> index +// CHECK: %[[VAL_5:.*]] = arith.constant 3 : index +// CHECK: %[[VAL_6:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1> +// CHECK: %[[VAL_7:.*]] = hlfir.elemental %[[VAL_6]] unordered : (!fir.shape<1>) -> !hlfir.expr { +// CHECK: ^bb0(%[[VAL_8:.*]]: index): +// CHECK: %[[VAL_9:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_11:.*]] = fir.do_loop %[[VAL_12:.*]] = %[[VAL_9]] to %[[VAL_5]] step %[[VAL_9]] iter_args(%[[VAL_13:.*]] = %[[VAL_10]]) -> (f32) { +// CHECK: %[[VAL_14:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_8]], %[[VAL_12]] : (!hlfir.expr>, index, index) -> !fir.logical<1> +// CHECK: %[[VAL_15:.*]] = fir.convert %[[VAL_14]] : (!fir.logical<1>) -> i1 +// CHECK: %[[VAL_16:.*]] = fir.if %[[VAL_15]] -> (f32) { +// CHECK: %[[VAL_17:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_8]], %[[VAL_12]] : (!hlfir.expr, index, index) -> f32 +// CHECK: %[[VAL_18:.*]] = arith.addf %[[VAL_13]], %[[VAL_17]] : f32 +// CHECK: fir.result %[[VAL_18]] : f32 +// CHECK: } else { +// CHECK: fir.result %[[VAL_13]] : f32 +// CHECK: } +// CHECK: fir.result %[[VAL_16]] : f32 +// CHECK: } +// CHECK: hlfir.yield_element %[[VAL_11]] : f32 +// CHECK: } +// CHECK: return +// CHECK: } + +// unordered floating point reduction +func.func @sum_unordered_reduction(%arg0: !hlfir.expr<2x3xf32>) { + %cst = arith.constant 1 : i32 + %res = hlfir.sum %arg0 dim %cst {fastmath = #arith.fastmath} : (!hlfir.expr<2x3xf32>, i32) -> !hlfir.expr<3xf32> + return +} +// CHECK-LABEL: func.func @sum_unordered_reduction( +// CHECK-SAME: %[[VAL_0:.*]]: !hlfir.expr<2x3xf32>) { +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 3 : index +// CHECK: %[[VAL_4:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1> +// CHECK: %[[VAL_5:.*]] = hlfir.elemental %[[VAL_4]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xf32> { +// CHECK: ^bb0(%[[VAL_6:.*]]: index): +// CHECK: %[[VAL_7:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_9:.*]] = fir.do_loop %[[VAL_10:.*]] = %[[VAL_7]] to %[[VAL_2]] step %[[VAL_7]] unordered iter_args(%[[VAL_11:.*]] = %[[VAL_8]]) -> (f32) { +// CHECK: %[[VAL_12:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_10]], %[[VAL_6]] : (!hlfir.expr<2x3xf32>, index, index) -> f32 +// CHECK: %[[VAL_13:.*]] = arith.addf %[[VAL_11]], %[[VAL_12]] fastmath : f32 +// CHECK: fir.result %[[VAL_13]] : f32 +// CHECK: } +// CHECK: hlfir.yield_element %[[VAL_9]] : f32 +// CHECK: } +// CHECK: return +// CHECK: } + +// negative: total reduction +func.func @sum_total_reduction(%arg0: !fir.box>) { + %cst = arith.constant 1 : i32 + %res = hlfir.sum %arg0 dim %cst : (!fir.box>, i32) -> i32 + return +} +// CHECK-LABEL: func.func @sum_total_reduction( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.box>) { +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_2:.*]] = hlfir.sum %[[VAL_0]] dim %[[VAL_1]] : (!fir.box>, i32) -> i32 +// CHECK: return +// CHECK: } + +// negative: non-const dim +func.func @sum_non_const_dim(%arg0: !fir.box>, %dim: i32) { + %res = hlfir.sum %arg0 dim %dim : (!fir.box>, i32) -> i32 + return +} +// CHECK-LABEL: func.func @sum_non_const_dim( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.box>, +// CHECK-SAME: %[[VAL_1:.*]]: i32) { +// CHECK: %[[VAL_2:.*]] = hlfir.sum %[[VAL_0]] dim %[[VAL_1]] : (!fir.box>, i32) -> i32 +// CHECK: return +// CHECK: } From 0542e16c0b45f74be5f2ca6d52d9420876d5bc60 Mon Sep 17 00:00:00 2001 From: Slava Zakharin Date: Wed, 4 Dec 2024 14:26:35 -0800 Subject: [PATCH 2/3] Handle dynamically absent mask argument properly. --- .../Transforms/SimplifyHLFIRIntrinsics.cpp | 64 +++++++++++-- .../HLFIR/simplify-hlfir-intrinsics-sum.fir | 92 +++++++++++++++---- 2 files changed, 129 insertions(+), 27 deletions(-) diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp index 35dc881e880df..0c34c8221aeda 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp @@ -141,10 +141,17 @@ class SumAsElementalConversion : public mlir::OpRewritePattern { // If the mask is present and is a scalar, then we'd better load its value // outside of the reduction loop making the loop unswitching easier. // Maybe it is worth hoisting it from the elemental operation as well. + mlir::Value isPresentPred, maskValue; if (mask) { - hlfir::Entity maskValue{mask}; - if (maskValue.isScalar()) - mask = hlfir::loadTrivialScalar(loc, builder, maskValue); + if (mlir::isa(mask.getType())) { + // MASK represented by a box might be dynamically optional, + // so we have to check for its presence before accessing it. + isPresentPred = + builder.create(loc, builder.getI1Type(), mask); + } + + if (hlfir::Entity{mask}.isScalar()) + maskValue = genMaskValue(loc, builder, mask, isPresentPred, {}); } // NOTE: the outer elemental operation may be lowered into @@ -171,12 +178,10 @@ class SumAsElementalConversion : public mlir::OpRewritePattern { if (mask) { // Make the reduction value update conditional on the value // of the mask. - hlfir::Entity maskValue{mask}; - if (!maskValue.isScalar()) { + if (!maskValue) { // If the mask is an array, use the elemental and the loop indices // to address the proper mask element. - maskValue = hlfir::getElementAt(loc, builder, maskValue, indices); - maskValue = hlfir::loadTrivialScalar(loc, builder, maskValue); + maskValue = genMaskValue(loc, builder, mask, isPresentPred, indices); } mlir::Value isUnmasked = builder.create(loc, builder.getI1Type(), maskValue); @@ -273,6 +278,51 @@ class SumAsElementalConversion : public mlir::OpRewritePattern { llvm_unreachable("unsupported SUM reduction type"); } + + static mlir::Value genMaskValue(mlir::Location loc, + fir::FirOpBuilder &builder, mlir::Value mask, + mlir::Value isPresentPred, + mlir::ValueRange indices) { + mlir::OpBuilder::InsertionGuard guard(builder); + fir::IfOp ifOp; + mlir::Type maskType = + hlfir::getFortranElementType(fir::unwrapPassByRefType(mask.getType())); + if (isPresentPred) { + ifOp = builder.create(loc, maskType, isPresentPred, + /*withElseRegion=*/true); + + // Use 'true', if the mask is not present. + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + mlir::Value trueValue = builder.createBool(loc, true); + trueValue = builder.createConvert(loc, maskType, trueValue); + builder.create(loc, trueValue); + + // Load the mask value, if the mask is present. + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + } + + hlfir::Entity maskVar{mask}; + if (maskVar.isScalar()) { + if (mlir::isa(mask.getType())) { + // MASK may be a boxed scalar. + mlir::Value addr = hlfir::genVariableRawAddress(loc, builder, maskVar); + mask = builder.create(loc, hlfir::Entity{addr}); + } else { + mask = hlfir::loadTrivialScalar(loc, builder, maskVar); + } + } else { + // Load from the mask array. + assert(!indices.empty() && "no indices for addressing the mask array"); + maskVar = hlfir::getElementAt(loc, builder, maskVar, indices); + mask = hlfir::loadTrivialScalar(loc, builder, maskVar); + } + + if (!isPresentPred) + return mask; + + builder.create(loc, mask); + return ifOp.getResult(0); + } }; class SimplifyHLFIRIntrinsics diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir index 05a4dfde6344e..48c4144f70393 100644 --- a/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir +++ b/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir @@ -229,6 +229,50 @@ func.func @sum_scalar_mask(%arg0: !hlfir.expr, %mask: !fir.ref, %mask: !fir.box>) { + %cst = arith.constant 1 : i32 + %res = hlfir.sum %arg0 dim %cst mask %mask : (!hlfir.expr, i32, !fir.box>) -> !hlfir.expr<3xf32> + return +} +// CHECK-LABEL: func.func @sum_scalar_boxed_mask( +// CHECK-SAME: %[[VAL_0:.*]]: !hlfir.expr, +// CHECK-SAME: %[[VAL_1:.*]]: !fir.box>) { +// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_3:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr) -> !fir.shape<2> +// CHECK: %[[VAL_4:.*]] = hlfir.get_extent %[[VAL_3]] {dim = 0 : index} : (!fir.shape<2>) -> index +// CHECK: %[[VAL_5:.*]] = arith.constant 3 : index +// CHECK: %[[VAL_6:.*]] = fir.shape %[[VAL_5]] : (index) -> !fir.shape<1> +// CHECK: %[[VAL_7:.*]] = hlfir.elemental %[[VAL_6]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xf32> { +// CHECK: ^bb0(%[[VAL_8:.*]]: index): +// CHECK: %[[VAL_9:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_11:.*]] = fir.is_present %[[VAL_1]] : (!fir.box>) -> i1 +// CHECK: %[[VAL_12:.*]] = fir.if %[[VAL_11]] -> (!fir.logical<1>) { +// CHECK: %[[VAL_13:.*]] = fir.box_addr %[[VAL_1]] : (!fir.box>) -> !fir.ref> +// CHECK: %[[VAL_14:.*]] = fir.load %[[VAL_13]] : !fir.ref> +// CHECK: fir.result %[[VAL_14]] : !fir.logical<1> +// CHECK: } else { +// CHECK: %[[VAL_15:.*]] = arith.constant true +// CHECK: %[[VAL_16:.*]] = fir.convert %[[VAL_15]] : (i1) -> !fir.logical<1> +// CHECK: fir.result %[[VAL_16]] : !fir.logical<1> +// CHECK: } +// CHECK: %[[VAL_17:.*]] = fir.do_loop %[[VAL_18:.*]] = %[[VAL_9]] to %[[VAL_4]] step %[[VAL_9]] iter_args(%[[VAL_19:.*]] = %[[VAL_10]]) -> (f32) { +// CHECK: %[[VAL_20:.*]] = fir.convert %[[VAL_12]] : (!fir.logical<1>) -> i1 +// CHECK: %[[VAL_21:.*]] = fir.if %[[VAL_20]] -> (f32) { +// CHECK: %[[VAL_22:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_18]], %[[VAL_8]] : (!hlfir.expr, index, index) -> f32 +// CHECK: %[[VAL_23:.*]] = arith.addf %[[VAL_19]], %[[VAL_22]] : f32 +// CHECK: fir.result %[[VAL_23]] : f32 +// CHECK: } else { +// CHECK: fir.result %[[VAL_19]] : f32 +// CHECK: } +// CHECK: fir.result %[[VAL_21]] : f32 +// CHECK: } +// CHECK: hlfir.yield_element %[[VAL_17]] : f32 +// CHECK: } +// CHECK: return +// CHECK: } + // array mask func.func @sum_array_mask(%arg0: !hlfir.expr, %mask: !fir.box>>) { %cst = arith.constant 2 : i32 @@ -247,29 +291,37 @@ func.func @sum_array_mask(%arg0: !hlfir.expr, %mask: !fir.box (f32) { -// CHECK: %[[VAL_14:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_15:.*]]:3 = fir.box_dims %[[VAL_1]], %[[VAL_14]] : (!fir.box>>, index) -> (index, index, index) -// CHECK: %[[VAL_16:.*]] = arith.constant 1 : index -// CHECK: %[[VAL_17:.*]]:3 = fir.box_dims %[[VAL_1]], %[[VAL_16]] : (!fir.box>>, index) -> (index, index, index) -// CHECK: %[[VAL_18:.*]] = arith.constant 1 : index -// CHECK: %[[VAL_19:.*]] = arith.subi %[[VAL_15]]#0, %[[VAL_18]] : index -// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_8]], %[[VAL_19]] : index -// CHECK: %[[VAL_21:.*]] = arith.subi %[[VAL_17]]#0, %[[VAL_18]] : index -// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_12]], %[[VAL_21]] : index -// CHECK: %[[VAL_23:.*]] = hlfir.designate %[[VAL_1]] (%[[VAL_20]], %[[VAL_22]]) : (!fir.box>>, index, index) -> !fir.ref> -// CHECK: %[[VAL_24:.*]] = fir.load %[[VAL_23]] : !fir.ref> -// CHECK: %[[VAL_25:.*]] = fir.convert %[[VAL_24]] : (!fir.logical<1>) -> i1 -// CHECK: %[[VAL_26:.*]] = fir.if %[[VAL_25]] -> (f32) { -// CHECK: %[[VAL_27:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_8]], %[[VAL_12]] : (!hlfir.expr, index, index) -> f32 -// CHECK: %[[VAL_28:.*]] = arith.addf %[[VAL_13]], %[[VAL_27]] : f32 -// CHECK: fir.result %[[VAL_28]] : f32 +// CHECK: %[[VAL_11:.*]] = fir.is_present %[[VAL_1]] : (!fir.box>>) -> i1 +// CHECK: %[[VAL_12:.*]] = fir.do_loop %[[VAL_13:.*]] = %[[VAL_9]] to %[[VAL_5]] step %[[VAL_9]] iter_args(%[[VAL_14:.*]] = %[[VAL_10]]) -> (f32) { +// CHECK: %[[VAL_15:.*]] = fir.if %[[VAL_11]] -> (!fir.logical<1>) { +// CHECK: %[[VAL_16:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_17:.*]]:3 = fir.box_dims %[[VAL_1]], %[[VAL_16]] : (!fir.box>>, index) -> (index, index, index) +// CHECK: %[[VAL_18:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_19:.*]]:3 = fir.box_dims %[[VAL_1]], %[[VAL_18]] : (!fir.box>>, index) -> (index, index, index) +// CHECK: %[[VAL_20:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_21:.*]] = arith.subi %[[VAL_17]]#0, %[[VAL_20]] : index +// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_8]], %[[VAL_21]] : index +// CHECK: %[[VAL_23:.*]] = arith.subi %[[VAL_19]]#0, %[[VAL_20]] : index +// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_13]], %[[VAL_23]] : index +// CHECK: %[[VAL_25:.*]] = hlfir.designate %[[VAL_1]] (%[[VAL_22]], %[[VAL_24]]) : (!fir.box>>, index, index) -> !fir.ref> +// CHECK: %[[VAL_26:.*]] = fir.load %[[VAL_25]] : !fir.ref> +// CHECK: fir.result %[[VAL_26]] : !fir.logical<1> // CHECK: } else { -// CHECK: fir.result %[[VAL_13]] : f32 +// CHECK: %[[VAL_27:.*]] = arith.constant true +// CHECK: %[[VAL_28:.*]] = fir.convert %[[VAL_27]] : (i1) -> !fir.logical<1> +// CHECK: fir.result %[[VAL_28]] : !fir.logical<1> // CHECK: } -// CHECK: fir.result %[[VAL_26]] : f32 +// CHECK: %[[VAL_29:.*]] = fir.convert %[[VAL_15]] : (!fir.logical<1>) -> i1 +// CHECK: %[[VAL_30:.*]] = fir.if %[[VAL_29]] -> (f32) { +// CHECK: %[[VAL_31:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_8]], %[[VAL_13]] : (!hlfir.expr, index, index) -> f32 +// CHECK: %[[VAL_32:.*]] = arith.addf %[[VAL_14]], %[[VAL_31]] : f32 +// CHECK: fir.result %[[VAL_32]] : f32 +// CHECK: } else { +// CHECK: fir.result %[[VAL_14]] : f32 +// CHECK: } +// CHECK: fir.result %[[VAL_30]] : f32 // CHECK: } -// CHECK: hlfir.yield_element %[[VAL_11]] : f32 +// CHECK: hlfir.yield_element %[[VAL_12]] : f32 // CHECK: } // CHECK: return // CHECK: } From 091e5922a4a6b9cb84dc323be2b753d69caf881e Mon Sep 17 00:00:00 2001 From: Slava Zakharin Date: Wed, 4 Dec 2024 14:48:59 -0800 Subject: [PATCH 3/3] Fixed typo in the test. --- flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir index 48c4144f70393..703b6673154f3 100644 --- a/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir +++ b/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir @@ -142,12 +142,12 @@ func.func @sum_box_unknown_extent2(%arg0: !fir.box>> // CHECK: } // expr with unknown extent -func.func @sum_expr_unkwnonw_extent1(%arg0: !hlfir.expr) { +func.func @sum_expr_unknown_extent1(%arg0: !hlfir.expr) { %cst = arith.constant 1 : i32 %res = hlfir.sum %arg0 dim %cst : (!hlfir.expr, i32) -> !hlfir.expr<3xf32> return } -// CHECK-LABEL: func.func @sum_expr_unkwnonw_extent1( +// CHECK-LABEL: func.func @sum_expr_unknown_extent1( // CHECK-SAME: %[[VAL_0:.*]]: !hlfir.expr) { // CHECK: %[[VAL_1:.*]] = arith.constant 1 : i32 // CHECK: %[[VAL_2:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr) -> !fir.shape<2> @@ -168,12 +168,12 @@ func.func @sum_expr_unkwnonw_extent1(%arg0: !hlfir.expr) { // CHECK: return // CHECK: } -func.func @sum_expr_unkwnonw_extent2(%arg0: !hlfir.expr) { +func.func @sum_expr_unknown_extent2(%arg0: !hlfir.expr) { %cst = arith.constant 2 : i32 %res = hlfir.sum %arg0 dim %cst : (!hlfir.expr, i32) -> !hlfir.expr return } -// CHECK-LABEL: func.func @sum_expr_unkwnonw_extent2( +// CHECK-LABEL: func.func @sum_expr_unknown_extent2( // CHECK-SAME: %[[VAL_0:.*]]: !hlfir.expr) { // CHECK: %[[VAL_1:.*]] = arith.constant 2 : i32 // CHECK: %[[VAL_2:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr) -> !fir.shape<2>