Skip to content

Commit 8f16af3

Browse files
authored
[Flang][OpenMP] Fix mapping of character type with LEN > 1 specified (#154172)
Currently, there's a number of issues with mapping characters with LEN's specified (strings effectively). They're represented as a char type in FIR with a len parameter, and then later on they're expanded into an array of characters when we're translating to the LLVM dialect. However, we don't generate a bounds for these at lowering. The fix in this PR for this is to generate a bounds from the LEN parameter and attatch it to the map on lowering from FIR to the LLVM dialect when we encounter this type.
1 parent 4294907 commit 8f16af3

File tree

4 files changed

+236
-2
lines changed

4 files changed

+236
-2
lines changed

flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,21 @@ struct MapInfoOpConversion
6060
: public OpenMPFIROpConversion<mlir::omp::MapInfoOp> {
6161
using OpenMPFIROpConversion::OpenMPFIROpConversion;
6262

63+
mlir::omp::MapBoundsOp
64+
createBoundsForCharString(mlir::ConversionPatternRewriter &rewriter,
65+
unsigned int len, mlir::Location loc) const {
66+
mlir::Type i64Ty = rewriter.getIntegerType(64);
67+
auto lBound = mlir::LLVM::ConstantOp::create(rewriter, loc, i64Ty, 0);
68+
auto uBoundAndExt =
69+
mlir::LLVM::ConstantOp::create(rewriter, loc, i64Ty, len - 1);
70+
auto stride = mlir::LLVM::ConstantOp::create(rewriter, loc, i64Ty, 1);
71+
auto baseLb = mlir::LLVM::ConstantOp::create(rewriter, loc, i64Ty, 1);
72+
auto mapBoundType = rewriter.getType<mlir::omp::MapBoundsType>();
73+
return mlir::omp::MapBoundsOp::create(rewriter, loc, mapBoundType, lBound,
74+
uBoundAndExt, uBoundAndExt, stride,
75+
/*strideInBytes*/ false, baseLb);
76+
}
77+
6378
llvm::LogicalResult
6479
matchAndRewrite(mlir::omp::MapInfoOp curOp, OpAdaptor adaptor,
6580
mlir::ConversionPatternRewriter &rewriter) const override {
@@ -69,13 +84,79 @@ struct MapInfoOpConversion
6984
return mlir::failure();
7085

7186
llvm::SmallVector<mlir::NamedAttribute> newAttrs;
72-
mlir::omp::MapInfoOp newOp;
87+
mlir::omp::MapBoundsOp mapBoundsOp;
7388
for (mlir::NamedAttribute attr : curOp->getAttrs()) {
7489
if (auto typeAttr = mlir::dyn_cast<mlir::TypeAttr>(attr.getValue())) {
7590
mlir::Type newAttr;
7691
if (fir::isTypeWithDescriptor(typeAttr.getValue())) {
7792
newAttr = lowerTy().convertBoxTypeAsStruct(
7893
mlir::cast<fir::BaseBoxType>(typeAttr.getValue()));
94+
} else if (fir::isa_char_string(fir::unwrapSequenceType(
95+
fir::unwrapPassByRefType(typeAttr.getValue()))) &&
96+
!characterWithDynamicLen(
97+
fir::unwrapPassByRefType(typeAttr.getValue()))) {
98+
// Characters with a LEN param are represented as strings
99+
// (array of characters), the lowering to LLVM dialect
100+
// doesn't generate bounds for these (and this is not
101+
// done at the initial lowering either) and there is
102+
// minor inconsistencies in the variable types we
103+
// create for the map without this step when converting
104+
// to the LLVM dialect.
105+
//
106+
// For example, given the types:
107+
//
108+
// 1) CHARACTER(LEN=16), dimension(:,:), allocatable :: char_arr
109+
// 2) CHARACTER(LEN=16), dimension(10,10) :: char_arr
110+
//
111+
// We get the FIR types (note for 1: we already peeled off the
112+
// dynamic extents from the type at this stage, but the conversion
113+
// to llvm dialect does that in any case, so the final result
114+
// is the same):
115+
//
116+
// 1) !fir.char<1,16>
117+
// 2) !fir.array<10x10x!fir.char<1,16>>
118+
//
119+
// Which are converted to the LLVM dialect types:
120+
//
121+
// 1) !llvm.array<16 x i8>
122+
// 2) llvm.array<10 x array<10 x array<16 x i8>>
123+
//
124+
// And in both cases, we are missing the innermost bounds for
125+
// the !fir.char<1,16> which is expanded into a 16 x i8 array
126+
// in the conversion to LLVM dialect.
127+
//
128+
// The problem with this is that we would like to treat these
129+
// cases identically and not have to create specialised
130+
// lowerings for either of these in the lowering to LLVM-IR
131+
// and treat them like any other array that passes through.
132+
//
133+
// To do so below, we generate an extra bound for the
134+
// innermost array (the char type/string) using the LEN
135+
// parameter of the character type. And we "canonicalize"
136+
// the type, stripping it down to the base element type,
137+
// which in this case is an i8. This effectively allows
138+
// the lowering to treat this as a 1-D array with multiple
139+
// bounds which it is capable of handling without any special
140+
// casing.
141+
// TODO: Handle dynamic LEN characters.
142+
if (auto ct = mlir::dyn_cast_or_null<fir::CharacterType>(
143+
fir::unwrapSequenceType(typeAttr.getValue()))) {
144+
newAttr = converter->convertType(
145+
fir::unwrapSequenceType(typeAttr.getValue()));
146+
if (auto type = mlir::dyn_cast<mlir::LLVM::LLVMArrayType>(newAttr))
147+
newAttr = type.getElementType();
148+
// We do not generate MapBoundsOps for the device pass, as
149+
// MapBoundsOps are not generated for the device pass, as
150+
// they're unused in the device lowering.
151+
auto offloadMod =
152+
llvm::dyn_cast_or_null<mlir::omp::OffloadModuleInterface>(
153+
*curOp->getParentOfType<mlir::ModuleOp>());
154+
if (!offloadMod.getIsTargetDevice())
155+
mapBoundsOp = createBoundsForCharString(rewriter, ct.getLen(),
156+
curOp.getLoc());
157+
} else {
158+
newAttr = converter->convertType(typeAttr.getValue());
159+
}
79160
} else {
80161
newAttr = converter->convertType(typeAttr.getValue());
81162
}
@@ -85,8 +166,13 @@ struct MapInfoOpConversion
85166
}
86167
}
87168

88-
rewriter.replaceOpWithNewOp<mlir::omp::MapInfoOp>(
169+
auto newOp = rewriter.replaceOpWithNewOp<mlir::omp::MapInfoOp>(
89170
curOp, resTypes, adaptor.getOperands(), newAttrs);
171+
if (mapBoundsOp) {
172+
rewriter.startOpModification(newOp);
173+
newOp.getBoundsMutable().append(mlir::ValueRange{mapBoundsOp});
174+
rewriter.finalizeOpModification(newOp);
175+
}
90176

91177
return mlir::success();
92178
}
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
// RUN: fir-opt --cfg-conversion --fir-to-llvm-ir="target=aarch64-unknown-linux-gnu" %s | FileCheck %s
2+
3+
module attributes {omp.is_target_device = false} {
4+
func.func @_QPchar_array(%arg0 : !fir.ref<!fir.array<10x10x!fir.char<1,16>>>) {
5+
%c9 = arith.constant 9 : index
6+
%c0 = arith.constant 0 : index
7+
%c1 = arith.constant 1 : index
8+
%c10 = arith.constant 10 : index
9+
%0 = omp.map.bounds lower_bound(%c0 : index) upper_bound(%c9 : index) extent(%c10 : index) stride(%c1 : index) start_idx(%c1 : index)
10+
%1 = omp.map.bounds lower_bound(%c0 : index) upper_bound(%c9 : index) extent(%c10 : index) stride(%c1 : index) start_idx(%c1 : index)
11+
%2 = omp.map.info var_ptr(%arg0 : !fir.ref<!fir.array<10x10x!fir.char<1,16>>>, !fir.array<10x10x!fir.char<1,16>>) map_clauses(tofrom) capture(ByRef) bounds(%0, %1) -> !fir.ref<!fir.array<10x10x!fir.char<1,16>>> {name = ""}
12+
omp.target map_entries(%2 -> %arg1 : !fir.ref<!fir.array<10x10x!fir.char<1,16>>>) {
13+
omp.terminator
14+
}
15+
return
16+
}
17+
18+
// CHECK-LABEL: llvm.func @_QPchar_array(
19+
// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr) {
20+
// CHECK: %[[VAL_0:.*]] = llvm.mlir.constant(9 : index) : i64
21+
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(0 : index) : i64
22+
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(1 : index) : i64
23+
// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(10 : index) : i64
24+
// CHECK: %[[VAL_4:.*]] = omp.map.bounds lower_bound(%[[VAL_1]] : i64) upper_bound(%[[VAL_0]] : i64) extent(%[[VAL_3]] : i64) stride(%[[VAL_2]] : i64) start_idx(%[[VAL_2]] : i64)
25+
// CHECK: %[[VAL_5:.*]] = omp.map.bounds lower_bound(%[[VAL_1]] : i64) upper_bound(%[[VAL_0]] : i64) extent(%[[VAL_3]] : i64) stride(%[[VAL_2]] : i64) start_idx(%[[VAL_2]] : i64)
26+
// CHECK: %[[VAL_6:.*]] = llvm.mlir.constant(0 : i64) : i64
27+
// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(15 : i64) : i64
28+
// CHECK: %[[VAL_8:.*]] = llvm.mlir.constant(1 : i64) : i64
29+
// CHECK: %[[VAL_9:.*]] = llvm.mlir.constant(1 : i64) : i64
30+
// CHECK: %[[VAL_10:.*]] = omp.map.bounds lower_bound(%[[VAL_6]] : i64) upper_bound(%[[VAL_7]] : i64) extent(%[[VAL_7]] : i64) stride(%[[VAL_8]] : i64) start_idx(%[[VAL_9]] : i64)
31+
// CHECK: %[[VAL_11:.*]] = omp.map.info var_ptr(%[[ARG0]] : !llvm.ptr, i8) map_clauses(tofrom) capture(ByRef) bounds(%[[VAL_4]], %[[VAL_5]], %[[VAL_10]]) -> !llvm.ptr {name = ""}
32+
// CHECK: omp.target map_entries(%[[VAL_11]] -> %[[VAL_12:.*]] : !llvm.ptr) {
33+
// CHECK: omp.terminator
34+
// CHECK: }
35+
// CHECK: llvm.return
36+
// CHECK: }
37+
38+
func.func @_QPallocatable_char_array(%arg0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?x!fir.char<1,16>>>>>) {
39+
%c1 = arith.constant 1 : index
40+
%c0 = arith.constant 0 : index
41+
%0 = fir.load %arg0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?x!fir.char<1,16>>>>>
42+
%1:3 = fir.box_dims %0, %c0 : (!fir.box<!fir.heap<!fir.array<?x?x!fir.char<1,16>>>>, index) -> (index, index, index)
43+
%2 = arith.subi %1#1, %c1 : index
44+
%3 = omp.map.bounds lower_bound(%c0 : index) upper_bound(%2 : index) extent(%1#1 : index) stride(%1#2 : index) start_idx(%1#0 : index) {stride_in_bytes = true}
45+
%4 = arith.muli %1#2, %1#1 : index
46+
%5:3 = fir.box_dims %0, %c1 : (!fir.box<!fir.heap<!fir.array<?x?x!fir.char<1,16>>>>, index) -> (index, index, index)
47+
%6 = arith.subi %5#1, %c1 : index
48+
%7 = omp.map.bounds lower_bound(%c0 : index) upper_bound(%6 : index) extent(%5#1 : index) stride(%4 : index) start_idx(%5#0 : index) {stride_in_bytes = true}
49+
%8 = fir.box_offset %arg0 base_addr : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?x!fir.char<1,16>>>>>) -> !fir.llvm_ptr<!fir.ref<!fir.array<?x?x!fir.char<1,16>>>>
50+
%9 = omp.map.info var_ptr(%arg0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?x!fir.char<1,16>>>>>, !fir.char<1,16>) map_clauses(tofrom) capture(ByRef) var_ptr_ptr(%8 : !fir.llvm_ptr<!fir.ref<!fir.array<?x?x!fir.char<1,16>>>>) bounds(%3, %7) -> !fir.llvm_ptr<!fir.ref<!fir.array<?x?x!fir.char<1,16>>>> {name = ""}
51+
%10 = omp.map.info var_ptr(%arg0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?x!fir.char<1,16>>>>>, !fir.box<!fir.heap<!fir.array<?x?x!fir.char<1,16>>>>) map_clauses(to) capture(ByRef) members(%9 : [0] : !fir.llvm_ptr<!fir.ref<!fir.array<?x?x!fir.char<1,16>>>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?x?x!fir.char<1,16>>>>> {name = "csv_chem_list_a"}
52+
omp.target map_entries(%10 -> %arg1, %9 -> %arg2 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?x!fir.char<1,16>>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?x?x!fir.char<1,16>>>>) {
53+
omp.terminator
54+
}
55+
return
56+
}
57+
58+
// CHECK-LABEL: llvm.func @_QPallocatable_char_array(
59+
// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr) {
60+
// CHECK: %[[VAL_0:.*]] = llvm.mlir.constant(1 : i32) : i32
61+
// CHECK: %[[VAL_1:.*]] = llvm.alloca %[[VAL_0]] x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<2 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
62+
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(1 : index) : i64
63+
// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(0 : index) : i64
64+
// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(72 : i32) : i32
65+
// CHECK: "llvm.intr.memcpy"(%[[VAL_1]], %[[ARG0]], %[[VAL_4]]) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
66+
// CHECK: %[[VAL_5:.*]] = llvm.getelementptr %[[VAL_1]][0, 7, %[[VAL_3]], 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<2 x array<3 x i64>>)>
67+
// CHECK: %[[VAL_6:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr -> i64
68+
// CHECK: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_1]][0, 7, %[[VAL_3]], 1] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<2 x array<3 x i64>>)>
69+
// CHECK: %[[VAL_8:.*]] = llvm.load %[[VAL_7]] : !llvm.ptr -> i64
70+
// CHECK: %[[VAL_9:.*]] = llvm.getelementptr %[[VAL_1]][0, 7, %[[VAL_3]], 2] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<2 x array<3 x i64>>)>
71+
// CHECK: %[[VAL_10:.*]] = llvm.load %[[VAL_9]] : !llvm.ptr -> i64
72+
// CHECK: %[[VAL_11:.*]] = llvm.sub %[[VAL_8]], %[[VAL_2]] : i64
73+
// CHECK: %[[VAL_12:.*]] = omp.map.bounds lower_bound(%[[VAL_3]] : i64) upper_bound(%[[VAL_11]] : i64) extent(%[[VAL_8]] : i64) stride(%[[VAL_10]] : i64) start_idx(%[[VAL_6]] : i64) {stride_in_bytes = true}
74+
// CHECK: %[[VAL_13:.*]] = llvm.mul %[[VAL_10]], %[[VAL_8]] : i64
75+
// CHECK: %[[VAL_14:.*]] = llvm.getelementptr %[[VAL_1]][0, 7, %[[VAL_2]], 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<2 x array<3 x i64>>)>
76+
// CHECK: %[[VAL_15:.*]] = llvm.load %[[VAL_14]] : !llvm.ptr -> i64
77+
// CHECK: %[[VAL_16:.*]] = llvm.getelementptr %[[VAL_1]][0, 7, %[[VAL_2]], 1] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<2 x array<3 x i64>>)>
78+
// CHECK: %[[VAL_17:.*]] = llvm.load %[[VAL_16]] : !llvm.ptr -> i64
79+
// CHECK: %[[VAL_18:.*]] = llvm.getelementptr %[[VAL_1]][0, 7, %[[VAL_2]], 2] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<2 x array<3 x i64>>)>
80+
// CHECK: %[[VAL_19:.*]] = llvm.load %[[VAL_18]] : !llvm.ptr -> i64
81+
// CHECK: %[[VAL_20:.*]] = llvm.sub %[[VAL_17]], %[[VAL_2]] : i64
82+
// CHECK: %[[VAL_21:.*]] = omp.map.bounds lower_bound(%[[VAL_3]] : i64) upper_bound(%[[VAL_20]] : i64) extent(%[[VAL_17]] : i64) stride(%[[VAL_13]] : i64) start_idx(%[[VAL_15]] : i64) {stride_in_bytes = true}
83+
// CHECK: %[[VAL_22:.*]] = llvm.getelementptr %[[ARG0]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<2 x array<3 x i64>>)>
84+
// CHECK: %[[VAL_23:.*]] = llvm.mlir.constant(0 : i64) : i64
85+
// CHECK: %[[VAL_24:.*]] = llvm.mlir.constant(15 : i64) : i64
86+
// CHECK: %[[VAL_25:.*]] = llvm.mlir.constant(1 : i64) : i64
87+
// CHECK: %[[VAL_26:.*]] = llvm.mlir.constant(1 : i64) : i64
88+
// CHECK: %[[VAL_27:.*]] = omp.map.bounds lower_bound(%[[VAL_23]] : i64) upper_bound(%[[VAL_24]] : i64) extent(%[[VAL_24]] : i64) stride(%[[VAL_25]] : i64) start_idx(%[[VAL_26]] : i64)
89+
// CHECK: %[[VAL_28:.*]] = omp.map.info var_ptr(%[[ARG0]] : !llvm.ptr, i8) map_clauses(tofrom) capture(ByRef) var_ptr_ptr(%[[VAL_22]] : !llvm.ptr) bounds(%[[VAL_12]], %[[VAL_21]], %[[VAL_27]]) -> !llvm.ptr {name = ""}
90+
// CHECK: %[[VAL_29:.*]] = omp.map.info var_ptr(%[[ARG0]] : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<2 x array<3 x i64>>)>) map_clauses(to) capture(ByRef) members(%[[VAL_28]] : [0] : !llvm.ptr) -> !llvm.ptr {name = "csv_chem_list_a"}
91+
// CHECK: omp.target map_entries(%[[VAL_29]] -> %[[VAL_30:.*]], %[[VAL_28]] -> %[[VAL_31:.*]] : !llvm.ptr, !llvm.ptr) {
92+
// CHECK: omp.terminator
93+
// CHECK: }
94+
// CHECK: llvm.return
95+
// CHECK: }
96+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
! Offloading test that verifies certain type of character string arrays
2+
! map to and from device without problem.
3+
! REQUIRES: flang, amdgpu
4+
5+
! RUN: %libomptarget-compile-fortran-run-and-check-generic
6+
program main
7+
implicit none
8+
type char_t
9+
CHARACTER(LEN=16), dimension(10,10) :: char_arr
10+
end type char_t
11+
type(char_t) :: dtype_char
12+
13+
!$omp target enter data map(alloc:dtype_char%char_arr)
14+
15+
!$omp target
16+
dtype_char%char_arr(2,2) = 'c'
17+
!$omp end target
18+
19+
!$omp target update from(dtype_char%char_arr)
20+
21+
22+
print *, dtype_char%char_arr(2,2)
23+
end program
24+
25+
!CHECK: c
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
! Offloading test that verifies certain type of character string arrays
2+
! (in this case allocatable) map to and from device without problem.
3+
! REQUIRES: flang, amdgpu
4+
5+
! RUN: %libomptarget-compile-fortran-run-and-check-generic
6+
program main
7+
implicit none
8+
type char_t
9+
CHARACTER(LEN=16), dimension(:,:), allocatable :: char_arr
10+
end type char_t
11+
type(char_t) :: dtype_char
12+
13+
allocate(dtype_char%char_arr(10,10))
14+
15+
!$omp target enter data map(alloc:dtype_char%char_arr)
16+
17+
!$omp target
18+
dtype_char%char_arr(2,2) = 'c'
19+
!$omp end target
20+
21+
!$omp target update from(dtype_char%char_arr)
22+
23+
24+
print *, dtype_char%char_arr(2,2)
25+
end program
26+
27+
!CHECK: c

0 commit comments

Comments
 (0)