@@ -229,6 +229,50 @@ func.func @sum_scalar_mask(%arg0: !hlfir.expr<?x3xf32>, %mask: !fir.ref<!fir.log
229229// CHECK: return
230230// CHECK: }
231231
232+ // scalar boxed mask
233+ func.func @sum_scalar_boxed_mask(%arg0: !hlfir.expr<?x3xf32>, %mask: !fir.box<!fir.logical<1>>) {
234+ %cst = arith.constant 1 : i32
235+ %res = hlfir.sum %arg0 dim %cst mask %mask : (!hlfir.expr<?x3xf32>, i32, !fir.box<!fir.logical<1>>) -> !hlfir.expr<3xf32>
236+ return
237+ }
238+ // CHECK-LABEL: func.func @sum_scalar_boxed_mask(
239+ // CHECK-SAME: %[[VAL_0:.*]]: !hlfir.expr<?x3xf32>,
240+ // CHECK-SAME: %[[VAL_1:.*]]: !fir.box<!fir.logical<1>>) {
241+ // CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32
242+ // CHECK: %[[VAL_3:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x3xf32>) -> !fir.shape<2>
243+ // CHECK: %[[VAL_4:.*]] = hlfir.get_extent %[[VAL_3]] {dim = 0 : index} : (!fir.shape<2>) -> index
244+ // CHECK: %[[VAL_5:.*]] = arith.constant 3 : index
245+ // CHECK: %[[VAL_6:.*]] = fir.shape %[[VAL_5]] : (index) -> !fir.shape<1>
246+ // CHECK: %[[VAL_7:.*]] = hlfir.elemental %[[VAL_6]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xf32> {
247+ // CHECK: ^bb0(%[[VAL_8:.*]]: index):
248+ // CHECK: %[[VAL_9:.*]] = arith.constant 1 : index
249+ // CHECK: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
250+ // CHECK: %[[VAL_11:.*]] = fir.is_present %[[VAL_1]] : (!fir.box<!fir.logical<1>>) -> i1
251+ // CHECK: %[[VAL_12:.*]] = fir.if %[[VAL_11]] -> (!fir.logical<1>) {
252+ // CHECK: %[[VAL_13:.*]] = fir.box_addr %[[VAL_1]] : (!fir.box<!fir.logical<1>>) -> !fir.ref<!fir.logical<1>>
253+ // CHECK: %[[VAL_14:.*]] = fir.load %[[VAL_13]] : !fir.ref<!fir.logical<1>>
254+ // CHECK: fir.result %[[VAL_14]] : !fir.logical<1>
255+ // CHECK: } else {
256+ // CHECK: %[[VAL_15:.*]] = arith.constant true
257+ // CHECK: %[[VAL_16:.*]] = fir.convert %[[VAL_15]] : (i1) -> !fir.logical<1>
258+ // CHECK: fir.result %[[VAL_16]] : !fir.logical<1>
259+ // CHECK: }
260+ // CHECK: %[[VAL_17:.*]] = fir.do_loop %[[VAL_18:.*]] = %[[VAL_9]] to %[[VAL_4]] step %[[VAL_9]] iter_args(%[[VAL_19:.*]] = %[[VAL_10]]) -> (f32) {
261+ // CHECK: %[[VAL_20:.*]] = fir.convert %[[VAL_12]] : (!fir.logical<1>) -> i1
262+ // CHECK: %[[VAL_21:.*]] = fir.if %[[VAL_20]] -> (f32) {
263+ // CHECK: %[[VAL_22:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_18]], %[[VAL_8]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
264+ // CHECK: %[[VAL_23:.*]] = arith.addf %[[VAL_19]], %[[VAL_22]] : f32
265+ // CHECK: fir.result %[[VAL_23]] : f32
266+ // CHECK: } else {
267+ // CHECK: fir.result %[[VAL_19]] : f32
268+ // CHECK: }
269+ // CHECK: fir.result %[[VAL_21]] : f32
270+ // CHECK: }
271+ // CHECK: hlfir.yield_element %[[VAL_17]] : f32
272+ // CHECK: }
273+ // CHECK: return
274+ // CHECK: }
275+
232276// array mask
233277func.func @sum_array_mask(%arg0: !hlfir.expr<?x3xf32>, %mask: !fir.box<!fir.array<?x3x!fir.logical<1>>>) {
234278 %cst = arith.constant 2 : i32
@@ -247,29 +291,37 @@ func.func @sum_array_mask(%arg0: !hlfir.expr<?x3xf32>, %mask: !fir.box<!fir.arra
247291// CHECK: ^bb0(%[[VAL_8:.*]]: index):
248292// CHECK: %[[VAL_9:.*]] = arith.constant 1 : index
249293// CHECK: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
250- // CHECK: %[[VAL_11:.*]] = fir.do_loop %[[VAL_12:.*]] = %[[VAL_9]] to %[[VAL_5]] step %[[VAL_9]] iter_args(%[[VAL_13:.*]] = %[[VAL_10]]) -> (f32) {
251- // CHECK: %[[VAL_14:.*]] = arith.constant 0 : index
252- // CHECK: %[[VAL_15:.*]]:3 = fir.box_dims %[[VAL_1]], %[[VAL_14]] : (!fir.box<!fir.array<?x3x!fir.logical<1>>>, index) -> (index, index, index)
253- // CHECK: %[[VAL_16:.*]] = arith.constant 1 : index
254- // CHECK: %[[VAL_17:.*]]:3 = fir.box_dims %[[VAL_1]], %[[VAL_16]] : (!fir.box<!fir.array<?x3x!fir.logical<1>>>, index) -> (index, index, index)
255- // CHECK: %[[VAL_18:.*]] = arith.constant 1 : index
256- // CHECK: %[[VAL_19:.*]] = arith.subi %[[VAL_15]]#0, %[[VAL_18]] : index
257- // CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_8]], %[[VAL_19]] : index
258- // CHECK: %[[VAL_21:.*]] = arith.subi %[[VAL_17]]#0, %[[VAL_18]] : index
259- // CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_12]], %[[VAL_21]] : index
260- // CHECK: %[[VAL_23:.*]] = hlfir.designate %[[VAL_1]] (%[[VAL_20]], %[[VAL_22]]) : (!fir.box<!fir.array<?x3x!fir.logical<1>>>, index, index) -> !fir.ref<!fir.logical<1>>
261- // CHECK: %[[VAL_24:.*]] = fir.load %[[VAL_23]] : !fir.ref<!fir.logical<1>>
262- // CHECK: %[[VAL_25:.*]] = fir.convert %[[VAL_24]] : (!fir.logical<1>) -> i1
263- // CHECK: %[[VAL_26:.*]] = fir.if %[[VAL_25]] -> (f32) {
264- // CHECK: %[[VAL_27:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_8]], %[[VAL_12]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
265- // CHECK: %[[VAL_28:.*]] = arith.addf %[[VAL_13]], %[[VAL_27]] : f32
266- // CHECK: fir.result %[[VAL_28]] : f32
294+ // CHECK: %[[VAL_11:.*]] = fir.is_present %[[VAL_1]] : (!fir.box<!fir.array<?x3x!fir.logical<1>>>) -> i1
295+ // CHECK: %[[VAL_12:.*]] = fir.do_loop %[[VAL_13:.*]] = %[[VAL_9]] to %[[VAL_5]] step %[[VAL_9]] iter_args(%[[VAL_14:.*]] = %[[VAL_10]]) -> (f32) {
296+ // CHECK: %[[VAL_15:.*]] = fir.if %[[VAL_11]] -> (!fir.logical<1>) {
297+ // CHECK: %[[VAL_16:.*]] = arith.constant 0 : index
298+ // CHECK: %[[VAL_17:.*]]:3 = fir.box_dims %[[VAL_1]], %[[VAL_16]] : (!fir.box<!fir.array<?x3x!fir.logical<1>>>, index) -> (index, index, index)
299+ // CHECK: %[[VAL_18:.*]] = arith.constant 1 : index
300+ // CHECK: %[[VAL_19:.*]]:3 = fir.box_dims %[[VAL_1]], %[[VAL_18]] : (!fir.box<!fir.array<?x3x!fir.logical<1>>>, index) -> (index, index, index)
301+ // CHECK: %[[VAL_20:.*]] = arith.constant 1 : index
302+ // CHECK: %[[VAL_21:.*]] = arith.subi %[[VAL_17]]#0, %[[VAL_20]] : index
303+ // CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_8]], %[[VAL_21]] : index
304+ // CHECK: %[[VAL_23:.*]] = arith.subi %[[VAL_19]]#0, %[[VAL_20]] : index
305+ // CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_13]], %[[VAL_23]] : index
306+ // CHECK: %[[VAL_25:.*]] = hlfir.designate %[[VAL_1]] (%[[VAL_22]], %[[VAL_24]]) : (!fir.box<!fir.array<?x3x!fir.logical<1>>>, index, index) -> !fir.ref<!fir.logical<1>>
307+ // CHECK: %[[VAL_26:.*]] = fir.load %[[VAL_25]] : !fir.ref<!fir.logical<1>>
308+ // CHECK: fir.result %[[VAL_26]] : !fir.logical<1>
267309// CHECK: } else {
268- // CHECK: fir.result %[[VAL_13]] : f32
310+ // CHECK: %[[VAL_27:.*]] = arith.constant true
311+ // CHECK: %[[VAL_28:.*]] = fir.convert %[[VAL_27]] : (i1) -> !fir.logical<1>
312+ // CHECK: fir.result %[[VAL_28]] : !fir.logical<1>
269313// CHECK: }
270- // CHECK: fir.result %[[VAL_26]] : f32
314+ // CHECK: %[[VAL_29:.*]] = fir.convert %[[VAL_15]] : (!fir.logical<1>) -> i1
315+ // CHECK: %[[VAL_30:.*]] = fir.if %[[VAL_29]] -> (f32) {
316+ // CHECK: %[[VAL_31:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_8]], %[[VAL_13]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
317+ // CHECK: %[[VAL_32:.*]] = arith.addf %[[VAL_14]], %[[VAL_31]] : f32
318+ // CHECK: fir.result %[[VAL_32]] : f32
319+ // CHECK: } else {
320+ // CHECK: fir.result %[[VAL_14]] : f32
321+ // CHECK: }
322+ // CHECK: fir.result %[[VAL_30]] : f32
271323// CHECK: }
272- // CHECK: hlfir.yield_element %[[VAL_11 ]] : f32
324+ // CHECK: hlfir.yield_element %[[VAL_12 ]] : f32
273325// CHECK: }
274326// CHECK: return
275327// CHECK: }
0 commit comments