Skip to content

Commit 7208d06

Browse files
ftynsekiranchandramohan
authored andcommitted
[OMPIRBuilder] add minimalist reduction support
This introduces a builder function for emitting IR performing reductions in OpenMP. Reduction variable privatization and initialization to the reduction-neutral value is expected to be handled separately. The caller provides the reduction functions. Further commits can provide implementation of reduction functions for the reduction operators defined in the OpenMP specification. This implementation was tested on an MLIR fork targeting OpenMP from C and produced correct executable code. Reviewed By: Meinersbur Differential Revision: https://reviews.llvm.org/D104928
1 parent 713ca6b commit 7208d06

File tree

4 files changed

+854
-9
lines changed

4 files changed

+854
-9
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,122 @@ class OpenMPIRBuilder {
486486
/// \param Loc The location where the taskyield directive was encountered.
487487
void createTaskyield(const LocationDescription &Loc);
488488

489+
/// Functions used to generate reductions. Such functions take two Values
490+
/// representing LHS and RHS of the reduction, respectively, and a reference
491+
/// to the value that is updated to refer to the reduction result.
492+
using ReductionGenTy =
493+
function_ref<InsertPointTy(InsertPointTy, Value *, Value *, Value *&)>;
494+
495+
/// Functions used to generate atomic reductions. Such functions take two
496+
/// Values representing pointers to LHS and RHS of the reduction. They are
497+
/// expected to atomically update the LHS to the reduced value.
498+
using AtomicReductionGenTy =
499+
function_ref<InsertPointTy(InsertPointTy, Value *, Value *)>;
500+
501+
/// Information about an OpenMP reduction.
502+
struct ReductionInfo {
503+
/// Returns the type of the element being reduced.
504+
Type *getElementType() const {
505+
return Variable->getType()->getPointerElementType();
506+
}
507+
508+
/// Reduction variable of pointer type.
509+
Value *Variable;
510+
511+
/// Thread-private partial reduction variable.
512+
Value *PrivateVariable;
513+
514+
/// Callback for generating the reduction body. The IR produced by this will
515+
/// be used to combine two values in a thread-safe context, e.g., under
516+
/// lock or within the same thread, and therefore need not be atomic.
517+
ReductionGenTy ReductionGen;
518+
519+
/// Callback for generating the atomic reduction body, may be null. The IR
520+
/// produced by this will be used to atomically combine two values during
521+
/// reduction. If null, the implementation will use the non-atomic version
522+
/// along with the appropriate synchronization mechanisms.
523+
AtomicReductionGenTy AtomicReductionGen;
524+
};
525+
526+
// TODO: provide atomic and non-atomic reduction generators for reduction
527+
// operators defined by the OpenMP specification.
528+
529+
/// Generator for '#omp reduction'.
530+
///
531+
/// Emits the IR instructing the runtime to perform the specific kind of
532+
/// reductions. Expects reduction variables to have been privatized and
533+
/// initialized to reduction-neutral values separately. Emits the calls to
534+
/// runtime functions as well as the reduction function and the basic blocks
535+
/// performing the reduction atomically and non-atomically.
536+
///
537+
/// The code emitted for the following:
538+
///
539+
/// \code
540+
/// type var_1;
541+
/// type var_2;
542+
/// #pragma omp <directive> reduction(reduction-op:var_1,var_2)
543+
/// /* body */;
544+
/// \endcode
545+
///
546+
/// corresponds to the following sketch.
547+
///
548+
/// \code
549+
/// void _outlined_par() {
550+
/// // N is the number of different reductions.
551+
/// void *red_array[] = {privatized_var_1, privatized_var_2, ...};
552+
/// switch(__kmpc_reduce(..., N, /*size of data in red array*/, red_array,
553+
/// _omp_reduction_func,
554+
/// _gomp_critical_user.reduction.var)) {
555+
/// case 1: {
556+
/// var_1 = var_1 <reduction-op> privatized_var_1;
557+
/// var_2 = var_2 <reduction-op> privatized_var_2;
558+
/// // ...
559+
/// __kmpc_end_reduce(...);
560+
/// break;
561+
/// }
562+
/// case 2: {
563+
/// _Atomic<ReductionOp>(var_1, privatized_var_1);
564+
/// _Atomic<ReductionOp>(var_2, privatized_var_2);
565+
/// // ...
566+
/// break;
567+
/// }
568+
/// default: break;
569+
/// }
570+
/// }
571+
///
572+
/// void _omp_reduction_func(void **lhs, void **rhs) {
573+
/// *(type *)lhs[0] = *(type *)lhs[0] <reduction-op> *(type *)rhs[0];
574+
/// *(type *)lhs[1] = *(type *)lhs[1] <reduction-op> *(type *)rhs[1];
575+
/// // ...
576+
/// }
577+
/// \endcode
578+
///
579+
/// \param Loc The location where the reduction was
580+
/// encountered. Must be within the associate
581+
/// directive and after the last local access to the
582+
/// reduction variables.
583+
/// \param AllocaIP An insertion point suitable for allocas usable
584+
/// in reductions.
585+
/// \param Variables A list of variables in which the reduction
586+
/// results will be stored (values of pointer type).
587+
/// \param PrivateVariables A list of variables in which the partial
588+
/// reduction results are stored (values of pointer
589+
/// type). Coindexed with Variables. Privatization
590+
/// must be handled separately from this call.
591+
/// \param ReductionGen A list of generators for non-atomic reduction
592+
/// bodies. Each takes a pair of partially reduced
593+
/// values and sets a new one.
594+
/// \param AtomicReductionGen A list of generators for atomic reduction
595+
/// bodies, empty if the reduction cannot be
596+
/// performed with atomics. Each takes a pair of
597+
/// _pointers_ to paritally reduced values and
598+
/// atomically stores the result into the first.
599+
/// \param IsNoWait A flag set if the reduction is marked as nowait.
600+
InsertPointTy createReductions(const LocationDescription &Loc,
601+
InsertPointTy AllocaIP,
602+
ArrayRef<ReductionInfo> ReductionInfos,
603+
bool IsNoWait = false);
604+
489605
///}
490606

491607
/// Return the insertion point used by the underlying IRBuilder.

llvm/include/llvm/Frontend/OpenMP/OMPKinds.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,7 @@ __OMP_RTL_ATTRS(__kmpc_task_allow_completion_event, DefaultAttrs,
907907
OMP_IDENT_FLAG(OMP_IDENT_FLAG_##Name, #Name, Value)
908908

909909
__OMP_IDENT_FLAG(KMPC, 0x02)
910+
__OMP_IDENT_FLAG(ATOMIC_REDUCE, 0x10)
910911
__OMP_IDENT_FLAG(BARRIER_EXPL, 0x20)
911912
__OMP_IDENT_FLAG(BARRIER_IMPL, 0x0040)
912913
__OMP_IDENT_FLAG(BARRIER_IMPL_MASK, 0x01C0)

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,179 @@ OpenMPIRBuilder::createSection(const LocationDescription &Loc,
10091009
/*IsCancellable*/ true);
10101010
}
10111011

1012+
/// Create a function with a unique name and a "void (i8*, i8*)" signature in
1013+
/// the given module and return it.
1014+
Function *getFreshReductionFunc(Module &M) {
1015+
Type *VoidTy = Type::getVoidTy(M.getContext());
1016+
Type *Int8PtrTy = Type::getInt8PtrTy(M.getContext());
1017+
auto *FuncTy =
1018+
FunctionType::get(VoidTy, {Int8PtrTy, Int8PtrTy}, /* IsVarArg */ false);
1019+
return Function::Create(FuncTy, GlobalVariable::InternalLinkage,
1020+
M.getDataLayout().getDefaultGlobalsAddressSpace(),
1021+
".omp.reduction.func", &M);
1022+
}
1023+
1024+
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductions(
1025+
const LocationDescription &Loc, InsertPointTy AllocaIP,
1026+
ArrayRef<ReductionInfo> ReductionInfos, bool IsNoWait) {
1027+
for (const ReductionInfo &RI : ReductionInfos) {
1028+
(void)RI;
1029+
assert(RI.Variable && "expected non-null variable");
1030+
assert(RI.PrivateVariable && "expected non-null private variable");
1031+
assert(RI.ReductionGen && "expected non-null reduction generator callback");
1032+
assert(RI.Variable->getType() == RI.PrivateVariable->getType() &&
1033+
"expected variables and their private equivalents to have the same "
1034+
"type");
1035+
assert(RI.Variable->getType()->isPointerTy() &&
1036+
"expected variables to be pointers");
1037+
}
1038+
1039+
if (!updateToLocation(Loc))
1040+
return InsertPointTy();
1041+
1042+
BasicBlock *InsertBlock = Loc.IP.getBlock();
1043+
BasicBlock *ContinuationBlock =
1044+
InsertBlock->splitBasicBlock(Loc.IP.getPoint(), "reduce.finalize");
1045+
InsertBlock->getTerminator()->eraseFromParent();
1046+
1047+
// Create and populate array of type-erased pointers to private reduction
1048+
// values.
1049+
unsigned NumReductions = ReductionInfos.size();
1050+
Type *RedArrayTy = ArrayType::get(Builder.getInt8PtrTy(), NumReductions);
1051+
Builder.restoreIP(AllocaIP);
1052+
Value *RedArray = Builder.CreateAlloca(RedArrayTy, nullptr, "red.array");
1053+
1054+
Builder.SetInsertPoint(InsertBlock, InsertBlock->end());
1055+
1056+
for (auto En : enumerate(ReductionInfos)) {
1057+
unsigned Index = En.index();
1058+
const ReductionInfo &RI = En.value();
1059+
Value *RedArrayElemPtr = Builder.CreateConstInBoundsGEP2_64(
1060+
RedArrayTy, RedArray, 0, Index, "red.array.elem." + Twine(Index));
1061+
Value *Casted =
1062+
Builder.CreateBitCast(RI.PrivateVariable, Builder.getInt8PtrTy(),
1063+
"private.red.var." + Twine(Index) + ".casted");
1064+
Builder.CreateStore(Casted, RedArrayElemPtr);
1065+
}
1066+
1067+
// Emit a call to the runtime function that orchestrates the reduction.
1068+
// Declare the reduction function in the process.
1069+
Function *Func = Builder.GetInsertBlock()->getParent();
1070+
Module *Module = Func->getParent();
1071+
Value *RedArrayPtr =
1072+
Builder.CreateBitCast(RedArray, Builder.getInt8PtrTy(), "red.array.ptr");
1073+
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
1074+
bool CanGenerateAtomic =
1075+
llvm::all_of(ReductionInfos, [](const ReductionInfo &RI) {
1076+
return RI.AtomicReductionGen;
1077+
});
1078+
Value *Ident = getOrCreateIdent(
1079+
SrcLocStr, CanGenerateAtomic ? IdentFlag::OMP_IDENT_FLAG_ATOMIC_REDUCE
1080+
: IdentFlag(0));
1081+
Value *ThreadId = getOrCreateThreadID(Ident);
1082+
Constant *NumVariables = Builder.getInt32(NumReductions);
1083+
const DataLayout &DL = Module->getDataLayout();
1084+
unsigned RedArrayByteSize = DL.getTypeStoreSize(RedArrayTy);
1085+
Constant *RedArraySize = Builder.getInt64(RedArrayByteSize);
1086+
Function *ReductionFunc = getFreshReductionFunc(*Module);
1087+
Value *Lock = getOMPCriticalRegionLock(".reduction");
1088+
Function *ReduceFunc = getOrCreateRuntimeFunctionPtr(
1089+
IsNoWait ? RuntimeFunction::OMPRTL___kmpc_reduce_nowait
1090+
: RuntimeFunction::OMPRTL___kmpc_reduce);
1091+
CallInst *ReduceCall =
1092+
Builder.CreateCall(ReduceFunc,
1093+
{Ident, ThreadId, NumVariables, RedArraySize,
1094+
RedArrayPtr, ReductionFunc, Lock},
1095+
"reduce");
1096+
1097+
// Create final reduction entry blocks for the atomic and non-atomic case.
1098+
// Emit IR that dispatches control flow to one of the blocks based on the
1099+
// reduction supporting the atomic mode.
1100+
BasicBlock *NonAtomicRedBlock =
1101+
BasicBlock::Create(Module->getContext(), "reduce.switch.nonatomic", Func);
1102+
BasicBlock *AtomicRedBlock =
1103+
BasicBlock::Create(Module->getContext(), "reduce.switch.atomic", Func);
1104+
SwitchInst *Switch =
1105+
Builder.CreateSwitch(ReduceCall, ContinuationBlock, /* NumCases */ 2);
1106+
Switch->addCase(Builder.getInt32(1), NonAtomicRedBlock);
1107+
Switch->addCase(Builder.getInt32(2), AtomicRedBlock);
1108+
1109+
// Populate the non-atomic reduction using the elementwise reduction function.
1110+
// This loads the elements from the global and private variables and reduces
1111+
// them before storing back the result to the global variable.
1112+
Builder.SetInsertPoint(NonAtomicRedBlock);
1113+
for (auto En : enumerate(ReductionInfos)) {
1114+
const ReductionInfo &RI = En.value();
1115+
Type *ValueType = RI.getElementType();
1116+
Value *RedValue = Builder.CreateLoad(ValueType, RI.Variable,
1117+
"red.value." + Twine(En.index()));
1118+
Value *PrivateRedValue =
1119+
Builder.CreateLoad(ValueType, RI.PrivateVariable,
1120+
"red.private.value." + Twine(En.index()));
1121+
Value *Reduced;
1122+
Builder.restoreIP(
1123+
RI.ReductionGen(Builder.saveIP(), RedValue, PrivateRedValue, Reduced));
1124+
if (!Builder.GetInsertBlock())
1125+
return InsertPointTy();
1126+
Builder.CreateStore(Reduced, RI.Variable);
1127+
}
1128+
Function *EndReduceFunc = getOrCreateRuntimeFunctionPtr(
1129+
IsNoWait ? RuntimeFunction::OMPRTL___kmpc_end_reduce_nowait
1130+
: RuntimeFunction::OMPRTL___kmpc_end_reduce);
1131+
Builder.CreateCall(EndReduceFunc, {Ident, ThreadId, Lock});
1132+
Builder.CreateBr(ContinuationBlock);
1133+
1134+
// Populate the atomic reduction using the atomic elementwise reduction
1135+
// function. There are no loads/stores here because they will be happening
1136+
// inside the atomic elementwise reduction.
1137+
Builder.SetInsertPoint(AtomicRedBlock);
1138+
if (CanGenerateAtomic) {
1139+
for (const ReductionInfo &RI : ReductionInfos) {
1140+
Builder.restoreIP(RI.AtomicReductionGen(Builder.saveIP(), RI.Variable,
1141+
RI.PrivateVariable));
1142+
if (!Builder.GetInsertBlock())
1143+
return InsertPointTy();
1144+
}
1145+
Builder.CreateBr(ContinuationBlock);
1146+
} else {
1147+
Builder.CreateUnreachable();
1148+
}
1149+
1150+
// Populate the outlined reduction function using the elementwise reduction
1151+
// function. Partial values are extracted from the type-erased array of
1152+
// pointers to private variables.
1153+
BasicBlock *ReductionFuncBlock =
1154+
BasicBlock::Create(Module->getContext(), "", ReductionFunc);
1155+
Builder.SetInsertPoint(ReductionFuncBlock);
1156+
Value *LHSArrayPtr = Builder.CreateBitCast(ReductionFunc->getArg(0),
1157+
RedArrayTy->getPointerTo());
1158+
Value *RHSArrayPtr = Builder.CreateBitCast(ReductionFunc->getArg(1),
1159+
RedArrayTy->getPointerTo());
1160+
for (auto En : enumerate(ReductionInfos)) {
1161+
const ReductionInfo &RI = En.value();
1162+
Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
1163+
RedArrayTy, LHSArrayPtr, 0, En.index());
1164+
Value *LHSI8Ptr = Builder.CreateLoad(Builder.getInt8PtrTy(), LHSI8PtrPtr);
1165+
Value *LHSPtr = Builder.CreateBitCast(LHSI8Ptr, RI.Variable->getType());
1166+
Value *LHS = Builder.CreateLoad(RI.getElementType(), LHSPtr);
1167+
Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
1168+
RedArrayTy, RHSArrayPtr, 0, En.index());
1169+
Value *RHSI8Ptr = Builder.CreateLoad(Builder.getInt8PtrTy(), RHSI8PtrPtr);
1170+
Value *RHSPtr =
1171+
Builder.CreateBitCast(RHSI8Ptr, RI.PrivateVariable->getType());
1172+
Value *RHS = Builder.CreateLoad(RI.getElementType(), RHSPtr);
1173+
Value *Reduced;
1174+
Builder.restoreIP(RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced));
1175+
if (!Builder.GetInsertBlock())
1176+
return InsertPointTy();
1177+
Builder.CreateStore(Reduced, LHSPtr);
1178+
}
1179+
Builder.CreateRetVoid();
1180+
1181+
Builder.SetInsertPoint(ContinuationBlock);
1182+
return Builder.saveIP();
1183+
}
1184+
10121185
OpenMPIRBuilder::InsertPointTy
10131186
OpenMPIRBuilder::createMaster(const LocationDescription &Loc,
10141187
BodyGenCallbackTy BodyGenCB,

0 commit comments

Comments
 (0)