Skip to content

Commit c27ad3d

Browse files
committed
Fix Scalar assign bug. And Fix CI tests
1 parent 2b010b8 commit c27ad3d

File tree

3 files changed

+49
-19
lines changed

3 files changed

+49
-19
lines changed

flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -487,11 +487,30 @@ static void replaceWithUnorderedDoLoop(OpBuilder &builder, Location loc,
487487
Value destBox = destConvert.getValue();
488488
Value srcBox = srcConvert.getValue();
489489

490+
// get defining alloca op of destBox and srcBox
491+
auto destAlloca = destBox.getDefiningOp<fir::AllocaOp>();
492+
493+
if (!destAlloca) {
494+
emitError(loc, "Unimplemented: FortranAssign to OpenMP lowering\n");
495+
return;
496+
}
497+
498+
// get the store op that stores to the alloca
499+
for (auto user : destAlloca->getUsers()) {
500+
if (auto storeOp = dyn_cast<fir::StoreOp>(user)) {
501+
destBox = storeOp.getValue();
502+
break;
503+
}
504+
}
505+
490506
builder.setInsertionPoint(teamsOp);
491-
// Load destination array box and source scalar
492-
auto arrayBox = builder.create<fir::LoadOp>(loc, destBox);
507+
// Load destination array box (if it's a reference)
508+
Value arrayBox = destBox;
509+
if (isa<fir::ReferenceType>(destBox.getType()))
510+
arrayBox = builder.create<fir::LoadOp>(loc, destBox);
511+
493512
auto scalarValue = builder.create<fir::BoxAddrOp>(loc, srcBox);
494-
auto scalar = builder.create<fir::LoadOp>(loc, scalarValue);
513+
Value scalar = builder.create<fir::LoadOp>(loc, scalarValue);
495514

496515
// Calculate total number of elements (flattened)
497516
auto c0 = builder.create<arith::ConstantIndexOp>(loc, 0);
@@ -543,9 +562,8 @@ WorkdistributeRuntimeCallLower(omp::WorkdistributeOp workdistribute,
543562
bool changed = false;
544563
omp::TargetOp targetOp;
545564
// Get the target op parent of teams
546-
if (auto teamsOp = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp())) {
547-
targetOp = dyn_cast<omp::TargetOp>(teamsOp->getParentOp());
548-
}
565+
targetOp = dyn_cast<omp::TargetOp>(teams->getParentOp());
566+
SmallVector<Operation *> opsToErase;
549567
for (auto &op : workdistribute.getOps()) {
550568
if (&op == terminator) {
551569
break;
@@ -560,12 +578,15 @@ WorkdistributeRuntimeCallLower(omp::WorkdistributeOp workdistribute,
560578
targetOpsToProcess.insert(targetOp);
561579
replaceWithUnorderedDoLoop(rewriter, loc, teams, workdistribute,
562580
runtimeCall);
563-
op.erase();
564-
return true;
581+
opsToErase.push_back(&op);
582+
changed = true;
565583
}
566584
}
567585
}
568586
}
587+
for (auto *op : opsToErase) {
588+
op->erase();
589+
}
569590
return changed;
570591
}
571592

@@ -911,7 +932,7 @@ static void reloadCacheAndRecompute(
911932

912933
unsigned originalMapVarsSize = targetOp.getMapVars().size();
913934
unsigned hostEvalVarsSize = hostEvalVars.size();
914-
// Create Stores for allocs.
935+
// Create load operations for each allocated variable.
915936
for (unsigned i = 0; i < allocs.size(); ++i) {
916937
Value original = allocs[i];
917938
// Get the new block argument for this specific allocated value.
@@ -1196,6 +1217,12 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter,
11961217
Block *targetBlock = &targetOp.getRegion().front();
11971218
assert(targetBlock == &targetOp.getRegion().back());
11981219
IRMapping mapping;
1220+
1221+
auto targetDataOp = cast<omp::TargetDataOp>(targetOp->getParentOp());
1222+
if (!targetDataOp) {
1223+
llvm_unreachable("Expected target op to be inside target_data op");
1224+
return;
1225+
}
11991226
// create mapping for host_eval_vars
12001227
unsigned hostEvalVarCount = targetOp.getHostEvalVars().size();
12011228
for (unsigned i = 0; i < targetOp.getHostEvalVars().size(); ++i) {
@@ -1361,12 +1388,14 @@ static void computeAllocsCacheRecomputable(
13611388
it++) {
13621389
// Check if any of the results are used outside the split point.
13631390
for (auto res : it->getResults()) {
1364-
if (usedOutsideSplit(res, splitBeforeOp))
1391+
if (usedOutsideSplit(res, splitBeforeOp)) {
13651392
requiredVals.push_back(res);
1393+
}
13661394
}
13671395
// If the op is not recomputable, add it to the nonRecomputable set.
1368-
if (!isRecomputableAfterFission(&*it, splitBeforeOp))
1396+
if (!isRecomputableAfterFission(&*it, splitBeforeOp)) {
13691397
nonRecomputable.insert(&*it);
1398+
}
13701399
}
13711400
// For each required value, collect its dependencies.
13721401
for (auto requiredVal : requiredVals)

flang/test/Transforms/OpenMP/lower-workdistribute-fission-host.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@
4242
// CHECK: fir.store %[[VAL_26]] to %[[VAL_17]] : !fir.ref<index>
4343
// CHECK: fir.store %[[VAL_30]] to %[[VAL_20]] : !fir.ref<!fir.heap<index>>
4444
// CHECK: omp.target host_eval(%[[VAL_24]] -> %[[VAL_31:.*]], %[[VAL_25]] -> %[[VAL_32:.*]], %[[VAL_26]] -> %[[VAL_33:.*]] : index, index, index) map_entries(%[[VAL_7]] -> %[[VAL_34:.*]], %[[VAL_8]] -> %[[VAL_35:.*]], %[[VAL_9]] -> %[[VAL_36:.*]], %[[VAL_10]] -> %[[VAL_37:.*]], %[[VAL_13]] -> %[[VAL_38:.*]], %[[VAL_16]] -> %[[VAL_39:.*]], %[[VAL_19]] -> %[[VAL_40:.*]], %[[VAL_22]] -> %[[VAL_41:.*]] : !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<!fir.heap<index>>) {
45-
// CHECK: %[[VAL_42:.*]] = fir.load %[[VAL_38]] : !fir.llvm_ptr<index>
46-
// CHECK: %[[VAL_43:.*]] = fir.load %[[VAL_39]] : !fir.llvm_ptr<index>
47-
// CHECK: %[[VAL_44:.*]] = fir.load %[[VAL_40]] : !fir.llvm_ptr<index>
48-
// CHECK: %[[VAL_45:.*]] = fir.load %[[VAL_41]] : !fir.llvm_ptr<!fir.heap<index>>
45+
// CHECK: %[[VAL_42:.*]] = fir.load %[[VAL_38]] : !fir.ref<index>
46+
// CHECK: %[[VAL_43:.*]] = fir.load %[[VAL_39]] : !fir.ref<index>
47+
// CHECK: %[[VAL_44:.*]] = fir.load %[[VAL_40]] : !fir.ref<index>
48+
// CHECK: %[[VAL_45:.*]] = fir.load %[[VAL_41]] : !fir.ref<!fir.heap<index>>
4949
// CHECK: %[[VAL_46:.*]] = arith.addi %[[VAL_43]], %[[VAL_43]] : index
5050
// CHECK: omp.teams {
5151
// CHECK: omp.parallel {

flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@
4242
// CHECK: fir.store %[[VAL_26]] to %[[VAL_17]] : !fir.ref<index>
4343
// CHECK: fir.store %[[VAL_30]] to %[[VAL_20]] : !fir.ref<!fir.heap<index>>
4444
// CHECK: omp.target map_entries(%[[VAL_7]] -> %[[VAL_31:.*]], %[[VAL_8]] -> %[[VAL_32:.*]], %[[VAL_9]] -> %[[VAL_33:.*]], %[[VAL_10]] -> %[[VAL_34:.*]], %[[VAL_13]] -> %[[VAL_35:.*]], %[[VAL_16]] -> %[[VAL_36:.*]], %[[VAL_19]] -> %[[VAL_37:.*]], %[[VAL_22]] -> %[[VAL_38:.*]] : !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<!fir.heap<index>>) {
45-
// CHECK: %[[VAL_39:.*]] = fir.load %[[VAL_35]] : !fir.llvm_ptr<index>
46-
// CHECK: %[[VAL_40:.*]] = fir.load %[[VAL_36]] : !fir.llvm_ptr<index>
47-
// CHECK: %[[VAL_41:.*]] = fir.load %[[VAL_37]] : !fir.llvm_ptr<index>
48-
// CHECK: %[[VAL_42:.*]] = fir.load %[[VAL_38]] : !fir.llvm_ptr<!fir.heap<index>>
45+
// CHECK: %[[VAL_39:.*]] = fir.load %[[VAL_35]] : !fir.ref<index>
46+
// CHECK: %[[VAL_40:.*]] = fir.load %[[VAL_36]] : !fir.ref<index>
47+
// CHECK: %[[VAL_41:.*]] = fir.load %[[VAL_37]] : !fir.ref<index>
48+
// CHECK: %[[VAL_42:.*]] = fir.load %[[VAL_38]] : !fir.ref<!fir.heap<index>>
4949
// CHECK: %[[VAL_43:.*]] = arith.addi %[[VAL_40]], %[[VAL_40]] : index
5050
// CHECK: omp.teams {
5151
// CHECK: omp.parallel {
@@ -77,6 +77,7 @@
7777
// CHECK: return
7878
// CHECK: }
7979

80+
8081
module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} {
8182
func.func @x(%lb : index, %ub : index, %step : index, %addr : !fir.ref<index>) {
8283
%lb_ref = fir.alloca index {bindc_name = "lb"}

0 commit comments

Comments
 (0)