@@ -63,8 +63,7 @@ const std::string &getDefaultString(HelperFuncType HFT) {
6363 const static std::string DefaultQueue =
6464 DpctGlobalInfo::useNoQueueDevice ()
6565 ? DpctGlobalInfo::getGlobalQueueName ()
66- : buildString (MapNames::getDpctNamespace () + " get_" +
67- DpctGlobalInfo::getDeviceQueueName () + " ()" );
66+ : DpctGlobalInfo::getDefaultQueueFreeFuncCall ();
6867 return DefaultQueue;
6968 }
7069 case clang::dpct::HelperFuncType::HFT_DefaultQueuePtr: {
@@ -74,8 +73,8 @@ const std::string &getDefaultString(HelperFuncType HFT) {
7473 : (DpctGlobalInfo::useSYCLCompat ()
7574 ? buildString (MapNames::getDpctNamespace () +
7675 " get_current_device().default_queue()" )
77- : buildString (" & " + MapNames::getDpctNamespace () + " get_ " +
78- DpctGlobalInfo::getDeviceQueueName () + " () " ));
76+ : buildString (
77+ " & " , DpctGlobalInfo::getDefaultQueueFreeFuncCall () ));
7978 return DefaultQueue;
8079 }
8180 case clang::dpct::HelperFuncType::HFT_CurrentDevice: {
@@ -269,6 +268,15 @@ void processTypeLoc(const TypeLoc &TL, ExprAnalysis &EA,
269268 }
270269 EA.applyAllSubExprRepl ();
271270}
271+ HelperFuncCatalog getQueueKind () {
272+ if (DpctGlobalInfo::useSYCLCompat ()) {
273+ return HelperFuncCatalog::GetDefaultQueue;
274+ }
275+ if (DpctGlobalInfo::getUsmLevel () == UsmLevel::UL_Restricted) {
276+ return HelperFuncCatalog::GetInOrderQueue;
277+ }
278+ return HelperFuncCatalog::GetOutOfOrderQueue;
279+ }
272280
273281// /// class FreeQueriesInfo /////
274282class FreeQueriesInfo {
@@ -930,6 +938,11 @@ void DpctFileInfo::insertHeader(HeaderType Type, unsigned Offset,
930938 << CCLVerValue << getNL ();
931939 insertHeader (MigratedMacroDefinitionOS.str (), FileBeginOffset,
932940 InsertPosition::IP_AlwaysLeft);
941+ for (const auto &File :
942+ DpctGlobalInfo::getCustomHelperFunctionAddtionalIncludes ()) {
943+ insertHeader (" #include \" " + File + +" \" " + getNL (), FirstIncludeOffset,
944+ InsertPosition::IP_Right);
945+ }
933946 return ;
934947
935948 // Because <dpct/dpl_utils.hpp> includes <oneapi/dpl/execution> and
@@ -1225,15 +1238,26 @@ std::string DpctGlobalInfo::getDefaultQueue(const Stmt *S) {
12251238
12261239 return buildString (RegexPrefix, ' Q' , Idx, RegexSuffix);
12271240}
1228- const std::string &DpctGlobalInfo::getDeviceQueueName () {
1229- static const std::string DeviceQueue = [&]() {
1241+ const std::string &DpctGlobalInfo::getDefaultQueueFreeFuncCall () {
1242+ static const std::string DefaultQueueFreeFuncCall = [&]() {
1243+ if (auto Iter = MapNames::CustomHelperFunctionMap.find (getQueueKind ());
1244+ Iter != MapNames::CustomHelperFunctionMap.end ()) {
1245+ return Iter->second ;
1246+ }
1247+ return MapNames::getDpctNamespace () + " get_" +
1248+ getDefaultQueueMemFuncName () + " ()" ;
1249+ }();
1250+ return DefaultQueueFreeFuncCall;
1251+ }
1252+ const std::string &DpctGlobalInfo::getDefaultQueueMemFuncName () {
1253+ static const std::string DefaultQueueMemFuncName = [&]() {
12301254 if (DpctGlobalInfo::useSYCLCompat ())
12311255 return " default_queue" ;
12321256 if (DpctGlobalInfo::getUsmLevel () == UsmLevel::UL_None)
12331257 return " out_of_order_queue" ;
12341258 return " in_order_queue" ;
12351259 }();
1236- return DeviceQueue ;
1260+ return DefaultQueueMemFuncName ;
12371261}
12381262void DpctGlobalInfo::setContext (ASTContext &C) {
12391263 Context = &C;
@@ -1588,7 +1612,8 @@ void DpctGlobalInfo::buildReplacements() {
15881612 QDecl << " &q_ct1 = " ;
15891613 if (DpctGlobalInfo::useSYCLCompat ())
15901614 QDecl << ' *' ;
1591- QDecl << " dev_ct1." << DpctGlobalInfo::getDeviceQueueName () << " ();" ;
1615+ QDecl << " dev_ct1." << DpctGlobalInfo::getDefaultQueueMemFuncName ()
1616+ << " ();" ;
15921617 } else {
15931618 DevDecl << MapNames::getClNamespace () + " device dev_ct1;" ;
15941619 // Now the UsmLevel must not be UL_None here.
@@ -2237,6 +2262,7 @@ void DpctGlobalInfo::resetInfo() {
22372262 SpellingLocToDFIsMapForAssumeNDRange.clear ();
22382263 DFIToSpellingLocsMapForAssumeNDRange.clear ();
22392264 FreeQueriesInfo::reset ();
2265+ CustomHelperFunctionAddtionalIncludes.clear ();
22402266}
22412267void DpctGlobalInfo::updateSpellingLocDFIMaps (
22422268 SourceLocation SL, std::shared_ptr<DeviceFunctionInfo> DFI) {
@@ -2454,6 +2480,8 @@ std::vector<std::pair<std::string, std::vector<std::string>>>
24542480std::vector<std::pair<std::string, std::vector<std::string>>>
24552481 DpctGlobalInfo::CodePinDumpFuncDepsVec;
24562482std::unordered_set<std::string> DpctGlobalInfo::NeedParenAPISet = {};
2483+ std::unordered_set<std::string>
2484+ DpctGlobalInfo::CustomHelperFunctionAddtionalIncludes = {};
24572485// /// class DpctNameGenerator /////
24582486void DpctNameGenerator::printName (const FunctionDecl *FD,
24592487 llvm::raw_ostream &OS) {
@@ -6072,6 +6100,7 @@ void KernelCallExpr::removeExtraIndent() {
60726100 getFilePath (), getOffset () - LocInfo.Indent .length (),
60736101 LocInfo.Indent .length (), " " , nullptr ));
60746102}
6103+
60756104void KernelCallExpr::addDevCapCheckStmt () {
60766105 llvm::SmallVector<std::string> AspectList;
60776106 if (getVarMap ().hasBF64 ()) {
@@ -6081,17 +6110,28 @@ void KernelCallExpr::addDevCapCheckStmt() {
60816110 AspectList.push_back (MapNames::getClNamespace () + " aspect::fp16" );
60826111 }
60836112 if (!AspectList.empty ()) {
6084- requestFeature (HelperFeatureEnum::device_ext);
60856113 std::string Str;
60866114 llvm::raw_string_ostream OS (Str);
6087- OS << MapNames::getDpctNamespace () << " get_device(" ;
6088- OS << MapNames::getDpctNamespace () << " get_device_id(" ;
6089- printStreamBase (OS);
6090- OS << " get_device())).has_capability_or_fail({" << AspectList.front ();
6091- for (size_t i = 1 ; i < AspectList.size (); ++i) {
6092- OS << " , " << AspectList[i];
6093- }
6094- OS << " });" ;
6115+ if (auto Iter = MapNames::CustomHelperFunctionMap.find (getQueueKind ());
6116+ Iter != MapNames::CustomHelperFunctionMap.end ()) {
6117+ OS << MapNames::getDpctNamespace () << " has_capability_or_fail(" ;
6118+ OS << Iter->second << " .get_device(), " ;
6119+ OS << " {" << AspectList.front ();
6120+ for (size_t i = 1 ; i < AspectList.size (); ++i) {
6121+ OS << " , " << AspectList[i];
6122+ }
6123+ OS << " });" ;
6124+ } else {
6125+ requestFeature (HelperFeatureEnum::device_ext);
6126+ OS << MapNames::getDpctNamespace () << " get_device(" ;
6127+ OS << MapNames::getDpctNamespace () << " get_device_id(" ;
6128+ printStreamBase (OS);
6129+ OS << " get_device())).has_capability_or_fail({" << AspectList.front ();
6130+ for (size_t i = 1 ; i < AspectList.size (); ++i) {
6131+ OS << " , " << AspectList[i];
6132+ }
6133+ OS << " });" ;
6134+ }
60956135 OuterStmts.OthersList .emplace_back (OS.str ());
60966136 }
60976137}
@@ -6141,8 +6181,7 @@ void KernelCallExpr::addStreamDecl() {
61416181 buildString (MapNames::getClNamespace () + " stream " ,
61426182 DpctGlobalInfo::getStreamName (), " (64 * 1024, 80, cgh);" ));
61436183 if (getVarMap ().hasSync ()) {
6144- auto DefaultQueue = buildString (MapNames::getDpctNamespace (), " get_" ,
6145- DpctGlobalInfo::getDeviceQueueName (), " ()" );
6184+ auto DefaultQueue = DpctGlobalInfo::getDefaultQueueFreeFuncCall ();
61466185 if (DpctGlobalInfo::getUsmLevel () == UsmLevel::UL_None) {
61476186 OuterStmts.OthersList .emplace_back (
61486187 buildString (MapNames::getDpctNamespace (), " global_memory<" ,
0 commit comments