Skip to content

Commit 7d56f36

Browse files
committed
[𝘀𝗽𝗿] changes to main this commit is based on
Created using spr 1.3.4 [skip ci]
1 parent f873fc3 commit 7d56f36

File tree

34 files changed

+1508
-115
lines changed

34 files changed

+1508
-115
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -879,7 +879,7 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> {
879879
TypeAttr:$var_type,
880880
Optional<OpenMP_PointerLikeType>:$var_ptr_ptr,
881881
Variadic<OpenMP_PointerLikeType>:$members,
882-
OptionalAttr<AnyIntElementsAttr>:$members_index,
882+
OptionalAttr<IndexListArrayAttr>:$members_index,
883883
Variadic<OpenMP_MapBoundsType>:$bounds, /* rank-0 to rank-{n-1} */
884884
OptionalAttr<UI64Attr>:$map_type,
885885
OptionalAttr<VariableCaptureKindAttr>:$map_capture_type,

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 19 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1392,16 +1392,15 @@ static void printMapClause(OpAsmPrinter &p, Operation *op,
13921392
}
13931393

13941394
static ParseResult parseMembersIndex(OpAsmParser &parser,
1395-
DenseIntElementsAttr &membersIdx) {
1396-
SmallVector<APInt> values;
1395+
ArrayAttr &membersIdx) {
1396+
SmallVector<Attribute> values, memberIdxs;
13971397
int64_t value;
1398-
int64_t shape[2] = {0, 0};
1399-
unsigned shapeTmp = 0;
1398+
14001399
auto parseIndices = [&]() -> ParseResult {
14011400
if (parser.parseInteger(value))
14021401
return failure();
1403-
shapeTmp++;
1404-
values.push_back(APInt(32, value));
1402+
values.push_back(IntegerAttr::get(parser.getBuilder().getIntegerType(64),
1403+
mlir::APInt(64, value)));
14051404
return success();
14061405
};
14071406

@@ -1415,51 +1414,32 @@ static ParseResult parseMembersIndex(OpAsmParser &parser,
14151414
if (failed(parser.parseRSquare()))
14161415
return failure();
14171416

1418-
// Only set once, if any indices are not the same size
1419-
// we error out in the next check as that's unsupported
1420-
if (shape[1] == 0)
1421-
shape[1] = shapeTmp;
1422-
1423-
// Verify that the recently parsed list is equal to the
1424-
// first one we parsed, they must be equal lengths to
1425-
// keep the rectangular shape DenseIntElementsAttr
1426-
// requires
1427-
if (shapeTmp != shape[1])
1428-
return failure();
1429-
1430-
shapeTmp = 0;
1431-
shape[0]++;
1417+
memberIdxs.push_back(ArrayAttr::get(parser.getContext(), values));
1418+
values.clear();
14321419
} while (succeeded(parser.parseOptionalComma()));
14331420

1434-
if (!values.empty()) {
1435-
ShapedType valueType =
1436-
VectorType::get(shape, IntegerType::get(parser.getContext(), 32));
1437-
membersIdx = DenseIntElementsAttr::get(valueType, values);
1438-
}
1421+
if (!memberIdxs.empty())
1422+
membersIdx = ArrayAttr::get(parser.getContext(), memberIdxs);
14391423

14401424
return success();
14411425
}
14421426

14431427
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op,
1444-
DenseIntElementsAttr membersIdx) {
1445-
llvm::ArrayRef<int64_t> shape = membersIdx.getShapedType().getShape();
1446-
assert(shape.size() <= 2);
1447-
1428+
ArrayAttr membersIdx) {
14481429
if (!membersIdx)
14491430
return;
14501431

1451-
for (int i = 0; i < shape[0]; ++i) {
1432+
SmallVector<std::string> idxs;
1433+
for (auto [i, v] : llvm::enumerate(membersIdx)) {
1434+
auto memberIdx = mlir::cast<mlir::ArrayAttr>(v);
14521435
p << "[";
1453-
int rowOffset = i * shape[1];
1454-
for (int j = 0; j < shape[1]; ++j) {
1455-
p << membersIdx.getValues<int32_t>()[rowOffset + j];
1456-
if ((j + 1) < shape[1])
1457-
p << ",";
1458-
}
1459-
p << "]";
1460-
1461-
if ((i + 1) < shape[0])
1436+
for (auto v2 : memberIdx.getValue())
1437+
idxs.push_back(
1438+
std::to_string(mlir::cast<mlir::IntegerAttr>(v2).getInt()));
1439+
p << llvm::join(idxs, ",") << "]";
1440+
if ((i + 1) < membersIdx.getValue().size())
14621441
p << ", ";
1442+
idxs.clear();
14631443
}
14641444
}
14651445

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 58 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2468,51 +2468,45 @@ static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) {
24682468
return std::distance(mapData.MapClause.begin(), res);
24692469
}
24702470

2471-
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
2472-
bool first) {
2473-
DenseIntElementsAttr indexAttr = mapInfo.getMembersIndexAttr();
2474-
2471+
static mlir::omp::MapInfoOp
2472+
getFirstOrLastMappedMemberPtr(mlir::omp::MapInfoOp mapInfo, bool first) {
2473+
mlir::ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
24752474
// Only 1 member has been mapped, we can return it.
24762475
if (indexAttr.size() == 1)
2477-
if (auto mapOp =
2478-
dyn_cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp()))
2476+
if (auto mapOp = mlir::dyn_cast<mlir::omp::MapInfoOp>(
2477+
mapInfo.getMembers()[0].getDefiningOp()))
24792478
return mapOp;
24802479

2481-
llvm::ArrayRef<int64_t> shape = indexAttr.getShapedType().getShape();
2482-
llvm::SmallVector<size_t> indices(shape[0]);
2480+
llvm::SmallVector<size_t> indices(indexAttr.size());
24832481
std::iota(indices.begin(), indices.end(), 0);
24842482

24852483
llvm::sort(indices.begin(), indices.end(),
24862484
[&](const size_t a, const size_t b) {
2487-
auto indexValues = indexAttr.getValues<int32_t>();
2488-
for (int i = 0; i < shape[1]; ++i) {
2489-
int aIndex = indexValues[a * shape[1] + i];
2490-
int bIndex = indexValues[b * shape[1] + i];
2485+
auto memberIndicesA = mlir::cast<mlir::ArrayAttr>(indexAttr[a]);
2486+
auto memberIndicesB = mlir::cast<mlir::ArrayAttr>(indexAttr[b]);
2487+
for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
2488+
int64_t aIndex =
2489+
mlir::cast<mlir::IntegerAttr>(std::get<0>(it)).getInt();
2490+
int64_t bIndex =
2491+
mlir::cast<mlir::IntegerAttr>(std::get<1>(it)).getInt();
24912492

24922493
if (aIndex == bIndex)
24932494
continue;
24942495

2495-
if (aIndex != -1 && bIndex == -1)
2496-
return false;
2497-
2498-
if (aIndex == -1 && bIndex != -1)
2499-
return true;
2500-
2501-
// A is earlier in the record type layout than B
25022496
if (aIndex < bIndex)
25032497
return first;
25042498

2505-
if (bIndex < aIndex)
2499+
if (aIndex > bIndex)
25062500
return !first;
25072501
}
25082502

2509-
// Iterated the entire list and couldn't make a decision, all
2510-
// elements were likely the same. Return false, since the sort
2511-
// comparator should return false for equal elements.
2512-
return false;
2503+
// Iterated the up until the end of the smallest member and
2504+
// they were found to be equal up to that point, so select
2505+
// the member with the lowest index count, so the "parent"
2506+
return memberIndicesA.size() < memberIndicesB.size();
25132507
});
25142508

2515-
return llvm::cast<omp::MapInfoOp>(
2509+
return llvm::cast<mlir::omp::MapInfoOp>(
25162510
mapInfo.getMembers()[indices.front()].getDefiningOp());
25172511
}
25182512

@@ -2663,6 +2657,8 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
26632657
auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
26642658
int firstMemberIdx = getMapDataMemberIdx(
26652659
mapData, getFirstOrLastMappedMemberPtr(mapOp, true));
2660+
// NOTE/TODO: Should perhaps use OriginalValue here instead of Pointers to
2661+
// avoid offset or any manipulations interfering with the calculation.
26662662
lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
26672663
builder.getPtrTy());
26682664
int lastMemberIdx = getMapDataMemberIdx(
@@ -2680,17 +2676,8 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
26802676
/*isSigned=*/false);
26812677
combinedInfo.Sizes.push_back(size);
26822678

2683-
// TODO: This will need to be expanded to include the whole host of logic for
2684-
// the map flags that Clang currently supports (e.g. it should take the map
2685-
// flag of the parent map flag, remove the OMP_MAP_TARGET_PARAM and do some
2686-
// further case specific flag modifications). For the moment, it handles what
2687-
// we support as expected.
2688-
llvm::omp::OpenMPOffloadMappingFlags mapFlag =
2689-
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
2690-
26912679
llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
26922680
ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
2693-
ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
26942681

26952682
// This creates the initial MEMBER_OF mapping that consists of
26962683
// the parent/top level container (same as above effectively, except
@@ -2699,6 +2686,12 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
26992686
// only relevant if the structure in its totality is being mapped,
27002687
// otherwise the above suffices.
27012688
if (!parentClause.getPartialMap()) {
2689+
// TODO: This will need to be expanded to include the whole host of logic
2690+
// for the map flags that Clang currently supports (e.g. it should do some
2691+
// further case specific flag modifications). For the moment, it handles
2692+
// what we support as expected.
2693+
llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
2694+
ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
27022695
combinedInfo.Types.emplace_back(mapFlag);
27032696
combinedInfo.DevicePointers.emplace_back(
27042697
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
@@ -2749,6 +2742,31 @@ static void processMapMembersWithParent(
27492742

27502743
assert(memberDataIdx >= 0 && "could not find mapped member of structure");
27512744

2745+
// If we're currently mapping a pointer to a block of data, we must
2746+
// initially map the pointer, and then attatch/bind the data with a
2747+
// subsequent map to the pointer. This segment of code generates the
2748+
// pointer mapping, which can in certain cases can be optimised
2749+
// out as Clang currently does in its lowering. However, for the moment
2750+
// we do not do so, in part as we currently have substantially less
2751+
// information on the data being mapped at this stage.
2752+
if (checkIfPointerMap(memberClause)) {
2753+
auto mapFlag = llvm::omp::OpenMPOffloadMappingFlags(
2754+
memberClause.getMapType().value());
2755+
mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
2756+
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
2757+
ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
2758+
combinedInfo.Types.emplace_back(mapFlag);
2759+
combinedInfo.DevicePointers.emplace_back(
2760+
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
2761+
combinedInfo.Names.emplace_back(
2762+
LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
2763+
combinedInfo.BasePointers.emplace_back(
2764+
mapData.BasePointers[mapDataIndex]);
2765+
combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
2766+
combinedInfo.Sizes.emplace_back(builder.getInt64(
2767+
moduleTranslation.getLLVMModule()->getDataLayout().getPointerSize()));
2768+
}
2769+
27522770
// Same MemberOfFlag to indicate its link with parent and other members
27532771
// of.
27542772
auto mapFlag =
@@ -2764,7 +2782,12 @@ static void processMapMembersWithParent(
27642782
mapData.DevicePointers[memberDataIdx]);
27652783
combinedInfo.Names.emplace_back(
27662784
LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
2767-
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
2785+
if (checkIfPointerMap(memberClause))
2786+
combinedInfo.BasePointers.emplace_back(
2787+
mapData.BasePointers[memberDataIdx]);
2788+
else
2789+
combinedInfo.BasePointers.emplace_back(
2790+
mapData.BasePointers[mapDataIndex]);
27682791
combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
27692792
combinedInfo.Sizes.emplace_back(mapData.Sizes[memberDataIdx]);
27702793
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
// This test checks the offload sizes, map types and base pointers and pointers
4+
// provided to the OpenMP kernel argument structure are correct when lowering
5+
// to LLVM-IR from MLIR when performing explicit member mapping of a record type
6+
// that includes fortran allocatables in various locations of the record types
7+
// hierarchy.
8+
9+
module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
10+
llvm.func @omp_nested_derived_type_alloca_map(%arg0: !llvm.ptr) {
11+
%0 = llvm.mlir.constant(4 : index) : i64
12+
%1 = llvm.mlir.constant(1 : index) : i64
13+
%2 = llvm.mlir.constant(2 : index) : i64
14+
%3 = llvm.mlir.constant(0 : index) : i64
15+
%4 = llvm.mlir.constant(6 : index) : i64
16+
%5 = omp.map.bounds lower_bound(%3 : i64) upper_bound(%0 : i64) extent(%0 : i64) stride(%1 : i64) start_idx(%3 : i64) {stride_in_bytes = true}
17+
%6 = llvm.getelementptr %arg0[0, 6] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"_QFtest_nested_derived_type_alloca_map_operand_and_block_additionTtop_layer", (f32, struct<(ptr, i64, i32, i8, i8, i8, i8)>, array<10 x i32>, f32, struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>, i32, struct<"_QFtest_nested_derived_type_alloca_map_operand_and_block_additionTmiddle_layer", (f32, array<10 x i32>, struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>, i32)>)>
18+
%7 = llvm.getelementptr %6[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"_QFtest_nested_derived_type_alloca_map_operand_and_block_additionTmiddle_layer", (f32, array<10 x i32>, struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>, i32)>
19+
%8 = llvm.getelementptr %7[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
20+
%9 = omp.map.info var_ptr(%7 : !llvm.ptr, i32) var_ptr_ptr(%8 : !llvm.ptr) map_clauses(tofrom) capture(ByRef) bounds(%5) -> !llvm.ptr {name = ""}
21+
%10 = omp.map.info var_ptr(%7 : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "one_l%nest%array_k"}
22+
%11 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.struct<"_QFtest_nested_derived_type_alloca_map_operand_and_block_additionTtop_layer", (f32, struct<(ptr, i64, i32, i8, i8, i8, i8)>, array<10 x i32>, f32, struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>, i32, struct<"_QFtest_nested_derived_type_alloca_map_operand_and_block_additionTmiddle_layer", (f32, array<10 x i32>, struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>, i32)>)>) map_clauses(tofrom) capture(ByRef) members(%10, %9 : [6,2], [6,2,0] : !llvm.ptr, !llvm.ptr) -> !llvm.ptr {name = "one_l", partial_map = true}
23+
omp.target map_entries(%10 -> %arg1, %9 -> %arg2, %11 -> %arg3 : !llvm.ptr, !llvm.ptr, !llvm.ptr) {
24+
omp.terminator
25+
}
26+
llvm.return
27+
}
28+
}
29+
30+
// CHECK: @.offload_sizes{{.*}} = private unnamed_addr constant [4 x i64] [i64 0, i64 48, i64 8, i64 20]
31+
// CHECK: @.offload_maptypes{{.*}} = private unnamed_addr constant [4 x i64] [i64 32, i64 281474976710659, i64 281474976710659, i64 281474976710675]
32+
33+
// CHECK: define void @omp_nested_derived_type_alloca_map(ptr %[[ARG:.*]]) {
34+
35+
// CHECK: %[[NESTED_DTYPE_MEMBER_GEP:.*]] = getelementptr %_QFtest_nested_derived_type_alloca_map_operand_and_block_additionTtop_layer, ptr %[[ARG]], i32 0, i32 6
36+
// CHECK: %[[NESTED_ALLOCATABLE_MEMBER_GEP:.*]] = getelementptr %_QFtest_nested_derived_type_alloca_map_operand_and_block_additionTmiddle_layer, ptr %[[NESTED_DTYPE_MEMBER_GEP]], i32 0, i32 2
37+
// CHECK: %[[NESTED_ALLOCATABLE_MEMBER_BADDR_GEP:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, ptr %[[NESTED_ALLOCATABLE_MEMBER_GEP]], i32 0, i32 0
38+
// CHECK: %[[NESTED_ALLOCATABLE_MEMBER_BADDR_LOAD:.*]] = load ptr, ptr %[[NESTED_ALLOCATABLE_MEMBER_BADDR_GEP]], align 8
39+
// CHECK: %[[ARR_OFFSET:.*]] = getelementptr inbounds i32, ptr %[[NESTED_ALLOCATABLE_MEMBER_BADDR_LOAD]], i64 0
40+
// CHECK: %[[DTYPE_SIZE_SEGMENT_CALC_1:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, ptr %[[NESTED_ALLOCATABLE_MEMBER_GEP]], i64 1
41+
// CHECK: %[[DTYPE_SIZE_SEGMENT_CALC_2:.*]] = ptrtoint ptr %[[DTYPE_SIZE_SEGMENT_CALC_1]] to i64
42+
// CHECK: %[[DTYPE_SIZE_SEGMENT_CALC_3:.*]] = ptrtoint ptr %[[NESTED_ALLOCATABLE_MEMBER_GEP]] to i64
43+
// CHECK: %[[DTYPE_SIZE_SEGMENT_CALC_4:.*]] = sub i64 %[[DTYPE_SIZE_SEGMENT_CALC_2]], %[[DTYPE_SIZE_SEGMENT_CALC_3]]
44+
// CHECK: %[[DTYPE_SIZE_SEGMENT_CALC_5:.*]] = sdiv exact i64 %[[DTYPE_SIZE_SEGMENT_CALC_4]], ptrtoint (ptr getelementptr (i8, ptr null, i32 1) to i64)
45+
46+
// CHECK: %[[BASE_PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_baseptrs, i32 0, i32 0
47+
// CHECK: store ptr %[[ARG]], ptr %[[BASE_PTRS]], align 8
48+
// CHECK: %[[OFFLOAD_PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_ptrs, i32 0, i32 0
49+
// CHECK: store ptr %[[NESTED_ALLOCATABLE_MEMBER_GEP]], ptr %[[OFFLOAD_PTRS]], align 8
50+
// CHECK: %[[OFFLOAD_SIZES:.*]] = getelementptr inbounds [4 x i64], ptr %.offload_sizes, i32 0, i32 0
51+
// CHECK: store i64 %[[DTYPE_SIZE_SEGMENT_CALC_5]], ptr %[[OFFLOAD_SIZES]], align 8
52+
53+
// CHECK: %[[BASE_PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_baseptrs, i32 0, i32 1
54+
// CHECK: store ptr %[[ARG]], ptr %[[BASE_PTRS]], align 8
55+
// CHECK: %[[OFFLOAD_PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_ptrs, i32 0, i32 1
56+
// CHECK: store ptr %[[NESTED_ALLOCATABLE_MEMBER_GEP]], ptr %[[OFFLOAD_PTRS]], align 8
57+
58+
// CHECK: %[[BASE_PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_baseptrs, i32 0, i32 2
59+
// CHECK: store ptr %[[ARG]], ptr %[[BASE_PTRS]], align 8
60+
// CHECK: %[[OFFLOAD_PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_ptrs, i32 0, i32 2
61+
// CHECK: store ptr %[[NESTED_ALLOCATABLE_MEMBER_BADDR_GEP]], ptr %[[OFFLOAD_PTRS]], align 8
62+
63+
// CHECK: %[[BASE_PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_baseptrs, i32 0, i32 3
64+
// CHECK: store ptr %[[NESTED_ALLOCATABLE_MEMBER_BADDR_GEP]], ptr %[[BASE_PTRS]], align 8
65+
// CHECK: %[[OFFLOAD_PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_ptrs, i32 0, i32 3
66+
// CHECK: store ptr %[[ARR_OFFSET]], ptr %[[OFFLOAD_PTRS]], align 8

0 commit comments

Comments
 (0)