3030#include " llvm/Frontend/OpenMP/OMPConstants.h"
3131#include " llvm/Frontend/OpenMP/OMPDeviceConstants.h"
3232#include " llvm/Frontend/OpenMP/OMPIRBuilder.h"
33+ #include " llvm/IR/Constants.h"
3334#include " llvm/IR/DebugInfoMetadata.h"
3435#include " llvm/IR/DerivedTypes.h"
3536#include " llvm/IR/IRBuilder.h"
@@ -542,6 +543,20 @@ static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
542543 llvm_unreachable (" Unknown ClauseProcBindKind kind" );
543544}
544545
546+ // / Maps block arguments from \p blockArgIface (which are MLIR values) to the
547+ // / corresponding LLVM values of \p the interface's operands. This is useful
548+ // / when an OpenMP region with entry block arguments is converted to LLVM. In
549+ // / this case the block arguments are (part of) of the OpenMP region's entry
550+ // / arguments and the operands are (part of) of the operands to the OpenMP op
551+ // / containing the region.
552+ static void forwardArgs (LLVM::ModuleTranslation &moduleTranslation,
553+ omp::BlockArgOpenMPOpInterface blockArgIface) {
554+ llvm::SmallVector<std::pair<Value, BlockArgument>> blockArgsPairs;
555+ blockArgIface.getBlockArgsPairs (blockArgsPairs);
556+ for (auto [var, arg] : blockArgsPairs)
557+ moduleTranslation.mapValue (arg, moduleTranslation.lookupValue (var));
558+ }
559+
545560// / Helper function to map block arguments defined by ignored loop wrappers to
546561// / LLVM values and prevent any uses of those from triggering null pointer
547562// / dereferences.
@@ -554,17 +569,10 @@ convertIgnoredWrapper(omp::LoopWrapperInterface opInst,
554569 // Map block arguments directly to the LLVM value associated to the
555570 // corresponding operand. This is semantically equivalent to this wrapper not
556571 // being present.
557- auto forwardArgs =
558- [&moduleTranslation](omp::BlockArgOpenMPOpInterface blockArgIface) {
559- llvm::SmallVector<std::pair<Value, BlockArgument>> blockArgsPairs;
560- blockArgIface.getBlockArgsPairs (blockArgsPairs);
561- for (auto [var, arg] : blockArgsPairs)
562- moduleTranslation.mapValue (arg, moduleTranslation.lookupValue (var));
563- };
564-
565572 return llvm::TypeSwitch<Operation *, LogicalResult>(opInst)
566573 .Case ([&](omp::SimdOp op) {
567- forwardArgs (cast<omp::BlockArgOpenMPOpInterface>(*op));
574+ forwardArgs (moduleTranslation,
575+ cast<omp::BlockArgOpenMPOpInterface>(*op));
568576 op.emitWarning () << " simd information on composite construct discarded" ;
569577 return success ();
570578 })
@@ -5803,6 +5811,61 @@ convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder,
58035811 return WalkResult::interrupt ();
58045812 return WalkResult::skip ();
58055813 }
5814+
5815+ // Non-target ops might nest target-related ops, therefore, we
5816+ // translate them as non-OpenMP scopes. Translating them is needed by
5817+ // nested target-related ops since they might need LLVM values defined
5818+ // in their parent non-target ops.
5819+ if (isa<omp::OpenMPDialect>(oper->getDialect ()) &&
5820+ oper->getParentOfType <LLVM::LLVMFuncOp>() &&
5821+ !oper->getRegions ().empty ()) {
5822+ if (auto blockArgsIface =
5823+ dyn_cast<omp::BlockArgOpenMPOpInterface>(oper))
5824+ forwardArgs (moduleTranslation, blockArgsIface);
5825+ else {
5826+ // Here we map entry block arguments of
5827+ // non-BlockArgOpenMPOpInterface ops if they can be encountered
5828+ // inside of a function and they define any of these arguments.
5829+ if (isa<mlir::omp::AtomicUpdateOp>(oper))
5830+ for (auto [operand, arg] :
5831+ llvm::zip_equal (oper->getOperands (),
5832+ oper->getRegion (0 ).getArguments ())) {
5833+ moduleTranslation.mapValue (
5834+ arg, builder.CreateLoad (
5835+ moduleTranslation.convertType (arg.getType ()),
5836+ moduleTranslation.lookupValue (operand)));
5837+ }
5838+ }
5839+
5840+ if (auto loopNest = dyn_cast<omp::LoopNestOp>(oper)) {
5841+ assert (builder.GetInsertBlock () &&
5842+ " No insert block is set for the builder" );
5843+ for (auto iv : loopNest.getIVs ()) {
5844+ // Map iv to an undefined value just to keep the IR validity.
5845+ moduleTranslation.mapValue (
5846+ iv, llvm::PoisonValue::get (
5847+ moduleTranslation.convertType (iv.getType ())));
5848+ }
5849+ }
5850+
5851+ for (Region ®ion : oper->getRegions ()) {
5852+ // Regions are fake in the sense that they are not a truthful
5853+ // translation of the OpenMP construct being converted (e.g. no
5854+ // OpenMP runtime calls will be generated). We just need this to
5855+ // prepare the kernel invocation args.
5856+ SmallVector<llvm::PHINode *> phis;
5857+ auto result = convertOmpOpRegions (
5858+ region, oper->getName ().getStringRef ().str () + " .fake.region" ,
5859+ builder, moduleTranslation, &phis);
5860+ if (failed (handleError (result, *oper)))
5861+ return WalkResult::interrupt ();
5862+
5863+ builder.SetInsertPoint (result.get (), result.get ()->end ());
5864+ }
5865+
5866+ return WalkResult::skip ();
5867+ }
5868+
58065869 return WalkResult::advance ();
58075870 }).wasInterrupted ();
58085871 return failure (interrupted);
0 commit comments