Skip to content

Commit 2f699e2

Browse files
authored
[SYCLomatic] Fix the device memory migration issue in the complex macro scenario (#2439)
Signed-off-by: intwanghao <[email protected]>
1 parent 3eda52d commit 2f699e2

File tree

6 files changed

+234
-37
lines changed

6 files changed

+234
-37
lines changed

clang/lib/DPCT/ASTTraversal.cpp

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -364,11 +364,25 @@ void IncludesCallbacks::MacroExpands(const Token &MacroNameTok,
364364
for (i = 0; i < MI->getNumTokens(); i++) {
365365
std::shared_ptr<dpct::DpctGlobalInfo::MacroExpansionRecord> R =
366366
std::make_shared<dpct::DpctGlobalInfo::MacroExpansionRecord>(
367-
MacroNameTok.getIdentifierInfo(), MI, Range, IsInAnalysisScope, i);
367+
MacroNameTok.getIdentifierInfo(), MI, Range, IsInAnalysisScope,
368+
i);
368369
dpct::DpctGlobalInfo::getExpansionRangeToMacroRecord()
369370
[getCombinedStrFromLoc(MI->getReplacementToken(i).getLocation())] =
370371
R;
371372
}
373+
if (Args && IsInAnalysisScope) {
374+
for (unsigned int i = 0; i < Args->getNumMacroArguments(); ++i) {
375+
std::shared_ptr<dpct::DpctGlobalInfo::MacroArgRecord> R =
376+
std::make_shared<dpct::DpctGlobalInfo::MacroArgRecord>(MI, i);
377+
auto str =
378+
getCombinedStrFromLoc(Args->getUnexpArgument(i)->getLocation());
379+
auto &Global = DpctGlobalInfo::getInstance();
380+
dpct::DpctGlobalInfo::getMacroArgRecordMap()
381+
[Global.getMainFile()->getFilePath().getPath().str() +
382+
getCombinedStrFromLoc(
383+
Args->getUnexpArgument(i)->getLocation())] = R;
384+
}
385+
}
372386
std::shared_ptr<dpct::DpctGlobalInfo::MacroExpansionRecord> R =
373387
std::make_shared<dpct::DpctGlobalInfo::MacroExpansionRecord>(
374388
MacroNameTok.getIdentifierInfo(), MI, Range, IsInAnalysisScope,
@@ -8634,13 +8648,6 @@ void MemVarRefMigrationRule::runRule(const MatchFinder::MatchResult &Result) {
86348648
auto Info = Global.findMemVarInfo(Decl);
86358649

86368650
if (Info && Info->isUseDeviceGlobal()) {
8637-
if (Decl->hasInit()) {
8638-
auto InitStr = getInitForDeviceGlobal(Decl);
8639-
if (!InitStr.empty()) {
8640-
report(Decl->getBeginLoc(), Diagnostics::DEVICE_GLOBAL_INIT, false);
8641-
Info->setInitForDeviceGlobal(InitStr);
8642-
}
8643-
}
86448651
auto VarType = Info->getType();
86458652
if (VarType->isArray()) {
86468653
if (const auto *const ICE =
@@ -8786,13 +8793,8 @@ void ConstantMemVarMigrationRule::runRule(
87868793
if (!Info)
87878794
return;
87888795
if (Info->isUseDeviceGlobal()) {
8789-
if (MemVar->hasInit()) {
8790-
auto InitStr = getInitForDeviceGlobal(MemVar);
8791-
if (!InitStr.empty()) {
8792-
report(MemVar->getBeginLoc(), Diagnostics::DEVICE_GLOBAL_INIT, false);
8793-
Info->setInitForDeviceGlobal(InitStr);
8794-
}
8795-
}
8796+
Info->migrateToDeviceGlobal(MemVar);
8797+
return;
87968798
}
87978799

87988800
Info->setIgnoreFlag(true);
@@ -9242,13 +9244,8 @@ void MemVarMigrationRule::runRule(
92429244
if (!Info)
92439245
return;
92449246
if (Info->isUseDeviceGlobal()) {
9245-
if (MemVar->hasInit()) {
9246-
auto InitStr = getInitForDeviceGlobal(MemVar);
9247-
if (!InitStr.empty()) {
9248-
report(MemVar->getBeginLoc(), Diagnostics::DEVICE_GLOBAL_INIT, false);
9249-
Info->setInitForDeviceGlobal(InitStr);
9250-
}
9251-
}
9247+
Info->migrateToDeviceGlobal(MemVar);
9248+
return;
92529249
}
92539250

92549251
if (auto VTD = DpctGlobalInfo::findParent<VarTemplateDecl>(MemVar)) {

clang/lib/DPCT/AnalysisInfo.cpp

Lines changed: 175 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,8 @@ void DpctFileInfo::buildReplacements() {
619619
// found, postfix "_ct" is added to this __constant__ symbol's name.
620620
std::unordered_map<unsigned int, std::string> ReplUpdated;
621621
for (const auto &Entry : MemVarMap) {
622-
if (Entry.second->isIgnore() || !Entry.second->isConstant())
622+
if (Entry.second->isIgnore() || !Entry.second->isConstant() ||
623+
Entry.second->isUseDeviceGlobal())
623624
continue;
624625

625626
auto Name = Entry.second->getName();
@@ -1122,6 +1123,20 @@ DpctGlobalInfo::MacroDefRecord::MacroDefRecord(SourceLocation NTL, bool IIAS)
11221123
FilePath = LocInfo.first;
11231124
Offset = LocInfo.second;
11241125
}
1126+
1127+
DpctGlobalInfo::MacroArgRecord::MacroArgRecord(const MacroInfo *MI,
1128+
int ArgIndex)
1129+
: ArgIndex(ArgIndex) {
1130+
ArgName = MI->params()[ArgIndex]->getName().str();
1131+
for (auto Tok : MI->tokens()) {
1132+
auto II = Tok.getIdentifierInfo();
1133+
if (II && (II == MI->params()[ArgIndex])) {
1134+
ArgLoc = Tok.getLocation();
1135+
break;
1136+
}
1137+
}
1138+
}
1139+
11251140
DpctGlobalInfo::MacroExpansionRecord::MacroExpansionRecord(
11261141
IdentifierInfo *ID, const MacroInfo *MI, SourceRange Range,
11271142
bool IsInAnalysisScope, int TokenIndex) {
@@ -2399,6 +2414,8 @@ bool DpctGlobalInfo::CheckUnicodeSecurityFlag = false;
23992414
bool DpctGlobalInfo::EnablepProfilingFlag = false;
24002415
std::map<std::string, std::shared_ptr<DpctGlobalInfo::MacroExpansionRecord>>
24012416
DpctGlobalInfo::ExpansionRangeToMacroRecord;
2417+
std::unordered_map<std::string, std::shared_ptr<DpctGlobalInfo::MacroArgRecord>>
2418+
DpctGlobalInfo::MacroArgRecordMap;
24022419
std::map<std::string, SourceLocation> DpctGlobalInfo::EndifLocationOfIfdef;
24032420
std::vector<std::pair<clang::tooling::UnifiedPath, size_t>>
24042421
DpctGlobalInfo::ConditionalCompilationLoc;
@@ -2831,6 +2848,163 @@ std::shared_ptr<MemVarInfo> MemVarInfo::buildMemVarInfo(const VarDecl *Var) {
28312848
}
28322849
return DpctGlobalInfo::getInstance().insertMemVarInfo(Var);
28332850
}
2851+
2852+
// This function, `migrateToDeviceGlobal`, migrates a CUDA `__device__` or
2853+
// `__constant__` variable declaration to the SYCL device global equivalent. The
2854+
// migration process involves four key steps. The function handles various
2855+
// transformations as follows:
2856+
//
2857+
// 1. Remove any array brackets following the variable name.
2858+
// - It identifies the array type using TypeLoc and removes the brackets
2859+
// while preserving the dimensions for later use.
2860+
// - If the array size comes from a macro argument, it maps the macro
2861+
// argument correctly using the `MacroArgRecord`.
2862+
// 2. Process the initialization expression.
2863+
// - If the initialization style is C-style (with an equal sign), it remove
2864+
// the equal sign and adds braces around scalar initializers to use
2865+
// initializer list in SYCL.
2866+
// 3. Replace the variable type.
2867+
// - Replace the origin type with
2868+
// `sycl::ext::oneapi::experimental::device_global`
2869+
// and the correct base type and dimensions.
2870+
// - It manages macro arguments to correctly replace the base type when
2871+
// required.
2872+
// 4. Insert the `static` specifier if the variable is declared globally and
2873+
// does not already have the `static` storage class.
2874+
//
2875+
// Example1 (Specifier __device__ will be removed in preprocessor callbacks):
2876+
// Origin code:
2877+
// __device__ int var_a[3] = {1, 2, 3};
2878+
//
2879+
// As follow list the result after each key step listed in previous:
2880+
// 1. int var_a = {1, 2, 3};
2881+
// 2. int var_a {1, 2, 3};
2882+
// 3. sycl::ext::oneapi::experimental::device_global<int[3]> var_a {1, 2, 3};
2883+
// 4. static sycl::ext::oneapi::experimental::device_global<int[3]> var_a {1, 2,
2884+
// 3};
2885+
//
2886+
// Example2 (Specifier __device__ will be removed in preprocessor callbacks):
2887+
// Origin code:
2888+
// #define VAR(type, name, size) static __device__ type name[size];
2889+
// VAR(int, a, 3)
2890+
//
2891+
// As follow list the result after each key step listed in previous:
2892+
// 1. #define VAR(type, name, size) static type name;
2893+
// VAR(int, a, 3)
2894+
// 2. #define VAR(type, name, init) static type name;
2895+
// VAR(int, a, 3)
2896+
// 3. #define VAR(type, name, init) static
2897+
// sycl::ext::oneapi::experimental::device_global<type[size]> name;
2898+
// VAR(int, a, 3)
2899+
// 4. #define VAR(type, name, init) static
2900+
// sycl::ext::oneapi::experimental::device_global<type[size]> name;
2901+
// VAR(int, a, 3)
2902+
void MemVarInfo::migrateToDeviceGlobal(const VarDecl *MemVar) {
2903+
auto &SM = DpctGlobalInfo::getSourceManager();
2904+
auto &Ctx = DpctGlobalInfo::getContext();
2905+
auto &MacroArgMap = DpctGlobalInfo::getMacroArgRecordMap();
2906+
auto TSI = MemVar->getTypeSourceInfo();
2907+
auto OriginTL = TSI->getTypeLoc();
2908+
auto TL = OriginTL;
2909+
auto BegLoc = MemVar->getBeginLoc();
2910+
if (BegLoc.isMacroID()) {
2911+
BegLoc = SM.getExpansionLoc(BegLoc);
2912+
}
2913+
auto LocInfo = DpctGlobalInfo::getLocInfo(BegLoc);
2914+
std::string Dims;
2915+
bool IsArray = OriginTL.getType()->isArrayType();
2916+
// Step 1
2917+
while (auto ATL = TL.getAs<clang::ArrayTypeLoc>()) {
2918+
auto BRange = ATL.getBracketsRange();
2919+
BRange = getDefinitionRange(BRange.getBegin(), BRange.getEnd());
2920+
auto RT =
2921+
ReplaceText(SM.getSpellingLoc(BRange.getBegin()),
2922+
SM.getSpellingLoc(BRange.getEnd()).getLocWithOffset(1), "");
2923+
DpctGlobalInfo::getInstance().addReplacement(RT.getReplacement(Ctx));
2924+
Dims += "[";
2925+
std::string SizeStr;
2926+
if (clang::Expr *SE = ATL.getSizeExpr()) {
2927+
auto SizeLoc = SE->getBeginLoc();
2928+
if (SM.isMacroArgExpansion(SizeLoc)) {
2929+
auto Iter =
2930+
MacroArgMap.find(DpctGlobalInfo::getInstance()
2931+
.getMainFile()
2932+
->getFilePath()
2933+
.getPath()
2934+
.str() +
2935+
getCombinedStrFromLoc(SM.getSpellingLoc(SizeLoc)));
2936+
if (Iter != MacroArgMap.end()) {
2937+
SizeStr = Iter->second->ArgName;
2938+
}
2939+
}
2940+
if (SizeStr.empty()) {
2941+
SizeStr = ExprAnalysis::ref(SE);
2942+
}
2943+
}
2944+
Dims += SizeStr + "]";
2945+
TL = ATL.getElementLoc();
2946+
}
2947+
// Step 2
2948+
if (MemVar->hasInit()) {
2949+
if ((MemVar->getInitStyle() == VarDecl::InitializationStyle::CInit)) {
2950+
DiagnosticsUtils::report(LocInfo.first, LocInfo.second,
2951+
Diagnostics::DEVICE_GLOBAL_INIT, true, false);
2952+
if (!dyn_cast<InitListExpr>(
2953+
MemVar->getInit()->IgnoreImplicitAsWritten())) {
2954+
auto IBS = InsertBeforeStmt(MemVar->getInit(), "{");
2955+
auto IAS = InsertAfterStmt(MemVar->getInit(), "}");
2956+
DpctGlobalInfo::getInstance().addReplacement(IBS.getReplacement(Ctx));
2957+
DpctGlobalInfo::getInstance().addReplacement(IAS.getReplacement(Ctx));
2958+
}
2959+
auto NextTok = Lexer::findNextToken(
2960+
IsArray ? SM.getSpellingLoc(OriginTL.getEndLoc())
2961+
: SM.getSpellingLoc(MemVar->getLocation()),
2962+
SM, DpctGlobalInfo::getContext().getLangOpts());
2963+
if (NextTok.has_value() && NextTok.value().is(tok::equal)) {
2964+
auto RTok = ReplaceToken(NextTok.value().getLocation(), "");
2965+
DpctGlobalInfo::getInstance().addReplacement(RTok.getReplacement(Ctx));
2966+
}
2967+
}
2968+
}
2969+
// Step 3
2970+
std::string BaseTypeStr;
2971+
SourceLocation TypeReplLoc;
2972+
size_t TypeReplLen = 0;
2973+
if (SM.isMacroArgExpansion(OriginTL.getBeginLoc())) {
2974+
auto Iter = MacroArgMap.find(
2975+
DpctGlobalInfo::getInstance()
2976+
.getMainFile()
2977+
->getFilePath()
2978+
.getPath()
2979+
.str() +
2980+
getCombinedStrFromLoc(SM.getSpellingLoc(OriginTL.getBeginLoc())));
2981+
if (Iter != MacroArgMap.end()) {
2982+
BaseTypeStr = Iter->second->ArgName;
2983+
TypeReplLoc = Iter->second->ArgLoc;
2984+
TypeReplLen = BaseTypeStr.size();
2985+
}
2986+
}
2987+
if (BaseTypeStr.empty()) {
2988+
BaseTypeStr = getType()->getBaseNameWithoutQualifiers();
2989+
TypeReplLoc = TL.getBeginLoc();
2990+
TypeReplLen =
2991+
SM.getFileOffset(TL.getEndLoc()) - SM.getFileOffset(TL.getBeginLoc()) +
2992+
Lexer::MeasureTokenLength(TL.getEndLoc(), SM,
2993+
DpctGlobalInfo::getContext().getLangOpts());
2994+
}
2995+
std::string TypeStr = MapNames::getClNamespace() +
2996+
"ext::oneapi::experimental::device_global<" +
2997+
BaseTypeStr + Dims + ">";
2998+
auto RT = ReplaceText(TypeReplLoc, TypeReplLen, std::move(TypeStr));
2999+
DpctGlobalInfo::getInstance().addReplacement(RT.getReplacement(Ctx));
3000+
// Step 4
3001+
if (MemVar->getStorageClass() != SC_Static && getScope() == Global) {
3002+
DpctGlobalInfo::getInstance().addReplacement(
3003+
std::make_shared<ExtReplacement>(LocInfo.first, LocInfo.second, 0,
3004+
"static ", nullptr));
3005+
}
3006+
}
3007+
28343008
MemVarInfo::VarAttrKind MemVarInfo::getAddressAttr(const VarDecl *VD) {
28353009
if (VD->hasAttrs())
28363010
return getAddressAttr(VD->getAttrs());

clang/lib/DPCT/AnalysisInfo.h

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,22 @@ class DpctGlobalInfo {
633633
bool IsInAnalysisScope;
634634
MacroDefRecord(SourceLocation NTL, bool IIAS);
635635
};
636+
// This class is used to store information about macro arguments in a
637+
// macro definition. For example, consider the macro definition:
638+
// "#define CALL(x, y) x(y)".
639+
// - For the first argument "x", the member ArgName will be "x", ArgLoc will
640+
// be the source location of the token "x" in the macro definition, and
641+
// ArgIndex will be 0.
642+
// - For the second argument "y", the member ArgName will be "y", ArgLoc will
643+
// be the source location of the token "y" in the macro definition, and
644+
// ArgIndex will be 1.
645+
class MacroArgRecord {
646+
public:
647+
std::string ArgName;
648+
SourceLocation ArgLoc;
649+
int ArgIndex;
650+
MacroArgRecord(const MacroInfo *MI, int ArgIndex);
651+
};
636652

637653
class MacroExpansionRecord {
638654
public:
@@ -1189,6 +1205,10 @@ class DpctGlobalInfo {
11891205
getExpansionRangeBeginMap() {
11901206
return ExpansionRangeBeginMap;
11911207
}
1208+
static std::unordered_map<std::string, std::shared_ptr<MacroArgRecord>> &
1209+
getMacroArgRecordMap() {
1210+
return MacroArgRecordMap;
1211+
}
11921212
static std::map<std::string, std::shared_ptr<MacroExpansionRecord>> &
11931213
getExpansionRangeToMacroRecord() {
11941214
return ExpansionRangeToMacroRecord;
@@ -1568,6 +1588,11 @@ class DpctGlobalInfo {
15681588
static std::map<std::string,
15691589
std::shared_ptr<DpctGlobalInfo::MacroExpansionRecord>>
15701590
ExpansionRangeToMacroRecord;
1591+
// key: The hash string of the location of function-like macro argument
1592+
// value: Function-like macro argument information
1593+
static std::unordered_map<std::string,
1594+
std::shared_ptr<DpctGlobalInfo::MacroArgRecord>>
1595+
MacroArgRecordMap;
15711596
static std::map<std::string, SourceLocation> EndifLocationOfIfdef;
15721597
static std::vector<std::pair<clang::tooling::UnifiedPath, size_t>>
15731598
ConditionalCompilationLoc;
@@ -1894,7 +1919,7 @@ class MemVarInfo : public VarInfo {
18941919
bool isUseHelperFunc() { return UseHelperFuncFlag; }
18951920
void setUseDeviceGlobalFlag(bool Flag) { UseDeviceGlobalFlag = Flag; }
18961921
bool isUseDeviceGlobal() { return UseDeviceGlobalFlag; }
1897-
void setInitForDeviceGlobal(std::string Init) { InitList = Init; }
1922+
void migrateToDeviceGlobal(const VarDecl *MemVar);
18981923

18991924
private:
19001925
bool isTreatPointerAsArray() {

clang/lib/DPCT/Utility.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4811,16 +4811,6 @@ std::string getNameSpace(const NamespaceDecl *NSD) {
48114811
return NameSpace;
48124812
}
48134813

4814-
std::string getInitForDeviceGlobal(const VarDecl *VD) {
4815-
auto Init = VD->getInit()->IgnoreImplicitAsWritten();
4816-
if (auto IL = dyn_cast<InitListExpr>(Init)) {
4817-
return dpct::ExprAnalysis::ref(IL);
4818-
} else if (dyn_cast<CXXConstructExpr>(Init)) {
4819-
return "";
4820-
}
4821-
return "{" + dpct::ExprAnalysis::ref(Init) + "}";
4822-
}
4823-
48244814
void getNameSpace(const NamespaceDecl *NSD,
48254815
std::vector<std::string> &Namespaces) {
48264816
if (!NSD)

clang/lib/DPCT/Utility.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,6 @@ bool containIterationSpaceBuiltinVar(const clang::Stmt *Node);
568568
bool containBuiltinWarpSize(const clang::Stmt *Node);
569569
bool isCapturedByLambda(const clang::TypeLoc *TL);
570570
std::string getNameSpace(const NamespaceDecl *NSD);
571-
std::string getInitForDeviceGlobal(const VarDecl *VD);
572571
void getNameSpace(const NamespaceDecl *NSD,
573572
std::vector<std::string> &Namespaces);
574573
std::string getTemplateArgumentAsString(const clang::TemplateArgument &Arg,

clang/test/dpct/device_global/device_global.cu

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,20 @@ __constant__ int arr_d[10] = {2};
6060
__constant__ A arr_e[10];
6161
__device__ B arr_f[10];
6262

63+
// CHECK: #define TABLE_BEGIN(type, name, size) static const sycl::ext::oneapi::experimental::device_global<type[size]> name {
64+
// CHECK: #define TABLE_END() };
65+
#define TABLE_BEGIN(type, name, size) static const __device__ type name[size] = {
66+
#define TABLE_END() };
67+
68+
// CHECK: TABLE_BEGIN(int, arr_g, 10)
69+
// CHECK: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9
70+
// CHECK: TABLE_END()
71+
TABLE_BEGIN(int, arr_g, 10)
72+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9
73+
TABLE_END()
6374

6475
// CHECK: int device_func() {
65-
// CHECK: float *p = arra_a.get();
76+
// CHECK: float *p = arr_a.get();
6677
// CHECK: arr_a[0] = 1;
6778
// CHECK: return arr_a[0] + arr_b[0] + arr_c[0] + arr_d[0] + arr_e[0].data + arr_f[0].data;
6879
// CHECK: }
@@ -73,8 +84,9 @@ __device__ B arr_f[10];
7384
// CHECK: *ptr = var_a.get() + var_b.get() + var_c.get() + var_d.get() + var_e.get().data + var_f.get().data + device_func();
7485
// CHECK: }
7586
__device__ int device_func() {
76-
float *p = arra_a;
87+
float *p = arr_a;
7788
arr_a[0] = 1;
89+
arr_g[0];
7890
return arr_a[0] + arr_b[0] + arr_c[0] + arr_d[0] + arr_e[0].data + arr_f[0].data;
7991
}
8092

0 commit comments

Comments
 (0)