Skip to content

Commit bac9575

Browse files
authored
[flang] Reset all extents to zero for empty hlfir.elemental loops. (#124867)
An hlfir.elemental with a shape `(0, HUGE)` still runs `HUGE` number of iterations when expanded into a loop nest. HLFIR transformational operations inlined as hlfir.elemental may execute slower comparing to Fortran runtime implementation. This patch adds an option for BufferizeHLFIR pass to reset all upper bounds in the elemental loop nests to zero, if the result is an empty array. A separate patch will enable this option in the driver after I do more performance testing. The option is off by default now.
1 parent b870875 commit bac9575

File tree

5 files changed

+119
-9
lines changed

5 files changed

+119
-9
lines changed

flang/include/flang/Optimizer/Builder/FIRBuilder.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -813,6 +813,18 @@ uint64_t getAllocaAddressSpace(mlir::DataLayout *dataLayout);
813813
llvm::SmallVector<mlir::Value> deduceOptimalExtents(mlir::ValueRange extents1,
814814
mlir::ValueRange extents2);
815815

816+
/// Given array extents generate code that sets them all to zeroes,
817+
/// if the array is empty, e.g.:
818+
/// %false = arith.constant false
819+
/// %c0 = arith.constant 0 : index
820+
/// %p1 = arith.cmpi eq, %e0, %c0 : index
821+
/// %p2 = arith.ori %false, %p1 : i1
822+
/// %p3 = arith.cmpi eq, %e1, %c0 : index
823+
/// %p4 = arith.ori %p1, %p2 : i1
824+
/// %result0 = arith.select %p4, %c0, %e0 : index
825+
/// %result1 = arith.select %p4, %c0, %e1 : index
826+
llvm::SmallVector<mlir::Value> updateRuntimeExtentsForEmptyArrays(
827+
fir::FirOpBuilder &builder, mlir::Location loc, mlir::ValueRange extents);
816828
} // namespace fir::factory
817829

818830
#endif // FORTRAN_OPTIMIZER_BUILDER_FIRBUILDER_H

flang/include/flang/Optimizer/HLFIR/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ def ConvertHLFIRtoFIR : Pass<"convert-hlfir-to-fir", "::mlir::ModuleOp"> {
1919

2020
def BufferizeHLFIR : Pass<"bufferize-hlfir", "::mlir::ModuleOp"> {
2121
let summary = "Convert HLFIR operations operating on hlfir.expr into operations on memory";
22+
let options = [Option<"optimizeEmptyElementals", "opt-empty-elementals",
23+
"bool", /*default=*/"false",
24+
"When converting hlfir.elemental into a loop nest, "
25+
"check if the resulting expression is an empty array, "
26+
"and make sure none of the loops is executed.">];
2227
}
2328

2429
def OptimizedBufferization : Pass<"opt-bufferization"> {

flang/lib/Optimizer/Builder/FIRBuilder.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1759,3 +1759,29 @@ fir::factory::deduceOptimalExtents(mlir::ValueRange extents1,
17591759
}
17601760
return extents;
17611761
}
1762+
1763+
llvm::SmallVector<mlir::Value> fir::factory::updateRuntimeExtentsForEmptyArrays(
1764+
fir::FirOpBuilder &builder, mlir::Location loc, mlir::ValueRange extents) {
1765+
if (extents.size() <= 1)
1766+
return extents;
1767+
1768+
mlir::Type i1Type = builder.getI1Type();
1769+
mlir::Value isEmpty = createZeroValue(builder, loc, i1Type);
1770+
1771+
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> zeroes;
1772+
for (mlir::Value extent : extents) {
1773+
mlir::Type type = extent.getType();
1774+
mlir::Value zero = createZeroValue(builder, loc, type);
1775+
zeroes.push_back(zero);
1776+
mlir::Value isZero = builder.create<mlir::arith::CmpIOp>(
1777+
loc, mlir::arith::CmpIPredicate::eq, extent, zero);
1778+
isEmpty = builder.create<mlir::arith::OrIOp>(loc, isEmpty, isZero);
1779+
}
1780+
1781+
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> newExtents;
1782+
for (auto [zero, extent] : llvm::zip_equal(zeroes, extents)) {
1783+
newExtents.push_back(
1784+
builder.create<mlir::arith::SelectOp>(loc, isEmpty, zero, extent));
1785+
}
1786+
return newExtents;
1787+
}

flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -761,8 +761,10 @@ struct HLFIRListener : public mlir::OpBuilder::Listener {
761761
struct ElementalOpConversion
762762
: public mlir::OpConversionPattern<hlfir::ElementalOp> {
763763
using mlir::OpConversionPattern<hlfir::ElementalOp>::OpConversionPattern;
764-
explicit ElementalOpConversion(mlir::MLIRContext *ctx)
765-
: mlir::OpConversionPattern<hlfir::ElementalOp>{ctx} {
764+
explicit ElementalOpConversion(mlir::MLIRContext *ctx,
765+
bool optimizeEmptyElementals = false)
766+
: mlir::OpConversionPattern<hlfir::ElementalOp>{ctx},
767+
optimizeEmptyElementals(optimizeEmptyElementals) {
766768
// This pattern recursively converts nested ElementalOp's
767769
// by cloning and then converting them, so we have to allow
768770
// for recursive pattern application. The recursion is bounded
@@ -791,6 +793,10 @@ struct ElementalOpConversion
791793
// of the loop nest.
792794
temp = derefPointersAndAllocatables(loc, builder, temp);
793795

796+
if (optimizeEmptyElementals)
797+
extents = fir::factory::updateRuntimeExtentsForEmptyArrays(builder, loc,
798+
extents);
799+
794800
// Generate a loop nest looping around the fir.elemental shape and clone
795801
// fir.elemental region inside the inner loop.
796802
hlfir::LoopNest loopNest =
@@ -861,6 +867,9 @@ struct ElementalOpConversion
861867
rewriter.replaceOp(elemental, bufferizedExpr);
862868
return mlir::success();
863869
}
870+
871+
private:
872+
bool optimizeEmptyElementals = false;
864873
};
865874
struct CharExtremumOpConversion
866875
: public mlir::OpConversionPattern<hlfir::CharExtremumOp> {
@@ -932,6 +941,8 @@ struct EvaluateInMemoryOpConversion
932941

933942
class BufferizeHLFIR : public hlfir::impl::BufferizeHLFIRBase<BufferizeHLFIR> {
934943
public:
944+
using BufferizeHLFIRBase<BufferizeHLFIR>::BufferizeHLFIRBase;
945+
935946
void runOnOperation() override {
936947
// TODO: make this a pass operating on FuncOp. The issue is that
937948
// FirOpBuilder helpers may generate new FuncOp because of runtime/llvm
@@ -943,13 +954,13 @@ class BufferizeHLFIR : public hlfir::impl::BufferizeHLFIRBase<BufferizeHLFIR> {
943954
auto module = this->getOperation();
944955
auto *context = &getContext();
945956
mlir::RewritePatternSet patterns(context);
946-
patterns
947-
.insert<ApplyOpConversion, AsExprOpConversion, AssignOpConversion,
948-
AssociateOpConversion, CharExtremumOpConversion,
949-
ConcatOpConversion, DestroyOpConversion, ElementalOpConversion,
950-
EndAssociateOpConversion, EvaluateInMemoryOpConversion,
951-
NoReassocOpConversion, SetLengthOpConversion,
952-
ShapeOfOpConversion, GetLengthOpConversion>(context);
957+
patterns.insert<ApplyOpConversion, AsExprOpConversion, AssignOpConversion,
958+
AssociateOpConversion, CharExtremumOpConversion,
959+
ConcatOpConversion, DestroyOpConversion,
960+
EndAssociateOpConversion, EvaluateInMemoryOpConversion,
961+
NoReassocOpConversion, SetLengthOpConversion,
962+
ShapeOfOpConversion, GetLengthOpConversion>(context);
963+
patterns.insert<ElementalOpConversion>(context, optimizeEmptyElementals);
953964
mlir::ConversionTarget target(*context);
954965
// Note that YieldElementOp is not marked as an illegal operation.
955966
// It must be erased by its parent converter and there is no explicit
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// Test hlfir.elemental code generation with a dynamic check
2+
// for empty result array
3+
// RUN: fir-opt %s --bufferize-hlfir=opt-empty-elementals=true | FileCheck %s
4+
5+
func.func @test(%v: i32, %e0: i32, %e1: i32, %e2: i64, %e3: i64) {
6+
%shape = fir.shape %e0, %e1, %e2, %e3 : (i32, i32, i64, i64) -> !fir.shape<4>
7+
%result = hlfir.elemental %shape : (!fir.shape<4>) -> !hlfir.expr<?x?x?x?xi32> {
8+
^bb0(%arg0: index, %arg1: index, %arg2: index, %arg3: index):
9+
hlfir.yield_element %v : i32
10+
}
11+
return
12+
}
13+
// CHECK-LABEL: func.func @test(
14+
// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32,
15+
// CHECK-SAME: %[[VAL_3:.*]]: i64, %[[VAL_4:.*]]: i64) {
16+
// CHECK: %[[VAL_5:.*]] = fir.shape %[[VAL_1]], %[[VAL_2]], %[[VAL_3]], %[[VAL_4]] : (i32, i32, i64, i64) -> !fir.shape<4>
17+
// CHECK: %[[VAL_6:.*]] = fir.convert %[[VAL_1]] : (i32) -> index
18+
// CHECK: %[[VAL_7:.*]] = fir.convert %[[VAL_2]] : (i32) -> index
19+
// CHECK: %[[VAL_8:.*]] = fir.convert %[[VAL_3]] : (i64) -> index
20+
// CHECK: %[[VAL_9:.*]] = fir.convert %[[VAL_4]] : (i64) -> index
21+
// CHECK: %[[VAL_10:.*]] = fir.allocmem !fir.array<?x?x?x?xi32>, %[[VAL_6]], %[[VAL_7]], %[[VAL_8]], %[[VAL_9]] {bindc_name = ".tmp.array", uniq_name = ""}
22+
// CHECK: %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_10]](%[[VAL_5]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<?x?x?x?xi32>>, !fir.shape<4>) -> (!fir.box<!fir.array<?x?x?x?xi32>>, !fir.heap<!fir.array<?x?x?x?xi32>>)
23+
// CHECK: %[[VAL_12:.*]] = arith.constant true
24+
// CHECK: %[[VAL_13:.*]] = arith.constant false
25+
// CHECK: %[[C0_1:.*]] = arith.constant 0 : index
26+
// CHECK: %[[VAL_15:.*]] = arith.cmpi eq, %[[VAL_6]], %[[C0_1]] : index
27+
// CHECK: %[[VAL_16:.*]] = arith.ori %[[VAL_13]], %[[VAL_15]] : i1
28+
// CHECK: %[[C0_2:.*]] = arith.constant 0 : index
29+
// CHECK: %[[VAL_17:.*]] = arith.cmpi eq, %[[VAL_7]], %[[C0_2]] : index
30+
// CHECK: %[[VAL_18:.*]] = arith.ori %[[VAL_16]], %[[VAL_17]] : i1
31+
// CHECK: %[[C0_3:.*]] = arith.constant 0 : index
32+
// CHECK: %[[VAL_19:.*]] = arith.cmpi eq, %[[VAL_8]], %[[C0_3]] : index
33+
// CHECK: %[[VAL_20:.*]] = arith.ori %[[VAL_18]], %[[VAL_19]] : i1
34+
// CHECK: %[[C0_4:.*]] = arith.constant 0 : index
35+
// CHECK: %[[VAL_21:.*]] = arith.cmpi eq, %[[VAL_9]], %[[C0_4]] : index
36+
// CHECK: %[[VAL_22:.*]] = arith.ori %[[VAL_20]], %[[VAL_21]] : i1
37+
// CHECK: %[[VAL_23:.*]] = arith.select %[[VAL_22]], %[[C0_1]], %[[VAL_6]] : index
38+
// CHECK: %[[VAL_24:.*]] = arith.select %[[VAL_22]], %[[C0_2]], %[[VAL_7]] : index
39+
// CHECK: %[[VAL_25:.*]] = arith.select %[[VAL_22]], %[[C0_3]], %[[VAL_8]] : index
40+
// CHECK: %[[VAL_26:.*]] = arith.select %[[VAL_22]], %[[C0_4]], %[[VAL_9]] : index
41+
// CHECK: %[[VAL_27:.*]] = arith.constant 1 : index
42+
// CHECK: fir.do_loop %[[VAL_28:.*]] = %[[VAL_27]] to %[[VAL_26]] step %[[VAL_27]] {
43+
// CHECK: fir.do_loop %[[VAL_29:.*]] = %[[VAL_27]] to %[[VAL_25]] step %[[VAL_27]] {
44+
// CHECK: fir.do_loop %[[VAL_30:.*]] = %[[VAL_27]] to %[[VAL_24]] step %[[VAL_27]] {
45+
// CHECK: fir.do_loop %[[VAL_31:.*]] = %[[VAL_27]] to %[[VAL_23]] step %[[VAL_27]] {
46+
// CHECK: %[[VAL_32:.*]] = hlfir.designate %[[VAL_11]]#0 (%[[VAL_31]], %[[VAL_30]], %[[VAL_29]], %[[VAL_28]]) : (!fir.box<!fir.array<?x?x?x?xi32>>, index, index, index, index) -> !fir.ref<i32>
47+
// CHECK: hlfir.assign %[[VAL_0]] to %[[VAL_32]] temporary_lhs : i32, !fir.ref<i32>
48+
// CHECK: }
49+
// CHECK: }
50+
// CHECK: }
51+
// CHECK: }
52+
// CHECK: %[[VAL_33:.*]] = fir.undefined tuple<!fir.box<!fir.array<?x?x?x?xi32>>, i1>
53+
// CHECK: %[[VAL_34:.*]] = fir.insert_value %[[VAL_33]], %[[VAL_12]], [1 : index] : (tuple<!fir.box<!fir.array<?x?x?x?xi32>>, i1>, i1) -> tuple<!fir.box<!fir.array<?x?x?x?xi32>>, i1>
54+
// CHECK: %[[VAL_35:.*]] = fir.insert_value %[[VAL_34]], %[[VAL_11]]#0, [0 : index] : (tuple<!fir.box<!fir.array<?x?x?x?xi32>>, i1>, !fir.box<!fir.array<?x?x?x?xi32>>) -> tuple<!fir.box<!fir.array<?x?x?x?xi32>>, i1>
55+
// CHECK: return
56+
// CHECK: }

0 commit comments

Comments
 (0)