Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions flang/include/flang/Optimizer/HLFIR/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ def SimplifyHLFIRIntrinsics : Pass<"simplify-hlfir-intrinsics"> {
"the hlfir.matmul.">];
}

def ExpressionSimplification : Pass<"hlfir-expression-simplification"> {
let summary = "Simplify Fortran expressions";
}

def InlineElementals : Pass<"inline-elementals"> {
let summary = "Inline chained hlfir.elemental operations";
}
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
add_flang_library(HLFIRTransforms
BufferizeHLFIR.cpp
ConvertToFIR.cpp
ExpressionSimplification.cpp
InlineElementals.cpp
InlineHLFIRAssign.cpp
InlineHLFIRCopyIn.cpp
Expand Down
99 changes: 99 additions & 0 deletions flang/lib/Optimizer/HLFIR/Transforms/ExpressionSimplification.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
//===- ExpressionSimplification.cpp - Simplify HLFIR expressions ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/HLFIR/Passes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace hlfir {
#define GEN_PASS_DEF_EXPRESSIONSIMPLIFICATION
#include "flang/Optimizer/HLFIR/Passes.h.inc"
} // namespace hlfir

// Get the first user of `op`.
// Note that we consider the first user to be the one on the lowest line of
// the emitted HLFIR. The user iterator considers the opposite.
template <typename UserOp>
static UserOp getFirstUser(mlir::Operation *op) {
auto it = op->user_begin(), end = op->user_end(), prev = it;
for (; it != end; prev = it++)
;
if (prev != end)
if (auto userOp = mlir::dyn_cast<UserOp>(*prev))
return userOp;
return {};
}

// Get the last user of `op`.
// Note that we consider the last user to be the one on the highest line of
// the emitted HLFIR. The user iterator considers the opposite.
template <typename UserOp>
static UserOp getLastUser(mlir::Operation *op) {
if (!op->getUsers().empty())
if (auto userOp = mlir::dyn_cast<UserOp>(*op->user_begin()))
return userOp;
return {};
}

namespace {

// Trim operations can be erased in certain expressions, such as character
// comparisons.
// Since a character comparison appends spaces to the shorter character,
// calls to trim() that are used only in the comparison can be eliminated.
//
// Example:
// `trim(x) == trim(y)`
// can be simplified to
// `x == y`
class EraseTrim : public mlir::OpRewritePattern<hlfir::CharTrimOp> {
public:
using mlir::OpRewritePattern<hlfir::CharTrimOp>::OpRewritePattern;

llvm::LogicalResult
matchAndRewrite(hlfir::CharTrimOp trimOp,
mlir::PatternRewriter &rewriter) const override {
int trimUses = std::distance(trimOp->use_begin(), trimOp->use_end());
auto cmpCharOp = getFirstUser<hlfir::CmpCharOp>(trimOp);
auto destroyOp = getLastUser<hlfir::DestroyOp>(trimOp);
if (!cmpCharOp || !destroyOp || trimUses != 2)
return rewriter.notifyMatchFailure(
trimOp, "hlfir.char_trim is not used (only) by hlfir.cmpchar");

rewriter.eraseOp(destroyOp);
rewriter.replaceOp(trimOp, trimOp.getChr());
return mlir::success();
}
};

class ExpressionSimplificationPass
: public hlfir::impl::ExpressionSimplificationBase<
ExpressionSimplificationPass> {
public:
void runOnOperation() override {
mlir::MLIRContext *context = &getContext();

mlir::GreedyRewriteConfig config;
// Prevent the pattern driver from merging blocks.
config.setRegionSimplificationLevel(
mlir::GreedySimplifyRegionLevel::Disabled);

mlir::RewritePatternSet patterns(context);
patterns.insert<EraseTrim>(context);

if (mlir::failed(mlir::applyPatternsGreedily(
getOperation(), std::move(patterns), config))) {
mlir::emitError(getOperation()->getLoc(),
"failure in HLFIR expression simplification");
signalPassFailure();
}
}
};

} // namespace
4 changes: 4 additions & 0 deletions flang/lib/Optimizer/Passes/Pipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,10 @@ void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm,
void createHLFIRToFIRPassPipeline(mlir::PassManager &pm,
EnableOpenMP enableOpenMP,
llvm::OptimizationLevel optLevel) {
if (optLevel.getSizeLevel() > 0 || optLevel.getSpeedupLevel() > 0) {
addNestedPassToAllTopLevelOperations<PassConstructor>(
pm, hlfir::createExpressionSimplification);
}
if (optLevel.isOptimizingForSpeed()) {
addCanonicalizerPassWithoutRegionSimplification(pm);
addNestedPassToAllTopLevelOperations<PassConstructor>(
Expand Down
9 changes: 9 additions & 0 deletions flang/test/Driver/mlir-pass-pipeline.f90
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@
! ALL: Pass statistics report

! ALL: Fortran::lower::VerifierPass
! O2-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
! O2-NEXT: 'fir.global' Pipeline
! O2-NEXT: ExpressionSimplification
! O2-NEXT: 'func.func' Pipeline
! O2-NEXT: ExpressionSimplification
! O2-NEXT: 'omp.declare_reduction' Pipeline
! O2-NEXT: ExpressionSimplification
! O2-NEXT: 'omp.private' Pipeline
! O2-NEXT: ExpressionSimplification
! O2-NEXT: Canonicalizer
! ALL: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
! ALL-NEXT:'fir.global' Pipeline
Expand Down
101 changes: 101 additions & 0 deletions flang/test/HLFIR/expression-simplification.fir
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// RUN: fir-opt %s --hlfir-expression-simplification | FileCheck %s

// Test removal of trim() calls.

// logical function test_char_cmp(x, y) result(cmp)
// character(*) :: x, y
// cmp = trim(x) == trim(y)
// end function

func.func @_QPtest_char_cmp(%arg0: !fir.boxchar<1> {fir.bindc_name = "x"},
%arg1: !fir.boxchar<1> {fir.bindc_name = "y"}) -> !fir.logical<4> {
%0 = fir.dummy_scope : !fir.dscope
%1 = fir.alloca !fir.logical<4> {bindc_name = "cmp", uniq_name = "_QFtest_char_cmpEcmp"}
%2:2 = hlfir.declare %1 {uniq_name = "_QFtest_char_cmpEcmp"} : (!fir.ref<!fir.logical<4>>) -> (!fir.ref<!fir.logical<4>>, !fir.ref<!fir.logical<4>>)
%3:2 = fir.unboxchar %arg0 : (!fir.boxchar<1>) -> (!fir.ref<!fir.char<1,?>>, index)
%4:2 = hlfir.declare %3#0 typeparams %3#1 dummy_scope %0 {uniq_name = "_QFtest_char_cmpEx"} : (!fir.ref<!fir.char<1,?>>, index, !fir.dscope) -> (!fir.boxchar<1>, !fir.ref<!fir.char<1,?>>)
%5:2 = fir.unboxchar %arg1 : (!fir.boxchar<1>) -> (!fir.ref<!fir.char<1,?>>, index)
%6:2 = hlfir.declare %5#0 typeparams %5#1 dummy_scope %0 {uniq_name = "_QFtest_char_cmpEy"} : (!fir.ref<!fir.char<1,?>>, index, !fir.dscope) -> (!fir.boxchar<1>, !fir.ref<!fir.char<1,?>>)
%7 = hlfir.char_trim %4#0 : (!fir.boxchar<1>) -> !hlfir.expr<!fir.char<1,?>>
%8 = hlfir.char_trim %6#0 : (!fir.boxchar<1>) -> !hlfir.expr<!fir.char<1,?>>
%9 = hlfir.cmpchar eq %7 %8 : (!hlfir.expr<!fir.char<1,?>>, !hlfir.expr<!fir.char<1,?>>) -> i1
%10 = fir.convert %9 : (i1) -> !fir.logical<4>
hlfir.assign %10 to %2#0 : !fir.logical<4>, !fir.ref<!fir.logical<4>>
hlfir.destroy %8 : !hlfir.expr<!fir.char<1,?>>
hlfir.destroy %7 : !hlfir.expr<!fir.char<1,?>>
%11 = fir.load %2#0 : !fir.ref<!fir.logical<4>>
return %11 : !fir.logical<4>
}

// CHECK-LABEL: func.func @_QPtest_char_cmp(
// CHECK-SAME: %[[ARG_0:.*]]: !fir.boxchar<1> {fir.bindc_name = "x"},
// CHECK-SAME: %[[ARG_1:.*]]: !fir.boxchar<1> {fir.bindc_name = "y"}) -> !fir.logical<4> {
// CHECK: %[[VAL_0:.*]] = fir.dummy_scope : !fir.dscope
// CHECK: %[[VAL_1:.*]] = fir.alloca !fir.logical<4> {bindc_name = "cmp", uniq_name = "_QFtest_char_cmpEcmp"}
// CHECK: %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_1]] {uniq_name = "_QFtest_char_cmpEcmp"} : (!fir.ref<!fir.logical<4>>) -> (!fir.ref<!fir.logical<4>>, !fir.ref<!fir.logical<4>>)
// CHECK: %[[VAL_3:.*]]:2 = fir.unboxchar %[[ARG_0]] : (!fir.boxchar<1>) -> (!fir.ref<!fir.char<1,?>>, index)
// CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_3]]#0 typeparams %[[VAL_3]]#1 dummy_scope %[[VAL_0]] {uniq_name = "_QFtest_char_cmpEx"} : (!fir.ref<!fir.char<1,?>>, index, !fir.dscope) -> (!fir.boxchar<1>, !fir.ref<!fir.char<1,?>>)
// CHECK: %[[VAL_6:.*]]:2 = fir.unboxchar %[[ARG_1]] : (!fir.boxchar<1>) -> (!fir.ref<!fir.char<1,?>>, index)
// CHECK: %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_6]]#0 typeparams %[[VAL_6]]#1 dummy_scope %[[VAL_0]] {uniq_name = "_QFtest_char_cmpEy"} : (!fir.ref<!fir.char<1,?>>, index, !fir.dscope) -> (!fir.boxchar<1>, !fir.ref<!fir.char<1,?>>)
// CHECK: %[[VAL_9:.*]] = hlfir.cmpchar eq %[[VAL_5]]#0 %[[VAL_8]]#0 : (!fir.boxchar<1>, !fir.boxchar<1>) -> i1
// CHECK: %[[VAL_10:.*]] = fir.convert %[[VAL_9]] : (i1) -> !fir.logical<4>
// CHECK: hlfir.assign %[[VAL_10]] to %[[VAL_2]]#0 : !fir.logical<4>, !fir.ref<!fir.logical<4>>
// CHECK: %[[VAL_11:.*]] = fir.load %[[VAL_2]]#0 : !fir.ref<!fir.logical<4>>
// CHECK: return %[[VAL_11]] : !fir.logical<4>
// CHECK: }

// Check that trim() is not removed when its result is stored.

// logical function test_char_cmp2(x, y) result(res)
// character(*) :: x, y
// character(:), allocatable :: tx
//
// tx = trim(x)
// res = tx == y
// end function

func.func @_QPtest_char_cmp2(%arg0: !fir.boxchar<1> {fir.bindc_name = "x"}, %arg1: !fir.boxchar<1> {fir.bindc_name = "y"}) -> !fir.logical<4> {
%0 = fir.dummy_scope : !fir.dscope
%1 = fir.alloca !fir.logical<4> {bindc_name = "res", uniq_name = "_QFtest_char_cmp2Eres"}
%2:2 = hlfir.declare %1 {uniq_name = "_QFtest_char_cmp2Eres"} : (!fir.ref<!fir.logical<4>>) -> (!fir.ref<!fir.logical<4>>, !fir.ref<!fir.logical<4>>)
%3 = fir.alloca !fir.box<!fir.heap<!fir.char<1,?>>> {bindc_name = "tx", uniq_name = "_QFtest_char_cmp2Etx"}
%4 = fir.zero_bits !fir.heap<!fir.char<1,?>>
%c0 = arith.constant 0 : index
%5 = fir.embox %4 typeparams %c0 : (!fir.heap<!fir.char<1,?>>, index) -> !fir.box<!fir.heap<!fir.char<1,?>>>
fir.store %5 to %3 : !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
%6:2 = hlfir.declare %3 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtest_char_cmp2Etx"} : (!fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>, !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>)
%7:2 = fir.unboxchar %arg0 : (!fir.boxchar<1>) -> (!fir.ref<!fir.char<1,?>>, index)
%8:2 = hlfir.declare %7#0 typeparams %7#1 dummy_scope %0 {uniq_name = "_QFtest_char_cmp2Ex"} : (!fir.ref<!fir.char<1,?>>, index, !fir.dscope) -> (!fir.boxchar<1>, !fir.ref<!fir.char<1,?>>)
%9:2 = fir.unboxchar %arg1 : (!fir.boxchar<1>) -> (!fir.ref<!fir.char<1,?>>, index)
%10:2 = hlfir.declare %9#0 typeparams %9#1 dummy_scope %0 {uniq_name = "_QFtest_char_cmp2Ey"} : (!fir.ref<!fir.char<1,?>>, index, !fir.dscope) -> (!fir.boxchar<1>, !fir.ref<!fir.char<1,?>>)
%11 = hlfir.char_trim %8#0 : (!fir.boxchar<1>) -> !hlfir.expr<!fir.char<1,?>>
hlfir.assign %11 to %6#0 realloc : !hlfir.expr<!fir.char<1,?>>, !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
hlfir.destroy %11 : !hlfir.expr<!fir.char<1,?>>
%12 = fir.load %6#0 : !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
%13 = fir.box_addr %12 : (!fir.box<!fir.heap<!fir.char<1,?>>>) -> !fir.heap<!fir.char<1,?>>
%14 = fir.load %6#0 : !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
%15 = fir.box_elesize %14 : (!fir.box<!fir.heap<!fir.char<1,?>>>) -> index
%16 = fir.emboxchar %13, %15 : (!fir.heap<!fir.char<1,?>>, index) -> !fir.boxchar<1>
%17 = hlfir.cmpchar eq %16 %10#0 : (!fir.boxchar<1>, !fir.boxchar<1>) -> i1
%18 = fir.convert %17 : (i1) -> !fir.logical<4>
hlfir.assign %18 to %2#0 : !fir.logical<4>, !fir.ref<!fir.logical<4>>
%19 = fir.load %2#0 : !fir.ref<!fir.logical<4>>
%20 = fir.load %6#0 : !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
%21 = fir.box_addr %20 : (!fir.box<!fir.heap<!fir.char<1,?>>>) -> !fir.heap<!fir.char<1,?>>
%22 = fir.convert %21 : (!fir.heap<!fir.char<1,?>>) -> i64
%c0_i64 = arith.constant 0 : i64
%23 = arith.cmpi ne, %22, %c0_i64 : i64
fir.if %23 {
%24 = fir.load %6#0 : !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
%25 = fir.box_addr %24 : (!fir.box<!fir.heap<!fir.char<1,?>>>) -> !fir.heap<!fir.char<1,?>>
fir.freemem %25 : !fir.heap<!fir.char<1,?>>
%26 = fir.zero_bits !fir.heap<!fir.char<1,?>>
%c0_0 = arith.constant 0 : index
%27 = fir.embox %26 typeparams %c0_0 : (!fir.heap<!fir.char<1,?>>, index) -> !fir.box<!fir.heap<!fir.char<1,?>>>
fir.store %27 to %6#0 : !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
}
return %19 : !fir.logical<4>
}

// CHECK-LABEL: func.func @_QPtest_char_cmp2(
// CHECK: hlfir.char_trim