Skip to content

Commit 8d8bd0a

Browse files
authored
[flang] Simplify the comparison of characters (#154593)
Because character comparison appends spaces to the shorter character, calls to trim() that are used only in the comparison can be eliminated. Example: `trim(x) == trim(y)` can be simplified to `x == y` This makes 527.cam4_r about 3% faster, measured on Neoverse V2. This patch implements the optimization above in a new pass: ExpressionSimplification. Although no other expression simplifications are planned at the moment, they could be easily added to the new pass. The ExpressionSimplification pass runs early in the HLFIR pipeline, to make it easy to identify expressions before any transformations occur.
1 parent 010f96a commit 8d8bd0a

File tree

6 files changed

+218
-0
lines changed

6 files changed

+218
-0
lines changed

flang/include/flang/Optimizer/HLFIR/Passes.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ def SimplifyHLFIRIntrinsics : Pass<"simplify-hlfir-intrinsics"> {
6161
"the hlfir.matmul.">];
6262
}
6363

64+
def ExpressionSimplification : Pass<"hlfir-expression-simplification"> {
65+
let summary = "Simplify Fortran expressions";
66+
}
67+
6468
def InlineElementals : Pass<"inline-elementals"> {
6569
let summary = "Inline chained hlfir.elemental operations";
6670
}

flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
33
add_flang_library(HLFIRTransforms
44
BufferizeHLFIR.cpp
55
ConvertToFIR.cpp
6+
ExpressionSimplification.cpp
67
InlineElementals.cpp
78
InlineHLFIRAssign.cpp
89
InlineHLFIRCopyIn.cpp
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
//===- ExpressionSimplification.cpp - Simplify HLFIR expressions ----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "flang/Optimizer/Builder/FIRBuilder.h"
10+
#include "flang/Optimizer/HLFIR/HLFIROps.h"
11+
#include "flang/Optimizer/HLFIR/Passes.h"
12+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
13+
14+
namespace hlfir {
15+
#define GEN_PASS_DEF_EXPRESSIONSIMPLIFICATION
16+
#include "flang/Optimizer/HLFIR/Passes.h.inc"
17+
} // namespace hlfir
18+
19+
// Get the first user of `op`.
20+
// Note that we consider the first user to be the one on the lowest line of
21+
// the emitted HLFIR. The user iterator considers the opposite.
22+
template <typename UserOp>
23+
static UserOp getFirstUser(mlir::Operation *op) {
24+
auto it = op->user_begin(), end = op->user_end(), prev = it;
25+
for (; it != end; prev = it++)
26+
;
27+
if (prev != end)
28+
if (auto userOp = mlir::dyn_cast<UserOp>(*prev))
29+
return userOp;
30+
return {};
31+
}
32+
33+
// Get the last user of `op`.
34+
// Note that we consider the last user to be the one on the highest line of
35+
// the emitted HLFIR. The user iterator considers the opposite.
36+
template <typename UserOp>
37+
static UserOp getLastUser(mlir::Operation *op) {
38+
if (!op->getUsers().empty())
39+
if (auto userOp = mlir::dyn_cast<UserOp>(*op->user_begin()))
40+
return userOp;
41+
return {};
42+
}
43+
44+
namespace {
45+
46+
// Trim operations can be erased in certain expressions, such as character
47+
// comparisons.
48+
// Since a character comparison appends spaces to the shorter character,
49+
// calls to trim() that are used only in the comparison can be eliminated.
50+
//
51+
// Example:
52+
// `trim(x) == trim(y)`
53+
// can be simplified to
54+
// `x == y`
55+
class EraseTrim : public mlir::OpRewritePattern<hlfir::CharTrimOp> {
56+
public:
57+
using mlir::OpRewritePattern<hlfir::CharTrimOp>::OpRewritePattern;
58+
59+
llvm::LogicalResult
60+
matchAndRewrite(hlfir::CharTrimOp trimOp,
61+
mlir::PatternRewriter &rewriter) const override {
62+
int trimUses = std::distance(trimOp->use_begin(), trimOp->use_end());
63+
auto cmpCharOp = getFirstUser<hlfir::CmpCharOp>(trimOp);
64+
auto destroyOp = getLastUser<hlfir::DestroyOp>(trimOp);
65+
if (!cmpCharOp || !destroyOp || trimUses != 2)
66+
return rewriter.notifyMatchFailure(
67+
trimOp, "hlfir.char_trim is not used (only) by hlfir.cmpchar");
68+
69+
rewriter.eraseOp(destroyOp);
70+
rewriter.replaceOp(trimOp, trimOp.getChr());
71+
return mlir::success();
72+
}
73+
};
74+
75+
class ExpressionSimplificationPass
76+
: public hlfir::impl::ExpressionSimplificationBase<
77+
ExpressionSimplificationPass> {
78+
public:
79+
void runOnOperation() override {
80+
mlir::MLIRContext *context = &getContext();
81+
82+
mlir::GreedyRewriteConfig config;
83+
// Prevent the pattern driver from merging blocks.
84+
config.setRegionSimplificationLevel(
85+
mlir::GreedySimplifyRegionLevel::Disabled);
86+
87+
mlir::RewritePatternSet patterns(context);
88+
patterns.insert<EraseTrim>(context);
89+
90+
if (mlir::failed(mlir::applyPatternsGreedily(
91+
getOperation(), std::move(patterns), config))) {
92+
mlir::emitError(getOperation()->getLoc(),
93+
"failure in HLFIR expression simplification");
94+
signalPassFailure();
95+
}
96+
}
97+
};
98+
99+
} // namespace

flang/lib/Optimizer/Passes/Pipelines.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,10 @@ void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm,
245245
void createHLFIRToFIRPassPipeline(mlir::PassManager &pm,
246246
EnableOpenMP enableOpenMP,
247247
llvm::OptimizationLevel optLevel) {
248+
if (optLevel.getSizeLevel() > 0 || optLevel.getSpeedupLevel() > 0) {
249+
addNestedPassToAllTopLevelOperations<PassConstructor>(
250+
pm, hlfir::createExpressionSimplification);
251+
}
248252
if (optLevel.isOptimizingForSpeed()) {
249253
addCanonicalizerPassWithoutRegionSimplification(pm);
250254
addNestedPassToAllTopLevelOperations<PassConstructor>(

flang/test/Driver/mlir-pass-pipeline.f90

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,15 @@
1515
! ALL: Pass statistics report
1616

1717
! ALL: Fortran::lower::VerifierPass
18+
! O2-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
19+
! O2-NEXT: 'fir.global' Pipeline
20+
! O2-NEXT: ExpressionSimplification
21+
! O2-NEXT: 'func.func' Pipeline
22+
! O2-NEXT: ExpressionSimplification
23+
! O2-NEXT: 'omp.declare_reduction' Pipeline
24+
! O2-NEXT: ExpressionSimplification
25+
! O2-NEXT: 'omp.private' Pipeline
26+
! O2-NEXT: ExpressionSimplification
1827
! O2-NEXT: Canonicalizer
1928
! ALL: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
2029
! ALL-NEXT:'fir.global' Pipeline
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
// RUN: fir-opt %s --hlfir-expression-simplification | FileCheck %s
2+
3+
// Test removal of trim() calls.
4+
5+
// logical function test_char_cmp(x, y) result(cmp)
6+
// character(*) :: x, y
7+
// cmp = trim(x) == trim(y)
8+
// end function
9+
10+
func.func @_QPtest_char_cmp(%arg0: !fir.boxchar<1> {fir.bindc_name = "x"},
11+
%arg1: !fir.boxchar<1> {fir.bindc_name = "y"}) -> !fir.logical<4> {
12+
%0 = fir.dummy_scope : !fir.dscope
13+
%1 = fir.alloca !fir.logical<4> {bindc_name = "cmp", uniq_name = "_QFtest_char_cmpEcmp"}
14+
%2:2 = hlfir.declare %1 {uniq_name = "_QFtest_char_cmpEcmp"} : (!fir.ref<!fir.logical<4>>) -> (!fir.ref<!fir.logical<4>>, !fir.ref<!fir.logical<4>>)
15+
%3:2 = fir.unboxchar %arg0 : (!fir.boxchar<1>) -> (!fir.ref<!fir.char<1,?>>, index)
16+
%4:2 = hlfir.declare %3#0 typeparams %3#1 dummy_scope %0 {uniq_name = "_QFtest_char_cmpEx"} : (!fir.ref<!fir.char<1,?>>, index, !fir.dscope) -> (!fir.boxchar<1>, !fir.ref<!fir.char<1,?>>)
17+
%5:2 = fir.unboxchar %arg1 : (!fir.boxchar<1>) -> (!fir.ref<!fir.char<1,?>>, index)
18+
%6:2 = hlfir.declare %5#0 typeparams %5#1 dummy_scope %0 {uniq_name = "_QFtest_char_cmpEy"} : (!fir.ref<!fir.char<1,?>>, index, !fir.dscope) -> (!fir.boxchar<1>, !fir.ref<!fir.char<1,?>>)
19+
%7 = hlfir.char_trim %4#0 : (!fir.boxchar<1>) -> !hlfir.expr<!fir.char<1,?>>
20+
%8 = hlfir.char_trim %6#0 : (!fir.boxchar<1>) -> !hlfir.expr<!fir.char<1,?>>
21+
%9 = hlfir.cmpchar eq %7 %8 : (!hlfir.expr<!fir.char<1,?>>, !hlfir.expr<!fir.char<1,?>>) -> i1
22+
%10 = fir.convert %9 : (i1) -> !fir.logical<4>
23+
hlfir.assign %10 to %2#0 : !fir.logical<4>, !fir.ref<!fir.logical<4>>
24+
hlfir.destroy %8 : !hlfir.expr<!fir.char<1,?>>
25+
hlfir.destroy %7 : !hlfir.expr<!fir.char<1,?>>
26+
%11 = fir.load %2#0 : !fir.ref<!fir.logical<4>>
27+
return %11 : !fir.logical<4>
28+
}
29+
30+
// CHECK-LABEL: func.func @_QPtest_char_cmp(
31+
// CHECK-SAME: %[[ARG_0:.*]]: !fir.boxchar<1> {fir.bindc_name = "x"},
32+
// CHECK-SAME: %[[ARG_1:.*]]: !fir.boxchar<1> {fir.bindc_name = "y"}) -> !fir.logical<4> {
33+
// CHECK: %[[VAL_0:.*]] = fir.dummy_scope : !fir.dscope
34+
// CHECK: %[[VAL_1:.*]] = fir.alloca !fir.logical<4> {bindc_name = "cmp", uniq_name = "_QFtest_char_cmpEcmp"}
35+
// CHECK: %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_1]] {uniq_name = "_QFtest_char_cmpEcmp"} : (!fir.ref<!fir.logical<4>>) -> (!fir.ref<!fir.logical<4>>, !fir.ref<!fir.logical<4>>)
36+
// CHECK: %[[VAL_3:.*]]:2 = fir.unboxchar %[[ARG_0]] : (!fir.boxchar<1>) -> (!fir.ref<!fir.char<1,?>>, index)
37+
// CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_3]]#0 typeparams %[[VAL_3]]#1 dummy_scope %[[VAL_0]] {uniq_name = "_QFtest_char_cmpEx"} : (!fir.ref<!fir.char<1,?>>, index, !fir.dscope) -> (!fir.boxchar<1>, !fir.ref<!fir.char<1,?>>)
38+
// CHECK: %[[VAL_6:.*]]:2 = fir.unboxchar %[[ARG_1]] : (!fir.boxchar<1>) -> (!fir.ref<!fir.char<1,?>>, index)
39+
// CHECK: %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_6]]#0 typeparams %[[VAL_6]]#1 dummy_scope %[[VAL_0]] {uniq_name = "_QFtest_char_cmpEy"} : (!fir.ref<!fir.char<1,?>>, index, !fir.dscope) -> (!fir.boxchar<1>, !fir.ref<!fir.char<1,?>>)
40+
// CHECK: %[[VAL_9:.*]] = hlfir.cmpchar eq %[[VAL_5]]#0 %[[VAL_8]]#0 : (!fir.boxchar<1>, !fir.boxchar<1>) -> i1
41+
// CHECK: %[[VAL_10:.*]] = fir.convert %[[VAL_9]] : (i1) -> !fir.logical<4>
42+
// CHECK: hlfir.assign %[[VAL_10]] to %[[VAL_2]]#0 : !fir.logical<4>, !fir.ref<!fir.logical<4>>
43+
// CHECK: %[[VAL_11:.*]] = fir.load %[[VAL_2]]#0 : !fir.ref<!fir.logical<4>>
44+
// CHECK: return %[[VAL_11]] : !fir.logical<4>
45+
// CHECK: }
46+
47+
// Check that trim() is not removed when its result is stored.
48+
49+
// logical function test_char_cmp2(x, y) result(res)
50+
// character(*) :: x, y
51+
// character(:), allocatable :: tx
52+
//
53+
// tx = trim(x)
54+
// res = tx == y
55+
// end function
56+
57+
func.func @_QPtest_char_cmp2(%arg0: !fir.boxchar<1> {fir.bindc_name = "x"}, %arg1: !fir.boxchar<1> {fir.bindc_name = "y"}) -> !fir.logical<4> {
58+
%0 = fir.dummy_scope : !fir.dscope
59+
%1 = fir.alloca !fir.logical<4> {bindc_name = "res", uniq_name = "_QFtest_char_cmp2Eres"}
60+
%2:2 = hlfir.declare %1 {uniq_name = "_QFtest_char_cmp2Eres"} : (!fir.ref<!fir.logical<4>>) -> (!fir.ref<!fir.logical<4>>, !fir.ref<!fir.logical<4>>)
61+
%3 = fir.alloca !fir.box<!fir.heap<!fir.char<1,?>>> {bindc_name = "tx", uniq_name = "_QFtest_char_cmp2Etx"}
62+
%4 = fir.zero_bits !fir.heap<!fir.char<1,?>>
63+
%c0 = arith.constant 0 : index
64+
%5 = fir.embox %4 typeparams %c0 : (!fir.heap<!fir.char<1,?>>, index) -> !fir.box<!fir.heap<!fir.char<1,?>>>
65+
fir.store %5 to %3 : !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
66+
%6:2 = hlfir.declare %3 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtest_char_cmp2Etx"} : (!fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>, !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>)
67+
%7:2 = fir.unboxchar %arg0 : (!fir.boxchar<1>) -> (!fir.ref<!fir.char<1,?>>, index)
68+
%8:2 = hlfir.declare %7#0 typeparams %7#1 dummy_scope %0 {uniq_name = "_QFtest_char_cmp2Ex"} : (!fir.ref<!fir.char<1,?>>, index, !fir.dscope) -> (!fir.boxchar<1>, !fir.ref<!fir.char<1,?>>)
69+
%9:2 = fir.unboxchar %arg1 : (!fir.boxchar<1>) -> (!fir.ref<!fir.char<1,?>>, index)
70+
%10:2 = hlfir.declare %9#0 typeparams %9#1 dummy_scope %0 {uniq_name = "_QFtest_char_cmp2Ey"} : (!fir.ref<!fir.char<1,?>>, index, !fir.dscope) -> (!fir.boxchar<1>, !fir.ref<!fir.char<1,?>>)
71+
%11 = hlfir.char_trim %8#0 : (!fir.boxchar<1>) -> !hlfir.expr<!fir.char<1,?>>
72+
hlfir.assign %11 to %6#0 realloc : !hlfir.expr<!fir.char<1,?>>, !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
73+
hlfir.destroy %11 : !hlfir.expr<!fir.char<1,?>>
74+
%12 = fir.load %6#0 : !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
75+
%13 = fir.box_addr %12 : (!fir.box<!fir.heap<!fir.char<1,?>>>) -> !fir.heap<!fir.char<1,?>>
76+
%14 = fir.load %6#0 : !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
77+
%15 = fir.box_elesize %14 : (!fir.box<!fir.heap<!fir.char<1,?>>>) -> index
78+
%16 = fir.emboxchar %13, %15 : (!fir.heap<!fir.char<1,?>>, index) -> !fir.boxchar<1>
79+
%17 = hlfir.cmpchar eq %16 %10#0 : (!fir.boxchar<1>, !fir.boxchar<1>) -> i1
80+
%18 = fir.convert %17 : (i1) -> !fir.logical<4>
81+
hlfir.assign %18 to %2#0 : !fir.logical<4>, !fir.ref<!fir.logical<4>>
82+
%19 = fir.load %2#0 : !fir.ref<!fir.logical<4>>
83+
%20 = fir.load %6#0 : !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
84+
%21 = fir.box_addr %20 : (!fir.box<!fir.heap<!fir.char<1,?>>>) -> !fir.heap<!fir.char<1,?>>
85+
%22 = fir.convert %21 : (!fir.heap<!fir.char<1,?>>) -> i64
86+
%c0_i64 = arith.constant 0 : i64
87+
%23 = arith.cmpi ne, %22, %c0_i64 : i64
88+
fir.if %23 {
89+
%24 = fir.load %6#0 : !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
90+
%25 = fir.box_addr %24 : (!fir.box<!fir.heap<!fir.char<1,?>>>) -> !fir.heap<!fir.char<1,?>>
91+
fir.freemem %25 : !fir.heap<!fir.char<1,?>>
92+
%26 = fir.zero_bits !fir.heap<!fir.char<1,?>>
93+
%c0_0 = arith.constant 0 : index
94+
%27 = fir.embox %26 typeparams %c0_0 : (!fir.heap<!fir.char<1,?>>, index) -> !fir.box<!fir.heap<!fir.char<1,?>>>
95+
fir.store %27 to %6#0 : !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
96+
}
97+
return %19 : !fir.logical<4>
98+
}
99+
100+
// CHECK-LABEL: func.func @_QPtest_char_cmp2(
101+
// CHECK: hlfir.char_trim

0 commit comments

Comments
 (0)