@@ -487,11 +487,30 @@ static void replaceWithUnorderedDoLoop(OpBuilder &builder, Location loc,
487
487
Value destBox = destConvert.getValue ();
488
488
Value srcBox = srcConvert.getValue ();
489
489
490
+ // get defining alloca op of destBox and srcBox
491
+ auto destAlloca = destBox.getDefiningOp <fir::AllocaOp>();
492
+
493
+ if (!destAlloca) {
494
+ emitError (loc, " Unimplemented: FortranAssign to OpenMP lowering\n " );
495
+ return ;
496
+ }
497
+
498
+ // get the store op that stores to the alloca
499
+ for (auto user : destAlloca->getUsers ()) {
500
+ if (auto storeOp = dyn_cast<fir::StoreOp>(user)) {
501
+ destBox = storeOp.getValue ();
502
+ break ;
503
+ }
504
+ }
505
+
490
506
builder.setInsertionPoint (teamsOp);
491
- // Load destination array box and source scalar
492
- auto arrayBox = builder.create <fir::LoadOp>(loc, destBox);
507
+ // Load destination array box (if it's a reference)
508
+ Value arrayBox = destBox;
509
+ if (isa<fir::ReferenceType>(destBox.getType ()))
510
+ arrayBox = builder.create <fir::LoadOp>(loc, destBox);
511
+
493
512
auto scalarValue = builder.create <fir::BoxAddrOp>(loc, srcBox);
494
- auto scalar = builder.create <fir::LoadOp>(loc, scalarValue);
513
+ Value scalar = builder.create <fir::LoadOp>(loc, scalarValue);
495
514
496
515
// Calculate total number of elements (flattened)
497
516
auto c0 = builder.create <arith::ConstantIndexOp>(loc, 0 );
@@ -543,9 +562,8 @@ WorkdistributeRuntimeCallLower(omp::WorkdistributeOp workdistribute,
543
562
bool changed = false ;
544
563
omp::TargetOp targetOp;
545
564
// Get the target op parent of teams
546
- if (auto teamsOp = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp ())) {
547
- targetOp = dyn_cast<omp::TargetOp>(teamsOp->getParentOp ());
548
- }
565
+ targetOp = dyn_cast<omp::TargetOp>(teams->getParentOp ());
566
+ SmallVector<Operation *> opsToErase;
549
567
for (auto &op : workdistribute.getOps ()) {
550
568
if (&op == terminator) {
551
569
break ;
@@ -560,12 +578,15 @@ WorkdistributeRuntimeCallLower(omp::WorkdistributeOp workdistribute,
560
578
targetOpsToProcess.insert (targetOp);
561
579
replaceWithUnorderedDoLoop (rewriter, loc, teams, workdistribute,
562
580
runtimeCall);
563
- op. erase ( );
564
- return true ;
581
+ opsToErase. push_back (&op );
582
+ changed = true ;
565
583
}
566
584
}
567
585
}
568
586
}
587
+ for (auto *op : opsToErase) {
588
+ op->erase ();
589
+ }
569
590
return changed;
570
591
}
571
592
@@ -911,7 +932,7 @@ static void reloadCacheAndRecompute(
911
932
912
933
unsigned originalMapVarsSize = targetOp.getMapVars ().size ();
913
934
unsigned hostEvalVarsSize = hostEvalVars.size ();
914
- // Create Stores for allocs .
935
+ // Create load operations for each allocated variable .
915
936
for (unsigned i = 0 ; i < allocs.size (); ++i) {
916
937
Value original = allocs[i];
917
938
// Get the new block argument for this specific allocated value.
@@ -1196,6 +1217,12 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter,
1196
1217
Block *targetBlock = &targetOp.getRegion ().front ();
1197
1218
assert (targetBlock == &targetOp.getRegion ().back ());
1198
1219
IRMapping mapping;
1220
+
1221
+ auto targetDataOp = cast<omp::TargetDataOp>(targetOp->getParentOp ());
1222
+ if (!targetDataOp) {
1223
+ llvm_unreachable (" Expected target op to be inside target_data op" );
1224
+ return ;
1225
+ }
1199
1226
// create mapping for host_eval_vars
1200
1227
unsigned hostEvalVarCount = targetOp.getHostEvalVars ().size ();
1201
1228
for (unsigned i = 0 ; i < targetOp.getHostEvalVars ().size (); ++i) {
@@ -1361,12 +1388,14 @@ static void computeAllocsCacheRecomputable(
1361
1388
it++) {
1362
1389
// Check if any of the results are used outside the split point.
1363
1390
for (auto res : it->getResults ()) {
1364
- if (usedOutsideSplit (res, splitBeforeOp))
1391
+ if (usedOutsideSplit (res, splitBeforeOp)) {
1365
1392
requiredVals.push_back (res);
1393
+ }
1366
1394
}
1367
1395
// If the op is not recomputable, add it to the nonRecomputable set.
1368
- if (!isRecomputableAfterFission (&*it, splitBeforeOp))
1396
+ if (!isRecomputableAfterFission (&*it, splitBeforeOp)) {
1369
1397
nonRecomputable.insert (&*it);
1398
+ }
1370
1399
}
1371
1400
// For each required value, collect its dependencies.
1372
1401
for (auto requiredVal : requiredVals)
0 commit comments