Skip to content

Commit aabd6ca

Browse files
authored
[clang] [OpenMP] [Offload] Added support for Xteam min/max reduction. (llvm#1412)
2 parents ddc4e69 + 953acfc commit aabd6ca

15 files changed

+8345
-273
lines changed

clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp

Lines changed: 269 additions & 91 deletions
Large diffs are not rendered by default.

clang/lib/CodeGen/CGOpenMPRuntimeGPU.h

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -169,15 +169,17 @@ class CGOpenMPRuntimeGPU : public CGOpenMPRuntime {
169169
llvm::Value *initSpecializedKernel(CodeGenFunction &CGF);
170170

171171
std::pair<llvm::Value *, llvm::Value *>
172-
getXteamRedFunctionPtrs(CodeGenFunction &CGF, llvm::Type *RedVarType);
173-
174-
/// Call cross-team sum
175-
llvm::Value *getXteamRedSum(CodeGenFunction &CGF, llvm::Value *Val,
176-
llvm::Value *SumPtr, llvm::Value *DTeamVals,
177-
llvm::Value *DTeamsDonePtr,
178-
llvm::Value *ThreadStartIndex,
179-
llvm::Value *NumTeams, int BlockSize,
180-
bool IsFast);
172+
getXteamRedFunctionPtrs(CodeGenFunction &CGF, llvm::Type *RedVarType,
173+
CodeGenModule::XteamRedOpKind Opcode);
174+
175+
/// Generate a call to cross-team operation.
176+
llvm::Value *getXteamRedOperation(CodeGenFunction &CGF, llvm::Value *Val,
177+
llvm::Value *OrigVarPtr,
178+
llvm::Value *DTeamVals,
179+
llvm::Value *DTeamsDonePtr,
180+
llvm::Value *ThreadStartIndex,
181+
llvm::Value *NumTeams, int BlockSize,
182+
CodeGenModule::XteamRedOpKind, bool IsFast);
181183

182184
/// Emit call to Cross-team scan entry points
183185
llvm::Value *

clang/lib/CodeGen/CGStmt.cpp

Lines changed: 176 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -682,7 +682,7 @@ void CodeGenFunction::EmitXteamRedCode(const OMPExecutableDirective &D,
682682
// EmitStmt(CapturedForStmt);
683683

684684
// Now emit the calls to xteam_sum, one for each reduction variable
685-
EmitXteamRedSum(CapturedForStmt, *Args, CGM.getXteamRedBlockSize(D));
685+
EmitXteamRedOperation(CapturedForStmt, *Args, CGM.getXteamRedBlockSize(D));
686686
}
687687

688688
// Xteam codegen done
@@ -706,14 +706,9 @@ void CodeGenFunction::EmitXteamLocalAggregator(const ForStmt *FStmt) {
706706
RedVarType->isIntegerTy()) &&
707707
"Unhandled type");
708708
llvm::AllocaInst *XteamRedInst = Builder.CreateAlloca(RedVarType);
709-
llvm::Value *InitVal = nullptr;
710-
if (RedVarType->isFloatTy() || RedVarType->isDoubleTy() ||
711-
RedVarType->isHalfTy() || RedVarType->isBFloatTy())
712-
InitVal = llvm::ConstantFP::getZero(RedVarType);
713-
else if (RedVarType->isIntegerTy())
714-
InitVal = llvm::ConstantInt::get(RedVarType, 0);
715-
else
716-
llvm_unreachable("Unhandled type");
709+
// The initial value (referred to as the sentinel value) of the local
710+
// reduction variable depends on the opcode.
711+
llvm::Value *InitVal = getXteamRedSentinel(RedVarType, Itr->second.Opcode);
717712
Address XteamRedVarAddr(
718713
XteamRedInst, RedVarType,
719714
getContext().getTypeAlignInChars(RedVarExpr->getType()));
@@ -726,11 +721,56 @@ void CodeGenFunction::EmitXteamLocalAggregator(const ForStmt *FStmt) {
726721
}
727722
}
728723

729-
// Emit __kmpc_xteam_sum(*xteam_red_local_addr, red_var_addr) for each reduction
730-
// in the helper map for the given For Stmt
731-
void CodeGenFunction::EmitXteamRedSum(const ForStmt *FStmt,
732-
const FunctionArgList &Args,
733-
int BlockSize) {
724+
llvm::Value *
725+
CodeGenFunction::getXteamRedSentinel(llvm::Type *RedVarType,
726+
CodeGenModule::XteamRedOpKind Opcode) {
727+
assert((RedVarType->isFloatTy() || RedVarType->isDoubleTy() ||
728+
RedVarType->isHalfTy() || RedVarType->isBFloatTy() ||
729+
RedVarType->isIntegerTy()) &&
730+
"Unhandled type");
731+
assert(Opcode != CodeGenModule::XR_OP_unknown &&
732+
"Unexpected Xteam reduction opcode");
733+
if (RedVarType->isFloatTy() || RedVarType->isDoubleTy() ||
734+
RedVarType->isHalfTy() || RedVarType->isBFloatTy()) {
735+
if (Opcode == CodeGenModule::XR_OP_add)
736+
return llvm::ConstantFP::getZero(RedVarType);
737+
else if (Opcode == CodeGenModule::XR_OP_min)
738+
return llvm::ConstantFP::getInfinity(RedVarType);
739+
else // max operator
740+
return llvm::ConstantFP::getInfinity(RedVarType, /*Negative=*/true);
741+
} else {
742+
// Integer type
743+
if (RedVarType->getPrimitiveSizeInBits() == 16)
744+
return llvm::ConstantInt::get(Int16Ty,
745+
Opcode == CodeGenModule::XR_OP_add ? 0
746+
: Opcode == CodeGenModule::XR_OP_min
747+
? std::numeric_limits<int16_t>::max()
748+
: std::numeric_limits<int16_t>::min());
749+
else if (RedVarType->getPrimitiveSizeInBits() == 32)
750+
return llvm::ConstantInt::get(Int32Ty,
751+
Opcode == CodeGenModule::XR_OP_add ? 0
752+
: Opcode == CodeGenModule::XR_OP_min
753+
? std::numeric_limits<int32_t>::max()
754+
: std::numeric_limits<int32_t>::min());
755+
else {
756+
assert(RedVarType->getPrimitiveSizeInBits() == 64 &&
757+
"Expected a 64-bit integer");
758+
return llvm::ConstantInt::get(Int64Ty,
759+
Opcode == CodeGenModule::XR_OP_add ? 0
760+
: Opcode == CodeGenModule::XR_OP_min
761+
? std::numeric_limits<int64_t>::max()
762+
: std::numeric_limits<int64_t>::min());
763+
}
764+
}
765+
llvm_unreachable(
766+
"Unexpected type or opcode in Xteam reduction sentinel generation");
767+
}
768+
769+
// Emit a call to the DeviceRTL Xteam reduction function for each reduction
770+
// variable in the helper map for the given For Stmt.
771+
void CodeGenFunction::EmitXteamRedOperation(const ForStmt *FStmt,
772+
const FunctionArgList &Args,
773+
int BlockSize) {
734774
auto &RT = static_cast<CGOpenMPRuntimeGPU &>(CGM.getOpenMPRuntime());
735775
const CodeGenModule::XteamRedVarMap &RedVarMap = CGM.getXteamRedVarMap(FStmt);
736776

@@ -760,10 +800,12 @@ void CodeGenFunction::EmitXteamRedSum(const ForStmt *FStmt,
760800
const Expr *OrigRedVarExpr = RVI.RedVarExpr;
761801
const DeclRefExpr *DRE = cast<DeclRefExpr>(OrigRedVarExpr);
762802
Address OrigRedVarAddr = EmitLValue(DRE).getAddress();
763-
// Pass in OrigRedVarAddr.getPointer to kmpc_xteam_sum
764-
RT.getXteamRedSum(*this, Builder.CreateLoad(RVI.RedVarAddr),
765-
OrigRedVarAddr.emitRawPointer(*this), DTeamVals, DTeamsDonePtr,
766-
ThreadStartIdx, NumTeams, BlockSize, IsFast);
803+
// Note that fast Xteam reduction is available only for sum operator.
804+
RT.getXteamRedOperation(*this, Builder.CreateLoad(RVI.RedVarAddr),
805+
OrigRedVarAddr.emitRawPointer(*this), DTeamVals,
806+
DTeamsDonePtr, ThreadStartIdx, NumTeams, BlockSize,
807+
RVI.Opcode,
808+
IsFast && RVI.Opcode == CodeGenModule::XR_OP_add);
767809
}
768810
}
769811

@@ -857,8 +899,6 @@ void CodeGenFunction::EmitXteamScanPhaseTwo(const ForStmt *FStmt,
857899
}
858900
}
859901

860-
// Emit reduction into local aggregator for a statement within the reduced loop
861-
// where applicable
862902
bool CodeGenFunction::EmitXteamRedStmt(const Stmt *S) {
863903
if (CGM.getCurrentXteamRedStmt() == nullptr)
864904
return false;
@@ -882,6 +922,10 @@ bool CodeGenFunction::EmitXteamRedStmt(const Stmt *S) {
882922
const CodeGenModule::XteamRedVarMap &RedVarMap =
883923
CGM.getXteamRedVarMap(CGM.getCurrentXteamRedStmt());
884924

925+
// Currently, there is limited support in Xteam reduction for calls with
926+
// reduction variables in arguments. Either the call has to be at the
927+
// statement level or it has to be a call to a builtin function (e.g. min/max)
928+
// on the rhs of an assignment statement. Handle call at the statement level.
885929
if (isa<CallExpr>(S)) {
886930
const CallExpr *CE = cast<CallExpr>(S);
887931
assert(CE && "Unexpected null call expression");
@@ -962,45 +1006,131 @@ bool CodeGenFunction::EmitXteamRedStmt(const Stmt *S) {
9621006
RedRHSExpr = RedBO->getRHS()->IgnoreImpCasts();
9631007
} else {
9641008
const Expr *L1RhsExpr = RedBO->getRHS()->IgnoreImpCasts();
965-
assert(isa<BinaryOperator>(L1RhsExpr) &&
1009+
assert((isa<BinaryOperator>(L1RhsExpr) || isa<CallExpr>(L1RhsExpr) ||
1010+
isa<PseudoObjectExpr>(L1RhsExpr)) &&
9661011
"Expected rhs to be a binary operator");
967-
const BinaryOperator *L2BO = cast<BinaryOperator>(L1RhsExpr);
968-
auto OpcL2BO = L2BO->getOpcode();
969-
assert(OpcL2BO == BO_Add && "Unexpected operator");
970-
// If the redvar is lhs, use the rhs in the generated reduction statement
971-
// and vice-versa.
972-
if (CGM.isXteamRedVarExpr(L2BO->getLHS()->IgnoreImpCasts(), RedVarDecl))
973-
RedRHSExpr = L2BO->getRHS();
974-
else if (CGM.isXteamRedVarExpr(L2BO->getRHS()->IgnoreImpCasts(),
975-
RedVarDecl))
976-
RedRHSExpr = L2BO->getLHS();
977-
else
978-
llvm_unreachable("Unhandled add expression during xteam reduction");
1012+
if (isa<BinaryOperator>(L1RhsExpr)) {
1013+
const BinaryOperator *L2BO = cast<BinaryOperator>(L1RhsExpr);
1014+
auto OpcL2BO = L2BO->getOpcode();
1015+
assert(OpcL2BO == BO_Add && "Unexpected operator");
1016+
// If the redvar is lhs, use the rhs in the generated reduction statement
1017+
// and vice-versa.
1018+
if (CGM.isXteamRedVarExpr(L2BO->getLHS()->IgnoreImpCasts(), RedVarDecl))
1019+
RedRHSExpr = L2BO->getRHS();
1020+
else if (CGM.isXteamRedVarExpr(L2BO->getRHS()->IgnoreImpCasts(),
1021+
RedVarDecl))
1022+
RedRHSExpr = L2BO->getLHS();
1023+
else
1024+
llvm_unreachable("Unhandled add expression during xteam reduction");
1025+
} else if (isa<CallExpr>(L1RhsExpr)) {
1026+
const CallExpr *Call = cast<CallExpr>(L1RhsExpr);
1027+
assert(CGM.getStatusOptKernelBuiltin(Call) == CodeGenModule::NxSuccess &&
1028+
"Expected a call to an Xteam supported builtin");
1029+
EmitXteamRedStmtForBuiltinCall(Call, RedVarDecl, RedVarMap);
1030+
return true;
1031+
} else {
1032+
assert(isa<PseudoObjectExpr>(L1RhsExpr) && "Expected a PseudoObjectExpr");
1033+
auto [Status, ReturnExpr] = CGM.getStatusXteamSupportedPseudoObject(
1034+
cast<PseudoObjectExpr>(L1RhsExpr));
1035+
assert(Status == CodeGenModule::NxSuccess &&
1036+
"Expected call expression from analysis of PseudoObjectExpr");
1037+
const CallExpr *Call = cast<CallExpr>(ReturnExpr);
1038+
assert(CGM.getStatusOptKernelBuiltin(Call) == CodeGenModule::NxSuccess &&
1039+
"Expected a call to an Xteam supported builtin");
1040+
EmitXteamRedStmtForBuiltinCall(Call, RedVarDecl, RedVarMap);
1041+
return true;
1042+
}
9791043
}
9801044
assert(RedRHSExpr != nullptr && "Did not find a valid reduction rhs");
981-
llvm::Value *RHSValue = EmitScalarExpr(RedRHSExpr);
1045+
1046+
EmitLocalReductionStmt(RedRHSExpr, RedVarDecl, RedVarMap,
1047+
CodeGenModule::XR_OP_add);
1048+
return true;
1049+
}
1050+
1051+
void CodeGenFunction::EmitLocalReductionStmt(
1052+
const Expr *E, const VarDecl *RedVarDecl,
1053+
const CodeGenModule::XteamRedVarMap &RedVarMap,
1054+
CodeGenModule::XteamRedOpKind OpKind) {
1055+
// For add, generate *xteam_local = *xteam_local + rhs_value
1056+
// For min/max, generate *xteam_local = min/max(*xteam_local, other_operand)
1057+
1058+
// First, generate the other operand.
1059+
llvm::Value *RHSValue = EmitScalarExpr(E);
1060+
// Now handle the local reduction variable accesses.
9821061
auto It = RedVarMap.find(RedVarDecl);
9831062
assert(It != RedVarMap.end() && "Variable must be found in reduction map");
9841063
Address XteamRedLocalAddr = It->second.RedVarAddr;
985-
// Compute *xteam_red_local_addr + rhs_value
986-
llvm::Value *RedRHS = nullptr;
9871064
llvm::Type *RedVarType = ConvertTypeForMem(It->second.RedVarExpr->getType());
1065+
llvm::Value *Op1 = Builder.CreateLoad(XteamRedLocalAddr);
1066+
llvm::Value *RedRHS = nullptr;
9881067
if (RedVarType->isFloatTy() || RedVarType->isDoubleTy() ||
9891068
RedVarType->isHalfTy() || RedVarType->isBFloatTy()) {
990-
auto RHSOp = RHSValue->getType()->isIntegerTy()
991-
? Builder.CreateSIToFP(RHSValue, RedVarType)
992-
: Builder.CreateFPCast(RHSValue, RedVarType);
993-
RedRHS = Builder.CreateFAdd(Builder.CreateLoad(XteamRedLocalAddr), RHSOp);
1069+
auto Op2 = RHSValue->getType()->isIntegerTy()
1070+
? Builder.CreateSIToFP(RHSValue, RedVarType)
1071+
: Builder.CreateFPCast(RHSValue, RedVarType);
1072+
if (OpKind == CodeGenModule::XR_OP_add)
1073+
RedRHS = Builder.CreateFAdd(Op1, Op2);
1074+
else if (OpKind == CodeGenModule::XR_OP_min)
1075+
RedRHS =
1076+
Builder.CreateMinNum(Op1, Op2, /*FMFSource=*/nullptr, "xteam.min");
1077+
else if (OpKind == CodeGenModule::XR_OP_max)
1078+
RedRHS =
1079+
Builder.CreateMaxNum(Op1, Op2, /*FMFSource=*/nullptr, "xteam.max");
1080+
else
1081+
llvm_unreachable("Unexpected reduction kind");
9941082
} else if (RedVarType->isIntegerTy()) {
995-
auto RHSOp = RHSValue->getType()->isIntegerTy()
996-
? Builder.CreateIntCast(RHSValue, RedVarType, false)
997-
: Builder.CreateFPToSI(RHSValue, RedVarType);
998-
RedRHS = Builder.CreateAdd(Builder.CreateLoad(XteamRedLocalAddr), RHSOp);
1083+
auto Op2 = RHSValue->getType()->isIntegerTy()
1084+
? Builder.CreateIntCast(RHSValue, RedVarType, false)
1085+
: Builder.CreateFPToSI(RHSValue, RedVarType);
1086+
if (OpKind == CodeGenModule::XR_OP_add)
1087+
RedRHS = Builder.CreateAdd(Op1, Op2);
1088+
else if (OpKind == CodeGenModule::XR_OP_min)
1089+
// TODO Fix when unsigned
1090+
RedRHS = Builder.CreateBinaryIntrinsic(llvm::Intrinsic::smin, Op1, Op2,
1091+
nullptr, "xteam.min");
1092+
else if (OpKind == CodeGenModule::XR_OP_max)
1093+
// TODO fix when unsigned
1094+
RedRHS = Builder.CreateBinaryIntrinsic(llvm::Intrinsic::smax, Op1, Op2,
1095+
nullptr, "xteam.max");
1096+
else
1097+
llvm_unreachable("Unexpected reduction kind");
9991098
} else
10001099
llvm_unreachable("Unhandled type");
1001-
// *xteam_red_local_addr = *xteam_red_local_addr + rhs_value
1100+
assert(RedRHS && "Right hand side of statement cannot be null");
10021101
Builder.CreateStore(RedRHS, XteamRedLocalAddr);
1003-
return true;
1102+
}
1103+
1104+
std::pair<const Expr *, CodeGenModule::XteamRedOpKind>
1105+
CodeGenFunction::ExtractXteamRedRhsExpr(const CallExpr *Call,
1106+
const VarDecl *RedVarDecl) {
1107+
// Traverse arguments, identifying and ignoring the reduction variable, and
1108+
// then extracting the other argument.
1109+
CodeGenModule::XteamRedOpKind Opcode;
1110+
std::string CallName = Call->getDirectCallee()->getNameInfo().getAsString();
1111+
if (CGM.isOptKernelAMDGCNMax(CallName))
1112+
Opcode = CodeGenModule::XR_OP_max;
1113+
else if (CGM.isOptKernelAMDGCNMin(CallName))
1114+
Opcode = CodeGenModule::XR_OP_min;
1115+
else
1116+
llvm_unreachable("Epecting either min or max");
1117+
1118+
for (unsigned ArgIndex = 0; ArgIndex < Call->getNumArgs(); ++ArgIndex) {
1119+
const Expr *Arg = Call->getArg(ArgIndex);
1120+
while (isa<ImplicitCastExpr>(Arg))
1121+
Arg = cast<ImplicitCastExpr>(Arg)->getSubExpr();
1122+
if (CGM.isXteamRedVarExpr(Arg, RedVarDecl))
1123+
continue;
1124+
return std::make_pair(Call->getArg(ArgIndex), Opcode);
1125+
}
1126+
llvm_unreachable("Could not extract expected arg of min/max");
1127+
}
1128+
1129+
void CodeGenFunction::EmitXteamRedStmtForBuiltinCall(
1130+
const CallExpr *Call, const VarDecl *RedVarDecl,
1131+
const CodeGenModule::XteamRedVarMap &RedVarMap) {
1132+
auto [RhsExpr, Opcode] = ExtractXteamRedRhsExpr(Call, RedVarDecl);
1133+
EmitLocalReductionStmt(RhsExpr, RedVarDecl, RedVarMap, Opcode);
10041134
}
10051135

10061136
void CodeGenFunction::EmitStmt(const Stmt *S, ArrayRef<const Attr *> Attrs) {

clang/lib/CodeGen/CGStmtOpenMP.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -956,14 +956,19 @@ llvm::Function *CodeGenFunction::GenerateOpenMPCapturedStmtFunction(
956956
// If Xteam found, use it. Otherwise, query again. This is required to make
957957
// sure that the outlined routines have the correct signature.
958958
if (FStmt) {
959-
if (!CGM.isXteamRedKernel(FStmt))
960-
isXteamKernel =
961-
CGM.checkAndSetXteamRedKernel(D) == CodeGenModule::NxSuccess;
962-
else
959+
if (!CGM.isXteamRedKernel(FStmt)) {
960+
CodeGenModule::NoLoopXteamErr NxStatus =
961+
CGM.checkAndSetXteamRedKernel(D);
962+
DEBUG_WITH_TYPE(NO_LOOP_XTEAM_RED,
963+
CGM.emitNxResult("[Xteam-host]", D, NxStatus));
964+
isXteamKernel = (NxStatus == CodeGenModule::NxSuccess);
965+
} else
963966
isXteamKernel = true;
964967
} else {
965-
isXteamKernel =
966-
CGM.checkAndSetXteamRedKernel(D) == CodeGenModule::NxSuccess;
968+
CodeGenModule::NoLoopXteamErr NxStatus = CGM.checkAndSetXteamRedKernel(D);
969+
DEBUG_WITH_TYPE(NO_LOOP_XTEAM_RED,
970+
CGM.emitNxResult("[Xteam-host]", D, NxStatus));
971+
isXteamKernel = (NxStatus == CodeGenModule::NxSuccess);
967972
}
968973
}
969974

@@ -7677,8 +7682,11 @@ static void emitCommonOMPTargetDirective(CodeGenFunction &CGF,
76777682
const Stmt *OptKernelKey = CGM.getOptKernelKey(S);
76787683
if (OptKernelKey)
76797684
FStmt = CGM.getSingleForStmt(OptKernelKey);
7680-
if (FStmt && CGM.getLangOpts().OpenMPOffloadMandatory)
7681-
CGM.checkAndSetXteamRedKernel(S);
7685+
if (FStmt && CGM.getLangOpts().OpenMPOffloadMandatory) {
7686+
CodeGenModule::NoLoopXteamErr NxStatus = CGM.checkAndSetXteamRedKernel(S);
7687+
DEBUG_WITH_TYPE(NO_LOOP_XTEAM_RED,
7688+
CGM.emitNxResult("[Xteam-host]", S, NxStatus));
7689+
}
76827690

76837691
if (CGM.getLangOpts().OpenMPOffloadMandatory && !IsOffloadEntry) {
76847692
unsigned DiagID = CGM.getDiags().getCustomDiagID(

0 commit comments

Comments
 (0)