diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td index e2f092024c250..bfbaa5f838e90 100644 --- a/flang/include/flang/Optimizer/OpenMP/Passes.td +++ b/flang/include/flang/Optimizer/OpenMP/Passes.td @@ -93,6 +93,10 @@ def LowerWorkshare : Pass<"lower-workshare", "::mlir::ModuleOp"> { let summary = "Lower workshare construct"; } +def LowerWorkdistribute : Pass<"lower-workdistribute", "::mlir::ModuleOp"> { + let summary = "Lower workdistribute construct"; +} + def GenericLoopConversionPass : Pass<"omp-generic-loop-conversion", "mlir::func::FuncOp"> { let summary = "Converts OpenMP generic `omp.loop` to semantically " diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt index b85ee7e861a4f..23a7dc8f08399 100644 --- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt +++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt @@ -8,6 +8,7 @@ add_flang_library(FlangOpenMPTransforms MapsForPrivatizedSymbols.cpp MapInfoFinalization.cpp MarkDeclareTarget.cpp + LowerWorkdistribute.cpp LowerWorkshare.cpp LowerNontemporal.cpp SimdOnly.cpp diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp new file mode 100644 index 0000000000000..090d9a0e3b985 --- /dev/null +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -0,0 +1,1859 @@ +//===- LowerWorkdistribute.cpp +//-------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the lowering and optimisations of omp.workdistribute. +// +// Fortran array statements are lowered to fir as fir.do_loop unordered. +// lower-workdistribute pass works mainly on identifying fir.do_loop unordered +// that is nested in target{teams{workdistribute{fir.do_loop unordered}}} and +// lowers it to target{teams{parallel{distribute{wsloop{loop_nest}}}}}. +// It hoists all the other ops outside target region. +// Relaces heap allocation on target with omp.target_allocmem and +// deallocation with omp.target_freemem from host. Also replaces +// runtime function "Assign" with omp_target_memcpy. +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/HLFIR/Passes.h" +#include "flang/Optimizer/OpenMP/Utils.h" +#include "flang/Optimizer/Transforms/Passes.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Value.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/Frontend/OpenMP/OMPConstants.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace flangomp { +#define GEN_PASS_DEF_LOWERWORKDISTRIBUTE +#include "flang/Optimizer/OpenMP/Passes.h.inc" +} // namespace flangomp + +#define DEBUG_TYPE "lower-workdistribute" + +using namespace mlir; + +namespace { + +/// This string is used to identify the Fortran-specific runtime FortranAAssign. +static constexpr llvm::StringRef FortranAssignStr = "_FortranAAssign"; + +/// The isRuntimeCall function is a utility designed to determine +/// if a given operation is a call to a Fortran-specific runtime function. +static bool isRuntimeCall(Operation *op) { + if (auto callOp = dyn_cast(op)) { + auto callee = callOp.getCallee(); + if (!callee) + return false; + auto *func = op->getParentOfType().lookupSymbol(*callee); + if (func->getAttr(fir::FIROpsDialect::getFirRuntimeAttrName())) + return true; + } + return false; +} + +/// This is the single source of truth about whether we should parallelize an +/// operation nested in an omp.workdistribute region. +/// Parallelize here refers to dividing into units of work. +static bool shouldParallelize(Operation *op) { + // True if the op is a runtime call to Assign + if (isRuntimeCall(op)) { + fir::CallOp runtimeCall = cast(op); + auto funcName = (*runtimeCall.getCallee()).getRootReference().getValue(); + if (funcName == FortranAssignStr) { + return true; + } + } + // We cannot parallelize ops with side effects. + // Parallelizable operations should not produce + // values that other operations depend on + if (llvm::any_of(op->getResults(), + [](OpResult v) -> bool { return !v.use_empty(); })) + return false; + // We will parallelize unordered loops - these come from array syntax + if (auto loop = dyn_cast(op)) { + auto unordered = loop.getUnordered(); + if (!unordered) + return false; + return *unordered; + } + // We cannot parallelize anything else. + return false; +} + +/// The getPerfectlyNested function is a generic utility for finding +/// a single, "perfectly nested" operation within a parent operation. +template +static T getPerfectlyNested(Operation *op) { + if (op->getNumRegions() != 1) + return nullptr; + auto ®ion = op->getRegion(0); + if (region.getBlocks().size() != 1) + return nullptr; + auto *block = ®ion.front(); + auto *firstOp = &block->front(); + if (auto nested = dyn_cast(firstOp)) + if (firstOp->getNextNode() == block->getTerminator()) + return nested; + return nullptr; +} + +/// verifyTargetTeamsWorkdistribute method verifies that +/// omp.target { teams { workdistribute { ... } } } is well formed +/// and fails for function calls that don't have lowering implemented yet. +static LogicalResult +verifyTargetTeamsWorkdistribute(omp::WorkdistributeOp workdistribute) { + OpBuilder rewriter(workdistribute); + auto loc = workdistribute->getLoc(); + auto teams = dyn_cast(workdistribute->getParentOp()); + if (!teams) { + emitError(loc, "workdistribute not nested in teams\n"); + return failure(); + } + if (workdistribute.getRegion().getBlocks().size() != 1) { + emitError(loc, "workdistribute with multiple blocks\n"); + return failure(); + } + if (teams.getRegion().getBlocks().size() != 1) { + emitError(loc, "teams with multiple blocks\n"); + return failure(); + } + + bool foundWorkdistribute = false; + for (auto &op : teams.getOps()) { + if (isa(op)) { + if (foundWorkdistribute) { + emitError(loc, "teams has multiple workdistribute ops.\n"); + return failure(); + } + foundWorkdistribute = true; + continue; + } + // Identify any omp dialect ops present before/after workdistribute. + if (op.getDialect() && isa(op.getDialect()) && + !isa(op)) { + emitError(loc, "teams has omp ops other than workdistribute. Lowering " + "not implemented yet.\n"); + return failure(); + } + } + + omp::TargetOp targetOp = dyn_cast(teams->getParentOp()); + // return if not omp.target + if (!targetOp) + return success(); + + for (auto &op : workdistribute.getOps()) { + if (auto callOp = dyn_cast(op)) { + if (isRuntimeCall(&op)) { + auto funcName = (*callOp.getCallee()).getRootReference().getValue(); + // _FortranAAssign is handled. Other runtime calls are not supported + // in omp.workdistribute yet. + if (funcName == FortranAssignStr) + continue; + else { + emitError(loc, "Runtime call " + funcName + + " lowering not supported for workdistribute yet."); + return failure(); + } + } + } + } + return success(); +} + +/// fissionWorkdistribute method finds the parallelizable ops +/// within teams {workdistribute} region and moves them to their +/// own teams{workdistribute} region. +/// +/// If B() and D() are parallelizable, +/// +/// omp.teams { +/// omp.workdistribute { +/// A() +/// B() +/// C() +/// D() +/// E() +/// } +/// } +/// +/// becomes +/// +/// A() +/// omp.teams { +/// omp.workdistribute { +/// B() +/// } +/// } +/// C() +/// omp.teams { +/// omp.workdistribute { +/// D() +/// } +/// } +/// E() +static FailureOr +fissionWorkdistribute(omp::WorkdistributeOp workdistribute) { + OpBuilder rewriter(workdistribute); + auto loc = workdistribute->getLoc(); + auto teams = dyn_cast(workdistribute->getParentOp()); + auto *teamsBlock = &teams.getRegion().front(); + bool changed = false; + // Move the ops inside teams and before workdistribute outside. + IRMapping irMapping; + llvm::SmallVector teamsHoisted; + for (auto &op : teams.getOps()) { + if (&op == workdistribute) { + break; + } + if (shouldParallelize(&op)) { + emitError(loc, "teams has parallelize ops before first workdistribute\n"); + return failure(); + } else { + rewriter.setInsertionPoint(teams); + rewriter.clone(op, irMapping); + teamsHoisted.push_back(&op); + changed = true; + } + } + for (auto *op : llvm::reverse(teamsHoisted)) { + op->replaceAllUsesWith(irMapping.lookup(op)); + op->erase(); + } + + // While we have unhandled operations in the original workdistribute + auto *workdistributeBlock = &workdistribute.getRegion().front(); + auto *terminator = workdistributeBlock->getTerminator(); + while (&workdistributeBlock->front() != terminator) { + rewriter.setInsertionPoint(teams); + IRMapping mapping; + llvm::SmallVector hoisted; + Operation *parallelize = nullptr; + for (auto &op : workdistribute.getOps()) { + if (&op == terminator) { + break; + } + if (shouldParallelize(&op)) { + parallelize = &op; + break; + } else { + rewriter.clone(op, mapping); + hoisted.push_back(&op); + changed = true; + } + } + + for (auto *op : llvm::reverse(hoisted)) { + op->replaceAllUsesWith(mapping.lookup(op)); + op->erase(); + } + + if (parallelize && hoisted.empty() && + parallelize->getNextNode() == terminator) + break; + if (parallelize) { + auto newTeams = rewriter.cloneWithoutRegions(teams); + auto *newTeamsBlock = rewriter.createBlock( + &newTeams.getRegion(), newTeams.getRegion().begin(), {}, {}); + for (auto arg : teamsBlock->getArguments()) + newTeamsBlock->addArgument(arg.getType(), arg.getLoc()); + auto newWorkdistribute = rewriter.create(loc); + rewriter.create(loc); + rewriter.createBlock(&newWorkdistribute.getRegion(), + newWorkdistribute.getRegion().begin(), {}, {}); + auto *cloned = rewriter.clone(*parallelize); + parallelize->replaceAllUsesWith(cloned); + parallelize->erase(); + rewriter.create(loc); + changed = true; + } + } + return changed; +} + +/// Generate omp.parallel operation with an empty region. +static void genParallelOp(Location loc, OpBuilder &rewriter, bool composite) { + auto parallelOp = rewriter.create(loc); + parallelOp.setComposite(composite); + rewriter.createBlock(¶llelOp.getRegion()); + rewriter.setInsertionPoint(rewriter.create(loc)); + return; +} + +/// Generate omp.distribute operation with an empty region. +static void genDistributeOp(Location loc, OpBuilder &rewriter, bool composite) { + mlir::omp::DistributeOperands distributeClauseOps; + auto distributeOp = + rewriter.create(loc, distributeClauseOps); + distributeOp.setComposite(composite); + auto distributeBlock = rewriter.createBlock(&distributeOp.getRegion()); + rewriter.setInsertionPointToStart(distributeBlock); + return; +} + +/// Generate loop nest clause operands from fir.do_loop operation. +static void +genLoopNestClauseOps(OpBuilder &rewriter, fir::DoLoopOp loop, + mlir::omp::LoopNestOperands &loopNestClauseOps) { + assert(loopNestClauseOps.loopLowerBounds.empty() && + "Loop nest bounds were already emitted!"); + loopNestClauseOps.loopLowerBounds.push_back(loop.getLowerBound()); + loopNestClauseOps.loopUpperBounds.push_back(loop.getUpperBound()); + loopNestClauseOps.loopSteps.push_back(loop.getStep()); + loopNestClauseOps.loopInclusive = rewriter.getUnitAttr(); +} + +/// Generate omp.wsloop operation with an empty region and +/// clone the body of fir.do_loop operation inside the loop nest region. +static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop, + const mlir::omp::LoopNestOperands &clauseOps, + bool composite) { + + auto wsloopOp = rewriter.create(doLoop.getLoc()); + wsloopOp.setComposite(composite); + rewriter.createBlock(&wsloopOp.getRegion()); + + auto loopNestOp = + rewriter.create(doLoop.getLoc(), clauseOps); + + // Clone the loop's body inside the loop nest construct using the + // mapped values. + rewriter.cloneRegionBefore(doLoop.getRegion(), loopNestOp.getRegion(), + loopNestOp.getRegion().begin()); + Block *clonedBlock = &loopNestOp.getRegion().back(); + mlir::Operation *terminatorOp = clonedBlock->getTerminator(); + + // Erase fir.result op of do loop and create yield op. + if (auto resultOp = dyn_cast(terminatorOp)) { + rewriter.setInsertionPoint(terminatorOp); + rewriter.create(doLoop->getLoc()); + terminatorOp->erase(); + } + return; +} + +/// workdistributeDoLower method finds the fir.do_loop unoredered +/// nested in teams {workdistribute{fir.do_loop unoredered}} and +/// lowers it to teams {parallel { distribute {wsloop {loop_nest}}}}. +/// +/// If fir.do_loop is present inside teams workdistribute +/// +/// omp.teams { +/// omp.workdistribute { +/// fir.do_loop unoredered { +/// ... +/// } +/// } +/// } +/// +/// Then, its lowered to +/// +/// omp.teams { +/// omp.parallel { +/// omp.distribute { +/// omp.wsloop { +/// omp.loop_nest +/// ... +/// } +/// } +/// } +/// } +/// } +static bool +workdistributeDoLower(omp::WorkdistributeOp workdistribute, + SetVector &targetOpsToProcess) { + OpBuilder rewriter(workdistribute); + auto doLoop = getPerfectlyNested(workdistribute); + auto wdLoc = workdistribute->getLoc(); + if (doLoop && shouldParallelize(doLoop)) { + assert(doLoop.getReduceOperands().empty()); + + // Record the target ops to process later + if (auto teamsOp = dyn_cast(workdistribute->getParentOp())) { + auto targetOp = dyn_cast(teamsOp->getParentOp()); + if (targetOp) { + targetOpsToProcess.insert(targetOp); + } + } + // Generate the nested parallel, distribute, wsloop and loop_nest ops. + genParallelOp(wdLoc, rewriter, true); + genDistributeOp(wdLoc, rewriter, true); + mlir::omp::LoopNestOperands loopNestClauseOps; + genLoopNestClauseOps(rewriter, doLoop, loopNestClauseOps); + genWsLoopOp(rewriter, doLoop, loopNestClauseOps, true); + workdistribute.erase(); + return true; + } + return false; +} + +/// Check if the enclosed type in fir.ref is fir.box and fir.box encloses array +static bool isEnclosedTypeRefToBoxArray(Type type) { + // Check if it's a reference type + if (auto refType = dyn_cast(type)) { + // Get the referenced type (should be fir.box) + auto referencedType = refType.getEleTy(); + // Check if referenced type is a box + if (auto boxType = dyn_cast(referencedType)) { + // Get the boxed type and check if it's an array + auto boxedType = boxType.getEleTy(); + // Check if boxed type is a sequence (array) + return isa(boxedType); + } + } + return false; +} + +/// Check if the enclosed type in fir.box is scalar (not array) +static bool isEnclosedTypeBoxScalar(Type type) { + // Check if it's a box type + if (auto boxType = dyn_cast(type)) { + // Get the boxed type + auto boxedType = boxType.getEleTy(); + // Check if boxed type is NOT a sequence (array) + return !isa(boxedType); + } + return false; +} + +/// Check if the FortranAAssign call has src as scalar and dest as array +static bool isFortranAssignSrcScalarAndDestArray(fir::CallOp callOp) { + if (callOp.getNumOperands() < 2) + return false; + auto srcArg = callOp.getOperand(1); + auto destArg = callOp.getOperand(0); + // Both operands should be fir.convert ops + auto srcConvert = srcArg.getDefiningOp(); + auto destConvert = destArg.getDefiningOp(); + if (!srcConvert || !destConvert) { + emitError(callOp->getLoc(), + "Unimplemented: FortranAssign to OpenMP lowering\n"); + return false; + } + // Get the original types before conversion + auto srcOrigType = srcConvert.getValue().getType(); + auto destOrigType = destConvert.getValue().getType(); + + // Check if src is scalar and dest is array + bool srcIsScalar = isEnclosedTypeBoxScalar(srcOrigType); + bool destIsArray = isEnclosedTypeRefToBoxArray(destOrigType); + return srcIsScalar && destIsArray; +} + +/// Convert a flat index to multi-dimensional indices for an array box +/// Example: 2D array with shape (2,4) +/// Col 1 Col 2 Col 3 Col 4 +/// Row 1: (1,1) (1,2) (1,3) (1,4) +/// Row 2: (2,1) (2,2) (2,3) (2,4) +/// +/// extents: (2,4) +/// +/// flatIdx: 0 1 2 3 4 5 6 7 +/// Indices: (1,1) (1,2) (1,3) (1,4) (2,1) (2,2) (2,3) (2,4) +static SmallVector convertFlatToMultiDim(OpBuilder &builder, + Location loc, Value flatIdx, + Value arrayBox) { + // Get array type and rank + auto boxType = cast(arrayBox.getType()); + auto seqType = cast(boxType.getEleTy()); + int rank = seqType.getDimension(); + + // Get all extents + SmallVector extents; + // Get extents for each dimension + for (int i = 0; i < rank; ++i) { + auto dimIdx = arith::ConstantIndexOp::create(builder, loc, i); + auto boxDims = fir::BoxDimsOp::create(builder, loc, arrayBox, dimIdx); + extents.push_back(boxDims.getResult(1)); + } + + // Convert flat index to multi-dimensional indices + SmallVector indices(rank); + Value temp = flatIdx; + auto c1 = builder.create(loc, 1); + + // Work backwards through dimensions (row-major order) + for (int i = rank - 1; i >= 0; --i) { + Value zeroBasedIdx = builder.create(loc, temp, extents[i]); + // Convert to one-based index + indices[i] = builder.create(loc, zeroBasedIdx, c1); + if (i > 0) { + temp = builder.create(loc, temp, extents[i]); + } + } + + return indices; +} + +/// Calculate the total number of elements in the array box +/// (totalElems = extent(1) * extent(2) * ... * extent(n)) +static Value CalculateTotalElements(OpBuilder &builder, Location loc, + Value arrayBox) { + auto boxType = cast(arrayBox.getType()); + auto seqType = cast(boxType.getEleTy()); + int rank = seqType.getDimension(); + + Value totalElems = nullptr; + for (int i = 0; i < rank; ++i) { + auto dimIdx = arith::ConstantIndexOp::create(builder, loc, i); + auto boxDims = fir::BoxDimsOp::create(builder, loc, arrayBox, dimIdx); + Value extent = boxDims.getResult(1); + if (i == 0) { + totalElems = extent; + } else { + totalElems = builder.create(loc, totalElems, extent); + } + } + return totalElems; +} + +/// Replace the FortranAAssign runtime call with an unordered do loop +static void replaceWithUnorderedDoLoop(OpBuilder &builder, Location loc, + omp::TeamsOp teamsOp, + omp::WorkdistributeOp workdistribute, + fir::CallOp callOp) { + auto destConvert = callOp.getOperand(0).getDefiningOp(); + auto srcConvert = callOp.getOperand(1).getDefiningOp(); + + Value destBox = destConvert.getValue(); + Value srcBox = srcConvert.getValue(); + + // get defining alloca op of destBox and srcBox + auto destAlloca = destBox.getDefiningOp(); + + if (!destAlloca) { + emitError(loc, "Unimplemented: FortranAssign to OpenMP lowering\n"); + return; + } + + // get the store op that stores to the alloca + for (auto user : destAlloca->getUsers()) { + if (auto storeOp = dyn_cast(user)) { + destBox = storeOp.getValue(); + break; + } + } + + builder.setInsertionPoint(teamsOp); + // Load destination array box (if it's a reference) + Value arrayBox = destBox; + if (isa(destBox.getType())) + arrayBox = builder.create(loc, destBox); + + auto scalarValue = builder.create(loc, srcBox); + Value scalar = builder.create(loc, scalarValue); + + // Calculate total number of elements (flattened) + auto c0 = builder.create(loc, 0); + auto c1 = builder.create(loc, 1); + Value totalElems = CalculateTotalElements(builder, loc, arrayBox); + + auto *workdistributeBlock = &workdistribute.getRegion().front(); + builder.setInsertionPointToStart(workdistributeBlock); + // Create single unordered loop for flattened array + auto doLoop = fir::DoLoopOp::create(builder, loc, c0, totalElems, c1, true); + Block *loopBlock = &doLoop.getRegion().front(); + builder.setInsertionPointToStart(doLoop.getBody()); + + auto flatIdx = loopBlock->getArgument(0); + SmallVector indices = + convertFlatToMultiDim(builder, loc, flatIdx, arrayBox); + // Use fir.array_coor for linear addressing + auto elemPtr = fir::ArrayCoorOp::create( + builder, loc, fir::ReferenceType::get(scalar.getType()), arrayBox, + nullptr, nullptr, ValueRange{indices}, ValueRange{}); + + builder.create(loc, scalar, elemPtr); +} + +/// workdistributeRuntimeCallLower method finds the runtime calls +/// nested in teams {workdistribute{}} and +/// lowers FortranAAssign to unordered do loop if src is scalar and dest is +/// array. Other runtime calls are not handled currently. +static FailureOr +workdistributeRuntimeCallLower(omp::WorkdistributeOp workdistribute, + SetVector &targetOpsToProcess) { + OpBuilder rewriter(workdistribute); + auto loc = workdistribute->getLoc(); + auto teams = dyn_cast(workdistribute->getParentOp()); + if (!teams) { + emitError(loc, "workdistribute not nested in teams\n"); + return failure(); + } + if (workdistribute.getRegion().getBlocks().size() != 1) { + emitError(loc, "workdistribute with multiple blocks\n"); + return failure(); + } + if (teams.getRegion().getBlocks().size() != 1) { + emitError(loc, "teams with multiple blocks\n"); + return failure(); + } + auto *workdistributeBlock = &workdistribute.getRegion().front(); + auto *terminator = workdistributeBlock->getTerminator(); + bool changed = false; + omp::TargetOp targetOp; + // Get the target op parent of teams + targetOp = dyn_cast(teams->getParentOp()); + SmallVector opsToErase; + for (auto &op : workdistribute.getOps()) { + if (&op == terminator) { + break; + } + if (isRuntimeCall(&op)) { + rewriter.setInsertionPoint(&op); + fir::CallOp runtimeCall = cast(op); + auto funcName = (*runtimeCall.getCallee()).getRootReference().getValue(); + if (funcName == FortranAssignStr) { + if (isFortranAssignSrcScalarAndDestArray(runtimeCall) && targetOp) { + // Record the target ops to process later + targetOpsToProcess.insert(targetOp); + replaceWithUnorderedDoLoop(rewriter, loc, teams, workdistribute, + runtimeCall); + opsToErase.push_back(&op); + changed = true; + } + } + } + } + // Erase the runtime calls that have been replaced. + for (auto *op : opsToErase) { + op->erase(); + } + return changed; +} + +/// teamsWorkdistributeToSingleOp method hoists all the ops inside +/// teams {workdistribute{}} before teams op. +/// +/// If A() and B () are present inside teams workdistribute +/// +/// omp.teams { +/// omp.workdistribute { +/// A() +/// B() +/// } +/// } +/// +/// Then, its lowered to +/// +/// A() +/// B() +/// +/// If only the terminator remains in teams after hoisting, we erase teams op. +static bool +teamsWorkdistributeToSingleOp(omp::TeamsOp teamsOp, + SetVector &targetOpsToProcess) { + auto workdistributeOp = getPerfectlyNested(teamsOp); + if (!workdistributeOp) + return false; + // Get the block containing teamsOp (the parent block). + Block *parentBlock = teamsOp->getBlock(); + Block &workdistributeBlock = *workdistributeOp.getRegion().begin(); + // Record the target ops to process later + for (auto &op : workdistributeBlock.getOperations()) { + if (shouldParallelize(&op)) { + auto targetOp = dyn_cast(teamsOp->getParentOp()); + if (targetOp) { + targetOpsToProcess.insert(targetOp); + } + } + } + auto insertPoint = Block::iterator(teamsOp); + // Get the range of operations to move (excluding the terminator). + auto workdistributeBegin = workdistributeBlock.begin(); + auto workdistributeEnd = workdistributeBlock.getTerminator()->getIterator(); + // Move the operations from workdistribute block to before teamsOp. + parentBlock->getOperations().splice(insertPoint, + workdistributeBlock.getOperations(), + workdistributeBegin, workdistributeEnd); + // Erase the now-empty workdistributeOp. + workdistributeOp.erase(); + Block &teamsBlock = *teamsOp.getRegion().begin(); + // Check if only the terminator remains and erase teams op. + if (teamsBlock.getOperations().size() == 1 && + teamsBlock.getTerminator() != nullptr) { + teamsOp.erase(); + } + return true; +} + +/// If multiple workdistribute are nested in a target regions, we will need to +/// split the target region, but we want to preserve the data semantics of the +/// original data region and avoid unnecessary data movement at each of the +/// subkernels - we split the target region into a target_data{target} +/// nest where only the outer one moves the data +FailureOr splitTargetData(omp::TargetOp targetOp, + RewriterBase &rewriter) { + auto loc = targetOp->getLoc(); + if (targetOp.getMapVars().empty()) { + emitError(loc, "Target region has no data maps\n"); + return failure(); + } + // Collect all the mapinfo ops + SmallVector mapInfos; + for (auto opr : targetOp.getMapVars()) { + auto mapInfo = cast(opr.getDefiningOp()); + mapInfos.push_back(mapInfo); + } + + rewriter.setInsertionPoint(targetOp); + SmallVector innerMapInfos; + SmallVector outerMapInfos; + // Create new mapinfo ops for the inner target region + for (auto mapInfo : mapInfos) { + auto originalMapType = + (llvm::omp::OpenMPOffloadMappingFlags)(mapInfo.getMapType()); + auto originalCaptureType = mapInfo.getMapCaptureType(); + llvm::omp::OpenMPOffloadMappingFlags newMapType; + mlir::omp::VariableCaptureKind newCaptureType; + // For bycopy, we keep the same map type and capture type + // For byref, we change the map type to none and keep the capture type + if (originalCaptureType == mlir::omp::VariableCaptureKind::ByCopy) { + newMapType = originalMapType; + newCaptureType = originalCaptureType; + } else if (originalCaptureType == mlir::omp::VariableCaptureKind::ByRef) { + newMapType = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; + newCaptureType = originalCaptureType; + outerMapInfos.push_back(mapInfo); + } else { + emitError(targetOp->getLoc(), "Unhandled case"); + return failure(); + } + auto innerMapInfo = cast(rewriter.clone(*mapInfo)); + innerMapInfo.setMapTypeAttr(rewriter.getIntegerAttr( + rewriter.getIntegerType(64, false), + static_cast< + std::underlying_type_t>( + newMapType))); + innerMapInfo.setMapCaptureType(newCaptureType); + innerMapInfos.push_back(innerMapInfo.getResult()); + } + + rewriter.setInsertionPoint(targetOp); + auto device = targetOp.getDevice(); + auto ifExpr = targetOp.getIfExpr(); + auto deviceAddrVars = targetOp.getHasDeviceAddrVars(); + auto devicePtrVars = targetOp.getIsDevicePtrVars(); + // Create the target data op + auto targetDataOp = rewriter.create( + loc, device, ifExpr, outerMapInfos, deviceAddrVars, devicePtrVars); + auto taregtDataBlock = rewriter.createBlock(&targetDataOp.getRegion()); + rewriter.create(loc); + rewriter.setInsertionPointToStart(taregtDataBlock); + // Create the inner target op + auto newTargetOp = rewriter.create( + targetOp.getLoc(), targetOp.getAllocateVars(), + targetOp.getAllocatorVars(), targetOp.getBareAttr(), + targetOp.getDependKindsAttr(), targetOp.getDependVars(), + targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), + targetOp.getHostEvalVars(), targetOp.getIfExpr(), + targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(), + targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(), + innerMapInfos, targetOp.getNowaitAttr(), targetOp.getPrivateVars(), + targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(), + targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr()); + rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(), + newTargetOp.getRegion().begin()); + rewriter.replaceOp(targetOp, targetDataOp); + return newTargetOp; +} + +/// getNestedOpToIsolate function is designed to identify a specific teams +/// parallel op within the body of an omp::TargetOp that should be "isolated." +/// This returns a tuple of op, if its first op in targetBlock, or if the op is +/// last op in the tragte block. +static std::optional> +getNestedOpToIsolate(omp::TargetOp targetOp) { + if (targetOp.getRegion().empty()) + return std::nullopt; + auto *targetBlock = &targetOp.getRegion().front(); + for (auto &op : *targetBlock) { + bool first = &op == &*targetBlock->begin(); + bool last = op.getNextNode() == targetBlock->getTerminator(); + if (first && last) + return std::nullopt; + + if (isa(&op)) + return {{&op, first, last}}; + } + return std::nullopt; +} + +/// Temporary structure to hold the two mapinfo ops +struct TempOmpVar { + omp::MapInfoOp from, to; +}; + +/// isPtr checks if the type is a pointer or reference type. +static bool isPtr(Type ty) { + return isa(ty) || isa(ty); +} + +/// getPtrTypeForOmp returns an LLVM pointer type for the given type. +static Type getPtrTypeForOmp(Type ty) { + if (isPtr(ty)) + return LLVM::LLVMPointerType::get(ty.getContext()); + else + return fir::ReferenceType::get(ty); +} + +/// allocateTempOmpVar allocates a temporary variable for OpenMP mapping +static TempOmpVar allocateTempOmpVar(Location loc, Type ty, + RewriterBase &rewriter) { + MLIRContext &ctx = *ty.getContext(); + Value alloc; + Type allocType; + auto llvmPtrTy = LLVM::LLVMPointerType::get(&ctx); + // Get the appropriate type for allocation + if (isPtr(ty)) { + Type intTy = rewriter.getI32Type(); + auto one = rewriter.create(loc, intTy, 1); + allocType = llvmPtrTy; + alloc = rewriter.create(loc, llvmPtrTy, allocType, one); + allocType = intTy; + } else { + allocType = ty; + alloc = rewriter.create(loc, allocType); + } + // Lambda to create mapinfo ops + auto getMapInfo = [&](uint64_t mappingFlags, const char *name) { + return rewriter.create( + loc, alloc.getType(), alloc, TypeAttr::get(allocType), + rewriter.getIntegerAttr(rewriter.getIntegerType(64, /*isSigned=*/false), + mappingFlags), + rewriter.getAttr( + omp::VariableCaptureKind::ByRef), + /*varPtrPtr=*/Value{}, + /*members=*/SmallVector{}, + /*member_index=*/mlir::ArrayAttr{}, + /*bounds=*/ValueRange(), + /*mapperId=*/mlir::FlatSymbolRefAttr(), + /*name=*/rewriter.getStringAttr(name), rewriter.getBoolAttr(false)); + }; + // Create mapinfo ops. + uint64_t mapFrom = + static_cast>( + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM); + uint64_t mapTo = + static_cast>( + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); + auto mapInfoFrom = getMapInfo(mapFrom, "__flang_workdistribute_from"); + auto mapInfoTo = getMapInfo(mapTo, "__flang_workdistribute_to"); + return TempOmpVar{mapInfoFrom, mapInfoTo}; +} + +// usedOutsideSplit checks if a value is used outside the split operation. +static bool usedOutsideSplit(Value v, Operation *split) { + if (!split) + return false; + auto targetOp = cast(split->getParentOp()); + auto *targetBlock = &targetOp.getRegion().front(); + for (auto *user : v.getUsers()) { + while (user->getBlock() != targetBlock) { + user = user->getParentOp(); + } + if (!user->isBeforeInBlock(split)) + return true; + } + return false; +} + +/// isRecomputableAfterFission checks if an operation can be recomputed +static bool isRecomputableAfterFission(Operation *op, Operation *splitBefore) { + // If the op has side effects, it cannot be recomputed. + // We consider fir.declare as having no side effects. + return isa(op) || isMemoryEffectFree(op); +} + +/// collectNonRecomputableDeps collects dependencies that cannot be recomputed +static void collectNonRecomputableDeps(Value &v, omp::TargetOp targetOp, + SetVector &nonRecomputable, + SetVector &toCache, + SetVector &toRecompute) { + Operation *op = v.getDefiningOp(); + // If v is a block argument, it must be from the targetOp. + if (!op) { + assert(cast(v).getOwner()->getParentOp() == targetOp); + return; + } + // If the op is in the nonRecomputable set, add it to toCache and return. + if (nonRecomputable.contains(op)) { + toCache.insert(op); + return; + } + // Add the op to toRecompute. + toRecompute.insert(op); + for (auto opr : op->getOperands()) + collectNonRecomputableDeps(opr, targetOp, nonRecomputable, toCache, + toRecompute); +} + +/// createBlockArgsAndMap creates block arguments and maps them +static void createBlockArgsAndMap(Location loc, RewriterBase &rewriter, + omp::TargetOp &targetOp, Block *targetBlock, + Block *newTargetBlock, + SmallVector &hostEvalVars, + SmallVector &mapOperands, + SmallVector &allocs, + IRMapping &irMapping) { + // FIRST: Map `host_eval_vars` to block arguments + unsigned originalHostEvalVarsSize = targetOp.getHostEvalVars().size(); + for (unsigned i = 0; i < hostEvalVars.size(); ++i) { + Value originalValue; + BlockArgument newArg; + if (i < originalHostEvalVarsSize) { + originalValue = targetBlock->getArgument(i); // Host_eval args come first + newArg = newTargetBlock->addArgument(originalValue.getType(), + originalValue.getLoc()); + } else { + originalValue = hostEvalVars[i]; + newArg = newTargetBlock->addArgument(originalValue.getType(), + originalValue.getLoc()); + } + irMapping.map(originalValue, newArg); + } + + // SECOND: Map `map_operands` to block arguments + unsigned originalMapVarsSize = targetOp.getMapVars().size(); + for (unsigned i = 0; i < mapOperands.size(); ++i) { + Value originalValue; + BlockArgument newArg; + // Map the new arguments from the original block. + if (i < originalMapVarsSize) { + originalValue = targetBlock->getArgument(originalHostEvalVarsSize + + i); // Offset by host_eval count + newArg = newTargetBlock->addArgument(originalValue.getType(), + originalValue.getLoc()); + } + // Map the new arguments from the `allocs`. + else { + originalValue = allocs[i - originalMapVarsSize]; + newArg = newTargetBlock->addArgument( + getPtrTypeForOmp(originalValue.getType()), originalValue.getLoc()); + } + irMapping.map(originalValue, newArg); + } + + // THIRD: Map `private_vars` to block arguments (if any) + unsigned originalPrivateVarsSize = targetOp.getPrivateVars().size(); + for (unsigned i = 0; i < originalPrivateVarsSize; ++i) { + auto originalArg = targetBlock->getArgument(originalHostEvalVarsSize + + originalMapVarsSize + i); + auto newArg = newTargetBlock->addArgument(originalArg.getType(), + originalArg.getLoc()); + irMapping.map(originalArg, newArg); + } + return; +} + +/// reloadCacheAndRecompute reloads cached values and recomputes operations +static void reloadCacheAndRecompute( + Location loc, RewriterBase &rewriter, Operation *splitBefore, + omp::TargetOp &targetOp, Block *targetBlock, Block *newTargetBlock, + SmallVector &hostEvalVars, SmallVector &mapOperands, + SmallVector &allocs, SetVector &toRecompute, + IRMapping &irMapping) { + // Handle the load operations for the allocs. + rewriter.setInsertionPointToStart(newTargetBlock); + auto llvmPtrTy = LLVM::LLVMPointerType::get(targetOp.getContext()); + + unsigned originalMapVarsSize = targetOp.getMapVars().size(); + unsigned hostEvalVarsSize = hostEvalVars.size(); + // Create load operations for each allocated variable. + for (unsigned i = 0; i < allocs.size(); ++i) { + Value original = allocs[i]; + // Get the new block argument for this specific allocated value. + Value newArg = + newTargetBlock->getArgument(hostEvalVarsSize + originalMapVarsSize + i); + Value restored; + // If the original value is a pointer or reference, load and convert if + // necessary. + if (isPtr(original.getType())) { + restored = rewriter.create(loc, llvmPtrTy, newArg); + if (!isa(original.getType())) + restored = + rewriter.create(loc, original.getType(), restored); + } else { + restored = rewriter.create(loc, newArg); + } + irMapping.map(original, restored); + } + // Clone the operations if they are in the toRecompute set. + for (auto it = targetBlock->begin(); it != splitBefore->getIterator(); it++) { + if (toRecompute.contains(&*it)) + rewriter.clone(*it, irMapping); + } +} + +/// Given a teamsOp, navigate down the nested structure to find the +/// innermost LoopNestOp. The expected nesting is: +/// teams -> parallel -> distribute -> wsloop -> loop_nest +static mlir::omp::LoopNestOp getLoopNestFromTeams(mlir::omp::TeamsOp teamsOp) { + if (teamsOp.getRegion().empty()) + return nullptr; + // Ensure the teams region has a single block. + if (teamsOp.getRegion().getBlocks().size() != 1) + return nullptr; + // Find parallel op inside teams + mlir::omp::ParallelOp parallelOp = nullptr; + // Look for the parallel op in the teams region + for (auto &op : teamsOp.getRegion().front()) { + if (auto parallel = dyn_cast(op)) { + parallelOp = parallel; + break; + } + } + if (!parallelOp) + return nullptr; + + // Find distribute op inside parallel + mlir::omp::DistributeOp distributeOp = nullptr; + for (auto &op : parallelOp.getRegion().front()) { + if (auto distribute = dyn_cast(op)) { + distributeOp = distribute; + break; + } + } + if (!distributeOp) + return nullptr; + + // Find wsloop op inside distribute + mlir::omp::WsloopOp wsloopOp = nullptr; + for (auto &op : distributeOp.getRegion().front()) { + if (auto wsloop = dyn_cast(op)) { + wsloopOp = wsloop; + break; + } + } + if (!wsloopOp) + return nullptr; + + // Find loop_nest op inside wsloop + for (auto &op : wsloopOp.getRegion().front()) { + if (auto loopNest = dyn_cast(op)) { + return loopNest; + } + } + + return nullptr; +} + +/// Generate LLVM constant operations for i32 and i64 types. +static mlir::LLVM::ConstantOp +genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) { + mlir::Type i32Ty = rewriter.getI32Type(); + mlir::IntegerAttr attr = rewriter.getI32IntegerAttr(value); + return rewriter.create(loc, i32Ty, attr); +} + +/// Given a box descriptor, extract the base address of the data it describes. +/// If the box descriptor is a reference, load it first. +/// The base address is returned as an i8* pointer. +static Value genDescriptorGetBaseAddress(fir::FirOpBuilder &builder, + Location loc, Value boxDesc) { + Value box = boxDesc; + if (auto refBox = dyn_cast(boxDesc.getType())) { + box = fir::LoadOp::create(builder, loc, boxDesc); + } + assert(isa(box.getType()) && + "Unknown type passed to genDescriptorGetBaseAddress"); + auto i8Type = builder.getI8Type(); + auto unknownArrayType = + fir::SequenceType::get({fir::SequenceType::getUnknownExtent()}, i8Type); + auto i8BoxType = fir::BoxType::get(unknownArrayType); + auto typedBox = fir::ConvertOp::create(builder, loc, i8BoxType, box); + auto rawAddr = fir::BoxAddrOp::create(builder, loc, typedBox); + return rawAddr; +} + +/// Given a box descriptor, extract the total number of elements in the array it +/// describes. If the box descriptor is a reference, load it first. +/// The total number of elements is returned as an i64 value. +static Value genDescriptorGetTotalElements(fir::FirOpBuilder &builder, + Location loc, Value boxDesc) { + Value box = boxDesc; + if (auto refBox = dyn_cast(boxDesc.getType())) { + box = fir::LoadOp::create(builder, loc, boxDesc); + } + assert(isa(box.getType()) && + "Unknown type passed to genDescriptorGetTotalElements"); + auto i64Type = builder.getI64Type(); + return fir::BoxTotalElementsOp::create(builder, loc, i64Type, box); +} + +/// Given a box descriptor, extract the size of each element in the array it +/// describes. If the box descriptor is a reference, load it first. +/// The element size is returned as an i64 value. +static Value genDescriptorGetEleSize(fir::FirOpBuilder &builder, Location loc, + Value boxDesc) { + Value box = boxDesc; + if (auto refBox = dyn_cast(boxDesc.getType())) { + box = fir::LoadOp::create(builder, loc, boxDesc); + } + assert(isa(box.getType()) && + "Unknown type passed to genDescriptorGetElementSize"); + auto i64Type = builder.getI64Type(); + return fir::BoxEleSizeOp::create(builder, loc, i64Type, box); +} + +/// Given a box descriptor, compute the total size in bytes of the data it +/// describes. This is done by multiplying the total number of elements by the +/// size of each element. If the box descriptor is a reference, load it first. +/// The total size in bytes is returned as an i64 value. +static Value genDescriptorGetDataSizeInBytes(fir::FirOpBuilder &builder, + Location loc, Value boxDesc) { + Value box = boxDesc; + if (auto refBox = dyn_cast(boxDesc.getType())) { + box = fir::LoadOp::create(builder, loc, boxDesc); + } + assert(isa(box.getType()) && + "Unknown type passed to genDescriptorGetElementSize"); + Value eleSize = genDescriptorGetEleSize(builder, loc, box); + Value totalElements = genDescriptorGetTotalElements(builder, loc, box); + return mlir::arith::MulIOp::create(builder, loc, totalElements, eleSize); +} + +/// Generate a call to the OpenMP runtime function `omp_get_mapped_ptr` to +/// retrieve the device pointer corresponding to a given host pointer and device +/// number. If no mapping exists, the original host pointer is returned. +/// Signature: +/// void *omp_get_mapped_ptr(void *host_ptr, int device_num); +static mlir::Value genOmpGetMappedPtrIfPresent(fir::FirOpBuilder &builder, + mlir::Location loc, + mlir::Value hostPtr, + mlir::Value deviceNum, + mlir::ModuleOp module) { + auto *context = builder.getContext(); + auto voidPtrType = fir::LLVMPointerType::get(context, builder.getI8Type()); + auto i32Type = builder.getI32Type(); + auto funcName = "omp_get_mapped_ptr"; + auto funcOp = module.lookupSymbol(funcName); + + if (!funcOp) { + auto funcType = + mlir::FunctionType::get(context, {voidPtrType, i32Type}, {voidPtrType}); + + mlir::OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(module.getBody()); + + funcOp = mlir::func::FuncOp::create(builder, loc, funcName, funcType); + funcOp.setPrivate(); + } + + llvm::SmallVector args; + args.push_back(fir::ConvertOp::create(builder, loc, voidPtrType, hostPtr)); + args.push_back(fir::ConvertOp::create(builder, loc, i32Type, deviceNum)); + auto callOp = fir::CallOp::create(builder, loc, funcOp, args); + auto mappedPtr = callOp.getResult(0); + auto isNull = builder.genIsNullAddr(loc, mappedPtr); + auto convertedHostPtr = + fir::ConvertOp::create(builder, loc, voidPtrType, hostPtr); + auto result = arith::SelectOp::create(builder, loc, isNull, convertedHostPtr, + mappedPtr); + return result; +} + +/// Generate a call to the OpenMP runtime function `omp_target_memcpy` to +/// perform memory copy between host and device or between devices. +/// Signature: +/// int omp_target_memcpy(void *dst, const void *src, size_t length, +/// size_t dst_offset, size_t src_offset, +/// int dst_device, int src_device); +static void genOmpTargetMemcpyCall(fir::FirOpBuilder &builder, + mlir::Location loc, mlir::Value dst, + mlir::Value src, mlir::Value length, + mlir::Value dstOffset, mlir::Value srcOffset, + mlir::Value device, mlir::ModuleOp module) { + auto *context = builder.getContext(); + auto funcName = "omp_target_memcpy"; + auto voidPtrType = fir::LLVMPointerType::get(context, builder.getI8Type()); + auto sizeTType = builder.getI64Type(); // assuming size_t is 64-bit + auto i32Type = builder.getI32Type(); + auto funcOp = module.lookupSymbol(funcName); + + if (!funcOp) { + mlir::OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(module.getBody()); + llvm::SmallVector argTypes = { + voidPtrType, voidPtrType, sizeTType, sizeTType, + sizeTType, i32Type, i32Type}; + auto funcType = mlir::FunctionType::get(context, argTypes, {i32Type}); + funcOp = mlir::func::FuncOp::create(builder, loc, funcName, funcType); + funcOp.setPrivate(); + } + + llvm::SmallVector args{dst, src, length, dstOffset, + srcOffset, device, device}; + fir::CallOp::create(builder, loc, funcOp, args); + return; +} + +/// Generate code to replace a Fortran array assignment call with OpenMP +/// runtime calls to perform the equivalent operation on the device. +/// This involves extracting the source and destination pointers from the +/// Fortran array descriptors, retrieving their mapped device pointers (if any), +/// and invoking `omp_target_memcpy` to copy the data on the device. +static void genFortranAssignOmpReplacement(fir::FirOpBuilder &builder, + mlir::Location loc, + fir::CallOp callOp, + mlir::Value device, + mlir::ModuleOp module) { + assert(callOp.getNumResults() == 0 && + "Expected _FortranAAssign to have no results"); + assert(callOp.getNumOperands() >= 2 && + "Expected _FortranAAssign to have at least two operands"); + + // Extract the source and destination pointers from the call operands. + mlir::Value dest = callOp.getOperand(0); + mlir::Value src = callOp.getOperand(1); + + // Get the base addresses of the source and destination arrays. + mlir::Value srcBase = genDescriptorGetBaseAddress(builder, loc, src); + mlir::Value destBase = genDescriptorGetBaseAddress(builder, loc, dest); + + // Get the total size in bytes of the data to be copied. + mlir::Value srcDataSize = genDescriptorGetDataSizeInBytes(builder, loc, src); + + // Retrieve the mapped device pointers for source and destination. + // If no mapping exists, the original host pointer is used. + Value destPtr = + genOmpGetMappedPtrIfPresent(builder, loc, destBase, device, module); + Value srcPtr = + genOmpGetMappedPtrIfPresent(builder, loc, srcBase, device, module); + Value zero = builder.create(loc, builder.getI64Type(), + builder.getI64IntegerAttr(0)); + + // Generate the call to omp_target_memcpy to perform the data copy on the + // device. + genOmpTargetMemcpyCall(builder, loc, destPtr, srcPtr, srcDataSize, zero, zero, + device, module); +} + +/// Struct to hold the host eval vars corresponding to loop bounds and steps +struct HostEvalVars { + SmallVector lbs; + SmallVector ubs; + SmallVector steps; +}; + +/// moveToHost method clones all the ops from target region outside of it. +/// It hoists runtime function "_FortranAAssign" and replaces it with omp +/// version. Also hoists and replaces fir.allocmem with omp.target_allocmem and +/// fir.freemem with omp.target_freemem +static LogicalResult moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter, + mlir::ModuleOp module, + struct HostEvalVars &hostEvalVars) { + OpBuilder::InsertionGuard guard(rewriter); + Block *targetBlock = &targetOp.getRegion().front(); + assert(targetBlock == &targetOp.getRegion().back()); + IRMapping mapping; + + // Get the parent target_data op + auto targetDataOp = cast(targetOp->getParentOp()); + if (!targetDataOp) { + emitError(targetOp->getLoc(), + "Expected target op to be inside target_data op"); + return failure(); + } + // create mapping for host_eval_vars + unsigned hostEvalVarCount = targetOp.getHostEvalVars().size(); + for (unsigned i = 0; i < targetOp.getHostEvalVars().size(); ++i) { + Value hostEvalVar = targetOp.getHostEvalVars()[i]; + BlockArgument arg = targetBlock->getArguments()[i]; + mapping.map(arg, hostEvalVar); + } + // create mapping for map_vars + for (unsigned i = 0; i < targetOp.getMapVars().size(); ++i) { + Value mapInfo = targetOp.getMapVars()[i]; + BlockArgument arg = targetBlock->getArguments()[hostEvalVarCount + i]; + Operation *op = mapInfo.getDefiningOp(); + assert(op); + auto mapInfoOp = cast(op); + // map the block argument to the host-side variable pointer + mapping.map(arg, mapInfoOp.getVarPtr()); + } + // create mapping for private_vars + unsigned mapSize = targetOp.getMapVars().size(); + for (unsigned i = 0; i < targetOp.getPrivateVars().size(); ++i) { + Value privateVar = targetOp.getPrivateVars()[i]; + // The mapping should link the device-side variable to the host-side one. + BlockArgument arg = + targetBlock->getArguments()[hostEvalVarCount + mapSize + i]; + // Map the device-side copy (`arg`) to the host-side value (`privateVar`). + mapping.map(arg, privateVar); + } + + rewriter.setInsertionPoint(targetOp); + SmallVector opsToReplace; + Value device = targetOp.getDevice(); + + // If device is not specified, default to device 0. + if (!device) { + device = genI32Constant(targetOp.getLoc(), rewriter, 0); + } + // Clone all operations. + for (auto it = targetBlock->begin(), end = std::prev(targetBlock->end()); + it != end; ++it) { + auto *op = &*it; + Operation *clonedOp = rewriter.clone(*op, mapping); + // Map the results of the original op to the cloned op. + for (unsigned i = 0; i < op->getNumResults(); ++i) { + mapping.map(op->getResult(i), clonedOp->getResult(i)); + } + // fir.declare changes its type when hoisting it out of omp.target to + // omp.target_data Introduce a load, if original declareOp input is not of + // reference type, but cloned delcareOp input is reference type. + if (fir::DeclareOp clonedDeclareOp = dyn_cast(clonedOp)) { + auto originalDeclareOp = cast(op); + Type originalInType = originalDeclareOp.getMemref().getType(); + Type clonedInType = clonedDeclareOp.getMemref().getType(); + + fir::ReferenceType originalRefType = + dyn_cast(originalInType); + fir::ReferenceType clonedRefType = + dyn_cast(clonedInType); + if (!originalRefType && clonedRefType) { + Type clonedEleTy = clonedRefType.getElementType(); + if (clonedEleTy == originalDeclareOp.getType()) { + opsToReplace.push_back(clonedOp); + } + } + } + // Collect the ops to be replaced. + if (isa(clonedOp) || isa(clonedOp)) + opsToReplace.push_back(clonedOp); + // Check for runtime calls to be replaced. + if (isRuntimeCall(clonedOp)) { + fir::CallOp runtimeCall = cast(op); + auto funcName = (*runtimeCall.getCallee()).getRootReference().getValue(); + if (funcName == FortranAssignStr) { + opsToReplace.push_back(clonedOp); + } else { + emitError(runtimeCall->getLoc(), "Unhandled runtime call hoisting."); + return failure(); + } + } + } + // Replace fir.allocmem with omp.target_allocmem. + for (Operation *op : opsToReplace) { + if (auto allocOp = dyn_cast(op)) { + rewriter.setInsertionPoint(allocOp); + auto ompAllocmemOp = rewriter.create( + allocOp.getLoc(), rewriter.getI64Type(), device, + allocOp.getInTypeAttr(), allocOp.getUniqNameAttr(), + allocOp.getBindcNameAttr(), allocOp.getTypeparams(), + allocOp.getShape()); + auto firConvertOp = rewriter.create( + allocOp.getLoc(), allocOp.getResult().getType(), + ompAllocmemOp.getResult()); + rewriter.replaceOp(allocOp, firConvertOp.getResult()); + } + // Replace fir.freemem with omp.target_freemem. + else if (auto freeOp = dyn_cast(op)) { + rewriter.setInsertionPoint(freeOp); + auto firConvertOp = rewriter.create( + freeOp.getLoc(), rewriter.getI64Type(), freeOp.getHeapref()); + rewriter.create(freeOp.getLoc(), device, + firConvertOp.getResult()); + rewriter.eraseOp(freeOp); + } + // fir.declare changes its type when hoisting it out of omp.target to + // omp.target_data Introduce a load, if original declareOp input is not of + // reference type, but cloned delcareOp input is reference type. + else if (fir::DeclareOp clonedDeclareOp = dyn_cast(op)) { + Type clonedInType = clonedDeclareOp.getMemref().getType(); + fir::ReferenceType clonedRefType = + dyn_cast(clonedInType); + Type clonedEleTy = clonedRefType.getElementType(); + rewriter.setInsertionPoint(op); + Value loadedValue = rewriter.create( + clonedDeclareOp.getLoc(), clonedEleTy, clonedDeclareOp.getMemref()); + clonedDeclareOp.getResult().replaceAllUsesWith(loadedValue); + } + // Replace runtime calls with omp versions. + else if (isRuntimeCall(op)) { + fir::CallOp runtimeCall = cast(op); + auto funcName = (*runtimeCall.getCallee()).getRootReference().getValue(); + if (funcName == FortranAssignStr) { + rewriter.setInsertionPoint(op); + fir::FirOpBuilder builder{rewriter, op}; + + mlir::Location loc = runtimeCall.getLoc(); + genFortranAssignOmpReplacement(builder, loc, runtimeCall, device, + module); + rewriter.eraseOp(op); + } else { + emitError(runtimeCall->getLoc(), "Unhandled runtime call hoisting."); + return failure(); + } + } else { + emitError(op->getLoc(), "Unhandled op hoisting."); + return failure(); + } + } + + // Update the host_eval_vars to use the mapped values. + for (size_t i = 0; i < hostEvalVars.lbs.size(); ++i) { + hostEvalVars.lbs[i] = mapping.lookup(hostEvalVars.lbs[i]); + hostEvalVars.ubs[i] = mapping.lookup(hostEvalVars.ubs[i]); + hostEvalVars.steps[i] = mapping.lookup(hostEvalVars.steps[i]); + } + // Finally erase the original targetOp. + rewriter.eraseOp(targetOp); + return success(); +} + +/// Result of isolateOp method +struct SplitResult { + omp::TargetOp preTargetOp; + omp::TargetOp isolatedTargetOp; + omp::TargetOp postTargetOp; +}; + +/// computeAllocsCacheRecomputable method computes the allocs needed to cache +/// the values that are used outside the split point. It also computes the ops +/// that need to be cached and the ops that can be recomputed after the split. +static void computeAllocsCacheRecomputable( + omp::TargetOp targetOp, Operation *splitBeforeOp, RewriterBase &rewriter, + SmallVector &preMapOperands, SmallVector &postMapOperands, + SmallVector &allocs, SmallVector &requiredVals, + SetVector &nonRecomputable, SetVector &toCache, + SetVector &toRecompute) { + auto *targetBlock = &targetOp.getRegion().front(); + // Find all values that are used outside the split point. + for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator(); + it++) { + // Check if any of the results are used outside the split point. + for (auto res : it->getResults()) { + if (usedOutsideSplit(res, splitBeforeOp)) { + requiredVals.push_back(res); + } + } + // If the op is not recomputable, add it to the nonRecomputable set. + if (!isRecomputableAfterFission(&*it, splitBeforeOp)) { + nonRecomputable.insert(&*it); + } + } + // For each required value, collect its dependencies. + for (auto requiredVal : requiredVals) + collectNonRecomputableDeps(requiredVal, targetOp, nonRecomputable, toCache, + toRecompute); + // For each op in toCache, create an alloc and update the pre and post map + // operands. + for (Operation *op : toCache) { + for (auto res : op->getResults()) { + auto alloc = + allocateTempOmpVar(targetOp.getLoc(), res.getType(), rewriter); + allocs.push_back(res); + preMapOperands.push_back(alloc.from); + postMapOperands.push_back(alloc.to); + } + } +} + +/// genPreTargetOp method generates the preTargetOp that contains all the ops +/// before the split point. It also creates the block arguments and maps the +/// values accordingly. It also creates the store operations for the allocs. +static omp::TargetOp +genPreTargetOp(omp::TargetOp targetOp, SmallVector &preMapOperands, + SmallVector &allocs, Operation *splitBeforeOp, + RewriterBase &rewriter, struct HostEvalVars &hostEvalVars, + bool isTargetDevice) { + auto loc = targetOp.getLoc(); + auto *targetBlock = &targetOp.getRegion().front(); + SmallVector preHostEvalVars{targetOp.getHostEvalVars()}; + // update the hostEvalVars of preTargetOp + omp::TargetOp preTargetOp = rewriter.create( + targetOp.getLoc(), targetOp.getAllocateVars(), + targetOp.getAllocatorVars(), targetOp.getBareAttr(), + targetOp.getDependKindsAttr(), targetOp.getDependVars(), + targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), preHostEvalVars, + targetOp.getIfExpr(), targetOp.getInReductionVars(), + targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), + targetOp.getIsDevicePtrVars(), preMapOperands, targetOp.getNowaitAttr(), + targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(), + targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(), + targetOp.getPrivateMapsAttr()); + auto *preTargetBlock = rewriter.createBlock( + &preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {}); + IRMapping preMapping; + // Create block arguments and map the values. + createBlockArgsAndMap(loc, rewriter, targetOp, targetBlock, preTargetBlock, + preHostEvalVars, preMapOperands, allocs, preMapping); + + // Handle the store operations for the allocs. + rewriter.setInsertionPointToStart(preTargetBlock); + auto llvmPtrTy = LLVM::LLVMPointerType::get(targetOp.getContext()); + + // Clone the original operations. + for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator(); + it++) { + rewriter.clone(*it, preMapping); + } + + unsigned originalHostEvalVarsSize = preHostEvalVars.size(); + unsigned originalMapVarsSize = targetOp.getMapVars().size(); + // Create Stores for allocs. + for (unsigned i = 0; i < allocs.size(); ++i) { + Value originalResult = allocs[i]; + Value toStore = preMapping.lookup(originalResult); + // Get the new block argument for this specific allocated value. + Value newArg = preTargetBlock->getArgument(originalHostEvalVarsSize + + originalMapVarsSize + i); + // Create the store operation. + if (isPtr(originalResult.getType())) { + if (!isa(toStore.getType())) + toStore = rewriter.create(loc, llvmPtrTy, toStore); + rewriter.create(loc, toStore, newArg); + } else { + rewriter.create(loc, toStore, newArg); + } + } + rewriter.create(loc); + + // Update hostEvalVars with the mapped values for the loop bounds if we have + // a loopNestOp and we are not generating code for the target device. + omp::LoopNestOp loopNestOp = + getLoopNestFromTeams(cast(splitBeforeOp)); + if (loopNestOp && !isTargetDevice) { + for (size_t i = 0; i < loopNestOp.getLoopLowerBounds().size(); ++i) { + Value lb = loopNestOp.getLoopLowerBounds()[i]; + Value ub = loopNestOp.getLoopUpperBounds()[i]; + Value step = loopNestOp.getLoopSteps()[i]; + + hostEvalVars.lbs.push_back(preMapping.lookup(lb)); + hostEvalVars.ubs.push_back(preMapping.lookup(ub)); + hostEvalVars.steps.push_back(preMapping.lookup(step)); + } + } + + return preTargetOp; +} + +/// genIsolatedTargetOp method generates the isolatedTargetOp that contains the +/// ops between the split point. It also creates the block arguments and maps +/// the values accordingly. It also creates the load operations for the allocs +/// and recomputes the necessary ops. +static omp::TargetOp +genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector &postMapOperands, + Operation *splitBeforeOp, RewriterBase &rewriter, + SmallVector &allocs, + SetVector &toRecompute, + struct HostEvalVars &hostEvalVars, bool isTargetDevice) { + auto loc = targetOp.getLoc(); + auto *targetBlock = &targetOp.getRegion().front(); + SmallVector isolatedHostEvalVars{targetOp.getHostEvalVars()}; + // update the hostEvalVars of isolatedTargetOp + if (!hostEvalVars.lbs.empty() && !isTargetDevice) { + isolatedHostEvalVars.append(hostEvalVars.lbs.begin(), + hostEvalVars.lbs.end()); + isolatedHostEvalVars.append(hostEvalVars.ubs.begin(), + hostEvalVars.ubs.end()); + isolatedHostEvalVars.append(hostEvalVars.steps.begin(), + hostEvalVars.steps.end()); + } + // Create the isolated target op + omp::TargetOp isolatedTargetOp = rewriter.create( + targetOp.getLoc(), targetOp.getAllocateVars(), + targetOp.getAllocatorVars(), targetOp.getBareAttr(), + targetOp.getDependKindsAttr(), targetOp.getDependVars(), + targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), + isolatedHostEvalVars, targetOp.getIfExpr(), targetOp.getInReductionVars(), + targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), + targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(), + targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(), + targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(), + targetOp.getPrivateMapsAttr()); + auto *isolatedTargetBlock = + rewriter.createBlock(&isolatedTargetOp.getRegion(), + isolatedTargetOp.getRegion().begin(), {}, {}); + IRMapping isolatedMapping; + // Create block arguments and map the values. + createBlockArgsAndMap(loc, rewriter, targetOp, targetBlock, + isolatedTargetBlock, isolatedHostEvalVars, + postMapOperands, allocs, isolatedMapping); + // Handle the load operations for the allocs and recompute ops. + reloadCacheAndRecompute(loc, rewriter, splitBeforeOp, targetOp, targetBlock, + isolatedTargetBlock, isolatedHostEvalVars, + postMapOperands, allocs, toRecompute, + isolatedMapping); + + // Clone the original operations. + rewriter.clone(*splitBeforeOp, isolatedMapping); + rewriter.create(loc); + + // update the loop bounds in the isolatedTargetOp if we have host_eval vars + // and we are not generating code for the target device. + if (!hostEvalVars.lbs.empty() && !isTargetDevice) { + omp::TeamsOp teamsOp; + for (auto &op : *isolatedTargetBlock) { + if (isa(&op)) + teamsOp = cast(&op); + } + assert(teamsOp && "No teamsOp found in isolated target region"); + // Get the loopNestOp inside the teamsOp + auto loopNestOp = getLoopNestFromTeams(teamsOp); + // Get the BlockArgs related to host_eval vars and update loop_nest bounds + // to them + unsigned originalHostEvalVarsSize = targetOp.getHostEvalVars().size(); + unsigned index = originalHostEvalVarsSize; + // Replace loop bounds with the block arguments passed down via host_eval + SmallVector lbs, ubs, steps; + + // Collect new lb/ub/step values from target block args + for (size_t i = 0; i < hostEvalVars.lbs.size(); ++i) + lbs.push_back(isolatedTargetBlock->getArgument(index++)); + + for (size_t i = 0; i < hostEvalVars.ubs.size(); ++i) + ubs.push_back(isolatedTargetBlock->getArgument(index++)); + + for (size_t i = 0; i < hostEvalVars.steps.size(); ++i) + steps.push_back(isolatedTargetBlock->getArgument(index++)); + + // Reset the loop bounds + loopNestOp.getLoopLowerBoundsMutable().assign(lbs); + loopNestOp.getLoopUpperBoundsMutable().assign(ubs); + loopNestOp.getLoopStepsMutable().assign(steps); + } + + return isolatedTargetOp; +} + +/// genPostTargetOp method generates the postTargetOp that contains all the ops +/// after the split point. It also creates the block arguments and maps the +/// values accordingly. It also creates the load operations for the allocs +/// and recomputes the necessary ops. +static omp::TargetOp genPostTargetOp(omp::TargetOp targetOp, + Operation *splitBeforeOp, + SmallVector &postMapOperands, + RewriterBase &rewriter, + SmallVector &allocs, + SetVector &toRecompute) { + auto loc = targetOp.getLoc(); + auto *targetBlock = &targetOp.getRegion().front(); + SmallVector postHostEvalVars{targetOp.getHostEvalVars()}; + // Create the post target op + omp::TargetOp postTargetOp = rewriter.create( + targetOp.getLoc(), targetOp.getAllocateVars(), + targetOp.getAllocatorVars(), targetOp.getBareAttr(), + targetOp.getDependKindsAttr(), targetOp.getDependVars(), + targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), postHostEvalVars, + targetOp.getIfExpr(), targetOp.getInReductionVars(), + targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), + targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(), + targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(), + targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(), + targetOp.getPrivateMapsAttr()); + // Create the block for postTargetOp + auto *postTargetBlock = rewriter.createBlock( + &postTargetOp.getRegion(), postTargetOp.getRegion().begin(), {}, {}); + IRMapping postMapping; + // Create block arguments and map the values. + createBlockArgsAndMap(loc, rewriter, targetOp, targetBlock, postTargetBlock, + postHostEvalVars, postMapOperands, allocs, postMapping); + // Handle the load operations for the allocs and recompute ops. + reloadCacheAndRecompute(loc, rewriter, splitBeforeOp, targetOp, targetBlock, + postTargetBlock, postHostEvalVars, postMapOperands, + allocs, toRecompute, postMapping); + assert(splitBeforeOp->getNumResults() == 0 || + llvm::all_of(splitBeforeOp->getResults(), + [](Value result) { return result.use_empty(); })); + // Clone the original operations after the split point. + for (auto it = std::next(splitBeforeOp->getIterator()); + it != targetBlock->end(); it++) + rewriter.clone(*it, postMapping); + return postTargetOp; +} + +/// isolateOp method rewrites a omp.target_data { omp.target } in to +/// omp.target_data { +/// // preTargetOp region contains ops before splitBeforeOp. +/// omp.target {} +/// // isolatedTargetOp region contains splitBeforeOp, +/// omp.target {} +/// // postTargetOp region contains ops after splitBeforeOp. +/// omp.target {} +/// } +/// It also handles the mapping of variables and the caching/recomputing +/// of values as needed. +static FailureOr isolateOp(Operation *splitBeforeOp, + bool splitAfter, RewriterBase &rewriter, + mlir::ModuleOp module, + bool isTargetDevice) { + auto targetOp = cast(splitBeforeOp->getParentOp()); + assert(targetOp); + rewriter.setInsertionPoint(targetOp); + + // Prepare the map operands for preTargetOp and postTargetOp + auto preMapOperands = SmallVector(targetOp.getMapVars()); + auto postMapOperands = SmallVector(targetOp.getMapVars()); + + // Vectors to hold analysis results + SmallVector requiredVals; + SetVector toCache; + SetVector toRecompute; + SetVector nonRecomputable; + SmallVector allocs; + struct HostEvalVars hostEvalVars; + + // Analyze the ops in target region to determine which ops need to be + // cached and which ops need to be recomputed + computeAllocsCacheRecomputable( + targetOp, splitBeforeOp, rewriter, preMapOperands, postMapOperands, + allocs, requiredVals, nonRecomputable, toCache, toRecompute); + + rewriter.setInsertionPoint(targetOp); + + // Generate the preTargetOp that contains all the ops before splitBeforeOp. + auto preTargetOp = + genPreTargetOp(targetOp, preMapOperands, allocs, splitBeforeOp, rewriter, + hostEvalVars, isTargetDevice); + + // Move the ops of preTarget to host. + auto res = moveToHost(preTargetOp, rewriter, module, hostEvalVars); + if (failed(res)) + return failure(); + rewriter.setInsertionPoint(targetOp); + + // Generate the isolatedTargetOp + omp::TargetOp isolatedTargetOp = + genIsolatedTargetOp(targetOp, postMapOperands, splitBeforeOp, rewriter, + allocs, toRecompute, hostEvalVars, isTargetDevice); + + omp::TargetOp postTargetOp = nullptr; + // Generate the postTargetOp that contains all the ops after splitBeforeOp. + if (splitAfter) { + rewriter.setInsertionPoint(targetOp); + postTargetOp = genPostTargetOp(targetOp, splitBeforeOp, postMapOperands, + rewriter, allocs, toRecompute); + } + // Finally erase the original targetOp. + rewriter.eraseOp(targetOp); + return SplitResult{preTargetOp, isolatedTargetOp, postTargetOp}; +} + +/// Recursively fission target ops until no more nested ops can be isolated. +static LogicalResult fissionTarget(omp::TargetOp targetOp, + RewriterBase &rewriter, + mlir::ModuleOp module, bool isTargetDevice) { + auto tuple = getNestedOpToIsolate(targetOp); + if (!tuple) { + LLVM_DEBUG(llvm::dbgs() << " No op to isolate\n"); + struct HostEvalVars hostEvalVars; + return moveToHost(targetOp, rewriter, module, hostEvalVars); + } + Operation *toIsolate = std::get<0>(*tuple); + bool splitBefore = !std::get<1>(*tuple); + bool splitAfter = !std::get<2>(*tuple); + // Recursively isolate the target op. + if (splitBefore && splitAfter) { + auto res = + isolateOp(toIsolate, splitAfter, rewriter, module, isTargetDevice); + if (failed(res)) + return failure(); + return fissionTarget((*res).postTargetOp, rewriter, module, isTargetDevice); + } + // Isolate only before the op. + if (splitBefore) { + auto res = + isolateOp(toIsolate, splitAfter, rewriter, module, isTargetDevice); + if (failed(res)) + return failure(); + } else { + emitError(toIsolate->getLoc(), "Unhandled case in fissionTarget"); + return failure(); + } + return success(); +} + +/// Pass to lower omp.workdistribute ops. +class LowerWorkdistributePass + : public flangomp::impl::LowerWorkdistributeBase { +public: + void runOnOperation() override { + MLIRContext &context = getContext(); + auto moduleOp = getOperation(); + bool changed = false; + SetVector targetOpsToProcess; + auto verify = + moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { + if (failed(verifyTargetTeamsWorkdistribute(workdistribute))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + if (verify.wasInterrupted()) + return signalPassFailure(); + + auto fission = + moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { + auto res = fissionWorkdistribute(workdistribute); + if (failed(res)) + return WalkResult::interrupt(); + changed |= *res; + return WalkResult::advance(); + }); + if (fission.wasInterrupted()) + return signalPassFailure(); + + auto rtCallLower = + moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { + auto res = workdistributeRuntimeCallLower(workdistribute, + targetOpsToProcess); + if (failed(res)) + return WalkResult::interrupt(); + changed |= *res; + return WalkResult::advance(); + }); + if (rtCallLower.wasInterrupted()) + return signalPassFailure(); + + moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { + changed |= workdistributeDoLower(workdistribute, targetOpsToProcess); + }); + + moduleOp->walk([&](mlir::omp::TeamsOp teams) { + changed |= teamsWorkdistributeToSingleOp(teams, targetOpsToProcess); + }); + if (changed) { + bool isTargetDevice = + llvm::cast(*moduleOp) + .getIsTargetDevice(); + IRRewriter rewriter(&context); + for (auto targetOp : targetOpsToProcess) { + auto res = splitTargetData(targetOp, rewriter); + if (failed(res)) + return signalPassFailure(); + if (*res) { + if (failed(fissionTarget(*res, rewriter, moduleOp, isTargetDevice))) + return signalPassFailure(); + } + } + } + } +}; +} // namespace diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp index a83b0665eaf1f..1ecb6d383f939 100644 --- a/flang/lib/Optimizer/Passes/Pipelines.cpp +++ b/flang/lib/Optimizer/Passes/Pipelines.cpp @@ -301,8 +301,10 @@ void createHLFIRToFIRPassPipeline(mlir::PassManager &pm, addNestedPassToAllTopLevelOperations( pm, hlfir::createInlineHLFIRAssign); pm.addPass(hlfir::createConvertHLFIRtoFIR()); - if (enableOpenMP != EnableOpenMP::None) + if (enableOpenMP != EnableOpenMP::None) { pm.addPass(flangomp::createLowerWorkshare()); + pm.addPass(flangomp::createLowerWorkdistribute()); + } if (enableOpenMP == EnableOpenMP::Simd) pm.addPass(flangomp::createSimdOnlyPass()); } diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir index 195e5ad7f9dc8..59f6c73ae84ee 100644 --- a/flang/test/Fir/basic-program.fir +++ b/flang/test/Fir/basic-program.fir @@ -69,6 +69,7 @@ func.func @_QQmain() { // PASSES-NEXT: InlineHLFIRAssign // PASSES-NEXT: ConvertHLFIRtoFIR // PASSES-NEXT: LowerWorkshare +// PASSES-NEXT: LowerWorkdistribute // PASSES-NEXT: CSE // PASSES-NEXT: (S) 0 num-cse'd - Number of operations CSE'd // PASSES-NEXT: (S) 0 num-dce'd - Number of operations DCE'd diff --git a/flang/test/Lower/OpenMP/workdistribute-multiple.f90 b/flang/test/Lower/OpenMP/workdistribute-multiple.f90 new file mode 100644 index 0000000000000..cf1d9dd294cea --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute-multiple.f90 @@ -0,0 +1,20 @@ +! RUN: not %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - 2>&1 | FileCheck %s + +! CHECK: error: teams has multiple workdistribute ops. +! CHECK-LABEL: func @_QPteams_workdistribute_1 +subroutine teams_workdistribute_1() + use iso_fortran_env + real(kind=real32) :: a + real(kind=real32), dimension(10) :: x + real(kind=real32), dimension(10) :: y + !$omp teams + + !$omp workdistribute + y = a * x + y + !$omp end workdistribute + + !$omp workdistribute + y = a * y + x + !$omp end workdistribute + !$omp end teams +end subroutine teams_workdistribute_1 diff --git a/flang/test/Lower/OpenMP/workdistribute-saxpy-1d.f90 b/flang/test/Lower/OpenMP/workdistribute-saxpy-1d.f90 new file mode 100644 index 0000000000000..b2dbc0f15121e --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute-saxpy-1d.f90 @@ -0,0 +1,39 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s + +! CHECK-LABEL: func @_QPtarget_teams_workdistribute +subroutine target_teams_workdistribute() + use iso_fortran_env + real(kind=real32) :: a + real(kind=real32), dimension(10) :: x + real(kind=real32), dimension(10) :: y + + ! CHECK: omp.target_data + ! CHECK: omp.target + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + + !$omp target teams workdistribute + y = a * x + y + !$omp end target teams workdistribute +end subroutine target_teams_workdistribute + +! CHECK-LABEL: func @_QPteams_workdistribute +subroutine teams_workdistribute() + use iso_fortran_env + real(kind=real32) :: a + real(kind=real32), dimension(10) :: x + real(kind=real32), dimension(10) :: y + + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + + !$omp teams workdistribute + y = a * x + y + !$omp end teams workdistribute +end subroutine teams_workdistribute diff --git a/flang/test/Lower/OpenMP/workdistribute-saxpy-2d.f90 b/flang/test/Lower/OpenMP/workdistribute-saxpy-2d.f90 new file mode 100644 index 0000000000000..09e1211541edb --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute-saxpy-2d.f90 @@ -0,0 +1,45 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s + +! CHECK-LABEL: func @_QPtarget_teams_workdistribute +subroutine target_teams_workdistribute(a, x, y, rows, cols) + use iso_fortran_env + implicit none + + integer, intent(in) :: rows, cols + real(kind=real32) :: a + real(kind=real32), dimension(rows, cols) :: x, y + + ! CHECK: omp.target_data + ! CHECK: omp.target + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + ! CHECK: fir.do_loop + + !$omp target teams workdistribute + y = a * x + y + !$omp end target teams workdistribute +end subroutine target_teams_workdistribute + +! CHECK-LABEL: func @_QPteams_workdistribute +subroutine teams_workdistribute(a, x, y, rows, cols) + use iso_fortran_env + implicit none + + integer, intent(in) :: rows, cols + real(kind=real32) :: a + real(kind=real32), dimension(rows, cols) :: x, y + + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + ! CHECK: fir.do_loop + + !$omp teams workdistribute + y = a * x + y + !$omp end teams workdistribute +end subroutine teams_workdistribute diff --git a/flang/test/Lower/OpenMP/workdistribute-saxpy-3d.f90 b/flang/test/Lower/OpenMP/workdistribute-saxpy-3d.f90 new file mode 100644 index 0000000000000..cf5d0234edb39 --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute-saxpy-3d.f90 @@ -0,0 +1,47 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s + +! CHECK-LABEL: func @_QPtarget_teams_workdistribute +subroutine target_teams_workdistribute(a, x, y, rows, cols, depth) + use iso_fortran_env + implicit none + + integer, intent(in) :: rows, cols, depth + real(kind=real32) :: a + real(kind=real32), dimension(rows, cols, depth) :: x, y + + ! CHECK: omp.target_data + ! CHECK: omp.target + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + ! CHECK: fir.do_loop + ! CHECK: fir.do_loop + + !$omp target teams workdistribute + y = a * x + y + !$omp end target teams workdistribute +end subroutine target_teams_workdistribute + +! CHECK-LABEL: func @_QPteams_workdistribute +subroutine teams_workdistribute(a, x, y, rows, cols, depth) + use iso_fortran_env + implicit none + + integer, intent(in) :: rows, cols, depth + real(kind=real32) :: a + real(kind=real32), dimension(rows, cols, depth) :: x, y + + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + ! CHECK: fir.do_loop + ! CHECK: fir.do_loop + + !$omp teams workdistribute + y = a * x + y + !$omp end teams workdistribute +end subroutine teams_workdistribute diff --git a/flang/test/Lower/OpenMP/workdistribute-saxpy-and-scalar-assign.f90 b/flang/test/Lower/OpenMP/workdistribute-saxpy-and-scalar-assign.f90 new file mode 100644 index 0000000000000..516c4603bd5da --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute-saxpy-and-scalar-assign.f90 @@ -0,0 +1,53 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s + +! CHECK-LABEL: func @_QPtarget_teams_workdistribute +subroutine target_teams_workdistribute() + use iso_fortran_env + real(kind=real32) :: a + real(kind=real32), dimension(10) :: x + real(kind=real32), dimension(10) :: y + !$omp target teams workdistribute + + ! CHECK: omp.target_data + ! CHECK: omp.target + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + + y = a * x + y + + ! CHECK: omp.target + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + + y = 2.0_real32 + + !$omp end target teams workdistribute +end subroutine target_teams_workdistribute + +! CHECK-LABEL: func @_QPteams_workdistribute +subroutine teams_workdistribute() + use iso_fortran_env + real(kind=real32) :: a + real(kind=real32), dimension(10) :: x + real(kind=real32), dimension(10) :: y + !$omp teams workdistribute + + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + + y = a * x + y + + ! CHECK: fir.call @_FortranAAssign + y = 2.0_real32 + + !$omp end teams workdistribute +end subroutine teams_workdistribute diff --git a/flang/test/Lower/OpenMP/workdistribute-saxpy-two-2d.f90 b/flang/test/Lower/OpenMP/workdistribute-saxpy-two-2d.f90 new file mode 100644 index 0000000000000..4aeb2e89140cc --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute-saxpy-two-2d.f90 @@ -0,0 +1,68 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s + +! CHECK-LABEL: func @_QPtarget_teams_workdistribute +subroutine target_teams_workdistribute(a, x, y, rows, cols) + use iso_fortran_env + implicit none + + integer, intent(in) :: rows, cols + real(kind=real32) :: a + real(kind=real32), dimension(rows, cols) :: x, y + + !$omp target teams workdistribute + + ! CHECK: omp.target_data + ! CHECK: omp.target + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + ! CHECK: fir.do_loop + + y = a * x + y + + ! CHECK: omp.target + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + ! CHECK: fir.do_loop + + y = a * y + x + + !$omp end target teams workdistribute +end subroutine target_teams_workdistribute + +! CHECK-LABEL: func @_QPteams_workdistribute +subroutine teams_workdistribute(a, x, y, rows, cols) + use iso_fortran_env + implicit none + + integer, intent(in) :: rows, cols + real(kind=real32) :: a + real(kind=real32), dimension(rows, cols) :: x, y + + !$omp teams workdistribute + + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + ! CHECK: fir.do_loop + + y = a * x + y + + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + ! CHECK: fir.do_loop + + y = a * y + x + + !$omp end teams workdistribute +end subroutine teams_workdistribute diff --git a/flang/test/Lower/OpenMP/workdistribute-scalar-assign.f90 b/flang/test/Lower/OpenMP/workdistribute-scalar-assign.f90 new file mode 100644 index 0000000000000..3062b3598b8ae --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute-scalar-assign.f90 @@ -0,0 +1,29 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s + +! CHECK-LABEL: func @_QPtarget_teams_workdistribute_scalar_assign +subroutine target_teams_workdistribute_scalar_assign() + integer :: aa(10) + + ! CHECK: omp.target_data + ! CHECK: omp.target + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + + !$omp target teams workdistribute + aa = 20 + !$omp end target teams workdistribute + +end subroutine target_teams_workdistribute_scalar_assign + +! CHECK-LABEL: func @_QPteams_workdistribute_scalar_assign +subroutine teams_workdistribute_scalar_assign() + integer :: aa(10) + ! CHECK: fir.call @_FortranAAssign + !$omp teams workdistribute + aa = 20 + !$omp end teams workdistribute + +end subroutine teams_workdistribute_scalar_assign diff --git a/flang/test/Lower/OpenMP/workdistribute-target-teams-clauses.f90 b/flang/test/Lower/OpenMP/workdistribute-target-teams-clauses.f90 new file mode 100644 index 0000000000000..4a08e53bc316a --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute-target-teams-clauses.f90 @@ -0,0 +1,32 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s + +! CHECK-LABEL: func @_QPtarget_teams_workdistribute +! CHECK: omp.target_data map_entries({{.*}}) +! CHECK: omp.target thread_limit({{.*}}) host_eval({{.*}}) map_entries({{.*}}) +! CHECK: omp.teams num_teams({{.*}}) +! CHECK: omp.parallel +! CHECK: omp.distribute +! CHECK: omp.wsloop +! CHECK: omp.loop_nest + +subroutine target_teams_workdistribute() + use iso_fortran_env + real(kind=real32) :: a + real(kind=real32), dimension(10) :: x + real(kind=real32), dimension(10) :: y + integer :: i + + a = 2.0_real32 + x = [(real(i, real32), i = 1, 10)] + y = [(real(i * 0.5, real32), i = 1, 10)] + + !$omp target teams workdistribute & + !$omp& num_teams(4) & + !$omp& thread_limit(8) & + !$omp& default(shared) & + !$omp& private(i) & + !$omp& map(to: x) & + !$omp& map(tofrom: y) + y = a * x + y + !$omp end target teams workdistribute +end subroutine target_teams_workdistribute diff --git a/flang/test/Lower/OpenMP/workdistribute-teams-unsupported-after.f90 b/flang/test/Lower/OpenMP/workdistribute-teams-unsupported-after.f90 new file mode 100644 index 0000000000000..f9c5a771f401d --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute-teams-unsupported-after.f90 @@ -0,0 +1,22 @@ +! RUN: not %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - 2>&1 | FileCheck %s + +! CHECK: error: teams has omp ops other than workdistribute. Lowering not implemented yet. +! CHECK-LABEL: func @_QPteams_workdistribute_1 +subroutine teams_workdistribute_1() + use iso_fortran_env + real(kind=real32) :: a + real(kind=real32), dimension(10) :: x + real(kind=real32), dimension(10) :: y + !$omp teams + + !$omp workdistribute + y = a * x + y + !$omp end workdistribute + + !$omp distribute + do i = 1, 10 + x(i) = real(i, kind=real32) + end do + !$omp end distribute + !$omp end teams +end subroutine teams_workdistribute_1 diff --git a/flang/test/Lower/OpenMP/workdistribute-teams-unsupported-before.f90 b/flang/test/Lower/OpenMP/workdistribute-teams-unsupported-before.f90 new file mode 100644 index 0000000000000..3ef7f90087944 --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute-teams-unsupported-before.f90 @@ -0,0 +1,22 @@ +! RUN: not %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - 2>&1 | FileCheck %s + +! CHECK: error: teams has omp ops other than workdistribute. Lowering not implemented yet. +! CHECK-LABEL: func @_QPteams_workdistribute_1 +subroutine teams_workdistribute_1() + use iso_fortran_env + real(kind=real32) :: a + real(kind=real32), dimension(10) :: x + real(kind=real32), dimension(10) :: y + !$omp teams + + !$omp distribute + do i = 1, 10 + x(i) = real(i, kind=real32) + end do + !$omp end distribute + + !$omp workdistribute + y = a * x + y + !$omp end workdistribute + !$omp end teams +end subroutine teams_workdistribute_1 diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir new file mode 100644 index 0000000000000..00d10d6264ec9 --- /dev/null +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir @@ -0,0 +1,33 @@ +// RUN: fir-opt --lower-workdistribute %s | FileCheck %s + +// CHECK-LABEL: func.func @x({{.*}}) +// CHECK: omp.teams { +// CHECK: omp.parallel { +// CHECK: omp.distribute { +// CHECK: omp.wsloop { +// CHECK: omp.loop_nest (%[[VAL_1:.*]]) : index = (%[[ARG0:.*]]) to (%[[ARG1:.*]]) inclusive step (%[[ARG2:.*]]) { +// CHECK: %[[VAL_0:.*]] = arith.constant 0 : index +// CHECK: fir.store %[[VAL_0]] to %[[ARG4:.*]] : !fir.ref +// CHECK: omp.yield +// CHECK: } +// CHECK: } {omp.composite} +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } +// CHECK: return +// CHECK: } +func.func @x(%lb : index, %ub : index, %step : index, %b : i1, %addr : !fir.ref) { + omp.teams { + omp.workdistribute { + fir.do_loop %iv = %lb to %ub step %step unordered { + %zero = arith.constant 0 : index + fir.store %zero to %addr : !fir.ref + } + omp.terminator + } + omp.terminator + } + return +} diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission-host.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-host.mlir new file mode 100644 index 0000000000000..04e60ca8bbf37 --- /dev/null +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-host.mlir @@ -0,0 +1,117 @@ +// RUN: fir-opt --lower-workdistribute %s | FileCheck %s +// Test lowering of workdistribute after fission on host device. + +// CHECK-LABEL: func.func @x( +// CHECK: %[[VAL_0:.*]] = fir.alloca index {bindc_name = "lb"} +// CHECK: fir.store %[[ARG0:.*]] to %[[VAL_0]] : !fir.ref +// CHECK: %[[VAL_1:.*]] = fir.alloca index {bindc_name = "ub"} +// CHECK: fir.store %[[ARG1:.*]] to %[[VAL_1]] : !fir.ref +// CHECK: %[[VAL_2:.*]] = fir.alloca index {bindc_name = "step"} +// CHECK: fir.store %[[ARG2:.*]] to %[[VAL_2]] : !fir.ref +// CHECK: %[[VAL_3:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "lb"} +// CHECK: %[[VAL_4:.*]] = omp.map.info var_ptr(%[[VAL_1]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "ub"} +// CHECK: %[[VAL_5:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "step"} +// CHECK: %[[VAL_6:.*]] = omp.map.info var_ptr(%[[ARG3:.*]] : !fir.ref, index) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "addr"} +// CHECK: %[[VAL_7:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "lb"} +// CHECK: %[[VAL_8:.*]] = omp.map.info var_ptr(%[[VAL_1]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "ub"} +// CHECK: %[[VAL_9:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "step"} +// CHECK: %[[VAL_10:.*]] = omp.map.info var_ptr(%[[ARG3]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "addr"} +// CHECK: omp.target_data map_entries(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]], %[[VAL_6]] : !fir.ref, !fir.ref, !fir.ref, !fir.ref) { +// CHECK: %[[VAL_11:.*]] = fir.alloca index +// CHECK: %[[VAL_12:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref, index) map_clauses(from) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_from"} +// CHECK: %[[VAL_13:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_to"} +// CHECK: %[[VAL_14:.*]] = fir.alloca index +// CHECK: %[[VAL_15:.*]] = omp.map.info var_ptr(%[[VAL_14]] : !fir.ref, index) map_clauses(from) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_from"} +// CHECK: %[[VAL_16:.*]] = omp.map.info var_ptr(%[[VAL_14]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_to"} +// CHECK: %[[VAL_17:.*]] = fir.alloca index +// CHECK: %[[VAL_18:.*]] = omp.map.info var_ptr(%[[VAL_17]] : !fir.ref, index) map_clauses(from) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_from"} +// CHECK: %[[VAL_19:.*]] = omp.map.info var_ptr(%[[VAL_17]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_to"} +// CHECK: %[[VAL_20:.*]] = fir.alloca !fir.heap +// CHECK: %[[VAL_21:.*]] = omp.map.info var_ptr(%[[VAL_20]] : !fir.ref>, !fir.heap) map_clauses(from) capture(ByRef) -> !fir.ref> {name = "__flang_workdistribute_from"} +// CHECK: %[[VAL_22:.*]] = omp.map.info var_ptr(%[[VAL_20]] : !fir.ref>, !fir.heap) map_clauses(to) capture(ByRef) -> !fir.ref> {name = "__flang_workdistribute_to"} +// CHECK: %[[VAL_23:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[VAL_24:.*]] = fir.load %[[VAL_0]] : !fir.ref +// CHECK: %[[VAL_25:.*]] = fir.load %[[VAL_1]] : !fir.ref +// CHECK: %[[VAL_26:.*]] = fir.load %[[VAL_2]] : !fir.ref +// CHECK: %[[VAL_27:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_25]], %[[VAL_25]] : index +// CHECK: %[[VAL_29:.*]] = omp.target_allocmem %[[VAL_23]] : i32, index, %[[VAL_27]] {uniq_name = "dev_buf"} +// CHECK: %[[VAL_30:.*]] = fir.convert %[[VAL_29]] : (i64) -> !fir.heap +// CHECK: fir.store %[[VAL_24]] to %[[VAL_11]] : !fir.ref +// CHECK: fir.store %[[VAL_25]] to %[[VAL_14]] : !fir.ref +// CHECK: fir.store %[[VAL_26]] to %[[VAL_17]] : !fir.ref +// CHECK: fir.store %[[VAL_30]] to %[[VAL_20]] : !fir.ref> +// CHECK: omp.target host_eval(%[[VAL_24]] -> %[[VAL_31:.*]], %[[VAL_25]] -> %[[VAL_32:.*]], %[[VAL_26]] -> %[[VAL_33:.*]] : index, index, index) map_entries(%[[VAL_7]] -> %[[VAL_34:.*]], %[[VAL_8]] -> %[[VAL_35:.*]], %[[VAL_9]] -> %[[VAL_36:.*]], %[[VAL_10]] -> %[[VAL_37:.*]], %[[VAL_13]] -> %[[VAL_38:.*]], %[[VAL_16]] -> %[[VAL_39:.*]], %[[VAL_19]] -> %[[VAL_40:.*]], %[[VAL_22]] -> %[[VAL_41:.*]] : !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref>) { +// CHECK: %[[VAL_42:.*]] = fir.load %[[VAL_38]] : !fir.ref +// CHECK: %[[VAL_43:.*]] = fir.load %[[VAL_39]] : !fir.ref +// CHECK: %[[VAL_44:.*]] = fir.load %[[VAL_40]] : !fir.ref +// CHECK: %[[VAL_45:.*]] = fir.load %[[VAL_41]] : !fir.ref> +// CHECK: %[[VAL_46:.*]] = arith.addi %[[VAL_43]], %[[VAL_43]] : index +// CHECK: omp.teams { +// CHECK: omp.parallel { +// CHECK: omp.distribute { +// CHECK: omp.wsloop { +// CHECK: omp.loop_nest (%[[VAL_47:.*]]) : index = (%[[VAL_31]]) to (%[[VAL_32]]) inclusive step (%[[VAL_33]]) { +// CHECK: fir.store %[[VAL_46]] to %[[VAL_45]] : !fir.heap +// CHECK: omp.yield +// CHECK: } +// CHECK: } {omp.composite} +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } +// CHECK: omp.terminator +// CHECK: } +// CHECK: %[[VAL_48:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[VAL_49:.*]] = fir.load %[[VAL_11]] : !fir.ref +// CHECK: %[[VAL_50:.*]] = fir.load %[[VAL_14]] : !fir.ref +// CHECK: %[[VAL_51:.*]] = fir.load %[[VAL_17]] : !fir.ref +// CHECK: %[[VAL_52:.*]] = fir.load %[[VAL_20]] : !fir.ref> +// CHECK: %[[VAL_53:.*]] = arith.addi %[[VAL_50]], %[[VAL_50]] : index +// CHECK: fir.store %[[VAL_49]] to %[[VAL_52]] : !fir.heap +// CHECK: %[[VAL_54:.*]] = fir.convert %[[VAL_52]] : (!fir.heap) -> i64 +// CHECK: omp.target_freemem %[[VAL_48]], %[[VAL_54]] : i32, i64 +// CHECK: omp.terminator +// CHECK: } +// CHECK: return +// CHECK: } + +module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu", omp.is_gpu = false, omp.is_target_device = false} { +func.func @x(%lb : index, %ub : index, %step : index, %addr : !fir.ref) { + %lb_ref = fir.alloca index {bindc_name = "lb"} + fir.store %lb to %lb_ref : !fir.ref + %ub_ref = fir.alloca index {bindc_name = "ub"} + fir.store %ub to %ub_ref : !fir.ref + %step_ref = fir.alloca index {bindc_name = "step"} + fir.store %step to %step_ref : !fir.ref + + %lb_map = omp.map.info var_ptr(%lb_ref : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "lb"} + %ub_map = omp.map.info var_ptr(%ub_ref : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "ub"} + %step_map = omp.map.info var_ptr(%step_ref : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "step"} + %addr_map = omp.map.info var_ptr(%addr : !fir.ref, index) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "addr"} + + omp.target map_entries(%lb_map -> %ARG0, %ub_map -> %ARG1, %step_map -> %ARG2, %addr_map -> %ARG3 : !fir.ref, !fir.ref, !fir.ref, !fir.ref) { + %lb_val = fir.load %ARG0 : !fir.ref + %ub_val = fir.load %ARG1 : !fir.ref + %step_val = fir.load %ARG2 : !fir.ref + %one = arith.constant 1 : index + + %20 = arith.addi %ub_val, %ub_val : index + omp.teams { + omp.workdistribute { + %dev_mem = fir.allocmem index, %one {uniq_name = "dev_buf"} + fir.do_loop %iv = %lb_val to %ub_val step %step_val unordered { + fir.store %20 to %dev_mem : !fir.heap + } + fir.store %lb_val to %dev_mem : !fir.heap + fir.freemem %dev_mem : !fir.heap + omp.terminator + } + omp.terminator + } + omp.terminator + } + return +} +} diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir new file mode 100644 index 0000000000000..062eb701b52ef --- /dev/null +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir @@ -0,0 +1,118 @@ +// RUN: fir-opt --lower-workdistribute %s | FileCheck %s +// Test lowering of workdistribute after fission on host device. + +// CHECK-LABEL: func.func @x( +// CHECK: %[[VAL_0:.*]] = fir.alloca index {bindc_name = "lb"} +// CHECK: fir.store %[[ARG0:.*]] to %[[VAL_0]] : !fir.ref +// CHECK: %[[VAL_1:.*]] = fir.alloca index {bindc_name = "ub"} +// CHECK: fir.store %[[ARG1:.*]] to %[[VAL_1]] : !fir.ref +// CHECK: %[[VAL_2:.*]] = fir.alloca index {bindc_name = "step"} +// CHECK: fir.store %[[ARG2:.*]] to %[[VAL_2]] : !fir.ref +// CHECK: %[[VAL_3:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "lb"} +// CHECK: %[[VAL_4:.*]] = omp.map.info var_ptr(%[[VAL_1]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "ub"} +// CHECK: %[[VAL_5:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "step"} +// CHECK: %[[VAL_6:.*]] = omp.map.info var_ptr(%[[ARG3:.*]] : !fir.ref, index) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "addr"} +// CHECK: %[[VAL_7:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "lb"} +// CHECK: %[[VAL_8:.*]] = omp.map.info var_ptr(%[[VAL_1]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "ub"} +// CHECK: %[[VAL_9:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "step"} +// CHECK: %[[VAL_10:.*]] = omp.map.info var_ptr(%[[ARG3]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "addr"} +// CHECK: omp.target_data map_entries(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]], %[[VAL_6]] : !fir.ref, !fir.ref, !fir.ref, !fir.ref) { +// CHECK: %[[VAL_11:.*]] = fir.alloca index +// CHECK: %[[VAL_12:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref, index) map_clauses(from) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_from"} +// CHECK: %[[VAL_13:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_to"} +// CHECK: %[[VAL_14:.*]] = fir.alloca index +// CHECK: %[[VAL_15:.*]] = omp.map.info var_ptr(%[[VAL_14]] : !fir.ref, index) map_clauses(from) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_from"} +// CHECK: %[[VAL_16:.*]] = omp.map.info var_ptr(%[[VAL_14]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_to"} +// CHECK: %[[VAL_17:.*]] = fir.alloca index +// CHECK: %[[VAL_18:.*]] = omp.map.info var_ptr(%[[VAL_17]] : !fir.ref, index) map_clauses(from) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_from"} +// CHECK: %[[VAL_19:.*]] = omp.map.info var_ptr(%[[VAL_17]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_to"} +// CHECK: %[[VAL_20:.*]] = fir.alloca !fir.heap +// CHECK: %[[VAL_21:.*]] = omp.map.info var_ptr(%[[VAL_20]] : !fir.ref>, !fir.heap) map_clauses(from) capture(ByRef) -> !fir.ref> {name = "__flang_workdistribute_from"} +// CHECK: %[[VAL_22:.*]] = omp.map.info var_ptr(%[[VAL_20]] : !fir.ref>, !fir.heap) map_clauses(to) capture(ByRef) -> !fir.ref> {name = "__flang_workdistribute_to"} +// CHECK: %[[VAL_23:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[VAL_24:.*]] = fir.load %[[VAL_0]] : !fir.ref +// CHECK: %[[VAL_25:.*]] = fir.load %[[VAL_1]] : !fir.ref +// CHECK: %[[VAL_26:.*]] = fir.load %[[VAL_2]] : !fir.ref +// CHECK: %[[VAL_27:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_25]], %[[VAL_25]] : index +// CHECK: %[[VAL_29:.*]] = omp.target_allocmem %[[VAL_23]] : i32, index, %[[VAL_27]] {uniq_name = "dev_buf"} +// CHECK: %[[VAL_30:.*]] = fir.convert %[[VAL_29]] : (i64) -> !fir.heap +// CHECK: fir.store %[[VAL_24]] to %[[VAL_11]] : !fir.ref +// CHECK: fir.store %[[VAL_25]] to %[[VAL_14]] : !fir.ref +// CHECK: fir.store %[[VAL_26]] to %[[VAL_17]] : !fir.ref +// CHECK: fir.store %[[VAL_30]] to %[[VAL_20]] : !fir.ref> +// CHECK: omp.target map_entries(%[[VAL_7]] -> %[[VAL_31:.*]], %[[VAL_8]] -> %[[VAL_32:.*]], %[[VAL_9]] -> %[[VAL_33:.*]], %[[VAL_10]] -> %[[VAL_34:.*]], %[[VAL_13]] -> %[[VAL_35:.*]], %[[VAL_16]] -> %[[VAL_36:.*]], %[[VAL_19]] -> %[[VAL_37:.*]], %[[VAL_22]] -> %[[VAL_38:.*]] : !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref>) { +// CHECK: %[[VAL_39:.*]] = fir.load %[[VAL_35]] : !fir.ref +// CHECK: %[[VAL_40:.*]] = fir.load %[[VAL_36]] : !fir.ref +// CHECK: %[[VAL_41:.*]] = fir.load %[[VAL_37]] : !fir.ref +// CHECK: %[[VAL_42:.*]] = fir.load %[[VAL_38]] : !fir.ref> +// CHECK: %[[VAL_43:.*]] = arith.addi %[[VAL_40]], %[[VAL_40]] : index +// CHECK: omp.teams { +// CHECK: omp.parallel { +// CHECK: omp.distribute { +// CHECK: omp.wsloop { +// CHECK: omp.loop_nest (%[[VAL_44:.*]]) : index = (%[[VAL_39]]) to (%[[VAL_40]]) inclusive step (%[[VAL_41]]) { +// CHECK: fir.store %[[VAL_43]] to %[[VAL_42]] : !fir.heap +// CHECK: omp.yield +// CHECK: } +// CHECK: } {omp.composite} +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } +// CHECK: omp.terminator +// CHECK: } +// CHECK: %[[VAL_45:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[VAL_46:.*]] = fir.load %[[VAL_11]] : !fir.ref +// CHECK: %[[VAL_47:.*]] = fir.load %[[VAL_14]] : !fir.ref +// CHECK: %[[VAL_48:.*]] = fir.load %[[VAL_17]] : !fir.ref +// CHECK: %[[VAL_49:.*]] = fir.load %[[VAL_20]] : !fir.ref> +// CHECK: %[[VAL_50:.*]] = arith.addi %[[VAL_47]], %[[VAL_47]] : index +// CHECK: fir.store %[[VAL_46]] to %[[VAL_49]] : !fir.heap +// CHECK: %[[VAL_51:.*]] = fir.convert %[[VAL_49]] : (!fir.heap) -> i64 +// CHECK: omp.target_freemem %[[VAL_45]], %[[VAL_51]] : i32, i64 +// CHECK: omp.terminator +// CHECK: } +// CHECK: return +// CHECK: } + + +module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} { +func.func @x(%lb : index, %ub : index, %step : index, %addr : !fir.ref) { + %lb_ref = fir.alloca index {bindc_name = "lb"} + fir.store %lb to %lb_ref : !fir.ref + %ub_ref = fir.alloca index {bindc_name = "ub"} + fir.store %ub to %ub_ref : !fir.ref + %step_ref = fir.alloca index {bindc_name = "step"} + fir.store %step to %step_ref : !fir.ref + + %lb_map = omp.map.info var_ptr(%lb_ref : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "lb"} + %ub_map = omp.map.info var_ptr(%ub_ref : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "ub"} + %step_map = omp.map.info var_ptr(%step_ref : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "step"} + %addr_map = omp.map.info var_ptr(%addr : !fir.ref, index) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "addr"} + + omp.target map_entries(%lb_map -> %ARG0, %ub_map -> %ARG1, %step_map -> %ARG2, %addr_map -> %ARG3 : !fir.ref, !fir.ref, !fir.ref, !fir.ref) { + %lb_val = fir.load %ARG0 : !fir.ref + %ub_val = fir.load %ARG1 : !fir.ref + %step_val = fir.load %ARG2 : !fir.ref + %one = arith.constant 1 : index + + %20 = arith.addi %ub_val, %ub_val : index + omp.teams { + omp.workdistribute { + %dev_mem = fir.allocmem index, %one {uniq_name = "dev_buf"} + fir.do_loop %iv = %lb_val to %ub_val step %step_val unordered { + fir.store %20 to %dev_mem : !fir.heap + } + fir.store %lb_val to %dev_mem : !fir.heap + fir.freemem %dev_mem : !fir.heap + omp.terminator + } + omp.terminator + } + omp.terminator + } + return +} +} diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir new file mode 100644 index 0000000000000..c562b7009664d --- /dev/null +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir @@ -0,0 +1,71 @@ +// RUN: fir-opt --lower-workdistribute %s | FileCheck %s + +// CHECK-LABEL: func.func @test_fission_workdistribute( +// CHECK: %[[VAL_0:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 9 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 5.000000e+00 : f32 +// CHECK: fir.store %[[VAL_3]] to %[[ARG2:.*]] : !fir.ref +// CHECK: omp.teams { +// CHECK: omp.parallel { +// CHECK: omp.distribute { +// CHECK: omp.wsloop { +// CHECK: omp.loop_nest (%[[VAL_4:.*]]) : index = (%[[VAL_0]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) { +// CHECK: %[[VAL_5:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_4]] : (!fir.ref>, index) -> !fir.ref +// CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref +// CHECK: %[[VAL_7:.*]] = fir.coordinate_of %[[ARG1:.*]], %[[VAL_4]] : (!fir.ref>, index) -> !fir.ref +// CHECK: fir.store %[[VAL_6]] to %[[VAL_7]] : !fir.ref +// CHECK: omp.yield +// CHECK: } +// CHECK: } {omp.composite} +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } +// CHECK: fir.call @regular_side_effect_func(%[[ARG2:.*]]) : (!fir.ref) -> () +// CHECK: fir.call @my_fir_parallel_runtime_func(%[[ARG3:.*]]) : (!fir.ref) -> () +// CHECK: fir.do_loop %[[VAL_8:.*]] = %[[VAL_0]] to %[[VAL_2]] step %[[VAL_1]] { +// CHECK: %[[VAL_9:.*]] = fir.coordinate_of %[[ARG0]], %[[VAL_8]] : (!fir.ref>, index) -> !fir.ref +// CHECK: fir.store %[[VAL_3]] to %[[VAL_9]] : !fir.ref +// CHECK: } +// CHECK: %[[VAL_10:.*]] = fir.load %[[ARG2:.*]] : !fir.ref +// CHECK: fir.store %[[VAL_10]] to %[[ARG3:.*]] : !fir.ref +// CHECK: return +// CHECK: } +module { +func.func @regular_side_effect_func(%arg0: !fir.ref) { + return +} +func.func @my_fir_parallel_runtime_func(%arg0: !fir.ref) attributes {fir.runtime} { + return +} +func.func @test_fission_workdistribute(%arr1: !fir.ref>, %arr2: !fir.ref>, %scalar_ref1: !fir.ref, %scalar_ref2: !fir.ref) { + %c0_idx = arith.constant 0 : index + %c1_idx = arith.constant 1 : index + %c9_idx = arith.constant 9 : index + %float_val = arith.constant 5.0 : f32 + omp.teams { + omp.workdistribute { + fir.store %float_val to %scalar_ref1 : !fir.ref + fir.do_loop %iv = %c0_idx to %c9_idx step %c1_idx unordered { + %elem_ptr_arr1 = fir.coordinate_of %arr1, %iv : (!fir.ref>, index) -> !fir.ref + %loaded_val_loop1 = fir.load %elem_ptr_arr1 : !fir.ref + %elem_ptr_arr2 = fir.coordinate_of %arr2, %iv : (!fir.ref>, index) -> !fir.ref + fir.store %loaded_val_loop1 to %elem_ptr_arr2 : !fir.ref + } + fir.call @regular_side_effect_func(%scalar_ref1) : (!fir.ref) -> () + fir.call @my_fir_parallel_runtime_func(%scalar_ref2) : (!fir.ref) -> () + fir.do_loop %jv = %c0_idx to %c9_idx step %c1_idx { + %elem_ptr_ordered_loop = fir.coordinate_of %arr1, %jv : (!fir.ref>, index) -> !fir.ref + fir.store %float_val to %elem_ptr_ordered_loop : !fir.ref + } + %loaded_for_hoist = fir.load %scalar_ref1 : !fir.ref + fir.store %loaded_for_hoist to %scalar_ref2 : !fir.ref + omp.terminator + } + omp.terminator + } + return +} +} diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-runtime-assign-scalar.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-runtime-assign-scalar.mlir new file mode 100644 index 0000000000000..03d5d71df0a82 --- /dev/null +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-runtime-assign-scalar.mlir @@ -0,0 +1,108 @@ +// RUN: fir-opt --lower-workdistribute %s | FileCheck %s + +// Test lowering of workdistribute for a scalar assignment within a target teams workdistribute region. +// The test checks that the scalar assignment is correctly lowered to wsloop and loop_nest operations. + +// Example Fortran code: +// !$omp target teams workdistribute +// y = 3.0_real32 +// !$omp end target teams workdistribute + + +// CHECK-LABEL: func.func @x( +// CHECK: omp.target {{.*}} { +// CHECK: omp.teams { +// CHECK: omp.parallel { +// CHECK: omp.distribute { +// CHECK: omp.wsloop { +// CHECK: omp.loop_nest (%[[VAL_73:.*]]) : index = (%[[VAL_66:.*]]) to (%[[VAL_72:.*]]) inclusive step (%[[VAL_67:.*]]) { +// CHECK: %[[VAL_74:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_75:.*]]:3 = fir.box_dims %[[VAL_64:.*]], %[[VAL_74]] : (!fir.box>, index) -> (index, index, index) +// CHECK: %[[VAL_76:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_77:.*]]:3 = fir.box_dims %[[VAL_64]], %[[VAL_76]] : (!fir.box>, index) -> (index, index, index) +// CHECK: %[[VAL_78:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_79:.*]] = arith.remsi %[[VAL_73]], %[[VAL_77]]#1 : index +// CHECK: %[[VAL_80:.*]] = arith.addi %[[VAL_79]], %[[VAL_78]] : index +// CHECK: %[[VAL_81:.*]] = arith.divsi %[[VAL_73]], %[[VAL_77]]#1 : index +// CHECK: %[[VAL_82:.*]] = arith.remsi %[[VAL_81]], %[[VAL_75]]#1 : index +// CHECK: %[[VAL_83:.*]] = arith.addi %[[VAL_82]], %[[VAL_78]] : index +// CHECK: %[[VAL_84:.*]] = fir.array_coor %[[VAL_64]] %[[VAL_83]], %[[VAL_80]] : (!fir.box>, index, index) -> !fir.ref +// CHECK: fir.store %[[VAL_65:.*]] to %[[VAL_84]] : !fir.ref +// CHECK: omp.yield +// CHECK: } +// CHECK: } {omp.composite} +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } +// CHECK: omp.terminator +// CHECK: } +// CHECK: omp.terminator +// CHECK: } +// CHECK: return +// CHECK: } +// CHECK: func.func private @_FortranAAssign(!fir.ref>, !fir.box, !fir.ref, i32) attributes {fir.runtime} + +module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} { +func.func @x(%arr : !fir.ref>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c78 = arith.constant 78 : index + %cst = arith.constant 3.000000e+00 : f32 + %0 = fir.alloca i32 + %1 = fir.alloca i32 + %c10 = arith.constant 10 : index + %c20 = arith.constant 20 : index + %194 = arith.subi %c10, %c1 : index + %195 = omp.map.bounds lower_bound(%c0 : index) upper_bound(%194 : index) extent(%c10 : index) stride(%c1 : index) start_idx(%c1 : index) + %196 = arith.subi %c20, %c1 : index + %197 = omp.map.bounds lower_bound(%c0 : index) upper_bound(%196 : index) extent(%c20 : index) stride(%c1 : index) start_idx(%c1 : index) + %198 = omp.map.info var_ptr(%arr : !fir.ref>, f32) map_clauses(implicit, tofrom) capture(ByRef) bounds(%195, %197) -> !fir.ref> {name = "y"} + %199 = omp.map.info var_ptr(%1 : !fir.ref, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !fir.ref {name = ""} + %200 = omp.map.info var_ptr(%0 : !fir.ref, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !fir.ref {name = ""} + omp.target map_entries(%198 -> %arg5, %199 -> %arg6, %200 -> %arg7 : !fir.ref>, !fir.ref, !fir.ref) { + %c0_0 = arith.constant 0 : index + %201 = fir.load %arg7 : !fir.ref + %202 = fir.load %arg6 : !fir.ref + %203 = fir.convert %202 : (i32) -> i64 + %204 = fir.convert %201 : (i32) -> i64 + %205 = fir.convert %204 : (i64) -> index + %206 = arith.cmpi sgt, %205, %c0_0 : index + %207 = fir.convert %203 : (i64) -> index + %208 = arith.cmpi sgt, %207, %c0_0 : index + %209 = arith.select %208, %207, %c0_0 : index + %210 = arith.select %206, %205, %c0_0 : index + %211 = fir.shape %210, %209 : (index, index) -> !fir.shape<2> + %212 = fir.declare %arg5(%211) {uniq_name = "_QFFaxpy_array_workdistributeEy"} : (!fir.ref>, !fir.shape<2>) -> !fir.ref> + %213 = fir.embox %212(%211) : (!fir.ref>, !fir.shape<2>) -> !fir.box> + omp.teams { + %214 = fir.alloca !fir.box> {pinned} + omp.workdistribute { + %215 = fir.alloca f32 + %216 = fir.embox %215 : (!fir.ref) -> !fir.box + %217 = fir.shape %210, %209 : (index, index) -> !fir.shape<2> + %218 = fir.embox %212(%217) : (!fir.ref>, !fir.shape<2>) -> !fir.box> + fir.store %218 to %214 : !fir.ref>> + %219 = fir.address_of(@_QQclXf9c642d28e5bba1f07fa9a090b72f4fc) : !fir.ref> + %c39_i32 = arith.constant 39 : i32 + %220 = fir.convert %214 : (!fir.ref>>) -> !fir.ref> + %221 = fir.convert %216 : (!fir.box) -> !fir.box + %222 = fir.convert %219 : (!fir.ref>) -> !fir.ref + fir.call @_FortranAAssign(%220, %221, %222, %c39_i32) : (!fir.ref>, !fir.box, !fir.ref, i32) -> () + omp.terminator + } + omp.terminator + } + omp.terminator + } + return +} + +func.func private @_FortranAAssign(!fir.ref>, !fir.box, !fir.ref, i32) attributes {fir.runtime} + +fir.global linkonce @_QQclXf9c642d28e5bba1f07fa9a090b72f4fc constant : !fir.char<1,78> { + %0 = fir.string_lit "File: /work/github/skc7/llvm-project/build_fomp_reldebinfo/saxpy_tests/\00"(78) : !fir.char<1,78> + fir.has_value %0 : !fir.char<1,78> +} +}