diff --git a/flang/include/flang/Optimizer/HLFIR/Passes.td b/flang/include/flang/Optimizer/HLFIR/Passes.td index 04d7aec5fe489..bfff458f7a6c5 100644 --- a/flang/include/flang/Optimizer/HLFIR/Passes.td +++ b/flang/include/flang/Optimizer/HLFIR/Passes.td @@ -61,6 +61,10 @@ def SimplifyHLFIRIntrinsics : Pass<"simplify-hlfir-intrinsics"> { "the hlfir.matmul.">]; } +def ExpressionSimplification : Pass<"hlfir-expression-simplification"> { + let summary = "Simplify Fortran expressions"; +} + def InlineElementals : Pass<"inline-elementals"> { let summary = "Inline chained hlfir.elemental operations"; } diff --git a/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt b/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt index 3775a13e31e95..5c24fe58b05c4 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) add_flang_library(HLFIRTransforms BufferizeHLFIR.cpp ConvertToFIR.cpp + ExpressionSimplification.cpp InlineElementals.cpp InlineHLFIRAssign.cpp InlineHLFIRCopyIn.cpp diff --git a/flang/lib/Optimizer/HLFIR/Transforms/ExpressionSimplification.cpp b/flang/lib/Optimizer/HLFIR/Transforms/ExpressionSimplification.cpp new file mode 100644 index 0000000000000..0559b49d8ecba --- /dev/null +++ b/flang/lib/Optimizer/HLFIR/Transforms/ExpressionSimplification.cpp @@ -0,0 +1,99 @@ +//===- ExpressionSimplification.cpp - Simplify HLFIR expressions ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/HLFIR/HLFIROps.h" +#include "flang/Optimizer/HLFIR/Passes.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace hlfir { +#define GEN_PASS_DEF_EXPRESSIONSIMPLIFICATION +#include "flang/Optimizer/HLFIR/Passes.h.inc" +} // namespace hlfir + +// Get the first user of `op`. +// Note that we consider the first user to be the one on the lowest line of +// the emitted HLFIR. The user iterator considers the opposite. +template +static UserOp getFirstUser(mlir::Operation *op) { + auto it = op->user_begin(), end = op->user_end(), prev = it; + for (; it != end; prev = it++) + ; + if (prev != end) + if (auto userOp = mlir::dyn_cast(*prev)) + return userOp; + return {}; +} + +// Get the last user of `op`. +// Note that we consider the last user to be the one on the highest line of +// the emitted HLFIR. The user iterator considers the opposite. +template +static UserOp getLastUser(mlir::Operation *op) { + if (!op->getUsers().empty()) + if (auto userOp = mlir::dyn_cast(*op->user_begin())) + return userOp; + return {}; +} + +namespace { + +// Trim operations can be erased in certain expressions, such as character +// comparisons. +// Since a character comparison appends spaces to the shorter character, +// calls to trim() that are used only in the comparison can be eliminated. +// +// Example: +// `trim(x) == trim(y)` +// can be simplified to +// `x == y` +class EraseTrim : public mlir::OpRewritePattern { +public: + using mlir::OpRewritePattern::OpRewritePattern; + + llvm::LogicalResult + matchAndRewrite(hlfir::CharTrimOp trimOp, + mlir::PatternRewriter &rewriter) const override { + int trimUses = std::distance(trimOp->use_begin(), trimOp->use_end()); + auto cmpCharOp = getFirstUser(trimOp); + auto destroyOp = getLastUser(trimOp); + if (!cmpCharOp || !destroyOp || trimUses != 2) + return rewriter.notifyMatchFailure( + trimOp, "hlfir.char_trim is not used (only) by hlfir.cmpchar"); + + rewriter.eraseOp(destroyOp); + rewriter.replaceOp(trimOp, trimOp.getChr()); + return mlir::success(); + } +}; + +class ExpressionSimplificationPass + : public hlfir::impl::ExpressionSimplificationBase< + ExpressionSimplificationPass> { +public: + void runOnOperation() override { + mlir::MLIRContext *context = &getContext(); + + mlir::GreedyRewriteConfig config; + // Prevent the pattern driver from merging blocks. + config.setRegionSimplificationLevel( + mlir::GreedySimplifyRegionLevel::Disabled); + + mlir::RewritePatternSet patterns(context); + patterns.insert(context); + + if (mlir::failed(mlir::applyPatternsGreedily( + getOperation(), std::move(patterns), config))) { + mlir::emitError(getOperation()->getLoc(), + "failure in HLFIR expression simplification"); + signalPassFailure(); + } + } +}; + +} // namespace diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp index 7c2777baebef1..fe58f4512871c 100644 --- a/flang/lib/Optimizer/Passes/Pipelines.cpp +++ b/flang/lib/Optimizer/Passes/Pipelines.cpp @@ -242,6 +242,10 @@ void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm, void createHLFIRToFIRPassPipeline(mlir::PassManager &pm, EnableOpenMP enableOpenMP, llvm::OptimizationLevel optLevel) { + if (optLevel.getSizeLevel() > 0 || optLevel.getSpeedupLevel() > 0) { + addNestedPassToAllTopLevelOperations( + pm, hlfir::createExpressionSimplification); + } if (optLevel.isOptimizingForSpeed()) { addCanonicalizerPassWithoutRegionSimplification(pm); addNestedPassToAllTopLevelOperations( diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90 index e85a7728fc9af..34f319ee444b3 100644 --- a/flang/test/Driver/mlir-pass-pipeline.f90 +++ b/flang/test/Driver/mlir-pass-pipeline.f90 @@ -15,6 +15,15 @@ ! ALL: Pass statistics report ! ALL: Fortran::lower::VerifierPass +! O2-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private'] +! O2-NEXT: 'fir.global' Pipeline +! O2-NEXT: ExpressionSimplification +! O2-NEXT: 'func.func' Pipeline +! O2-NEXT: ExpressionSimplification +! O2-NEXT: 'omp.declare_reduction' Pipeline +! O2-NEXT: ExpressionSimplification +! O2-NEXT: 'omp.private' Pipeline +! O2-NEXT: ExpressionSimplification ! O2-NEXT: Canonicalizer ! ALL: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private'] ! ALL-NEXT:'fir.global' Pipeline diff --git a/flang/test/HLFIR/expression-simplification.fir b/flang/test/HLFIR/expression-simplification.fir new file mode 100644 index 0000000000000..15d1550f1f172 --- /dev/null +++ b/flang/test/HLFIR/expression-simplification.fir @@ -0,0 +1,101 @@ +// RUN: fir-opt %s --hlfir-expression-simplification | FileCheck %s + +// Test removal of trim() calls. + +// logical function test_char_cmp(x, y) result(cmp) +// character(*) :: x, y +// cmp = trim(x) == trim(y) +// end function + +func.func @_QPtest_char_cmp(%arg0: !fir.boxchar<1> {fir.bindc_name = "x"}, + %arg1: !fir.boxchar<1> {fir.bindc_name = "y"}) -> !fir.logical<4> { + %0 = fir.dummy_scope : !fir.dscope + %1 = fir.alloca !fir.logical<4> {bindc_name = "cmp", uniq_name = "_QFtest_char_cmpEcmp"} + %2:2 = hlfir.declare %1 {uniq_name = "_QFtest_char_cmpEcmp"} : (!fir.ref>) -> (!fir.ref>, !fir.ref>) + %3:2 = fir.unboxchar %arg0 : (!fir.boxchar<1>) -> (!fir.ref>, index) + %4:2 = hlfir.declare %3#0 typeparams %3#1 dummy_scope %0 {uniq_name = "_QFtest_char_cmpEx"} : (!fir.ref>, index, !fir.dscope) -> (!fir.boxchar<1>, !fir.ref>) + %5:2 = fir.unboxchar %arg1 : (!fir.boxchar<1>) -> (!fir.ref>, index) + %6:2 = hlfir.declare %5#0 typeparams %5#1 dummy_scope %0 {uniq_name = "_QFtest_char_cmpEy"} : (!fir.ref>, index, !fir.dscope) -> (!fir.boxchar<1>, !fir.ref>) + %7 = hlfir.char_trim %4#0 : (!fir.boxchar<1>) -> !hlfir.expr> + %8 = hlfir.char_trim %6#0 : (!fir.boxchar<1>) -> !hlfir.expr> + %9 = hlfir.cmpchar eq %7 %8 : (!hlfir.expr>, !hlfir.expr>) -> i1 + %10 = fir.convert %9 : (i1) -> !fir.logical<4> + hlfir.assign %10 to %2#0 : !fir.logical<4>, !fir.ref> + hlfir.destroy %8 : !hlfir.expr> + hlfir.destroy %7 : !hlfir.expr> + %11 = fir.load %2#0 : !fir.ref> + return %11 : !fir.logical<4> +} + +// CHECK-LABEL: func.func @_QPtest_char_cmp( +// CHECK-SAME: %[[ARG_0:.*]]: !fir.boxchar<1> {fir.bindc_name = "x"}, +// CHECK-SAME: %[[ARG_1:.*]]: !fir.boxchar<1> {fir.bindc_name = "y"}) -> !fir.logical<4> { +// CHECK: %[[VAL_0:.*]] = fir.dummy_scope : !fir.dscope +// CHECK: %[[VAL_1:.*]] = fir.alloca !fir.logical<4> {bindc_name = "cmp", uniq_name = "_QFtest_char_cmpEcmp"} +// CHECK: %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_1]] {uniq_name = "_QFtest_char_cmpEcmp"} : (!fir.ref>) -> (!fir.ref>, !fir.ref>) +// CHECK: %[[VAL_3:.*]]:2 = fir.unboxchar %[[ARG_0]] : (!fir.boxchar<1>) -> (!fir.ref>, index) +// CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_3]]#0 typeparams %[[VAL_3]]#1 dummy_scope %[[VAL_0]] {uniq_name = "_QFtest_char_cmpEx"} : (!fir.ref>, index, !fir.dscope) -> (!fir.boxchar<1>, !fir.ref>) +// CHECK: %[[VAL_6:.*]]:2 = fir.unboxchar %[[ARG_1]] : (!fir.boxchar<1>) -> (!fir.ref>, index) +// CHECK: %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_6]]#0 typeparams %[[VAL_6]]#1 dummy_scope %[[VAL_0]] {uniq_name = "_QFtest_char_cmpEy"} : (!fir.ref>, index, !fir.dscope) -> (!fir.boxchar<1>, !fir.ref>) +// CHECK: %[[VAL_9:.*]] = hlfir.cmpchar eq %[[VAL_5]]#0 %[[VAL_8]]#0 : (!fir.boxchar<1>, !fir.boxchar<1>) -> i1 +// CHECK: %[[VAL_10:.*]] = fir.convert %[[VAL_9]] : (i1) -> !fir.logical<4> +// CHECK: hlfir.assign %[[VAL_10]] to %[[VAL_2]]#0 : !fir.logical<4>, !fir.ref> +// CHECK: %[[VAL_11:.*]] = fir.load %[[VAL_2]]#0 : !fir.ref> +// CHECK: return %[[VAL_11]] : !fir.logical<4> +// CHECK: } + +// Check that trim() is not removed when its result is stored. + +// logical function test_char_cmp2(x, y) result(res) +// character(*) :: x, y +// character(:), allocatable :: tx +// +// tx = trim(x) +// res = tx == y +// end function + +func.func @_QPtest_char_cmp2(%arg0: !fir.boxchar<1> {fir.bindc_name = "x"}, %arg1: !fir.boxchar<1> {fir.bindc_name = "y"}) -> !fir.logical<4> { + %0 = fir.dummy_scope : !fir.dscope + %1 = fir.alloca !fir.logical<4> {bindc_name = "res", uniq_name = "_QFtest_char_cmp2Eres"} + %2:2 = hlfir.declare %1 {uniq_name = "_QFtest_char_cmp2Eres"} : (!fir.ref>) -> (!fir.ref>, !fir.ref>) + %3 = fir.alloca !fir.box>> {bindc_name = "tx", uniq_name = "_QFtest_char_cmp2Etx"} + %4 = fir.zero_bits !fir.heap> + %c0 = arith.constant 0 : index + %5 = fir.embox %4 typeparams %c0 : (!fir.heap>, index) -> !fir.box>> + fir.store %5 to %3 : !fir.ref>>> + %6:2 = hlfir.declare %3 {fortran_attrs = #fir.var_attrs, uniq_name = "_QFtest_char_cmp2Etx"} : (!fir.ref>>>) -> (!fir.ref>>>, !fir.ref>>>) + %7:2 = fir.unboxchar %arg0 : (!fir.boxchar<1>) -> (!fir.ref>, index) + %8:2 = hlfir.declare %7#0 typeparams %7#1 dummy_scope %0 {uniq_name = "_QFtest_char_cmp2Ex"} : (!fir.ref>, index, !fir.dscope) -> (!fir.boxchar<1>, !fir.ref>) + %9:2 = fir.unboxchar %arg1 : (!fir.boxchar<1>) -> (!fir.ref>, index) + %10:2 = hlfir.declare %9#0 typeparams %9#1 dummy_scope %0 {uniq_name = "_QFtest_char_cmp2Ey"} : (!fir.ref>, index, !fir.dscope) -> (!fir.boxchar<1>, !fir.ref>) + %11 = hlfir.char_trim %8#0 : (!fir.boxchar<1>) -> !hlfir.expr> + hlfir.assign %11 to %6#0 realloc : !hlfir.expr>, !fir.ref>>> + hlfir.destroy %11 : !hlfir.expr> + %12 = fir.load %6#0 : !fir.ref>>> + %13 = fir.box_addr %12 : (!fir.box>>) -> !fir.heap> + %14 = fir.load %6#0 : !fir.ref>>> + %15 = fir.box_elesize %14 : (!fir.box>>) -> index + %16 = fir.emboxchar %13, %15 : (!fir.heap>, index) -> !fir.boxchar<1> + %17 = hlfir.cmpchar eq %16 %10#0 : (!fir.boxchar<1>, !fir.boxchar<1>) -> i1 + %18 = fir.convert %17 : (i1) -> !fir.logical<4> + hlfir.assign %18 to %2#0 : !fir.logical<4>, !fir.ref> + %19 = fir.load %2#0 : !fir.ref> + %20 = fir.load %6#0 : !fir.ref>>> + %21 = fir.box_addr %20 : (!fir.box>>) -> !fir.heap> + %22 = fir.convert %21 : (!fir.heap>) -> i64 + %c0_i64 = arith.constant 0 : i64 + %23 = arith.cmpi ne, %22, %c0_i64 : i64 + fir.if %23 { + %24 = fir.load %6#0 : !fir.ref>>> + %25 = fir.box_addr %24 : (!fir.box>>) -> !fir.heap> + fir.freemem %25 : !fir.heap> + %26 = fir.zero_bits !fir.heap> + %c0_0 = arith.constant 0 : index + %27 = fir.embox %26 typeparams %c0_0 : (!fir.heap>, index) -> !fir.box>> + fir.store %27 to %6#0 : !fir.ref>>> + } + return %19 : !fir.logical<4> +} + +// CHECK-LABEL: func.func @_QPtest_char_cmp2( +// CHECK: hlfir.char_trim