12
12
// Fortran array statements are lowered to fir as fir.do_loop unordered.
13
13
// lower-workdistribute pass works mainly on identifying fir.do_loop unordered
14
14
// that is nested in target{teams{workdistribute{fir.do_loop unordered}}} and
15
- // lowers it to target{teams{parallel{wsloop{loop_nest}}}}.
15
+ // lowers it to target{teams{parallel{distribute{ wsloop{loop_nest} }}}}.
16
16
// It hoists all the other ops outside target region.
17
17
// Relaces heap allocation on target with omp.target_allocmem and
18
18
// deallocation with omp.target_freemem from host. Also replaces
19
- // runtime function "Assign" with omp.target_memcpy .
19
+ // runtime function "Assign" with omp_target_memcpy .
20
20
//
21
21
// ===----------------------------------------------------------------------===//
22
22
@@ -319,13 +319,14 @@ static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop,
319
319
// Then, its lowered to
320
320
//
321
321
// omp.teams {
322
- // omp.parallel {
323
- // omp.distribute {
324
- // omp.wsloop {
325
- // omp.loop_nest
326
- // ...
327
- // }
328
- // }
322
+ // omp.parallel {
323
+ // omp.distribute {
324
+ // omp.wsloop {
325
+ // omp.loop_nest
326
+ // ...
327
+ // }
328
+ // }
329
+ // }
329
330
// }
330
331
// }
331
332
@@ -345,6 +346,7 @@ WorkdistributeDoLower(omp::WorkdistributeOp workdistribute,
345
346
targetOpsToProcess.insert (targetOp);
346
347
}
347
348
}
349
+ // Generate the nested parallel, distribute, wsloop and loop_nest ops.
348
350
genParallelOp (wdLoc, rewriter, true );
349
351
genDistributeOp (wdLoc, rewriter, true );
350
352
mlir::omp::LoopNestOperands loopNestClauseOps;
@@ -584,6 +586,7 @@ WorkdistributeRuntimeCallLower(omp::WorkdistributeOp workdistribute,
584
586
}
585
587
}
586
588
}
589
+ // Erase the runtime calls that have been replaced.
587
590
for (auto *op : opsToErase) {
588
591
op->erase ();
589
592
}
@@ -772,6 +775,7 @@ static TempOmpVar allocateTempOmpVar(Location loc, Type ty,
772
775
Value alloc;
773
776
Type allocType;
774
777
auto llvmPtrTy = LLVM::LLVMPointerType::get (&ctx);
778
+ // Get the appropriate type for allocation
775
779
if (isPtr (ty)) {
776
780
Type intTy = rewriter.getI32Type ();
777
781
auto one = rewriter.create <LLVM::ConstantOp>(loc, intTy, 1 );
@@ -782,6 +786,7 @@ static TempOmpVar allocateTempOmpVar(Location loc, Type ty,
782
786
allocType = ty;
783
787
alloc = rewriter.create <fir::AllocaOp>(loc, allocType);
784
788
}
789
+ // Lambda to create mapinfo ops
785
790
auto getMapInfo = [&](uint64_t mappingFlags, const char *name) {
786
791
return rewriter.create <omp::MapInfoOp>(
787
792
loc, alloc.getType (), alloc, TypeAttr::get (allocType),
@@ -796,6 +801,7 @@ static TempOmpVar allocateTempOmpVar(Location loc, Type ty,
796
801
/* mapperId=*/ mlir::FlatSymbolRefAttr (),
797
802
/* name=*/ rewriter.getStringAttr (name), rewriter.getBoolAttr (false ));
798
803
};
804
+ // Create mapinfo ops.
799
805
uint64_t mapFrom =
800
806
static_cast <std::underlying_type_t <llvm::omp::OpenMPOffloadMappingFlags>>(
801
807
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
@@ -847,14 +853,17 @@ static void collectNonRecomputableDeps(Value &v, omp::TargetOp targetOp,
847
853
SetVector<Operation *> &toCache,
848
854
SetVector<Operation *> &toRecompute) {
849
855
Operation *op = v.getDefiningOp ();
856
+ // If v is a block argument, it must be from the targetOp.
850
857
if (!op) {
851
858
assert (cast<BlockArgument>(v).getOwner ()->getParentOp () == targetOp);
852
859
return ;
853
860
}
861
+ // If the op is in the nonRecomputable set, add it to toCache and return.
854
862
if (nonRecomputable.contains (op)) {
855
863
toCache.insert (op);
856
864
return ;
857
865
}
866
+ // Add the op to toRecompute.
858
867
toRecompute.insert (op);
859
868
for (auto opr : op->getOperands ())
860
869
collectNonRecomputableDeps (opr, targetOp, nonRecomputable, toCache,
@@ -939,6 +948,8 @@ static void reloadCacheAndRecompute(
939
948
Value newArg =
940
949
newTargetBlock->getArgument (hostEvalVarsSize + originalMapVarsSize + i);
941
950
Value restored;
951
+ // If the original value is a pointer or reference, load and convert if
952
+ // necessary.
942
953
if (isPtr (original.getType ())) {
943
954
restored = rewriter.create <LLVM::LoadOp>(loc, llvmPtrTy, newArg);
944
955
if (!isa<LLVM::LLVMPointerType>(original.getType ()))
@@ -967,6 +978,7 @@ static mlir::omp::LoopNestOp getLoopNestFromTeams(mlir::omp::TeamsOp teamsOp) {
967
978
return nullptr ;
968
979
// Find parallel op inside teams
969
980
mlir::omp::ParallelOp parallelOp = nullptr ;
981
+ // Look for the parallel op in the teams region
970
982
for (auto &op : teamsOp.getRegion ().front ()) {
971
983
if (auto parallel = dyn_cast<mlir::omp::ParallelOp>(op)) {
972
984
parallelOp = parallel;
@@ -1218,6 +1230,7 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter,
1218
1230
assert (targetBlock == &targetOp.getRegion ().back ());
1219
1231
IRMapping mapping;
1220
1232
1233
+ // Get the parent target_data op
1221
1234
auto targetDataOp = cast<omp::TargetDataOp>(targetOp->getParentOp ());
1222
1235
if (!targetDataOp) {
1223
1236
llvm_unreachable (" Expected target op to be inside target_data op" );
@@ -1255,6 +1268,7 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter,
1255
1268
SmallVector<Operation *> opsToReplace;
1256
1269
Value device = targetOp.getDevice ();
1257
1270
1271
+ // If device is not specified, default to device 0.
1258
1272
if (!device) {
1259
1273
device = genI32Constant (targetOp.getLoc (), rewriter, 0 );
1260
1274
}
@@ -1508,15 +1522,12 @@ genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector<Value> &postMapOperands,
1508
1522
SmallVector<Value> isolatedHostEvalVars{targetOp.getHostEvalVars ()};
1509
1523
// update the hostEvalVars of isolatedTargetOp
1510
1524
if (!hostEvalVars.lbs .empty () && !isTargetDevice) {
1511
- for (size_t i = 0 ; i < hostEvalVars.lbs .size (); ++i) {
1512
- isolatedHostEvalVars.push_back (hostEvalVars.lbs [i]);
1513
- }
1514
- for (size_t i = 0 ; i < hostEvalVars.ubs .size (); ++i) {
1515
- isolatedHostEvalVars.push_back (hostEvalVars.ubs [i]);
1516
- }
1517
- for (size_t i = 0 ; i < hostEvalVars.steps .size (); ++i) {
1518
- isolatedHostEvalVars.push_back (hostEvalVars.steps [i]);
1519
- }
1525
+ isolatedHostEvalVars.append (hostEvalVars.lbs .begin (),
1526
+ hostEvalVars.lbs .end ());
1527
+ isolatedHostEvalVars.append (hostEvalVars.ubs .begin (),
1528
+ hostEvalVars.ubs .end ());
1529
+ isolatedHostEvalVars.append (hostEvalVars.steps .begin (),
1530
+ hostEvalVars.steps .end ());
1520
1531
}
1521
1532
// Create the isolated target op
1522
1533
omp::TargetOp isolatedTargetOp = rewriter.create <omp::TargetOp>(
@@ -1708,13 +1719,14 @@ static void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter,
1708
1719
Operation *toIsolate = std::get<0 >(*tuple);
1709
1720
bool splitBefore = !std::get<1 >(*tuple);
1710
1721
bool splitAfter = !std::get<2 >(*tuple);
1711
-
1722
+ // Recursively isolate the target op.
1712
1723
if (splitBefore && splitAfter) {
1713
1724
auto res =
1714
1725
isolateOp (toIsolate, splitAfter, rewriter, module , isTargetDevice);
1715
1726
fissionTarget (res.postTargetOp , rewriter, module , isTargetDevice);
1716
1727
return ;
1717
1728
}
1729
+ // Isolate only before the op.
1718
1730
if (splitBefore) {
1719
1731
isolateOp (toIsolate, splitAfter, rewriter, module , isTargetDevice);
1720
1732
return ;
0 commit comments