Skip to content

Commit 5654efd

Browse files
authored
[MLIR][OpenMP] Support lowering of host_eval to LLVM IR (#179)
This patch updates the MLIR to LLVM IR lowering of `omp.target` to support passing `num_teams`, `num_threads`, `thread_limit` and SPMD loop bounds through the `host_eval` argument of `omp.target`. This replaces the previous implementation where this information was directly attached to the `omp.target` operation rather than captured to be used by the corresponding nested operation.
1 parent 2f3acbd commit 5654efd

File tree

6 files changed

+273
-102
lines changed

6 files changed

+273
-102
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -728,13 +728,12 @@ class OpenMPIRBuilder {
728728
LoopBodyGenCallbackTy BodyGenCB, Value *TripCount,
729729
const Twine &Name = "loop");
730730

731-
/// Generator for the control flow structure of an OpenMP canonical loop.
731+
/// Calculate the trip count of a canonical loop.
732732
///
733-
/// Instead of a logical iteration space, this allows specifying user-defined
734-
/// loop counter values using increment, upper- and lower bounds. To
735-
/// disambiguate the terminology when counting downwards, instead of lower
736-
/// bounds we use \p Start for the loop counter value in the first body
737-
/// iteration.
733+
/// This allows specifying user-defined loop counter values using increment,
734+
/// upper- and lower bounds. To disambiguate the terminology when counting
735+
/// downwards, instead of lower bounds we use \p Start for the loop counter
736+
/// value in the first body iteration.
738737
///
739738
/// Consider the following limitations:
740739
///
@@ -758,7 +757,32 @@ class OpenMPIRBuilder {
758757
///
759758
/// for (int i = 0; i < 42; i -= 1u)
760759
///
761-
//
760+
/// \param Loc The insert and source location description.
761+
/// \param Start Value of the loop counter for the first iterations.
762+
/// \param Stop Loop counter values past this will stop the loop.
763+
/// \param Step Loop counter increment after each iteration; negative
764+
/// means counting down.
765+
/// \param IsSigned Whether Start, Stop and Step are signed integers.
766+
/// \param InclusiveStop Whether \p Stop itself is a valid value for the loop
767+
/// counter.
768+
/// \param Name Base name used to derive instruction names.
769+
///
770+
/// \returns The value holding the calculated trip count.
771+
Value *calculateCanonicalLoopTripCount(const LocationDescription &Loc,
772+
Value *Start, Value *Stop, Value *Step,
773+
bool IsSigned, bool InclusiveStop,
774+
const Twine &Name = "loop");
775+
776+
/// Generator for the control flow structure of an OpenMP canonical loop.
777+
///
778+
/// Instead of a logical iteration space, this allows specifying user-defined
779+
/// loop counter values using increment, upper- and lower bounds. To
780+
/// disambiguate the terminology when counting downwards, instead of lower
781+
/// bounds we use \p Start for the loop counter value in the first body
782+
///
783+
/// It calls \see calculateCanonicalLoopTripCount for trip count calculations,
784+
/// so limitations of that method apply here as well.
785+
///
762786
/// \param Loc The insert and source location description.
763787
/// \param BodyGenCB Callback that will generate the loop body code.
764788
/// \param Start Value of the loop counter for the first iterations.

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4032,11 +4032,9 @@ OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc,
40324032
return CL;
40334033
}
40344034

4035-
Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
4036-
const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
4037-
Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
4038-
InsertPointTy ComputeIP, const Twine &Name) {
4039-
4035+
Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount(
4036+
const LocationDescription &Loc, Value *Start, Value *Stop, Value *Step,
4037+
bool IsSigned, bool InclusiveStop, const Twine &Name) {
40404038
// Consider the following difficulties (assuming 8-bit signed integers):
40414039
// * Adding \p Step to the loop counter which passes \p Stop may overflow:
40424040
// DO I = 1, 100, 50
@@ -4048,9 +4046,7 @@ Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
40484046
assert(IndVarTy == Stop->getType() && "Stop type mismatch");
40494047
assert(IndVarTy == Step->getType() && "Step type mismatch");
40504048

4051-
LocationDescription ComputeLoc =
4052-
ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
4053-
updateToLocation(ComputeLoc);
4049+
updateToLocation(Loc);
40544050

40554051
ConstantInt *Zero = ConstantInt::get(IndVarTy, 0);
40564052
ConstantInt *One = ConstantInt::get(IndVarTy, 1);
@@ -4090,8 +4086,20 @@ Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
40904086
Value *OneCmp = Builder.CreateICmp(CmpInst::ICMP_ULE, Span, Incr);
40914087
CountIfLooping = Builder.CreateSelect(OneCmp, One, CountIfTwo);
40924088
}
4093-
Value *TripCount = Builder.CreateSelect(ZeroCmp, Zero, CountIfLooping,
4094-
"omp_" + Name + ".tripcount");
4089+
4090+
return Builder.CreateSelect(ZeroCmp, Zero, CountIfLooping,
4091+
"omp_" + Name + ".tripcount");
4092+
}
4093+
4094+
Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
4095+
const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
4096+
Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
4097+
InsertPointTy ComputeIP, const Twine &Name) {
4098+
LocationDescription ComputeLoc =
4099+
ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
4100+
4101+
Value *TripCount = calculateCanonicalLoopTripCount(
4102+
ComputeLoc, Start, Stop, Step, IsSigned, InclusiveStop, Name);
40954103

40964104
auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
40974105
Builder.restoreIP(CodeGenIP);

llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1427,8 +1427,7 @@ TEST_F(OpenMPIRBuilderTest, CanonicalLoopSimple) {
14271427
EXPECT_EQ(&Loop->getAfter()->front(), RetInst);
14281428
}
14291429

1430-
TEST_F(OpenMPIRBuilderTest, CanonicalLoopBounds) {
1431-
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
1430+
TEST_F(OpenMPIRBuilderTest, CanonicalLoopTripCount) {
14321431
OpenMPIRBuilder OMPBuilder(*M);
14331432
OMPBuilder.initialize();
14341433
IRBuilder<> Builder(BB);
@@ -1444,17 +1443,8 @@ TEST_F(OpenMPIRBuilderTest, CanonicalLoopBounds) {
14441443
Value *StartVal = ConstantInt::get(LCTy, Start);
14451444
Value *StopVal = ConstantInt::get(LCTy, Stop);
14461445
Value *StepVal = ConstantInt::get(LCTy, Step);
1447-
auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, llvm::Value *LC) {
1448-
return Error::success();
1449-
};
1450-
Expected<CanonicalLoopInfo *> LoopResult =
1451-
OMPBuilder.createCanonicalLoop(Loc, LoopBodyGenCB, StartVal, StopVal,
1452-
StepVal, IsSigned, InclusiveStop);
1453-
assert(LoopResult && "unexpected error");
1454-
CanonicalLoopInfo *Loop = *LoopResult;
1455-
Loop->assertOK();
1456-
Builder.restoreIP(Loop->getAfterIP());
1457-
Value *TripCount = Loop->getTripCount();
1446+
Value *TripCount = OMPBuilder.calculateCanonicalLoopTripCount(
1447+
Loc, StartVal, StopVal, StepVal, IsSigned, InclusiveStop);
14581448
return cast<ConstantInt>(TripCount)->getValue().getZExtValue();
14591449
};
14601450

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 23 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1772,55 +1772,34 @@ LogicalResult TargetOp::verify() {
17721772
Operation *TargetOp::getInnermostCapturedOmpOp() {
17731773
Dialect *ompDialect = (*this)->getDialect();
17741774
Operation *capturedOp = nullptr;
1775-
Region *capturedParentRegion = nullptr;
17761775

1777-
walk<WalkOrder::PostOrder>([&](Operation *op) {
1776+
// Process in pre-order to check operations from outermost to innermost,
1777+
// ensuring we only enter the region of an operation if it meets the criteria
1778+
// for being captured. We stop the exploration of nested operations as soon as
1779+
// we process a region with no operation to be captured.
1780+
walk<WalkOrder::PreOrder>([&](Operation *op) {
17781781
if (op == *this)
1779-
return;
1780-
1781-
// Reset captured op if crossing through an omp.loop_nest, so that the top
1782-
// level one will be the one captured.
1783-
if (llvm::isa<LoopNestOp>(op)) {
1784-
capturedOp = nullptr;
1785-
capturedParentRegion = nullptr;
1786-
}
1782+
return WalkResult::advance();
17871783

1784+
// Ignore operations of other dialects or omp operations with no regions,
1785+
// because these will only be checked if they are siblings of an omp
1786+
// operation that can potentially be captured.
17881787
bool isOmpDialect = op->getDialect() == ompDialect;
17891788
bool hasRegions = op->getNumRegions() > 0;
1790-
1791-
if (capturedOp) {
1792-
bool isImmediateParent = false;
1793-
for (Region &region : op->getRegions()) {
1794-
if (&region == capturedParentRegion) {
1795-
isImmediateParent = true;
1796-
capturedParentRegion = op->getParentRegion();
1797-
break;
1798-
}
1799-
}
1800-
1801-
// Make sure the captured op is part of a (possibly multi-level) nest of
1802-
// OpenMP-only operations containing no unsupported siblings at any level.
1803-
if ((hasRegions && isOmpDialect != isImmediateParent) ||
1804-
(!isImmediateParent && !siblingAllowedInCapture(op))) {
1805-
capturedOp = nullptr;
1806-
capturedParentRegion = nullptr;
1807-
}
1808-
} else {
1809-
// The first OpenMP dialect op containing a region found while visiting
1810-
// in post-order should be the innermost captured OpenMP operation.
1811-
if (isOmpDialect && hasRegions) {
1812-
capturedOp = op;
1813-
capturedParentRegion = op->getParentRegion();
1814-
1815-
// Don't capture this op if it has a not-allowed sibling.
1816-
for (Operation &sibling : op->getParentRegion()->getOps()) {
1817-
if (&sibling != op && !siblingAllowedInCapture(&sibling)) {
1818-
capturedOp = nullptr;
1819-
capturedParentRegion = nullptr;
1820-
}
1821-
}
1822-
}
1823-
}
1789+
if (!isOmpDialect || !hasRegions)
1790+
return WalkResult::skip();
1791+
1792+
// Don't capture this op if it has a not-allowed sibling, and stop recursing
1793+
// into nested operations.
1794+
for (Operation &sibling : op->getParentRegion()->getOps())
1795+
if (&sibling != op && !siblingAllowedInCapture(&sibling))
1796+
return WalkResult::interrupt();
1797+
1798+
// Don't continue capturing nested operations if we reach an omp.loop_nest.
1799+
// Otherwise, process the contents of this operation.
1800+
capturedOp = op;
1801+
return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
1802+
: WalkResult::advance();
18241803
});
18251804

18261805
return capturedOp;

0 commit comments

Comments
 (0)