Skip to content

Commit 0542e16

Browse files
committed
Handle dynamically absent mask argument properly.
1 parent 196d6de commit 0542e16

File tree

2 files changed

+129
-27
lines changed

2 files changed

+129
-27
lines changed

flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,17 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
141141
// If the mask is present and is a scalar, then we'd better load its value
142142
// outside of the reduction loop making the loop unswitching easier.
143143
// Maybe it is worth hoisting it from the elemental operation as well.
144+
mlir::Value isPresentPred, maskValue;
144145
if (mask) {
145-
hlfir::Entity maskValue{mask};
146-
if (maskValue.isScalar())
147-
mask = hlfir::loadTrivialScalar(loc, builder, maskValue);
146+
if (mlir::isa<fir::BaseBoxType>(mask.getType())) {
147+
// MASK represented by a box might be dynamically optional,
148+
// so we have to check for its presence before accessing it.
149+
isPresentPred =
150+
builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), mask);
151+
}
152+
153+
if (hlfir::Entity{mask}.isScalar())
154+
maskValue = genMaskValue(loc, builder, mask, isPresentPred, {});
148155
}
149156

150157
// NOTE: the outer elemental operation may be lowered into
@@ -171,12 +178,10 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
171178
if (mask) {
172179
// Make the reduction value update conditional on the value
173180
// of the mask.
174-
hlfir::Entity maskValue{mask};
175-
if (!maskValue.isScalar()) {
181+
if (!maskValue) {
176182
// If the mask is an array, use the elemental and the loop indices
177183
// to address the proper mask element.
178-
maskValue = hlfir::getElementAt(loc, builder, maskValue, indices);
179-
maskValue = hlfir::loadTrivialScalar(loc, builder, maskValue);
184+
maskValue = genMaskValue(loc, builder, mask, isPresentPred, indices);
180185
}
181186
mlir::Value isUnmasked =
182187
builder.create<fir::ConvertOp>(loc, builder.getI1Type(), maskValue);
@@ -273,6 +278,51 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
273278

274279
llvm_unreachable("unsupported SUM reduction type");
275280
}
281+
282+
static mlir::Value genMaskValue(mlir::Location loc,
283+
fir::FirOpBuilder &builder, mlir::Value mask,
284+
mlir::Value isPresentPred,
285+
mlir::ValueRange indices) {
286+
mlir::OpBuilder::InsertionGuard guard(builder);
287+
fir::IfOp ifOp;
288+
mlir::Type maskType =
289+
hlfir::getFortranElementType(fir::unwrapPassByRefType(mask.getType()));
290+
if (isPresentPred) {
291+
ifOp = builder.create<fir::IfOp>(loc, maskType, isPresentPred,
292+
/*withElseRegion=*/true);
293+
294+
// Use 'true', if the mask is not present.
295+
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
296+
mlir::Value trueValue = builder.createBool(loc, true);
297+
trueValue = builder.createConvert(loc, maskType, trueValue);
298+
builder.create<fir::ResultOp>(loc, trueValue);
299+
300+
// Load the mask value, if the mask is present.
301+
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
302+
}
303+
304+
hlfir::Entity maskVar{mask};
305+
if (maskVar.isScalar()) {
306+
if (mlir::isa<fir::BaseBoxType>(mask.getType())) {
307+
// MASK may be a boxed scalar.
308+
mlir::Value addr = hlfir::genVariableRawAddress(loc, builder, maskVar);
309+
mask = builder.create<fir::LoadOp>(loc, hlfir::Entity{addr});
310+
} else {
311+
mask = hlfir::loadTrivialScalar(loc, builder, maskVar);
312+
}
313+
} else {
314+
// Load from the mask array.
315+
assert(!indices.empty() && "no indices for addressing the mask array");
316+
maskVar = hlfir::getElementAt(loc, builder, maskVar, indices);
317+
mask = hlfir::loadTrivialScalar(loc, builder, maskVar);
318+
}
319+
320+
if (!isPresentPred)
321+
return mask;
322+
323+
builder.create<fir::ResultOp>(loc, mask);
324+
return ifOp.getResult(0);
325+
}
276326
};
277327

278328
class SimplifyHLFIRIntrinsics

flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir

Lines changed: 72 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,50 @@ func.func @sum_scalar_mask(%arg0: !hlfir.expr<?x3xf32>, %mask: !fir.ref<!fir.log
229229
// CHECK: return
230230
// CHECK: }
231231

232+
// scalar boxed mask
233+
func.func @sum_scalar_boxed_mask(%arg0: !hlfir.expr<?x3xf32>, %mask: !fir.box<!fir.logical<1>>) {
234+
%cst = arith.constant 1 : i32
235+
%res = hlfir.sum %arg0 dim %cst mask %mask : (!hlfir.expr<?x3xf32>, i32, !fir.box<!fir.logical<1>>) -> !hlfir.expr<3xf32>
236+
return
237+
}
238+
// CHECK-LABEL: func.func @sum_scalar_boxed_mask(
239+
// CHECK-SAME: %[[VAL_0:.*]]: !hlfir.expr<?x3xf32>,
240+
// CHECK-SAME: %[[VAL_1:.*]]: !fir.box<!fir.logical<1>>) {
241+
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32
242+
// CHECK: %[[VAL_3:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x3xf32>) -> !fir.shape<2>
243+
// CHECK: %[[VAL_4:.*]] = hlfir.get_extent %[[VAL_3]] {dim = 0 : index} : (!fir.shape<2>) -> index
244+
// CHECK: %[[VAL_5:.*]] = arith.constant 3 : index
245+
// CHECK: %[[VAL_6:.*]] = fir.shape %[[VAL_5]] : (index) -> !fir.shape<1>
246+
// CHECK: %[[VAL_7:.*]] = hlfir.elemental %[[VAL_6]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xf32> {
247+
// CHECK: ^bb0(%[[VAL_8:.*]]: index):
248+
// CHECK: %[[VAL_9:.*]] = arith.constant 1 : index
249+
// CHECK: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
250+
// CHECK: %[[VAL_11:.*]] = fir.is_present %[[VAL_1]] : (!fir.box<!fir.logical<1>>) -> i1
251+
// CHECK: %[[VAL_12:.*]] = fir.if %[[VAL_11]] -> (!fir.logical<1>) {
252+
// CHECK: %[[VAL_13:.*]] = fir.box_addr %[[VAL_1]] : (!fir.box<!fir.logical<1>>) -> !fir.ref<!fir.logical<1>>
253+
// CHECK: %[[VAL_14:.*]] = fir.load %[[VAL_13]] : !fir.ref<!fir.logical<1>>
254+
// CHECK: fir.result %[[VAL_14]] : !fir.logical<1>
255+
// CHECK: } else {
256+
// CHECK: %[[VAL_15:.*]] = arith.constant true
257+
// CHECK: %[[VAL_16:.*]] = fir.convert %[[VAL_15]] : (i1) -> !fir.logical<1>
258+
// CHECK: fir.result %[[VAL_16]] : !fir.logical<1>
259+
// CHECK: }
260+
// CHECK: %[[VAL_17:.*]] = fir.do_loop %[[VAL_18:.*]] = %[[VAL_9]] to %[[VAL_4]] step %[[VAL_9]] iter_args(%[[VAL_19:.*]] = %[[VAL_10]]) -> (f32) {
261+
// CHECK: %[[VAL_20:.*]] = fir.convert %[[VAL_12]] : (!fir.logical<1>) -> i1
262+
// CHECK: %[[VAL_21:.*]] = fir.if %[[VAL_20]] -> (f32) {
263+
// CHECK: %[[VAL_22:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_18]], %[[VAL_8]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
264+
// CHECK: %[[VAL_23:.*]] = arith.addf %[[VAL_19]], %[[VAL_22]] : f32
265+
// CHECK: fir.result %[[VAL_23]] : f32
266+
// CHECK: } else {
267+
// CHECK: fir.result %[[VAL_19]] : f32
268+
// CHECK: }
269+
// CHECK: fir.result %[[VAL_21]] : f32
270+
// CHECK: }
271+
// CHECK: hlfir.yield_element %[[VAL_17]] : f32
272+
// CHECK: }
273+
// CHECK: return
274+
// CHECK: }
275+
232276
// array mask
233277
func.func @sum_array_mask(%arg0: !hlfir.expr<?x3xf32>, %mask: !fir.box<!fir.array<?x3x!fir.logical<1>>>) {
234278
%cst = arith.constant 2 : i32
@@ -247,29 +291,37 @@ func.func @sum_array_mask(%arg0: !hlfir.expr<?x3xf32>, %mask: !fir.box<!fir.arra
247291
// CHECK: ^bb0(%[[VAL_8:.*]]: index):
248292
// CHECK: %[[VAL_9:.*]] = arith.constant 1 : index
249293
// CHECK: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
250-
// CHECK: %[[VAL_11:.*]] = fir.do_loop %[[VAL_12:.*]] = %[[VAL_9]] to %[[VAL_5]] step %[[VAL_9]] iter_args(%[[VAL_13:.*]] = %[[VAL_10]]) -> (f32) {
251-
// CHECK: %[[VAL_14:.*]] = arith.constant 0 : index
252-
// CHECK: %[[VAL_15:.*]]:3 = fir.box_dims %[[VAL_1]], %[[VAL_14]] : (!fir.box<!fir.array<?x3x!fir.logical<1>>>, index) -> (index, index, index)
253-
// CHECK: %[[VAL_16:.*]] = arith.constant 1 : index
254-
// CHECK: %[[VAL_17:.*]]:3 = fir.box_dims %[[VAL_1]], %[[VAL_16]] : (!fir.box<!fir.array<?x3x!fir.logical<1>>>, index) -> (index, index, index)
255-
// CHECK: %[[VAL_18:.*]] = arith.constant 1 : index
256-
// CHECK: %[[VAL_19:.*]] = arith.subi %[[VAL_15]]#0, %[[VAL_18]] : index
257-
// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_8]], %[[VAL_19]] : index
258-
// CHECK: %[[VAL_21:.*]] = arith.subi %[[VAL_17]]#0, %[[VAL_18]] : index
259-
// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_12]], %[[VAL_21]] : index
260-
// CHECK: %[[VAL_23:.*]] = hlfir.designate %[[VAL_1]] (%[[VAL_20]], %[[VAL_22]]) : (!fir.box<!fir.array<?x3x!fir.logical<1>>>, index, index) -> !fir.ref<!fir.logical<1>>
261-
// CHECK: %[[VAL_24:.*]] = fir.load %[[VAL_23]] : !fir.ref<!fir.logical<1>>
262-
// CHECK: %[[VAL_25:.*]] = fir.convert %[[VAL_24]] : (!fir.logical<1>) -> i1
263-
// CHECK: %[[VAL_26:.*]] = fir.if %[[VAL_25]] -> (f32) {
264-
// CHECK: %[[VAL_27:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_8]], %[[VAL_12]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
265-
// CHECK: %[[VAL_28:.*]] = arith.addf %[[VAL_13]], %[[VAL_27]] : f32
266-
// CHECK: fir.result %[[VAL_28]] : f32
294+
// CHECK: %[[VAL_11:.*]] = fir.is_present %[[VAL_1]] : (!fir.box<!fir.array<?x3x!fir.logical<1>>>) -> i1
295+
// CHECK: %[[VAL_12:.*]] = fir.do_loop %[[VAL_13:.*]] = %[[VAL_9]] to %[[VAL_5]] step %[[VAL_9]] iter_args(%[[VAL_14:.*]] = %[[VAL_10]]) -> (f32) {
296+
// CHECK: %[[VAL_15:.*]] = fir.if %[[VAL_11]] -> (!fir.logical<1>) {
297+
// CHECK: %[[VAL_16:.*]] = arith.constant 0 : index
298+
// CHECK: %[[VAL_17:.*]]:3 = fir.box_dims %[[VAL_1]], %[[VAL_16]] : (!fir.box<!fir.array<?x3x!fir.logical<1>>>, index) -> (index, index, index)
299+
// CHECK: %[[VAL_18:.*]] = arith.constant 1 : index
300+
// CHECK: %[[VAL_19:.*]]:3 = fir.box_dims %[[VAL_1]], %[[VAL_18]] : (!fir.box<!fir.array<?x3x!fir.logical<1>>>, index) -> (index, index, index)
301+
// CHECK: %[[VAL_20:.*]] = arith.constant 1 : index
302+
// CHECK: %[[VAL_21:.*]] = arith.subi %[[VAL_17]]#0, %[[VAL_20]] : index
303+
// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_8]], %[[VAL_21]] : index
304+
// CHECK: %[[VAL_23:.*]] = arith.subi %[[VAL_19]]#0, %[[VAL_20]] : index
305+
// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_13]], %[[VAL_23]] : index
306+
// CHECK: %[[VAL_25:.*]] = hlfir.designate %[[VAL_1]] (%[[VAL_22]], %[[VAL_24]]) : (!fir.box<!fir.array<?x3x!fir.logical<1>>>, index, index) -> !fir.ref<!fir.logical<1>>
307+
// CHECK: %[[VAL_26:.*]] = fir.load %[[VAL_25]] : !fir.ref<!fir.logical<1>>
308+
// CHECK: fir.result %[[VAL_26]] : !fir.logical<1>
267309
// CHECK: } else {
268-
// CHECK: fir.result %[[VAL_13]] : f32
310+
// CHECK: %[[VAL_27:.*]] = arith.constant true
311+
// CHECK: %[[VAL_28:.*]] = fir.convert %[[VAL_27]] : (i1) -> !fir.logical<1>
312+
// CHECK: fir.result %[[VAL_28]] : !fir.logical<1>
269313
// CHECK: }
270-
// CHECK: fir.result %[[VAL_26]] : f32
314+
// CHECK: %[[VAL_29:.*]] = fir.convert %[[VAL_15]] : (!fir.logical<1>) -> i1
315+
// CHECK: %[[VAL_30:.*]] = fir.if %[[VAL_29]] -> (f32) {
316+
// CHECK: %[[VAL_31:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_8]], %[[VAL_13]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
317+
// CHECK: %[[VAL_32:.*]] = arith.addf %[[VAL_14]], %[[VAL_31]] : f32
318+
// CHECK: fir.result %[[VAL_32]] : f32
319+
// CHECK: } else {
320+
// CHECK: fir.result %[[VAL_14]] : f32
321+
// CHECK: }
322+
// CHECK: fir.result %[[VAL_30]] : f32
271323
// CHECK: }
272-
// CHECK: hlfir.yield_element %[[VAL_11]] : f32
324+
// CHECK: hlfir.yield_element %[[VAL_12]] : f32
273325
// CHECK: }
274326
// CHECK: return
275327
// CHECK: }

0 commit comments

Comments
 (0)