Skip to content

Commit 62a8dd7

Browse files
committed
Comments fix and new test.
1 parent ce678bc commit 62a8dd7

File tree

2 files changed

+63
-19
lines changed

2 files changed

+63
-19
lines changed

flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
// Fortran array statements are lowered to fir as fir.do_loop unordered.
1313
// lower-workdistribute pass works mainly on identifying fir.do_loop unordered
1414
// that is nested in target{teams{workdistribute{fir.do_loop unordered}}} and
15-
// lowers it to target{teams{parallel{wsloop{loop_nest}}}}.
15+
// lowers it to target{teams{parallel{distribute{wsloop{loop_nest}}}}}.
1616
// It hoists all the other ops outside target region.
1717
// Relaces heap allocation on target with omp.target_allocmem and
1818
// deallocation with omp.target_freemem from host. Also replaces
19-
// runtime function "Assign" with omp.target_memcpy.
19+
// runtime function "Assign" with omp_target_memcpy.
2020
//
2121
//===----------------------------------------------------------------------===//
2222

@@ -319,13 +319,14 @@ static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop,
319319
// Then, its lowered to
320320
//
321321
// omp.teams {
322-
// omp.parallel {
323-
// omp.distribute {
324-
// omp.wsloop {
325-
// omp.loop_nest
326-
// ...
327-
// }
328-
// }
322+
// omp.parallel {
323+
// omp.distribute {
324+
// omp.wsloop {
325+
// omp.loop_nest
326+
// ...
327+
// }
328+
// }
329+
// }
329330
// }
330331
// }
331332

@@ -345,6 +346,7 @@ WorkdistributeDoLower(omp::WorkdistributeOp workdistribute,
345346
targetOpsToProcess.insert(targetOp);
346347
}
347348
}
349+
// Generate the nested parallel, distribute, wsloop and loop_nest ops.
348350
genParallelOp(wdLoc, rewriter, true);
349351
genDistributeOp(wdLoc, rewriter, true);
350352
mlir::omp::LoopNestOperands loopNestClauseOps;
@@ -584,6 +586,7 @@ WorkdistributeRuntimeCallLower(omp::WorkdistributeOp workdistribute,
584586
}
585587
}
586588
}
589+
// Erase the runtime calls that have been replaced.
587590
for (auto *op : opsToErase) {
588591
op->erase();
589592
}
@@ -772,6 +775,7 @@ static TempOmpVar allocateTempOmpVar(Location loc, Type ty,
772775
Value alloc;
773776
Type allocType;
774777
auto llvmPtrTy = LLVM::LLVMPointerType::get(&ctx);
778+
// Get the appropriate type for allocation
775779
if (isPtr(ty)) {
776780
Type intTy = rewriter.getI32Type();
777781
auto one = rewriter.create<LLVM::ConstantOp>(loc, intTy, 1);
@@ -782,6 +786,7 @@ static TempOmpVar allocateTempOmpVar(Location loc, Type ty,
782786
allocType = ty;
783787
alloc = rewriter.create<fir::AllocaOp>(loc, allocType);
784788
}
789+
// Lambda to create mapinfo ops
785790
auto getMapInfo = [&](uint64_t mappingFlags, const char *name) {
786791
return rewriter.create<omp::MapInfoOp>(
787792
loc, alloc.getType(), alloc, TypeAttr::get(allocType),
@@ -796,6 +801,7 @@ static TempOmpVar allocateTempOmpVar(Location loc, Type ty,
796801
/*mapperId=*/mlir::FlatSymbolRefAttr(),
797802
/*name=*/rewriter.getStringAttr(name), rewriter.getBoolAttr(false));
798803
};
804+
// Create mapinfo ops.
799805
uint64_t mapFrom =
800806
static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
801807
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
@@ -847,14 +853,17 @@ static void collectNonRecomputableDeps(Value &v, omp::TargetOp targetOp,
847853
SetVector<Operation *> &toCache,
848854
SetVector<Operation *> &toRecompute) {
849855
Operation *op = v.getDefiningOp();
856+
// If v is a block argument, it must be from the targetOp.
850857
if (!op) {
851858
assert(cast<BlockArgument>(v).getOwner()->getParentOp() == targetOp);
852859
return;
853860
}
861+
// If the op is in the nonRecomputable set, add it to toCache and return.
854862
if (nonRecomputable.contains(op)) {
855863
toCache.insert(op);
856864
return;
857865
}
866+
// Add the op to toRecompute.
858867
toRecompute.insert(op);
859868
for (auto opr : op->getOperands())
860869
collectNonRecomputableDeps(opr, targetOp, nonRecomputable, toCache,
@@ -939,6 +948,8 @@ static void reloadCacheAndRecompute(
939948
Value newArg =
940949
newTargetBlock->getArgument(hostEvalVarsSize + originalMapVarsSize + i);
941950
Value restored;
951+
// If the original value is a pointer or reference, load and convert if
952+
// necessary.
942953
if (isPtr(original.getType())) {
943954
restored = rewriter.create<LLVM::LoadOp>(loc, llvmPtrTy, newArg);
944955
if (!isa<LLVM::LLVMPointerType>(original.getType()))
@@ -967,6 +978,7 @@ static mlir::omp::LoopNestOp getLoopNestFromTeams(mlir::omp::TeamsOp teamsOp) {
967978
return nullptr;
968979
// Find parallel op inside teams
969980
mlir::omp::ParallelOp parallelOp = nullptr;
981+
// Look for the parallel op in the teams region
970982
for (auto &op : teamsOp.getRegion().front()) {
971983
if (auto parallel = dyn_cast<mlir::omp::ParallelOp>(op)) {
972984
parallelOp = parallel;
@@ -1218,6 +1230,7 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter,
12181230
assert(targetBlock == &targetOp.getRegion().back());
12191231
IRMapping mapping;
12201232

1233+
// Get the parent target_data op
12211234
auto targetDataOp = cast<omp::TargetDataOp>(targetOp->getParentOp());
12221235
if (!targetDataOp) {
12231236
llvm_unreachable("Expected target op to be inside target_data op");
@@ -1255,6 +1268,7 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter,
12551268
SmallVector<Operation *> opsToReplace;
12561269
Value device = targetOp.getDevice();
12571270

1271+
// If device is not specified, default to device 0.
12581272
if (!device) {
12591273
device = genI32Constant(targetOp.getLoc(), rewriter, 0);
12601274
}
@@ -1508,15 +1522,12 @@ genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector<Value> &postMapOperands,
15081522
SmallVector<Value> isolatedHostEvalVars{targetOp.getHostEvalVars()};
15091523
// update the hostEvalVars of isolatedTargetOp
15101524
if (!hostEvalVars.lbs.empty() && !isTargetDevice) {
1511-
for (size_t i = 0; i < hostEvalVars.lbs.size(); ++i) {
1512-
isolatedHostEvalVars.push_back(hostEvalVars.lbs[i]);
1513-
}
1514-
for (size_t i = 0; i < hostEvalVars.ubs.size(); ++i) {
1515-
isolatedHostEvalVars.push_back(hostEvalVars.ubs[i]);
1516-
}
1517-
for (size_t i = 0; i < hostEvalVars.steps.size(); ++i) {
1518-
isolatedHostEvalVars.push_back(hostEvalVars.steps[i]);
1519-
}
1525+
isolatedHostEvalVars.append(hostEvalVars.lbs.begin(),
1526+
hostEvalVars.lbs.end());
1527+
isolatedHostEvalVars.append(hostEvalVars.ubs.begin(),
1528+
hostEvalVars.ubs.end());
1529+
isolatedHostEvalVars.append(hostEvalVars.steps.begin(),
1530+
hostEvalVars.steps.end());
15201531
}
15211532
// Create the isolated target op
15221533
omp::TargetOp isolatedTargetOp = rewriter.create<omp::TargetOp>(
@@ -1708,13 +1719,14 @@ static void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter,
17081719
Operation *toIsolate = std::get<0>(*tuple);
17091720
bool splitBefore = !std::get<1>(*tuple);
17101721
bool splitAfter = !std::get<2>(*tuple);
1711-
1722+
// Recursively isolate the target op.
17121723
if (splitBefore && splitAfter) {
17131724
auto res =
17141725
isolateOp(toIsolate, splitAfter, rewriter, module, isTargetDevice);
17151726
fissionTarget(res.postTargetOp, rewriter, module, isTargetDevice);
17161727
return;
17171728
}
1729+
// Isolate only before the op.
17181730
if (splitBefore) {
17191731
isolateOp(toIsolate, splitAfter, rewriter, module, isTargetDevice);
17201732
return;
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s
2+
3+
! CHECK-LABEL: func @_QPtarget_teams_workdistribute
4+
! CHECK: omp.target_data map_entries({{.*}})
5+
! CHECK: omp.target thread_limit({{.*}}) host_eval({{.*}}) map_entries({{.*}})
6+
! CHECK: omp.teams num_teams({{.*}})
7+
! CHECK: omp.parallel
8+
! CHECK: omp.distribute
9+
! CHECK: omp.wsloop
10+
! CHECK: omp.loop_nest
11+
12+
subroutine target_teams_workdistribute()
13+
use iso_fortran_env
14+
real(kind=real32) :: a
15+
real(kind=real32), dimension(10) :: x
16+
real(kind=real32), dimension(10) :: y
17+
integer :: i
18+
19+
a = 2.0_real32
20+
x = [(real(i, real32), i = 1, 10)]
21+
y = [(real(i * 0.5, real32), i = 1, 10)]
22+
23+
!$omp target teams workdistribute &
24+
!$omp& num_teams(4) &
25+
!$omp& thread_limit(8) &
26+
!$omp& default(shared) &
27+
!$omp& private(i) &
28+
!$omp& map(to: x) &
29+
!$omp& map(tofrom: y)
30+
y = a * x + y
31+
!$omp end target teams workdistribute
32+
end subroutine target_teams_workdistribute

0 commit comments

Comments
 (0)