Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions flang/include/flang/Optimizer/HLFIR/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ def SimplifyHLFIRIntrinsics : Pass<"simplify-hlfir-intrinsics"> {
"the hlfir.matmul.">];
}

def ExpressionSimplification
: Pass<"expression-simplification", "::mlir::ModuleOp"> {
let summary = "Simplify Fortran expressions";
}

def InlineElementals : Pass<"inline-elementals"> {
let summary = "Inline chained hlfir.elemental operations";
}
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
add_flang_library(HLFIRTransforms
BufferizeHLFIR.cpp
ConvertToFIR.cpp
ExpressionSimplification.cpp
InlineElementals.cpp
InlineHLFIRAssign.cpp
InlineHLFIRCopyIn.cpp
Expand Down
127 changes: 127 additions & 0 deletions flang/lib/Optimizer/HLFIR/Transforms/ExpressionSimplification.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
//===- ExpressionSimplification.cpp - Simplify HLFIR expressions ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/HLFIR/Passes.h"
#include "llvm/Support/DebugLog.h"

namespace hlfir {
#define GEN_PASS_DEF_EXPRESSIONSIMPLIFICATION
#include "flang/Optimizer/HLFIR/Passes.h.inc"
} // namespace hlfir

#define DEBUG_TYPE "expression-simplification"

static void removeOp(mlir::Operation *op) {
op->dropAllReferences();
op->dropAllUses();
op->erase();
}

// Get the first user of `op`.
// Note that we consider the first user to be the one on the lowest line of
// the emitted HLFIR. The user iterator considers the opposite.
template <typename UserOp>
static UserOp getFirstUser(mlir::Operation *op) {
auto it = op->user_begin(), end = op->user_end(), prev = it;
for (; it != end; prev = it++)
;
if (prev != end)
if (auto userOp = mlir::dyn_cast<UserOp>(*prev))
return userOp;
return {};
}

// Get the last user of `op`.
// Note that we consider the last user to be the one on the highest line of
// the emitted HLFIR. The user iterator considers the opposite.
template <typename UserOp>
static UserOp getLastUser(mlir::Operation *op) {
if (!op->getUsers().empty())
if (auto userOp = mlir::dyn_cast<UserOp>(*op->user_begin()))
return userOp;
return {};
}

namespace {

// This class analyzes a trimmed character and removes the trim operation if
// its result is not used elsewhere.
class TrimRemover {
public:
TrimRemover(mlir::Value charVal) : charVal(charVal) {}
TrimRemover(const TrimRemover &) = delete;

bool charWasTrimmed();
void removeTrim();

private:
mlir::Value charVal;
hlfir::CharTrimOp trimOp;
hlfir::CmpCharOp cmpCharOp;
hlfir::DestroyOp destroyOp;
};

bool TrimRemover::charWasTrimmed() {
LDBG() << "charWasTrimmed: " << charVal;

trimOp = mlir::dyn_cast<hlfir::CharTrimOp>(charVal.getDefiningOp());
if (!trimOp)
return false;
int trimUses = std::distance(trimOp->use_begin(), trimOp->use_end());
cmpCharOp = getFirstUser<hlfir::CmpCharOp>(trimOp);
destroyOp = getLastUser<hlfir::DestroyOp>(trimOp);
return cmpCharOp && destroyOp && trimUses == 2;
}

void TrimRemover::removeTrim() {
LDBG() << "removeTrim: " << trimOp;

cmpCharOp->replaceUsesOfWith(trimOp.getResult(), trimOp.getChr());
removeOp(destroyOp);
removeOp(trimOp);
}

class ExpressionSimplification
: public hlfir::impl::ExpressionSimplificationBase<
ExpressionSimplification> {
public:
using ExpressionSimplificationBase<
ExpressionSimplification>::ExpressionSimplificationBase;

void runOnOperation() override;

private:
// Simplify character comparisons.
// Because character comparison appends spaces to the shorter character,
// calls to trim() that are used only in the comparison can be eliminated.
//
// Example:
// `trim(x) == trim(y)`
// can be simplified to
// `x == y`
void simplifyCmpChar(hlfir::CmpCharOp cmpChar);
};

void ExpressionSimplification::simplifyCmpChar(hlfir::CmpCharOp cmpChar) {
TrimRemover lhsTrimRem(cmpChar.getLchr());
TrimRemover rhsTrimRem(cmpChar.getRchr());

if (lhsTrimRem.charWasTrimmed())
lhsTrimRem.removeTrim();
if (rhsTrimRem.charWasTrimmed())
rhsTrimRem.removeTrim();
}

void ExpressionSimplification::runOnOperation() {
mlir::ModuleOp module = getOperation();
module.walk([&](hlfir::CmpCharOp cmpChar) { simplifyCmpChar(cmpChar); });
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend walking all trim operations instead of the comparisons, e.g. consider the case where the same trim result is used in multiple character comparisons. We may not have such cases now, but the CSE may be able to optimize it in codes like (trim(x) == trim(y)).and.(trim(x) == trim(z)).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks for the suggestion, it simplified the code even further.

At least for now the pass is able to erase all trims in (trim(x) == trim(y)).and.(trim(x) == trim(z)).
But for each trim an hlfir.char_trim is generated, no result is currently reused.

If this becomes the case in the future, we would need to check all users to see if they are composed only of hlfir.cmpchars and one hlfir.destroy. To keep the code simple, I have not made this change.
Let me know if you think it would be better to implement it already.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I missed the notification. I will review now.

}

} // namespace
3 changes: 3 additions & 0 deletions flang/lib/Optimizer/Passes/Pipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,9 @@ void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm,
void createHLFIRToFIRPassPipeline(mlir::PassManager &pm,
EnableOpenMP enableOpenMP,
llvm::OptimizationLevel optLevel) {
if (optLevel.getSizeLevel() > 0 || optLevel.getSpeedupLevel() > 0) {
pm.addPass(hlfir::createExpressionSimplification());
}
if (optLevel.isOptimizingForSpeed()) {
addCanonicalizerPassWithoutRegionSimplification(pm);
addNestedPassToAllTopLevelOperations<PassConstructor>(
Expand Down
1 change: 1 addition & 0 deletions flang/test/Driver/mlir-pass-pipeline.f90
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
! ALL: Pass statistics report

! ALL: Fortran::lower::VerifierPass
! O2-NEXT: ExpressionSimplification
! O2-NEXT: Canonicalizer
! ALL: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
! ALL-NEXT:'fir.global' Pipeline
Expand Down
101 changes: 101 additions & 0 deletions flang/test/HLFIR/expression-simplification.fir
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// RUN: fir-opt %s --expression-simplification | FileCheck %s

// Test removal of trim() calls.

// logical function test_char_cmp(x, y) result(cmp)
// character(*) :: x, y
// cmp = trim(x) == trim(y)
// end function

func.func @_QPtest_char_cmp(%arg0: !fir.boxchar<1> {fir.bindc_name = "x"},
%arg1: !fir.boxchar<1> {fir.bindc_name = "y"}) -> !fir.logical<4> {
%0 = fir.dummy_scope : !fir.dscope
%1 = fir.alloca !fir.logical<4> {bindc_name = "cmp", uniq_name = "_QFtest_char_cmpEcmp"}
%2:2 = hlfir.declare %1 {uniq_name = "_QFtest_char_cmpEcmp"} : (!fir.ref<!fir.logical<4>>) -> (!fir.ref<!fir.logical<4>>, !fir.ref<!fir.logical<4>>)
%3:2 = fir.unboxchar %arg0 : (!fir.boxchar<1>) -> (!fir.ref<!fir.char<1,?>>, index)
%4:2 = hlfir.declare %3#0 typeparams %3#1 dummy_scope %0 {uniq_name = "_QFtest_char_cmpEx"} : (!fir.ref<!fir.char<1,?>>, index, !fir.dscope) -> (!fir.boxchar<1>, !fir.ref<!fir.char<1,?>>)
%5:2 = fir.unboxchar %arg1 : (!fir.boxchar<1>) -> (!fir.ref<!fir.char<1,?>>, index)
%6:2 = hlfir.declare %5#0 typeparams %5#1 dummy_scope %0 {uniq_name = "_QFtest_char_cmpEy"} : (!fir.ref<!fir.char<1,?>>, index, !fir.dscope) -> (!fir.boxchar<1>, !fir.ref<!fir.char<1,?>>)
%7 = hlfir.char_trim %4#0 : (!fir.boxchar<1>) -> !hlfir.expr<!fir.char<1,?>>
%8 = hlfir.char_trim %6#0 : (!fir.boxchar<1>) -> !hlfir.expr<!fir.char<1,?>>
%9 = hlfir.cmpchar eq %7 %8 : (!hlfir.expr<!fir.char<1,?>>, !hlfir.expr<!fir.char<1,?>>) -> i1
%10 = fir.convert %9 : (i1) -> !fir.logical<4>
hlfir.assign %10 to %2#0 : !fir.logical<4>, !fir.ref<!fir.logical<4>>
hlfir.destroy %8 : !hlfir.expr<!fir.char<1,?>>
hlfir.destroy %7 : !hlfir.expr<!fir.char<1,?>>
%11 = fir.load %2#0 : !fir.ref<!fir.logical<4>>
return %11 : !fir.logical<4>
}

// CHECK-LABEL: func.func @_QPtest_char_cmp(
// CHECK-SAME: %[[ARG_0:.*]]: !fir.boxchar<1> {fir.bindc_name = "x"},
// CHECK-SAME: %[[ARG_1:.*]]: !fir.boxchar<1> {fir.bindc_name = "y"}) -> !fir.logical<4> {
// CHECK: %[[VAL_0:.*]] = fir.dummy_scope : !fir.dscope
// CHECK: %[[VAL_1:.*]] = fir.alloca !fir.logical<4> {bindc_name = "cmp", uniq_name = "_QFtest_char_cmpEcmp"}
// CHECK: %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_1]] {uniq_name = "_QFtest_char_cmpEcmp"} : (!fir.ref<!fir.logical<4>>) -> (!fir.ref<!fir.logical<4>>, !fir.ref<!fir.logical<4>>)
// CHECK: %[[VAL_3:.*]]:2 = fir.unboxchar %[[ARG_0]] : (!fir.boxchar<1>) -> (!fir.ref<!fir.char<1,?>>, index)
// CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_3]]#0 typeparams %[[VAL_3]]#1 dummy_scope %[[VAL_0]] {uniq_name = "_QFtest_char_cmpEx"} : (!fir.ref<!fir.char<1,?>>, index, !fir.dscope) -> (!fir.boxchar<1>, !fir.ref<!fir.char<1,?>>)
// CHECK: %[[VAL_6:.*]]:2 = fir.unboxchar %[[ARG_1]] : (!fir.boxchar<1>) -> (!fir.ref<!fir.char<1,?>>, index)
// CHECK: %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_6]]#0 typeparams %[[VAL_6]]#1 dummy_scope %[[VAL_0]] {uniq_name = "_QFtest_char_cmpEy"} : (!fir.ref<!fir.char<1,?>>, index, !fir.dscope) -> (!fir.boxchar<1>, !fir.ref<!fir.char<1,?>>)
// CHECK: %[[VAL_9:.*]] = hlfir.cmpchar eq %[[VAL_5]]#0 %[[VAL_8]]#0 : (!fir.boxchar<1>, !fir.boxchar<1>) -> i1
// CHECK: %[[VAL_10:.*]] = fir.convert %[[VAL_9]] : (i1) -> !fir.logical<4>
// CHECK: hlfir.assign %[[VAL_10]] to %[[VAL_2]]#0 : !fir.logical<4>, !fir.ref<!fir.logical<4>>
// CHECK: %[[VAL_11:.*]] = fir.load %[[VAL_2]]#0 : !fir.ref<!fir.logical<4>>
// CHECK: return %[[VAL_11]] : !fir.logical<4>
// CHECK: }

// Check that trim() is not removed when its result is stored.

// logical function test_char_cmp2(x, y) result(res)
// character(*) :: x, y
// character(:), allocatable :: tx
//
// tx = trim(x)
// res = tx == y
// end function

func.func @_QPtest_char_cmp2(%arg0: !fir.boxchar<1> {fir.bindc_name = "x"}, %arg1: !fir.boxchar<1> {fir.bindc_name = "y"}) -> !fir.logical<4> {
%0 = fir.dummy_scope : !fir.dscope
%1 = fir.alloca !fir.logical<4> {bindc_name = "res", uniq_name = "_QFtest_char_cmp2Eres"}
%2:2 = hlfir.declare %1 {uniq_name = "_QFtest_char_cmp2Eres"} : (!fir.ref<!fir.logical<4>>) -> (!fir.ref<!fir.logical<4>>, !fir.ref<!fir.logical<4>>)
%3 = fir.alloca !fir.box<!fir.heap<!fir.char<1,?>>> {bindc_name = "tx", uniq_name = "_QFtest_char_cmp2Etx"}
%4 = fir.zero_bits !fir.heap<!fir.char<1,?>>
%c0 = arith.constant 0 : index
%5 = fir.embox %4 typeparams %c0 : (!fir.heap<!fir.char<1,?>>, index) -> !fir.box<!fir.heap<!fir.char<1,?>>>
fir.store %5 to %3 : !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
%6:2 = hlfir.declare %3 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtest_char_cmp2Etx"} : (!fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>, !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>)
%7:2 = fir.unboxchar %arg0 : (!fir.boxchar<1>) -> (!fir.ref<!fir.char<1,?>>, index)
%8:2 = hlfir.declare %7#0 typeparams %7#1 dummy_scope %0 {uniq_name = "_QFtest_char_cmp2Ex"} : (!fir.ref<!fir.char<1,?>>, index, !fir.dscope) -> (!fir.boxchar<1>, !fir.ref<!fir.char<1,?>>)
%9:2 = fir.unboxchar %arg1 : (!fir.boxchar<1>) -> (!fir.ref<!fir.char<1,?>>, index)
%10:2 = hlfir.declare %9#0 typeparams %9#1 dummy_scope %0 {uniq_name = "_QFtest_char_cmp2Ey"} : (!fir.ref<!fir.char<1,?>>, index, !fir.dscope) -> (!fir.boxchar<1>, !fir.ref<!fir.char<1,?>>)
%11 = hlfir.char_trim %8#0 : (!fir.boxchar<1>) -> !hlfir.expr<!fir.char<1,?>>
hlfir.assign %11 to %6#0 realloc : !hlfir.expr<!fir.char<1,?>>, !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
hlfir.destroy %11 : !hlfir.expr<!fir.char<1,?>>
%12 = fir.load %6#0 : !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
%13 = fir.box_addr %12 : (!fir.box<!fir.heap<!fir.char<1,?>>>) -> !fir.heap<!fir.char<1,?>>
%14 = fir.load %6#0 : !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
%15 = fir.box_elesize %14 : (!fir.box<!fir.heap<!fir.char<1,?>>>) -> index
%16 = fir.emboxchar %13, %15 : (!fir.heap<!fir.char<1,?>>, index) -> !fir.boxchar<1>
%17 = hlfir.cmpchar eq %16 %10#0 : (!fir.boxchar<1>, !fir.boxchar<1>) -> i1
%18 = fir.convert %17 : (i1) -> !fir.logical<4>
hlfir.assign %18 to %2#0 : !fir.logical<4>, !fir.ref<!fir.logical<4>>
%19 = fir.load %2#0 : !fir.ref<!fir.logical<4>>
%20 = fir.load %6#0 : !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
%21 = fir.box_addr %20 : (!fir.box<!fir.heap<!fir.char<1,?>>>) -> !fir.heap<!fir.char<1,?>>
%22 = fir.convert %21 : (!fir.heap<!fir.char<1,?>>) -> i64
%c0_i64 = arith.constant 0 : i64
%23 = arith.cmpi ne, %22, %c0_i64 : i64
fir.if %23 {
%24 = fir.load %6#0 : !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
%25 = fir.box_addr %24 : (!fir.box<!fir.heap<!fir.char<1,?>>>) -> !fir.heap<!fir.char<1,?>>
fir.freemem %25 : !fir.heap<!fir.char<1,?>>
%26 = fir.zero_bits !fir.heap<!fir.char<1,?>>
%c0_0 = arith.constant 0 : index
%27 = fir.embox %26 typeparams %c0_0 : (!fir.heap<!fir.char<1,?>>, index) -> !fir.box<!fir.heap<!fir.char<1,?>>>
fir.store %27 to %6#0 : !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
}
return %19 : !fir.logical<4>
}

// CHECK-LABEL: func.func @_QPtest_char_cmp2(
// CHECK: hlfir.cmpchar eq