Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 64 additions & 7 deletions flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/HLFIR/Passes.h"
#include "flang/Optimizer/OpenMP/Passes.h"
#include "flang/Optimizer/Support/Utils.h"
#include "flang/Optimizer/Transforms/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Dominance.h"
Expand Down Expand Up @@ -758,6 +759,16 @@ class BroadcastAssignBufferization
mlir::PatternRewriter &rewriter) const override;
};

static bool isAllocatableArray(mlir::Type ty) {
auto boxTy = mlir::dyn_cast<fir::BoxType>(ty);
if (!boxTy)
return false;
auto heapTy = mlir::dyn_cast<fir::HeapType>(boxTy.getElementType());
if (!heapTy)
return false;
return mlir::isa<fir::SequenceType>(heapTy.getElementType());
}

llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite(
hlfir::AssignOp assign, mlir::PatternRewriter &rewriter) const {
// Since RHS is a scalar and LHS is an array, LHS must be allocated
Expand Down Expand Up @@ -786,13 +797,59 @@ llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite(
mlir::Value shape = hlfir::genShape(loc, builder, lhs);
llvm::SmallVector<mlir::Value> extents =
hlfir::getIndexExtents(loc, builder, shape);
hlfir::LoopNest loopNest =
hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true,
flangomp::shouldUseWorkshareLowering(assign));
builder.setInsertionPointToStart(loopNest.body);
auto arrayElement =
hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices);
builder.create<hlfir::AssignOp>(loc, rhs, arrayElement);

bool isArrayRef =
mlir::isa<fir::SequenceType>(fir::unwrapRefType(lhs.getType()));
if (lhs.isSimplyContiguous() && extents.size() > 1 &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't lhs.isSimplyContiguous() always true for the allocatable arrays? What is the reason for adding isAllocatableArray method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, lhs.isSimplyContiguous() is always true for allocatable arrays.

I've added isAllocatableArray because the code now handles only !fir.ref<!fir.array> and !fir.box<!fir.heap<!fir.array>> types.

I guess it should also handle !fir.box<!fir.array>, when those are contiguous. They may appear when using non-default lower bounds.

By handling these cases, do you think it would be safe to drop the isAllocatableArray check? Or are there other array types that also pass the isSimplyContiguous check?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe isSimplyContiguous may also return true for !fir.box<!fir.ptr<.... I do not think you need to explicitly handle all these different cases. There is hlfir::derefPointersAndAllocatables call above. If lhs is a box, you may use the box_addr operation to get the base address of the array, otherwise, lhs is the base address of the array already.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Removing the isAllocatableArray check also made it easier to support !fir.box<!fir.array>.
Now all boxed arrays are converted to !fir.box<!fir.array<?x type>>

(isArrayRef || isAllocatableArray(lhs.getType()))) {
// Flatten the array to use a single assign loop, that can be better
// optimized.
mlir::Value n = extents[0];
for (size_t i = 1; i < extents.size(); ++i)
n = builder.create<mlir::arith::MulIOp>(loc, n, extents[i]);
llvm::SmallVector<mlir::Value> flatExtents = {n};

mlir::Type flatArrayType;
mlir::Value flatArray = lhs.getBase();
if (isArrayRef) {
// Array references must have fixed shape, when used in assignments.
int64_t flatExtent = 1;
for (const mlir::Value &extent : extents) {
mlir::Operation *op = extent.getDefiningOp();
assert(op && "no defining operation for constant array extent");
flatExtent *= fir::toInt(mlir::cast<mlir::arith::ConstantOp>(*op));
}

flatArrayType =
fir::ReferenceType::get(fir::SequenceType::get({flatExtent}, eleTy));
flatArray = builder.createConvert(loc, flatArrayType, flatArray);
} else {
shape = builder.genShape(loc, flatExtents);
flatArrayType = fir::BoxType::get(
fir::HeapType::get(fir::SequenceType::get(eleTy, 1)));
flatArray = builder.create<fir::ReboxOp>(loc, flatArrayType, flatArray,
shape, /*slice=*/mlir::Value{});
}

hlfir::LoopNest loopNest =
hlfir::genLoopNest(loc, builder, flatExtents, /*isUnordered=*/true,
flangomp::shouldUseWorkshareLowering(assign));
builder.setInsertionPointToStart(loopNest.body);

mlir::Value arrayElement =
builder.create<hlfir::DesignateOp>(loc, fir::ReferenceType::get(eleTy),
flatArray, loopNest.oneBasedIndices);
builder.create<hlfir::AssignOp>(loc, rhs, arrayElement);
} else {
hlfir::LoopNest loopNest =
hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true,
flangomp::shouldUseWorkshareLowering(assign));
builder.setInsertionPointToStart(loopNest.body);
auto arrayElement =
hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices);
builder.create<hlfir::AssignOp>(loc, rhs, arrayElement);
}

rewriter.eraseOp(assign);
return mlir::success();
}
Expand Down
50 changes: 38 additions & 12 deletions flang/test/HLFIR/opt-scalar-assign.fir
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@ func.func @_QPtest1() {
return
}
// CHECK-LABEL: func.func @_QPtest1() {
// CHECK: %[[VAL_0:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAL_2:.*]] = arith.constant 11 : index
// CHECK: %[[VAL_3:.*]] = arith.constant 13 : index
// CHECK: %[[VAL_4:.*]] = fir.alloca !fir.array<11x13xf32> {bindc_name = "x", uniq_name = "_QFtest1Ex"}
// CHECK: %[[VAL_5:.*]] = fir.shape %[[VAL_2]], %[[VAL_3]] : (index, index) -> !fir.shape<2>
// CHECK: %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_4]](%[[VAL_5]]) {uniq_name = "_QFtest1Ex"} : (!fir.ref<!fir.array<11x13xf32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<11x13xf32>>, !fir.ref<!fir.array<11x13xf32>>)
// CHECK: fir.do_loop %[[VAL_7:.*]] = %[[VAL_0]] to %[[VAL_3]] step %[[VAL_0]] unordered {
// CHECK: fir.do_loop %[[VAL_8:.*]] = %[[VAL_0]] to %[[VAL_2]] step %[[VAL_0]] unordered {
// CHECK: %[[VAL_9:.*]] = hlfir.designate %[[VAL_6]]#0 (%[[VAL_8]], %[[VAL_7]]) : (!fir.ref<!fir.array<11x13xf32>>, index, index) -> !fir.ref<f32>
// CHECK: hlfir.assign %[[VAL_1]] to %[[VAL_9]] : f32, !fir.ref<f32>
// CHECK: }
// CHECK: %[[VAL_0:.*]] = arith.constant 143 : index
// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAL_3:.*]] = arith.constant 11 : index
// CHECK: %[[VAL_4:.*]] = arith.constant 13 : index
// CHECK: %[[VAL_5:.*]] = fir.alloca !fir.array<11x13xf32> {bindc_name = "x", uniq_name = "_QFtest1Ex"}
// CHECK: %[[VAL_6:.*]] = fir.shape %[[VAL_3]], %[[VAL_4]] : (index, index) -> !fir.shape<2>
// CHECK: %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_5]](%[[VAL_6]]) {uniq_name = "_QFtest1Ex"} : (!fir.ref<!fir.array<11x13xf32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<11x13xf32>>, !fir.ref<!fir.array<11x13xf32>>)
// CHECK: %[[VAL_8:.*]] = fir.convert %[[VAL_7]]#0 : (!fir.ref<!fir.array<11x13xf32>>) -> !fir.ref<!fir.array<143xf32>>
// CHECK: fir.do_loop %[[VAL_9:.*]] = %[[VAL_1]] to %[[VAL_0]] step %[[VAL_1]] unordered {
// CHECK: %[[VAL_10:.*]] = hlfir.designate %[[VAL_8]] (%[[VAL_9]]) : (!fir.ref<!fir.array<143xf32>>, index) -> !fir.ref<f32>
// CHECK: hlfir.assign %[[VAL_2]] to %[[VAL_10]] : f32, !fir.ref<f32>
// CHECK: }
// CHECK: return
// CHECK: }
Expand Down Expand Up @@ -129,3 +129,29 @@ func.func @_QPtest5(%arg0: !fir.ref<!fir.array<77xcomplex<f32>>> {fir.bindc_name
// CHECK: }
// CHECK: return
// CHECK: }

func.func @_QPtest6(%arg0: !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>> {fir.bindc_name = "x"}) {
%c0_i32 = arith.constant 0 : i32
%0:2 = hlfir.declare %arg0 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtest6Ex"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>)
hlfir.assign %c0_i32 to %0#0 realloc : i32, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>
return
}

// CHECK-LABEL: func.func @_QPtest6(
// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>> {fir.bindc_name = "x"}) {
// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_3:.*]] = arith.constant 0 : i32
// CHECK: %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_0]] {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtest6Ex"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>)
// CHECK: %[[VAL_5:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>
// CHECK: %[[VAL_6:.*]]:3 = fir.box_dims %[[VAL_5]], %[[VAL_2]] : (!fir.box<!fir.heap<!fir.array<?x?xi32>>>, index) -> (index, index, index)
// CHECK: %[[VAL_7:.*]]:3 = fir.box_dims %[[VAL_5]], %[[VAL_1]] : (!fir.box<!fir.heap<!fir.array<?x?xi32>>>, index) -> (index, index, index)
// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_6]]#1, %[[VAL_7]]#1 : index
// CHECK: %[[VAL_9:.*]] = fir.shape %[[VAL_8]] : (index) -> !fir.shape<1>
// CHECK: %[[VAL_10:.*]] = fir.rebox %[[VAL_5]](%[[VAL_9]]) : (!fir.box<!fir.heap<!fir.array<?x?xi32>>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xi32>>>
// CHECK: fir.do_loop %[[VAL_11:.*]] = %[[VAL_1]] to %[[VAL_8]] step %[[VAL_1]] unordered {
// CHECK: %[[VAL_12:.*]] = hlfir.designate %[[VAL_10]] (%[[VAL_11]]) : (!fir.box<!fir.heap<!fir.array<?xi32>>>, index) -> !fir.ref<i32>
// CHECK: hlfir.assign %[[VAL_3]] to %[[VAL_12]] : i32, !fir.ref<i32>
// CHECK: }
// CHECK: return
// CHECK: }
Loading