diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index a7a0af231af33..f28eb51b6e942 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -2144,6 +2144,12 @@ void LoopNestOp::gatherWrappers( wrappers.push_back(wrapper); parent = parent->getParentOp(); } + + // omp.parallel can be misidentified as a loop wrapper when it's not taking + // that role but it contains no other operations in its region (e.g. parallel + // do/for). + if (llvm::isa(wrappers.back())) + wrappers.pop_back(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 0c9c699a1f390..432b5a757a989 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -770,6 +770,83 @@ convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, return bodyGenStatus; } +namespace { +using WrappedLoopBodyGenCallbackTy = function_ref; +} // namespace + +static LogicalResult +convertOmpLoopNest(omp::LoopNestOp loopOp, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, + llvm::StringRef blockName, + llvm::CanonicalLoopInfo *&loopInfo, + llvm::IRBuilderBase::InsertPoint &afterIP) { + llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); + + // TODO: support error propagation in OpenMPIRBuilder and use it instead of + // relying on captured variables. + SmallVector loopInfos; + SmallVector bodyInsertPoints; + LogicalResult bodyGenStatus = success(); + auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) { + // Make sure further conversions know about the induction variable. + moduleTranslation.mapValue( + loopOp.getRegion().front().getArgument(loopInfos.size()), iv); + + // Capture the body insertion point for use in nested loops. BodyIP of the + // CanonicalLoopInfo always points to the beginning of the entry block of + // the body. + bodyInsertPoints.push_back(ip); + + if (loopInfos.size() != loopOp.getNumLoops() - 1) + return; + + // Convert the body of the loop. + builder.restoreIP(ip); + convertOmpOpRegions(loopOp.getRegion(), blockName, builder, + moduleTranslation, bodyGenStatus); + }; + + // Delegate actual loop construction to the OpenMP IRBuilder. + // TODO: this currently assumes omp.loop_nest is semantically similar to SCF + // loop, i.e. it has a positive step, uses signed integer semantics. + // Reconsider this code when the nested loop operation clearly supports more + // cases. + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) { + llvm::Value *lowerBound = + moduleTranslation.lookupValue(loopOp.getLowerBound()[i]); + llvm::Value *upperBound = + moduleTranslation.lookupValue(loopOp.getUpperBound()[i]); + llvm::Value *step = moduleTranslation.lookupValue(loopOp.getStep()[i]); + + // Make sure loop trip count are emitted in the preheader of the outermost + // loop at the latest so that they are all available for the new collapsed + // loop will be created below. + llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc; + llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP; + if (i != 0) { + loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(), + ompLoc.DL); + computeIP = loopInfos.front()->getPreheaderIP(); + } + loopInfos.push_back(ompBuilder->createCanonicalLoop( + loc, bodyGen, lowerBound, upperBound, step, + /*IsSigned=*/true, loopOp.getInclusive(), computeIP)); + + if (failed(bodyGenStatus)) + return failure(); + } + + // Collapse loops. Store the insertion point because LoopInfos may get + // invalidated. + afterIP = loopInfos.front()->getAfterIP(); + loopInfo = ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {}); + + return success(); +} + /// Allocate space for privatized reduction variables. template static void allocByValReductionVars( @@ -896,14 +973,9 @@ static ArrayRef getIsByRef(std::optional> attr) { /// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder. static LogicalResult -convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation) { - auto wsloopOp = cast(opInst); - // FIXME: Here any other nested wrappers (e.g. omp.simd) are skipped, so - // codegen for composite constructs like 'DO/FOR SIMD' will be the same as for - // 'DO/FOR'. - auto loopOp = cast(wsloopOp.getWrappedLoop()); - +convertOmpWsloop(omp::WsloopOp wsloopOp, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, + WrappedLoopBodyGenCallbackTy loopCB) { llvm::ArrayRef isByRef = getIsByRef(wsloopOp.getReductionVarsByref()); assert(isByRef.size() == wsloopOp.getNumReductionVars()); @@ -912,7 +984,9 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, wsloopOp.getScheduleVal().value_or(omp::ClauseScheduleKind::Static); // Find the loop configuration. - llvm::Value *step = moduleTranslation.lookupValue(loopOp.getStep()[0]); + OperandRange loopStep = + cast(wsloopOp.getWrappedLoop()).getStep(); + llvm::Value *step = moduleTranslation.lookupValue(loopStep.front()); llvm::Type *ivType = step->getType(); llvm::Value *chunk = nullptr; if (wsloopOp.getScheduleChunkVar()) { @@ -986,65 +1060,10 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); // Generator of the canonical loop body. - // TODO: support error propagation in OpenMPIRBuilder and use it instead of - // relying on captured variables. - SmallVector loopInfos; - SmallVector bodyInsertPoints; - LogicalResult bodyGenStatus = success(); - auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) { - // Make sure further conversions know about the induction variable. - moduleTranslation.mapValue( - loopOp.getRegion().front().getArgument(loopInfos.size()), iv); - - // Capture the body insertion point for use in nested loops. BodyIP of the - // CanonicalLoopInfo always points to the beginning of the entry block of - // the body. - bodyInsertPoints.push_back(ip); - - if (loopInfos.size() != loopOp.getNumLoops() - 1) - return; - - // Convert the body of the loop. - builder.restoreIP(ip); - convertOmpOpRegions(loopOp.getRegion(), "omp.wsloop.region", builder, - moduleTranslation, bodyGenStatus); - }; - - // Delegate actual loop construction to the OpenMP IRBuilder. - // TODO: this currently assumes omp.loop_nest is semantically similar to SCF - // loop, i.e. it has a positive step, uses signed integer semantics. - // Reconsider this code when the nested loop operation clearly supports more - // cases. - llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); - for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) { - llvm::Value *lowerBound = - moduleTranslation.lookupValue(loopOp.getLowerBound()[i]); - llvm::Value *upperBound = - moduleTranslation.lookupValue(loopOp.getUpperBound()[i]); - llvm::Value *step = moduleTranslation.lookupValue(loopOp.getStep()[i]); - - // Make sure loop trip count are emitted in the preheader of the outermost - // loop at the latest so that they are all available for the new collapsed - // loop will be created below. - llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc; - llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP; - if (i != 0) { - loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back()); - computeIP = loopInfos.front()->getPreheaderIP(); - } - loopInfos.push_back(ompBuilder->createCanonicalLoop( - loc, bodyGen, lowerBound, upperBound, step, - /*IsSigned=*/true, loopOp.getInclusive(), computeIP)); - - if (failed(bodyGenStatus)) - return failure(); - } - - // Collapse loops. Store the insertion point because LoopInfos may get - // invalidated. - llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP(); - llvm::CanonicalLoopInfo *loopInfo = - ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {}); + llvm::CanonicalLoopInfo *loopInfo; + llvm::IRBuilderBase::InsertPoint afterIP; + if (failed(loopCB("omp.wsloop.region", loopInfo, afterIP))) + return failure(); allocaIP = findAllocaInsertPoint(builder, moduleTranslation); @@ -1054,6 +1073,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, wsloopOp.getScheduleModifier(); bool isSimd = wsloopOp.getSimdModifier(); + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); ompBuilder->applyWorkshareLoop( ompLoc.DL, loopInfo, allocaIP, !wsloopOp.getNowait(), convertToScheduleKind(schedule), chunk, isSimd, @@ -1475,74 +1495,15 @@ convertOrderKind(std::optional o) { } /// Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder. -static LogicalResult -convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation) { - auto simdOp = cast(opInst); - auto loopOp = cast(simdOp.getWrappedLoop()); - - llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); - +static LogicalResult convertOmpSimd(omp::SimdOp simdOp, + llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, + WrappedLoopBodyGenCallbackTy loopCB) { // Generator of the canonical loop body. - // TODO: support error propagation in OpenMPIRBuilder and use it instead of - // relying on captured variables. - SmallVector loopInfos; - SmallVector bodyInsertPoints; - LogicalResult bodyGenStatus = success(); - auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) { - // Make sure further conversions know about the induction variable. - moduleTranslation.mapValue( - loopOp.getRegion().front().getArgument(loopInfos.size()), iv); - - // Capture the body insertion point for use in nested loops. BodyIP of the - // CanonicalLoopInfo always points to the beginning of the entry block of - // the body. - bodyInsertPoints.push_back(ip); - - if (loopInfos.size() != loopOp.getNumLoops() - 1) - return; - - // Convert the body of the loop. - builder.restoreIP(ip); - convertOmpOpRegions(loopOp.getRegion(), "omp.simd.region", builder, - moduleTranslation, bodyGenStatus); - }; - - // Delegate actual loop construction to the OpenMP IRBuilder. - // TODO: this currently assumes omp.loop_nest is semantically similar to SCF - // loop, i.e. it has a positive step, uses signed integer semantics. - // Reconsider this code when the nested loop operation clearly supports more - // cases. - llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); - for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) { - llvm::Value *lowerBound = - moduleTranslation.lookupValue(loopOp.getLowerBound()[i]); - llvm::Value *upperBound = - moduleTranslation.lookupValue(loopOp.getUpperBound()[i]); - llvm::Value *step = moduleTranslation.lookupValue(loopOp.getStep()[i]); - - // Make sure loop trip count are emitted in the preheader of the outermost - // loop at the latest so that they are all available for the new collapsed - // loop will be created below. - llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc; - llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP; - if (i != 0) { - loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(), - ompLoc.DL); - computeIP = loopInfos.front()->getPreheaderIP(); - } - loopInfos.push_back(ompBuilder->createCanonicalLoop( - loc, bodyGen, lowerBound, upperBound, step, - /*IsSigned=*/true, /*Inclusive=*/true, computeIP)); - - if (failed(bodyGenStatus)) - return failure(); - } - - // Collapse loops. - llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP(); - llvm::CanonicalLoopInfo *loopInfo = - ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {}); + llvm::CanonicalLoopInfo *loopInfo; + llvm::IRBuilderBase::InsertPoint afterIP; + if (failed(loopCB("omp.simd.region", loopInfo, afterIP))) + return failure(); llvm::ConstantInt *simdlen = nullptr; if (std::optional simdlenVar = simdOp.getSimdlen()) @@ -1554,6 +1515,8 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, llvm::MapVector alignedVars; llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrderVal()); + + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); ompBuilder->applySimd(loopInfo, alignedVars, simdOp.getIfExpr() ? moduleTranslation.lookupValue(simdOp.getIfExpr()) @@ -3294,14 +3257,87 @@ static bool isTargetDeviceOp(Operation *op) { return false; } +static LogicalResult +convertWrappedLoopNest(omp::LoopNestOp loopOp, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + SmallVector wrappers; + loopOp.gatherWrappers(wrappers); + + auto loopCB = [&](llvm::StringRef blockName, + llvm::CanonicalLoopInfo *&loopInfo, + llvm::IRBuilderBase::InsertPoint &afterIP) { + return convertOmpLoopNest(loopOp, builder, moduleTranslation, blockName, + loopInfo, afterIP); + }; + + switch (wrappers.size()) { + case 1: + return llvm::TypeSwitch(wrappers.front()) + .Case([&](omp::WsloopOp op) { + return convertOmpWsloop(op, builder, moduleTranslation, loopCB); + }) + .Case([&](omp::SimdOp op) { + return convertOmpSimd(op, builder, moduleTranslation, loopCB); + }) + .Default([&](Operation *op) { + // TODO: Support omp.distribute, omp.taskloop. + return op->emitError("unsupported OpenMP operation: ") + << op->getName(); + }); + case 2: { + auto simdOp = llvm::dyn_cast(*wrappers[0]); + assert(simdOp && + (llvm::isa( + *wrappers[1])) && + "invalid loop wrappers"); + + // TODO: Take omp.simd information into account. + if (auto wsloopOp = llvm::dyn_cast(*wrappers[1])) + return convertOmpWsloop(wsloopOp, builder, moduleTranslation, loopCB); + + // TODO: Support (omp.distribute, omp.taskloop) + omp.simd. + return loopOp->emitError("unsupported composite OpenMP construct: ") + << wrappers[1]->getName() << " simd"; + } + case 3: { + // TODO: Support omp.distribute + omp.parallel + omp.wsloop. + auto wsloopOp = llvm::dyn_cast(*wrappers[0]); + auto parallelOp = llvm::dyn_cast(*wrappers[1]); + auto distributeOp = llvm::dyn_cast(*wrappers[2]); + assert(wsloopOp && parallelOp && distributeOp && "invalid loop wrappers"); + return loopOp->emitError( + "unsupported composite OpenMP construct: distribute parallel wsloop"); + } + case 4: { + // TODO: Support omp.distribute + omp.parallel + omp.wsloop + omp.simd. + auto simdOp = llvm::dyn_cast(*wrappers[0]); + auto wsloopOp = llvm::dyn_cast(*wrappers[1]); + auto parallelOp = llvm::dyn_cast(*wrappers[2]); + auto distributeOp = llvm::dyn_cast(*wrappers[3]); + assert(simdOp && wsloopOp && parallelOp && distributeOp && + "invalid loop wrappers"); + return loopOp->emitError( + "unsupported composite OpenMP construct: distribute " + "parallel wsloop simd"); + } + default: + llvm_unreachable("invalid loop wrappers"); + } +} + /// Given an OpenMP MLIR operation, create the corresponding LLVM IR /// (including OpenMP runtime calls). static LogicalResult convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { + auto wrapperOp = llvm::dyn_cast_if_present(op); + if (!llvm::isa(op) && wrapperOp && wrapperOp.isWrapper()) { + return convertWrappedLoopNest( + llvm::cast(wrapperOp.getWrappedLoop()), builder, + moduleTranslation); + } llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); - return llvm::TypeSwitch(op) .Case([&](omp::BarrierOp) { ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier); @@ -3342,12 +3378,6 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, .Case([&](omp::OrderedOp) { return convertOmpOrdered(*op, builder, moduleTranslation); }) - .Case([&](omp::WsloopOp) { - return convertOmpWsloop(*op, builder, moduleTranslation); - }) - .Case([&](omp::SimdOp) { - return convertOmpSimd(*op, builder, moduleTranslation); - }) .Case([&](omp::AtomicReadOp) { return convertOmpAtomicRead(*op, builder, moduleTranslation); }) @@ -3404,6 +3434,11 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, // and then discarded return success(); }) + .Case([](auto op) { + llvm_unreachable("unexpected loop-asociated construct"); + return failure(); + }) .Default([&](Operation *inst) { return inst->emitError("unsupported OpenMP operation: ") << inst->getName(); diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir index dfeaf4be33adb..fea0fe85d859d 100644 --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -699,7 +699,7 @@ llvm.func @simd_simple(%lb : i64, %ub : i64, %step : i64, %arg0: !llvm.ptr) { // CHECK-LABEL: @simd_simple_multiple llvm.func @simd_simple_multiple(%lb1 : i64, %ub1 : i64, %step1 : i64, %lb2 : i64, %ub2 : i64, %step2 : i64, %arg0: !llvm.ptr, %arg1: !llvm.ptr) { omp.simd { - omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) { + omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) inclusive step (%step1, %step2) { %3 = llvm.mlir.constant(2.000000e+00 : f32) : f32 // The form of the emitted IR is controlled by OpenMPIRBuilder and // tested there. Just check that the right metadata is added and collapsed