Skip to content

Commit 30e193f

Browse files
committed
Fixed handling of dynamically optional PAD.
1 parent 5b6b201 commit 30e193f

File tree

2 files changed

+37
-37
lines changed

2 files changed

+37
-37
lines changed

flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,18 +1002,10 @@ class ReshapeAsElementalConversion
10021002
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> arrayExtents =
10031003
hlfir::genExtentsVector(loc, builder, array);
10041004

1005-
mlir::Value arraySize, padSize;
1006-
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> padExtents;
1007-
if (pad) {
1008-
// If PAD is present, we have to use array size to start taking
1009-
// elements from the PAD array.
1010-
arraySize = computeArraySize(loc, builder, arrayExtents);
1011-
1012-
padExtents = hlfir::genExtentsVector(loc, builder, hlfir::Entity{pad});
1013-
// PAD size is needed to wrap around the linear index addressing
1014-
// the PAD array.
1015-
padSize = computeArraySize(loc, builder, padExtents);
1016-
}
1005+
// If PAD is present, we have to use array size to start taking
1006+
// elements from the PAD array.
1007+
mlir::Value arraySize =
1008+
pad ? computeArraySize(loc, builder, arrayExtents) : nullptr;
10171009
hlfir::Entity shape = hlfir::Entity{reshape.getShape()};
10181010
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> resultExtents;
10191011
mlir::Type indexType = builder.getIndexType();
@@ -1037,15 +1029,18 @@ class ReshapeAsElementalConversion
10371029

10381030
// In the 'else' block, return an element from the PAD.
10391031
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1032+
// PAD is dynamically optional, but we can unconditionally access it
1033+
// in the 'else' block. If we have to start taking elements from it,
1034+
// then it must be present in a valid program.
1035+
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> padExtents =
1036+
hlfir::genExtentsVector(loc, builder, hlfir::Entity{pad});
10401037
// Subtract the ARRAY size from the zero-based linear index
10411038
// to get the zero-based linear index into PAD.
10421039
mlir::Value padLinearIndex =
10431040
builder.create<mlir::arith::SubIOp>(loc, linearIndex, arraySize);
1044-
// PAD wraps around, when additional elements are needed.
1045-
padLinearIndex =
1046-
builder.create<mlir::arith::RemUIOp>(loc, padLinearIndex, padSize);
10471041
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> padIndices =
1048-
delinearizeIndex(loc, builder, padExtents, padLinearIndex);
1042+
delinearizeIndex(loc, builder, padExtents, padLinearIndex,
1043+
/*wrapAround=*/true);
10491044
mlir::Value padElement =
10501045
hlfir::loadElementAt(loc, builder, hlfir::Entity{pad}, padIndices);
10511046
builder.create<fir::ResultOp>(loc, padElement);
@@ -1055,7 +1050,8 @@ class ReshapeAsElementalConversion
10551050
}
10561051

10571052
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> arrayIndices =
1058-
delinearizeIndex(loc, builder, arrayExtents, linearIndex);
1053+
delinearizeIndex(loc, builder, arrayExtents, linearIndex,
1054+
/*wrapAround=*/false);
10591055
mlir::Value arrayElement =
10601056
hlfir::loadElementAt(loc, builder, array, arrayIndices);
10611057

@@ -1119,33 +1115,39 @@ class ReshapeAsElementalConversion
11191115
/// ...
11201116
/// i(n-1) := linearIndex % e(n-1) + 1
11211117
/// linearIndex := linearIndex / e(n-1)
1122-
/// in := linearIndex + 1
1118+
/// if (wrapAround) {
1119+
/// // If the index is allowed to wrap around, then
1120+
/// // we need to modulo it by the last dimension's extent.
1121+
/// in := linearIndex % en + 1
1122+
/// } else {
1123+
/// in := linearIndex + 1
1124+
/// }
11231125
static llvm::SmallVector<mlir::Value, Fortran::common::maxRank>
11241126
delinearizeIndex(mlir::Location loc, fir::FirOpBuilder &builder,
1125-
mlir::ValueRange extents, mlir::Value linearIndex) {
1127+
mlir::ValueRange extents, mlir::Value linearIndex,
1128+
bool wrapAround) {
11261129
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices;
11271130
mlir::Type indexType = builder.getIndexType();
11281131
mlir::Value one = builder.createIntegerConstant(loc, indexType, 1);
11291132
linearIndex = builder.createConvert(loc, indexType, linearIndex);
11301133

11311134
for (std::size_t dim = 0; dim < extents.size(); ++dim) {
1132-
mlir::Value currentIndex;
1133-
if (dim == extents.size() - 1) {
1134-
currentIndex = linearIndex;
1135-
} else {
1136-
mlir::Value extent =
1137-
builder.createConvert(loc, indexType, extents[dim]);
1135+
mlir::Value extent = builder.createConvert(loc, indexType, extents[dim]);
1136+
// Avoid the modulo for the last index, unless wrap around is allowed.
1137+
mlir::Value currentIndex = linearIndex;
1138+
if (dim != extents.size() - 1 || wrapAround)
11381139
currentIndex =
11391140
builder.create<mlir::arith::RemUIOp>(loc, linearIndex, extent);
1140-
linearIndex =
1141-
builder.create<mlir::arith::DivUIOp>(loc, linearIndex, extent);
1142-
}
1141+
// The result of the last division is unused, so it will be DCEd.
1142+
linearIndex =
1143+
builder.create<mlir::arith::DivUIOp>(loc, linearIndex, extent);
11431144
indices.push_back(
11441145
builder.create<mlir::arith::AddIOp>(loc, currentIndex, one));
11451146
}
11461147
return indices;
11471148
}
11481149

1150+
/// Return size of an array given its extents.
11491151
static mlir::Value computeArraySize(mlir::Location loc,
11501152
fir::FirOpBuilder &builder,
11511153
mlir::ValueRange extents) {

flang/test/HLFIR/simplify-hlfir-intrinsics-reshape.fir

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,6 @@ func.func @reshape_with_pad(%arg0: !fir.box<!fir.array<?x?x?xf32>>, %arg1: !fir.
4141
// CHECK: %[[ARRAY_DIM2:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_3]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
4242
// CHECK: %[[VAL_9:.*]] = arith.muli %[[ARRAY_DIM0]]#1, %[[ARRAY_DIM1]]#1 overflow<nuw> : index
4343
// CHECK: %[[ARRAY_SIZE:.*]] = arith.muli %[[VAL_9]], %[[ARRAY_DIM2]]#1 overflow<nuw> : index
44-
// CHECK: %[[PAD_DIM0:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_5]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
45-
// CHECK: %[[PAD_DIM1:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_4]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
46-
// CHECK: %[[PAD_DIM2:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_3]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
47-
// CHECK: %[[VAL_14:.*]] = arith.muli %[[PAD_DIM0]]#1, %[[PAD_DIM1]]#1 overflow<nuw> : index
48-
// CHECK: %[[PAD_SIZE:.*]] = arith.muli %[[VAL_14]], %[[PAD_DIM2]]#1 overflow<nuw> : index
4944
// CHECK: %[[VAL_16:.*]] = hlfir.designate %[[VAL_1]] (%[[VAL_4]]) : (!fir.ref<!fir.array<2xi32>>, index) -> !fir.ref<i32>
5045
// CHECK: %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
5146
// CHECK: %[[VAL_18:.*]] = hlfir.designate %[[VAL_1]] (%[[VAL_3]]) : (!fir.ref<!fir.array<2xi32>>, index) -> !fir.ref<i32>
@@ -80,15 +75,18 @@ func.func @reshape_with_pad(%arg0: !fir.box<!fir.array<?x?x?xf32>>, %arg1: !fir.
8075
// CHECK: %[[VAL_48:.*]] = fir.load %[[VAL_47]] : !fir.ref<f32>
8176
// CHECK: fir.result %[[VAL_48]] : f32
8277
// CHECK: } else {
78+
// CHECK: %[[PAD_DIM0:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_5]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
79+
// CHECK: %[[PAD_DIM1:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_4]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
80+
// CHECK: %[[PAD_DIM2:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_3]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
8381
// CHECK: %[[PAD_LINEAR_INDEX:.*]] = arith.subi %[[LINEAR_INDEX]], %[[ARRAY_SIZE]] overflow<nuw> : index
84-
// CHECK: %[[PAD_LINEAR_INDEX_MOD:.*]] = arith.remui %[[PAD_LINEAR_INDEX]], %[[PAD_SIZE]] : index
85-
// CHECK: %[[VAL_51:.*]] = arith.remui %[[PAD_LINEAR_INDEX_MOD]], %[[PAD_DIM0]]#1 : index
86-
// CHECK: %[[VAL_52:.*]] = arith.divui %[[PAD_LINEAR_INDEX_MOD]], %[[PAD_DIM0]]#1 : index
82+
// CHECK: %[[VAL_51:.*]] = arith.remui %[[PAD_LINEAR_INDEX]], %[[PAD_DIM0]]#1 : index
83+
// CHECK: %[[VAL_52:.*]] = arith.divui %[[PAD_LINEAR_INDEX]], %[[PAD_DIM0]]#1 : index
8784
// CHECK: %[[PAD_IDX0:.*]] = arith.addi %[[VAL_51]], %[[VAL_4]] overflow<nuw> : index
8885
// CHECK: %[[VAL_54:.*]] = arith.remui %[[VAL_52]], %[[PAD_DIM1]]#1 : index
8986
// CHECK: %[[VAL_55:.*]] = arith.divui %[[VAL_52]], %[[PAD_DIM1]]#1 : index
9087
// CHECK: %[[PAD_IDX1:.*]] = arith.addi %[[VAL_54]], %[[VAL_4]] overflow<nuw> : index
91-
// CHECK: %[[PAD_IDX2:.*]] = arith.addi %[[VAL_55]], %[[VAL_4]] overflow<nuw> : index
88+
// CHECK: %[[VAL_56:.*]] = arith.remui %[[VAL_55]], %[[PAD_DIM2]]#1 : index
89+
// CHECK: %[[PAD_IDX2:.*]] = arith.addi %[[VAL_56]], %[[VAL_4]] overflow<nuw> : index
9290
// CHECK: %[[VAL_58:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_5]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
9391
// CHECK: %[[VAL_59:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_4]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
9492
// CHECK: %[[VAL_60:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_3]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)

0 commit comments

Comments
 (0)