@@ -619,7 +619,8 @@ void DpctFileInfo::buildReplacements() {
619619 // found, postfix "_ct" is added to this __constant__ symbol's name.
620620 std::unordered_map<unsigned int , std::string> ReplUpdated;
621621 for (const auto &Entry : MemVarMap) {
622- if (Entry.second ->isIgnore () || !Entry.second ->isConstant ())
622+ if (Entry.second ->isIgnore () || !Entry.second ->isConstant () ||
623+ Entry.second ->isUseDeviceGlobal ())
623624 continue ;
624625
625626 auto Name = Entry.second ->getName ();
@@ -1122,6 +1123,20 @@ DpctGlobalInfo::MacroDefRecord::MacroDefRecord(SourceLocation NTL, bool IIAS)
11221123 FilePath = LocInfo.first ;
11231124 Offset = LocInfo.second ;
11241125}
1126+
1127+ DpctGlobalInfo::MacroArgRecord::MacroArgRecord (const MacroInfo *MI,
1128+ int ArgIndex)
1129+ : ArgIndex(ArgIndex) {
1130+ ArgName = MI->params ()[ArgIndex]->getName ().str ();
1131+ for (auto Tok : MI->tokens ()) {
1132+ auto II = Tok.getIdentifierInfo ();
1133+ if (II && (II == MI->params ()[ArgIndex])) {
1134+ ArgLoc = Tok.getLocation ();
1135+ break ;
1136+ }
1137+ }
1138+ }
1139+
11251140DpctGlobalInfo::MacroExpansionRecord::MacroExpansionRecord (
11261141 IdentifierInfo *ID, const MacroInfo *MI, SourceRange Range,
11271142 bool IsInAnalysisScope, int TokenIndex) {
@@ -2399,6 +2414,8 @@ bool DpctGlobalInfo::CheckUnicodeSecurityFlag = false;
23992414bool DpctGlobalInfo::EnablepProfilingFlag = false ;
24002415std::map<std::string, std::shared_ptr<DpctGlobalInfo::MacroExpansionRecord>>
24012416 DpctGlobalInfo::ExpansionRangeToMacroRecord;
2417+ std::unordered_map<std::string, std::shared_ptr<DpctGlobalInfo::MacroArgRecord>>
2418+ DpctGlobalInfo::MacroArgRecordMap;
24022419std::map<std::string, SourceLocation> DpctGlobalInfo::EndifLocationOfIfdef;
24032420std::vector<std::pair<clang::tooling::UnifiedPath, size_t >>
24042421 DpctGlobalInfo::ConditionalCompilationLoc;
@@ -2831,6 +2848,163 @@ std::shared_ptr<MemVarInfo> MemVarInfo::buildMemVarInfo(const VarDecl *Var) {
28312848 }
28322849 return DpctGlobalInfo::getInstance ().insertMemVarInfo (Var);
28332850}
2851+
2852+ // This function, `migrateToDeviceGlobal`, migrates a CUDA `__device__` or
2853+ // `__constant__` variable declaration to the SYCL device global equivalent. The
2854+ // migration process involves four key steps. The function handles various
2855+ // transformations as follows:
2856+ //
2857+ // 1. Remove any array brackets following the variable name.
2858+ // - It identifies the array type using TypeLoc and removes the brackets
2859+ // while preserving the dimensions for later use.
2860+ // - If the array size comes from a macro argument, it maps the macro
2861+ // argument correctly using the `MacroArgRecord`.
2862+ // 2. Process the initialization expression.
2863+ // - If the initialization style is C-style (with an equal sign), it remove
2864+ // the equal sign and adds braces around scalar initializers to use
2865+ // initializer list in SYCL.
2866+ // 3. Replace the variable type.
2867+ // - Replace the origin type with
2868+ // `sycl::ext::oneapi::experimental::device_global`
2869+ // and the correct base type and dimensions.
2870+ // - It manages macro arguments to correctly replace the base type when
2871+ // required.
2872+ // 4. Insert the `static` specifier if the variable is declared globally and
2873+ // does not already have the `static` storage class.
2874+ //
2875+ // Example1 (Specifier __device__ will be removed in preprocessor callbacks):
2876+ // Origin code:
2877+ // __device__ int var_a[3] = {1, 2, 3};
2878+ //
2879+ // As follow list the result after each key step listed in previous:
2880+ // 1. int var_a = {1, 2, 3};
2881+ // 2. int var_a {1, 2, 3};
2882+ // 3. sycl::ext::oneapi::experimental::device_global<int[3]> var_a {1, 2, 3};
2883+ // 4. static sycl::ext::oneapi::experimental::device_global<int[3]> var_a {1, 2,
2884+ // 3};
2885+ //
2886+ // Example2 (Specifier __device__ will be removed in preprocessor callbacks):
2887+ // Origin code:
2888+ // #define VAR(type, name, size) static __device__ type name[size];
2889+ // VAR(int, a, 3)
2890+ //
2891+ // As follow list the result after each key step listed in previous:
2892+ // 1. #define VAR(type, name, size) static type name;
2893+ // VAR(int, a, 3)
2894+ // 2. #define VAR(type, name, init) static type name;
2895+ // VAR(int, a, 3)
2896+ // 3. #define VAR(type, name, init) static
2897+ // sycl::ext::oneapi::experimental::device_global<type[size]> name;
2898+ // VAR(int, a, 3)
2899+ // 4. #define VAR(type, name, init) static
2900+ // sycl::ext::oneapi::experimental::device_global<type[size]> name;
2901+ // VAR(int, a, 3)
2902+ void MemVarInfo::migrateToDeviceGlobal (const VarDecl *MemVar) {
2903+ auto &SM = DpctGlobalInfo::getSourceManager ();
2904+ auto &Ctx = DpctGlobalInfo::getContext ();
2905+ auto &MacroArgMap = DpctGlobalInfo::getMacroArgRecordMap ();
2906+ auto TSI = MemVar->getTypeSourceInfo ();
2907+ auto OriginTL = TSI->getTypeLoc ();
2908+ auto TL = OriginTL;
2909+ auto BegLoc = MemVar->getBeginLoc ();
2910+ if (BegLoc.isMacroID ()) {
2911+ BegLoc = SM.getExpansionLoc (BegLoc);
2912+ }
2913+ auto LocInfo = DpctGlobalInfo::getLocInfo (BegLoc);
2914+ std::string Dims;
2915+ bool IsArray = OriginTL.getType ()->isArrayType ();
2916+ // Step 1
2917+ while (auto ATL = TL.getAs <clang::ArrayTypeLoc>()) {
2918+ auto BRange = ATL.getBracketsRange ();
2919+ BRange = getDefinitionRange (BRange.getBegin (), BRange.getEnd ());
2920+ auto RT =
2921+ ReplaceText (SM.getSpellingLoc (BRange.getBegin ()),
2922+ SM.getSpellingLoc (BRange.getEnd ()).getLocWithOffset (1 ), " " );
2923+ DpctGlobalInfo::getInstance ().addReplacement (RT.getReplacement (Ctx));
2924+ Dims += " [" ;
2925+ std::string SizeStr;
2926+ if (clang::Expr *SE = ATL.getSizeExpr ()) {
2927+ auto SizeLoc = SE->getBeginLoc ();
2928+ if (SM.isMacroArgExpansion (SizeLoc)) {
2929+ auto Iter =
2930+ MacroArgMap.find (DpctGlobalInfo::getInstance ()
2931+ .getMainFile ()
2932+ ->getFilePath ()
2933+ .getPath ()
2934+ .str () +
2935+ getCombinedStrFromLoc (SM.getSpellingLoc (SizeLoc)));
2936+ if (Iter != MacroArgMap.end ()) {
2937+ SizeStr = Iter->second ->ArgName ;
2938+ }
2939+ }
2940+ if (SizeStr.empty ()) {
2941+ SizeStr = ExprAnalysis::ref (SE);
2942+ }
2943+ }
2944+ Dims += SizeStr + " ]" ;
2945+ TL = ATL.getElementLoc ();
2946+ }
2947+ // Step 2
2948+ if (MemVar->hasInit ()) {
2949+ if ((MemVar->getInitStyle () == VarDecl::InitializationStyle::CInit)) {
2950+ DiagnosticsUtils::report (LocInfo.first , LocInfo.second ,
2951+ Diagnostics::DEVICE_GLOBAL_INIT, true , false );
2952+ if (!dyn_cast<InitListExpr>(
2953+ MemVar->getInit ()->IgnoreImplicitAsWritten ())) {
2954+ auto IBS = InsertBeforeStmt (MemVar->getInit (), " {" );
2955+ auto IAS = InsertAfterStmt (MemVar->getInit (), " }" );
2956+ DpctGlobalInfo::getInstance ().addReplacement (IBS.getReplacement (Ctx));
2957+ DpctGlobalInfo::getInstance ().addReplacement (IAS.getReplacement (Ctx));
2958+ }
2959+ auto NextTok = Lexer::findNextToken (
2960+ IsArray ? SM.getSpellingLoc (OriginTL.getEndLoc ())
2961+ : SM.getSpellingLoc (MemVar->getLocation ()),
2962+ SM, DpctGlobalInfo::getContext ().getLangOpts ());
2963+ if (NextTok.has_value () && NextTok.value ().is (tok::equal)) {
2964+ auto RTok = ReplaceToken (NextTok.value ().getLocation (), " " );
2965+ DpctGlobalInfo::getInstance ().addReplacement (RTok.getReplacement (Ctx));
2966+ }
2967+ }
2968+ }
2969+ // Step 3
2970+ std::string BaseTypeStr;
2971+ SourceLocation TypeReplLoc;
2972+ size_t TypeReplLen = 0 ;
2973+ if (SM.isMacroArgExpansion (OriginTL.getBeginLoc ())) {
2974+ auto Iter = MacroArgMap.find (
2975+ DpctGlobalInfo::getInstance ()
2976+ .getMainFile ()
2977+ ->getFilePath ()
2978+ .getPath ()
2979+ .str () +
2980+ getCombinedStrFromLoc (SM.getSpellingLoc (OriginTL.getBeginLoc ())));
2981+ if (Iter != MacroArgMap.end ()) {
2982+ BaseTypeStr = Iter->second ->ArgName ;
2983+ TypeReplLoc = Iter->second ->ArgLoc ;
2984+ TypeReplLen = BaseTypeStr.size ();
2985+ }
2986+ }
2987+ if (BaseTypeStr.empty ()) {
2988+ BaseTypeStr = getType ()->getBaseNameWithoutQualifiers ();
2989+ TypeReplLoc = TL.getBeginLoc ();
2990+ TypeReplLen =
2991+ SM.getFileOffset (TL.getEndLoc ()) - SM.getFileOffset (TL.getBeginLoc ()) +
2992+ Lexer::MeasureTokenLength (TL.getEndLoc (), SM,
2993+ DpctGlobalInfo::getContext ().getLangOpts ());
2994+ }
2995+ std::string TypeStr = MapNames::getClNamespace () +
2996+ " ext::oneapi::experimental::device_global<" +
2997+ BaseTypeStr + Dims + " >" ;
2998+ auto RT = ReplaceText (TypeReplLoc, TypeReplLen, std::move (TypeStr));
2999+ DpctGlobalInfo::getInstance ().addReplacement (RT.getReplacement (Ctx));
3000+ // Step 4
3001+ if (MemVar->getStorageClass () != SC_Static && getScope () == Global) {
3002+ DpctGlobalInfo::getInstance ().addReplacement (
3003+ std::make_shared<ExtReplacement>(LocInfo.first , LocInfo.second , 0 ,
3004+ " static " , nullptr ));
3005+ }
3006+ }
3007+
28343008MemVarInfo::VarAttrKind MemVarInfo::getAddressAttr (const VarDecl *VD) {
28353009 if (VD->hasAttrs ())
28363010 return getAddressAttr (VD->getAttrs ());
0 commit comments