@@ -682,7 +682,7 @@ void CodeGenFunction::EmitXteamRedCode(const OMPExecutableDirective &D,
682
682
// EmitStmt(CapturedForStmt);
683
683
684
684
// Now emit the calls to xteam_sum, one for each reduction variable
685
- EmitXteamRedSum (CapturedForStmt, *Args, CGM.getXteamRedBlockSize (D));
685
+ EmitXteamRedOperation (CapturedForStmt, *Args, CGM.getXteamRedBlockSize (D));
686
686
}
687
687
688
688
// Xteam codegen done
@@ -706,14 +706,9 @@ void CodeGenFunction::EmitXteamLocalAggregator(const ForStmt *FStmt) {
706
706
RedVarType->isIntegerTy ()) &&
707
707
" Unhandled type" );
708
708
llvm::AllocaInst *XteamRedInst = Builder.CreateAlloca (RedVarType);
709
- llvm::Value *InitVal = nullptr ;
710
- if (RedVarType->isFloatTy () || RedVarType->isDoubleTy () ||
711
- RedVarType->isHalfTy () || RedVarType->isBFloatTy ())
712
- InitVal = llvm::ConstantFP::getZero (RedVarType);
713
- else if (RedVarType->isIntegerTy ())
714
- InitVal = llvm::ConstantInt::get (RedVarType, 0 );
715
- else
716
- llvm_unreachable (" Unhandled type" );
709
+ // The initial value (referred to as the sentinel value) of the local
710
+ // reduction variable depends on the opcode.
711
+ llvm::Value *InitVal = getXteamRedSentinel (RedVarType, Itr->second .Opcode );
717
712
Address XteamRedVarAddr (
718
713
XteamRedInst, RedVarType,
719
714
getContext ().getTypeAlignInChars (RedVarExpr->getType ()));
@@ -726,11 +721,56 @@ void CodeGenFunction::EmitXteamLocalAggregator(const ForStmt *FStmt) {
726
721
}
727
722
}
728
723
729
- // Emit __kmpc_xteam_sum(*xteam_red_local_addr, red_var_addr) for each reduction
730
- // in the helper map for the given For Stmt
731
- void CodeGenFunction::EmitXteamRedSum (const ForStmt *FStmt,
732
- const FunctionArgList &Args,
733
- int BlockSize) {
724
+ llvm::Value *
725
+ CodeGenFunction::getXteamRedSentinel (llvm::Type *RedVarType,
726
+ CodeGenModule::XteamRedOpKind Opcode) {
727
+ assert ((RedVarType->isFloatTy () || RedVarType->isDoubleTy () ||
728
+ RedVarType->isHalfTy () || RedVarType->isBFloatTy () ||
729
+ RedVarType->isIntegerTy ()) &&
730
+ " Unhandled type" );
731
+ assert (Opcode != CodeGenModule::XR_OP_unknown &&
732
+ " Unexpected Xteam reduction opcode" );
733
+ if (RedVarType->isFloatTy () || RedVarType->isDoubleTy () ||
734
+ RedVarType->isHalfTy () || RedVarType->isBFloatTy ()) {
735
+ if (Opcode == CodeGenModule::XR_OP_add)
736
+ return llvm::ConstantFP::getZero (RedVarType);
737
+ else if (Opcode == CodeGenModule::XR_OP_min)
738
+ return llvm::ConstantFP::getInfinity (RedVarType);
739
+ else // max operator
740
+ return llvm::ConstantFP::getInfinity (RedVarType, /* Negative=*/ true );
741
+ } else {
742
+ // Integer type
743
+ if (RedVarType->getPrimitiveSizeInBits () == 16 )
744
+ return llvm::ConstantInt::get (Int16Ty,
745
+ Opcode == CodeGenModule::XR_OP_add ? 0
746
+ : Opcode == CodeGenModule::XR_OP_min
747
+ ? std::numeric_limits<int16_t >::max ()
748
+ : std::numeric_limits<int16_t >::min ());
749
+ else if (RedVarType->getPrimitiveSizeInBits () == 32 )
750
+ return llvm::ConstantInt::get (Int32Ty,
751
+ Opcode == CodeGenModule::XR_OP_add ? 0
752
+ : Opcode == CodeGenModule::XR_OP_min
753
+ ? std::numeric_limits<int32_t >::max ()
754
+ : std::numeric_limits<int32_t >::min ());
755
+ else {
756
+ assert (RedVarType->getPrimitiveSizeInBits () == 64 &&
757
+ " Expected a 64-bit integer" );
758
+ return llvm::ConstantInt::get (Int64Ty,
759
+ Opcode == CodeGenModule::XR_OP_add ? 0
760
+ : Opcode == CodeGenModule::XR_OP_min
761
+ ? std::numeric_limits<int64_t >::max ()
762
+ : std::numeric_limits<int64_t >::min ());
763
+ }
764
+ }
765
+ llvm_unreachable (
766
+ " Unexpected type or opcode in Xteam reduction sentinel generation" );
767
+ }
768
+
769
+ // Emit a call to the DeviceRTL Xteam reduction function for each reduction
770
+ // variable in the helper map for the given For Stmt.
771
+ void CodeGenFunction::EmitXteamRedOperation (const ForStmt *FStmt,
772
+ const FunctionArgList &Args,
773
+ int BlockSize) {
734
774
auto &RT = static_cast <CGOpenMPRuntimeGPU &>(CGM.getOpenMPRuntime ());
735
775
const CodeGenModule::XteamRedVarMap &RedVarMap = CGM.getXteamRedVarMap (FStmt);
736
776
@@ -760,10 +800,12 @@ void CodeGenFunction::EmitXteamRedSum(const ForStmt *FStmt,
760
800
const Expr *OrigRedVarExpr = RVI.RedVarExpr ;
761
801
const DeclRefExpr *DRE = cast<DeclRefExpr>(OrigRedVarExpr);
762
802
Address OrigRedVarAddr = EmitLValue (DRE).getAddress ();
763
- // Pass in OrigRedVarAddr.getPointer to kmpc_xteam_sum
764
- RT.getXteamRedSum (*this , Builder.CreateLoad (RVI.RedVarAddr ),
765
- OrigRedVarAddr.emitRawPointer (*this ), DTeamVals, DTeamsDonePtr,
766
- ThreadStartIdx, NumTeams, BlockSize, IsFast);
803
+ // Note that fast Xteam reduction is available only for sum operator.
804
+ RT.getXteamRedOperation (*this , Builder.CreateLoad (RVI.RedVarAddr ),
805
+ OrigRedVarAddr.emitRawPointer (*this ), DTeamVals,
806
+ DTeamsDonePtr, ThreadStartIdx, NumTeams, BlockSize,
807
+ RVI.Opcode ,
808
+ IsFast && RVI.Opcode == CodeGenModule::XR_OP_add);
767
809
}
768
810
}
769
811
@@ -857,8 +899,6 @@ void CodeGenFunction::EmitXteamScanPhaseTwo(const ForStmt *FStmt,
857
899
}
858
900
}
859
901
860
- // Emit reduction into local aggregator for a statement within the reduced loop
861
- // where applicable
862
902
bool CodeGenFunction::EmitXteamRedStmt (const Stmt *S) {
863
903
if (CGM.getCurrentXteamRedStmt () == nullptr )
864
904
return false ;
@@ -882,6 +922,10 @@ bool CodeGenFunction::EmitXteamRedStmt(const Stmt *S) {
882
922
const CodeGenModule::XteamRedVarMap &RedVarMap =
883
923
CGM.getXteamRedVarMap (CGM.getCurrentXteamRedStmt ());
884
924
925
+ // Currently, there is limited support in Xteam reduction for calls with
926
+ // reduction variables in arguments. Either the call has to be at the
927
+ // statement level or it has to be a call to a builtin function (e.g. min/max)
928
+ // on the rhs of an assignment statement. Handle call at the statement level.
885
929
if (isa<CallExpr>(S)) {
886
930
const CallExpr *CE = cast<CallExpr>(S);
887
931
assert (CE && " Unexpected null call expression" );
@@ -962,45 +1006,131 @@ bool CodeGenFunction::EmitXteamRedStmt(const Stmt *S) {
962
1006
RedRHSExpr = RedBO->getRHS ()->IgnoreImpCasts ();
963
1007
} else {
964
1008
const Expr *L1RhsExpr = RedBO->getRHS ()->IgnoreImpCasts ();
965
- assert (isa<BinaryOperator>(L1RhsExpr) &&
1009
+ assert ((isa<BinaryOperator>(L1RhsExpr) || isa<CallExpr>(L1RhsExpr) ||
1010
+ isa<PseudoObjectExpr>(L1RhsExpr)) &&
966
1011
" Expected rhs to be a binary operator" );
967
- const BinaryOperator *L2BO = cast<BinaryOperator>(L1RhsExpr);
968
- auto OpcL2BO = L2BO->getOpcode ();
969
- assert (OpcL2BO == BO_Add && " Unexpected operator" );
970
- // If the redvar is lhs, use the rhs in the generated reduction statement
971
- // and vice-versa.
972
- if (CGM.isXteamRedVarExpr (L2BO->getLHS ()->IgnoreImpCasts (), RedVarDecl))
973
- RedRHSExpr = L2BO->getRHS ();
974
- else if (CGM.isXteamRedVarExpr (L2BO->getRHS ()->IgnoreImpCasts (),
975
- RedVarDecl))
976
- RedRHSExpr = L2BO->getLHS ();
977
- else
978
- llvm_unreachable (" Unhandled add expression during xteam reduction" );
1012
+ if (isa<BinaryOperator>(L1RhsExpr)) {
1013
+ const BinaryOperator *L2BO = cast<BinaryOperator>(L1RhsExpr);
1014
+ auto OpcL2BO = L2BO->getOpcode ();
1015
+ assert (OpcL2BO == BO_Add && " Unexpected operator" );
1016
+ // If the redvar is lhs, use the rhs in the generated reduction statement
1017
+ // and vice-versa.
1018
+ if (CGM.isXteamRedVarExpr (L2BO->getLHS ()->IgnoreImpCasts (), RedVarDecl))
1019
+ RedRHSExpr = L2BO->getRHS ();
1020
+ else if (CGM.isXteamRedVarExpr (L2BO->getRHS ()->IgnoreImpCasts (),
1021
+ RedVarDecl))
1022
+ RedRHSExpr = L2BO->getLHS ();
1023
+ else
1024
+ llvm_unreachable (" Unhandled add expression during xteam reduction" );
1025
+ } else if (isa<CallExpr>(L1RhsExpr)) {
1026
+ const CallExpr *Call = cast<CallExpr>(L1RhsExpr);
1027
+ assert (CGM.getStatusOptKernelBuiltin (Call) == CodeGenModule::NxSuccess &&
1028
+ " Expected a call to an Xteam supported builtin" );
1029
+ EmitXteamRedStmtForBuiltinCall (Call, RedVarDecl, RedVarMap);
1030
+ return true ;
1031
+ } else {
1032
+ assert (isa<PseudoObjectExpr>(L1RhsExpr) && " Expected a PseudoObjectExpr" );
1033
+ auto [Status, ReturnExpr] = CGM.getStatusXteamSupportedPseudoObject (
1034
+ cast<PseudoObjectExpr>(L1RhsExpr));
1035
+ assert (Status == CodeGenModule::NxSuccess &&
1036
+ " Expected call expression from analysis of PseudoObjectExpr" );
1037
+ const CallExpr *Call = cast<CallExpr>(ReturnExpr);
1038
+ assert (CGM.getStatusOptKernelBuiltin (Call) == CodeGenModule::NxSuccess &&
1039
+ " Expected a call to an Xteam supported builtin" );
1040
+ EmitXteamRedStmtForBuiltinCall (Call, RedVarDecl, RedVarMap);
1041
+ return true ;
1042
+ }
979
1043
}
980
1044
assert (RedRHSExpr != nullptr && " Did not find a valid reduction rhs" );
981
- llvm::Value *RHSValue = EmitScalarExpr (RedRHSExpr);
1045
+
1046
+ EmitLocalReductionStmt (RedRHSExpr, RedVarDecl, RedVarMap,
1047
+ CodeGenModule::XR_OP_add);
1048
+ return true ;
1049
+ }
1050
+
1051
+ void CodeGenFunction::EmitLocalReductionStmt (
1052
+ const Expr *E, const VarDecl *RedVarDecl,
1053
+ const CodeGenModule::XteamRedVarMap &RedVarMap,
1054
+ CodeGenModule::XteamRedOpKind OpKind) {
1055
+ // For add, generate *xteam_local = *xteam_local + rhs_value
1056
+ // For min/max, generate *xteam_local = min/max(*xteam_local, other_operand)
1057
+
1058
+ // First, generate the other operand.
1059
+ llvm::Value *RHSValue = EmitScalarExpr (E);
1060
+ // Now handle the local reduction variable accesses.
982
1061
auto It = RedVarMap.find (RedVarDecl);
983
1062
assert (It != RedVarMap.end () && " Variable must be found in reduction map" );
984
1063
Address XteamRedLocalAddr = It->second .RedVarAddr ;
985
- // Compute *xteam_red_local_addr + rhs_value
986
- llvm::Value *RedRHS = nullptr ;
987
1064
llvm::Type *RedVarType = ConvertTypeForMem (It->second .RedVarExpr ->getType ());
1065
+ llvm::Value *Op1 = Builder.CreateLoad (XteamRedLocalAddr);
1066
+ llvm::Value *RedRHS = nullptr ;
988
1067
if (RedVarType->isFloatTy () || RedVarType->isDoubleTy () ||
989
1068
RedVarType->isHalfTy () || RedVarType->isBFloatTy ()) {
990
- auto RHSOp = RHSValue->getType ()->isIntegerTy ()
991
- ? Builder.CreateSIToFP (RHSValue, RedVarType)
992
- : Builder.CreateFPCast (RHSValue, RedVarType);
993
- RedRHS = Builder.CreateFAdd (Builder.CreateLoad (XteamRedLocalAddr), RHSOp);
1069
+ auto Op2 = RHSValue->getType ()->isIntegerTy ()
1070
+ ? Builder.CreateSIToFP (RHSValue, RedVarType)
1071
+ : Builder.CreateFPCast (RHSValue, RedVarType);
1072
+ if (OpKind == CodeGenModule::XR_OP_add)
1073
+ RedRHS = Builder.CreateFAdd (Op1, Op2);
1074
+ else if (OpKind == CodeGenModule::XR_OP_min)
1075
+ RedRHS =
1076
+ Builder.CreateMinNum (Op1, Op2, /* FMFSource=*/ nullptr , " xteam.min" );
1077
+ else if (OpKind == CodeGenModule::XR_OP_max)
1078
+ RedRHS =
1079
+ Builder.CreateMaxNum (Op1, Op2, /* FMFSource=*/ nullptr , " xteam.max" );
1080
+ else
1081
+ llvm_unreachable (" Unexpected reduction kind" );
994
1082
} else if (RedVarType->isIntegerTy ()) {
995
- auto RHSOp = RHSValue->getType ()->isIntegerTy ()
996
- ? Builder.CreateIntCast (RHSValue, RedVarType, false )
997
- : Builder.CreateFPToSI (RHSValue, RedVarType);
998
- RedRHS = Builder.CreateAdd (Builder.CreateLoad (XteamRedLocalAddr), RHSOp);
1083
+ auto Op2 = RHSValue->getType ()->isIntegerTy ()
1084
+ ? Builder.CreateIntCast (RHSValue, RedVarType, false )
1085
+ : Builder.CreateFPToSI (RHSValue, RedVarType);
1086
+ if (OpKind == CodeGenModule::XR_OP_add)
1087
+ RedRHS = Builder.CreateAdd (Op1, Op2);
1088
+ else if (OpKind == CodeGenModule::XR_OP_min)
1089
+ // TODO Fix when unsigned
1090
+ RedRHS = Builder.CreateBinaryIntrinsic (llvm::Intrinsic::smin, Op1, Op2,
1091
+ nullptr , " xteam.min" );
1092
+ else if (OpKind == CodeGenModule::XR_OP_max)
1093
+ // TODO fix when unsigned
1094
+ RedRHS = Builder.CreateBinaryIntrinsic (llvm::Intrinsic::smax, Op1, Op2,
1095
+ nullptr , " xteam.max" );
1096
+ else
1097
+ llvm_unreachable (" Unexpected reduction kind" );
999
1098
} else
1000
1099
llvm_unreachable (" Unhandled type" );
1001
- // *xteam_red_local_addr = *xteam_red_local_addr + rhs_value
1100
+ assert (RedRHS && " Right hand side of statement cannot be null " );
1002
1101
Builder.CreateStore (RedRHS, XteamRedLocalAddr);
1003
- return true ;
1102
+ }
1103
+
1104
+ std::pair<const Expr *, CodeGenModule::XteamRedOpKind>
1105
+ CodeGenFunction::ExtractXteamRedRhsExpr (const CallExpr *Call,
1106
+ const VarDecl *RedVarDecl) {
1107
+ // Traverse arguments, identifying and ignoring the reduction variable, and
1108
+ // then extracting the other argument.
1109
+ CodeGenModule::XteamRedOpKind Opcode;
1110
+ std::string CallName = Call->getDirectCallee ()->getNameInfo ().getAsString ();
1111
+ if (CGM.isOptKernelAMDGCNMax (CallName))
1112
+ Opcode = CodeGenModule::XR_OP_max;
1113
+ else if (CGM.isOptKernelAMDGCNMin (CallName))
1114
+ Opcode = CodeGenModule::XR_OP_min;
1115
+ else
1116
+ llvm_unreachable (" Epecting either min or max" );
1117
+
1118
+ for (unsigned ArgIndex = 0 ; ArgIndex < Call->getNumArgs (); ++ArgIndex) {
1119
+ const Expr *Arg = Call->getArg (ArgIndex);
1120
+ while (isa<ImplicitCastExpr>(Arg))
1121
+ Arg = cast<ImplicitCastExpr>(Arg)->getSubExpr ();
1122
+ if (CGM.isXteamRedVarExpr (Arg, RedVarDecl))
1123
+ continue ;
1124
+ return std::make_pair (Call->getArg (ArgIndex), Opcode);
1125
+ }
1126
+ llvm_unreachable (" Could not extract expected arg of min/max" );
1127
+ }
1128
+
1129
+ void CodeGenFunction::EmitXteamRedStmtForBuiltinCall (
1130
+ const CallExpr *Call, const VarDecl *RedVarDecl,
1131
+ const CodeGenModule::XteamRedVarMap &RedVarMap) {
1132
+ auto [RhsExpr, Opcode] = ExtractXteamRedRhsExpr (Call, RedVarDecl);
1133
+ EmitLocalReductionStmt (RhsExpr, RedVarDecl, RedVarMap, Opcode);
1004
1134
}
1005
1135
1006
1136
void CodeGenFunction::EmitStmt (const Stmt *S, ArrayRef<const Attr *> Attrs) {
0 commit comments