@@ -1009,6 +1009,179 @@ OpenMPIRBuilder::createSection(const LocationDescription &Loc,
1009
1009
/* IsCancellable*/ true );
1010
1010
}
1011
1011
1012
+ // / Create a function with a unique name and a "void (i8*, i8*)" signature in
1013
+ // / the given module and return it.
1014
+ Function *getFreshReductionFunc (Module &M) {
1015
+ Type *VoidTy = Type::getVoidTy (M.getContext ());
1016
+ Type *Int8PtrTy = Type::getInt8PtrTy (M.getContext ());
1017
+ auto *FuncTy =
1018
+ FunctionType::get (VoidTy, {Int8PtrTy, Int8PtrTy}, /* IsVarArg */ false );
1019
+ return Function::Create (FuncTy, GlobalVariable::InternalLinkage,
1020
+ M.getDataLayout ().getDefaultGlobalsAddressSpace (),
1021
+ " .omp.reduction.func" , &M);
1022
+ }
1023
+
1024
+ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductions (
1025
+ const LocationDescription &Loc, InsertPointTy AllocaIP,
1026
+ ArrayRef<ReductionInfo> ReductionInfos, bool IsNoWait) {
1027
+ for (const ReductionInfo &RI : ReductionInfos) {
1028
+ (void )RI;
1029
+ assert (RI.Variable && " expected non-null variable" );
1030
+ assert (RI.PrivateVariable && " expected non-null private variable" );
1031
+ assert (RI.ReductionGen && " expected non-null reduction generator callback" );
1032
+ assert (RI.Variable ->getType () == RI.PrivateVariable ->getType () &&
1033
+ " expected variables and their private equivalents to have the same "
1034
+ " type" );
1035
+ assert (RI.Variable ->getType ()->isPointerTy () &&
1036
+ " expected variables to be pointers" );
1037
+ }
1038
+
1039
+ if (!updateToLocation (Loc))
1040
+ return InsertPointTy ();
1041
+
1042
+ BasicBlock *InsertBlock = Loc.IP .getBlock ();
1043
+ BasicBlock *ContinuationBlock =
1044
+ InsertBlock->splitBasicBlock (Loc.IP .getPoint (), " reduce.finalize" );
1045
+ InsertBlock->getTerminator ()->eraseFromParent ();
1046
+
1047
+ // Create and populate array of type-erased pointers to private reduction
1048
+ // values.
1049
+ unsigned NumReductions = ReductionInfos.size ();
1050
+ Type *RedArrayTy = ArrayType::get (Builder.getInt8PtrTy (), NumReductions);
1051
+ Builder.restoreIP (AllocaIP);
1052
+ Value *RedArray = Builder.CreateAlloca (RedArrayTy, nullptr , " red.array" );
1053
+
1054
+ Builder.SetInsertPoint (InsertBlock, InsertBlock->end ());
1055
+
1056
+ for (auto En : enumerate(ReductionInfos)) {
1057
+ unsigned Index = En.index ();
1058
+ const ReductionInfo &RI = En.value ();
1059
+ Value *RedArrayElemPtr = Builder.CreateConstInBoundsGEP2_64 (
1060
+ RedArrayTy, RedArray, 0 , Index, " red.array.elem." + Twine (Index));
1061
+ Value *Casted =
1062
+ Builder.CreateBitCast (RI.PrivateVariable , Builder.getInt8PtrTy (),
1063
+ " private.red.var." + Twine (Index) + " .casted" );
1064
+ Builder.CreateStore (Casted, RedArrayElemPtr);
1065
+ }
1066
+
1067
+ // Emit a call to the runtime function that orchestrates the reduction.
1068
+ // Declare the reduction function in the process.
1069
+ Function *Func = Builder.GetInsertBlock ()->getParent ();
1070
+ Module *Module = Func->getParent ();
1071
+ Value *RedArrayPtr =
1072
+ Builder.CreateBitCast (RedArray, Builder.getInt8PtrTy (), " red.array.ptr" );
1073
+ Constant *SrcLocStr = getOrCreateSrcLocStr (Loc);
1074
+ bool CanGenerateAtomic =
1075
+ llvm::all_of (ReductionInfos, [](const ReductionInfo &RI) {
1076
+ return RI.AtomicReductionGen ;
1077
+ });
1078
+ Value *Ident = getOrCreateIdent (
1079
+ SrcLocStr, CanGenerateAtomic ? IdentFlag::OMP_IDENT_FLAG_ATOMIC_REDUCE
1080
+ : IdentFlag (0 ));
1081
+ Value *ThreadId = getOrCreateThreadID (Ident);
1082
+ Constant *NumVariables = Builder.getInt32 (NumReductions);
1083
+ const DataLayout &DL = Module->getDataLayout ();
1084
+ unsigned RedArrayByteSize = DL.getTypeStoreSize (RedArrayTy);
1085
+ Constant *RedArraySize = Builder.getInt64 (RedArrayByteSize);
1086
+ Function *ReductionFunc = getFreshReductionFunc (*Module);
1087
+ Value *Lock = getOMPCriticalRegionLock (" .reduction" );
1088
+ Function *ReduceFunc = getOrCreateRuntimeFunctionPtr (
1089
+ IsNoWait ? RuntimeFunction::OMPRTL___kmpc_reduce_nowait
1090
+ : RuntimeFunction::OMPRTL___kmpc_reduce);
1091
+ CallInst *ReduceCall =
1092
+ Builder.CreateCall (ReduceFunc,
1093
+ {Ident, ThreadId, NumVariables, RedArraySize,
1094
+ RedArrayPtr, ReductionFunc, Lock},
1095
+ " reduce" );
1096
+
1097
+ // Create final reduction entry blocks for the atomic and non-atomic case.
1098
+ // Emit IR that dispatches control flow to one of the blocks based on the
1099
+ // reduction supporting the atomic mode.
1100
+ BasicBlock *NonAtomicRedBlock =
1101
+ BasicBlock::Create (Module->getContext (), " reduce.switch.nonatomic" , Func);
1102
+ BasicBlock *AtomicRedBlock =
1103
+ BasicBlock::Create (Module->getContext (), " reduce.switch.atomic" , Func);
1104
+ SwitchInst *Switch =
1105
+ Builder.CreateSwitch (ReduceCall, ContinuationBlock, /* NumCases */ 2 );
1106
+ Switch->addCase (Builder.getInt32 (1 ), NonAtomicRedBlock);
1107
+ Switch->addCase (Builder.getInt32 (2 ), AtomicRedBlock);
1108
+
1109
+ // Populate the non-atomic reduction using the elementwise reduction function.
1110
+ // This loads the elements from the global and private variables and reduces
1111
+ // them before storing back the result to the global variable.
1112
+ Builder.SetInsertPoint (NonAtomicRedBlock);
1113
+ for (auto En : enumerate(ReductionInfos)) {
1114
+ const ReductionInfo &RI = En.value ();
1115
+ Type *ValueType = RI.getElementType ();
1116
+ Value *RedValue = Builder.CreateLoad (ValueType, RI.Variable ,
1117
+ " red.value." + Twine (En.index ()));
1118
+ Value *PrivateRedValue =
1119
+ Builder.CreateLoad (ValueType, RI.PrivateVariable ,
1120
+ " red.private.value." + Twine (En.index ()));
1121
+ Value *Reduced;
1122
+ Builder.restoreIP (
1123
+ RI.ReductionGen (Builder.saveIP (), RedValue, PrivateRedValue, Reduced));
1124
+ if (!Builder.GetInsertBlock ())
1125
+ return InsertPointTy ();
1126
+ Builder.CreateStore (Reduced, RI.Variable );
1127
+ }
1128
+ Function *EndReduceFunc = getOrCreateRuntimeFunctionPtr (
1129
+ IsNoWait ? RuntimeFunction::OMPRTL___kmpc_end_reduce_nowait
1130
+ : RuntimeFunction::OMPRTL___kmpc_end_reduce);
1131
+ Builder.CreateCall (EndReduceFunc, {Ident, ThreadId, Lock});
1132
+ Builder.CreateBr (ContinuationBlock);
1133
+
1134
+ // Populate the atomic reduction using the atomic elementwise reduction
1135
+ // function. There are no loads/stores here because they will be happening
1136
+ // inside the atomic elementwise reduction.
1137
+ Builder.SetInsertPoint (AtomicRedBlock);
1138
+ if (CanGenerateAtomic) {
1139
+ for (const ReductionInfo &RI : ReductionInfos) {
1140
+ Builder.restoreIP (RI.AtomicReductionGen (Builder.saveIP (), RI.Variable ,
1141
+ RI.PrivateVariable ));
1142
+ if (!Builder.GetInsertBlock ())
1143
+ return InsertPointTy ();
1144
+ }
1145
+ Builder.CreateBr (ContinuationBlock);
1146
+ } else {
1147
+ Builder.CreateUnreachable ();
1148
+ }
1149
+
1150
+ // Populate the outlined reduction function using the elementwise reduction
1151
+ // function. Partial values are extracted from the type-erased array of
1152
+ // pointers to private variables.
1153
+ BasicBlock *ReductionFuncBlock =
1154
+ BasicBlock::Create (Module->getContext (), " " , ReductionFunc);
1155
+ Builder.SetInsertPoint (ReductionFuncBlock);
1156
+ Value *LHSArrayPtr = Builder.CreateBitCast (ReductionFunc->getArg (0 ),
1157
+ RedArrayTy->getPointerTo ());
1158
+ Value *RHSArrayPtr = Builder.CreateBitCast (ReductionFunc->getArg (1 ),
1159
+ RedArrayTy->getPointerTo ());
1160
+ for (auto En : enumerate(ReductionInfos)) {
1161
+ const ReductionInfo &RI = En.value ();
1162
+ Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64 (
1163
+ RedArrayTy, LHSArrayPtr, 0 , En.index ());
1164
+ Value *LHSI8Ptr = Builder.CreateLoad (Builder.getInt8PtrTy (), LHSI8PtrPtr);
1165
+ Value *LHSPtr = Builder.CreateBitCast (LHSI8Ptr, RI.Variable ->getType ());
1166
+ Value *LHS = Builder.CreateLoad (RI.getElementType (), LHSPtr);
1167
+ Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64 (
1168
+ RedArrayTy, RHSArrayPtr, 0 , En.index ());
1169
+ Value *RHSI8Ptr = Builder.CreateLoad (Builder.getInt8PtrTy (), RHSI8PtrPtr);
1170
+ Value *RHSPtr =
1171
+ Builder.CreateBitCast (RHSI8Ptr, RI.PrivateVariable ->getType ());
1172
+ Value *RHS = Builder.CreateLoad (RI.getElementType (), RHSPtr);
1173
+ Value *Reduced;
1174
+ Builder.restoreIP (RI.ReductionGen (Builder.saveIP (), LHS, RHS, Reduced));
1175
+ if (!Builder.GetInsertBlock ())
1176
+ return InsertPointTy ();
1177
+ Builder.CreateStore (Reduced, LHSPtr);
1178
+ }
1179
+ Builder.CreateRetVoid ();
1180
+
1181
+ Builder.SetInsertPoint (ContinuationBlock);
1182
+ return Builder.saveIP ();
1183
+ }
1184
+
1012
1185
OpenMPIRBuilder::InsertPointTy
1013
1186
OpenMPIRBuilder::createMaster (const LocationDescription &Loc,
1014
1187
BodyGenCallbackTy BodyGenCB,
0 commit comments