Skip to content

Commit b454b2c

Browse files
authored
[SYCLomatic] Introduce a new kind of user defined migration rule to customize the helper function in migration (#2413)
Current only support replace/custom the usage of following 3 helper functions: -GetDefaultQueue -GetOutOfOrderQueue -GetInOrderQueue Signed-off-by: Jiang, Zhiwei <[email protected]>
1 parent 866a894 commit b454b2c

14 files changed

+281
-34
lines changed

clang/lib/DPCT/APINamesMemory.inc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY(
265265
"cuMemPrefetchAsync",
266266
MEMBER_CALL(CALL(MapNames::getDpctNamespace() + "get_device",
267267
ARG(2)),
268-
false, DpctGlobalInfo::getDeviceQueueName()),
268+
false, DpctGlobalInfo::getDefaultQueueMemFuncName()),
269269
DpctGlobalInfo::useSYCLCompat(), "prefetch", ARG(0), ARG(1)))))
270270

271271
ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY(
@@ -280,7 +280,7 @@ ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY(
280280
"cuMemAdvise",
281281
MEMBER_CALL(CALL(MapNames::getDpctNamespace() +
282282
"cpu_device"),
283-
false, DpctGlobalInfo::getDeviceQueueName()),
283+
false, DpctGlobalInfo::getDefaultQueueMemFuncName()),
284284
DpctGlobalInfo::useSYCLCompat(), "mem_advise", ARG(0), ARG(1), ARG("0"))),
285285
Diagnostics::DEFAULT_MEM_ADVICE, ARG(" and was set to 0")),
286286

@@ -289,7 +289,7 @@ ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY(
289289
MEMBER_CALL_FACTORY_ENTRY(
290290
"cuMemAdvise",
291291
MEMBER_CALL(CALL(MapNames::getDpctNamespace() + "cpu_device"),
292-
false, DpctGlobalInfo::getDeviceQueueName()),
292+
false, DpctGlobalInfo::getDefaultQueueMemFuncName()),
293293
DpctGlobalInfo::useSYCLCompat(), "mem_advise", ARG(0), ARG(1), ARG(2)))),
294294

295295
CONDITIONAL_FACTORY_ENTRY(
@@ -303,7 +303,7 @@ ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY(
303303
MEMBER_CALL(CALL(MapNames::getDpctNamespace() +
304304
"get_device",
305305
ARG(3)),
306-
false, DpctGlobalInfo::getDeviceQueueName()),
306+
false, DpctGlobalInfo::getDefaultQueueMemFuncName()),
307307
DpctGlobalInfo::useSYCLCompat(), "mem_advise", ARG(0), ARG(1), ARG("0"))),
308308
Diagnostics::DEFAULT_MEM_ADVICE, ARG(" and was set to 0")),
309309
FEATURE_REQUEST_FACTORY(
@@ -312,7 +312,7 @@ ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY(
312312
"cuMemAdvise",
313313
MEMBER_CALL(CALL(MapNames::getDpctNamespace() + "get_device",
314314
ARG(3)),
315-
false, DpctGlobalInfo::getDeviceQueueName()),
315+
false, DpctGlobalInfo::getDefaultQueueMemFuncName()),
316316
DpctGlobalInfo::useSYCLCompat(), "mem_advise", ARG(0), ARG(1), ARG(2))))))
317317

318318
ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY(

clang/lib/DPCT/ASTTraversal.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10595,7 +10595,8 @@ void MemoryMigrationRule::prefetchMigration(
1059510595
? "cpu_device()"
1059610596
: "get_device(" + StmtStrArg2 + ")");
1059710597
requestFeature(HelperFeatureEnum::device_ext);
10598-
Replacement = Prefix + "." + DpctGlobalInfo::getDeviceQueueName() + "()" +
10598+
Replacement = Prefix + "." +
10599+
DpctGlobalInfo::getDefaultQueueMemFuncName() + "()" +
1059910600
(DpctGlobalInfo::useSYCLCompat() ? "->" : ".") +
1060010601
"prefetch(" + StmtStrArg0 + "," + StmtStrArg1 + ")";
1060110602
} else {
@@ -10797,15 +10798,15 @@ void MemoryMigrationRule::cudaMemAdvise(const MatchFinder::MatchResult &Result,
1079710798
std::ostringstream OS;
1079810799
if (getStmtSpelling(C->getArg(3)) == "cudaCpuDeviceId") {
1079910800
OS << MapNames::getDpctNamespace() + "cpu_device()." +
10800-
DpctGlobalInfo::getDeviceQueueName() + "()";
10801+
DpctGlobalInfo::getDefaultQueueMemFuncName() + "()";
1080110802
OS << (DpctGlobalInfo::useSYCLCompat() ? "->" : ".") << "mem_advise("
1080210803
<< Arg0Str << ", " << Arg1Str << ", " << Arg2Str << ")";
1080310804
emplaceTransformation(new ReplaceStmt(C, OS.str()));
1080410805
requestFeature(HelperFeatureEnum::device_ext);
1080510806
return;
1080610807
}
1080710808
OS << MapNames::getDpctNamespace() + "get_device(" << Arg3Str
10808-
<< ")." + DpctGlobalInfo::getDeviceQueueName() + "()";
10809+
<< ")." + DpctGlobalInfo::getDefaultQueueMemFuncName() + "()";
1080910810
OS << (DpctGlobalInfo::useSYCLCompat() ? "->" : ".") << "mem_advise("
1081010811
<< Arg0Str << ", " << Arg1Str << ", " << Arg2Str << ")";
1081110812
emplaceTransformation(new ReplaceStmt(C, OS.str()));

clang/lib/DPCT/AnalysisInfo.cpp

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ const std::string &getDefaultString(HelperFuncType HFT) {
6363
const static std::string DefaultQueue =
6464
DpctGlobalInfo::useNoQueueDevice()
6565
? DpctGlobalInfo::getGlobalQueueName()
66-
: buildString(MapNames::getDpctNamespace() + "get_" +
67-
DpctGlobalInfo::getDeviceQueueName() + "()");
66+
: DpctGlobalInfo::getDefaultQueueFreeFuncCall();
6867
return DefaultQueue;
6968
}
7069
case clang::dpct::HelperFuncType::HFT_DefaultQueuePtr: {
@@ -74,8 +73,8 @@ const std::string &getDefaultString(HelperFuncType HFT) {
7473
: (DpctGlobalInfo::useSYCLCompat()
7574
? buildString(MapNames::getDpctNamespace() +
7675
"get_current_device().default_queue()")
77-
: buildString("&" + MapNames::getDpctNamespace() + "get_" +
78-
DpctGlobalInfo::getDeviceQueueName() + "()"));
76+
: buildString(
77+
"&", DpctGlobalInfo::getDefaultQueueFreeFuncCall()));
7978
return DefaultQueue;
8079
}
8180
case clang::dpct::HelperFuncType::HFT_CurrentDevice: {
@@ -269,6 +268,15 @@ void processTypeLoc(const TypeLoc &TL, ExprAnalysis &EA,
269268
}
270269
EA.applyAllSubExprRepl();
271270
}
271+
HelperFuncCatalog getQueueKind() {
272+
if (DpctGlobalInfo::useSYCLCompat()) {
273+
return HelperFuncCatalog::GetDefaultQueue;
274+
}
275+
if (DpctGlobalInfo::getUsmLevel() == UsmLevel::UL_Restricted) {
276+
return HelperFuncCatalog::GetInOrderQueue;
277+
}
278+
return HelperFuncCatalog::GetOutOfOrderQueue;
279+
}
272280

273281
///// class FreeQueriesInfo /////
274282
class FreeQueriesInfo {
@@ -930,6 +938,11 @@ void DpctFileInfo::insertHeader(HeaderType Type, unsigned Offset,
930938
<< CCLVerValue << getNL();
931939
insertHeader(MigratedMacroDefinitionOS.str(), FileBeginOffset,
932940
InsertPosition::IP_AlwaysLeft);
941+
for (const auto &File :
942+
DpctGlobalInfo::getCustomHelperFunctionAddtionalIncludes()) {
943+
insertHeader("#include \"" + File + +"\"" + getNL(), FirstIncludeOffset,
944+
InsertPosition::IP_Right);
945+
}
933946
return;
934947

935948
// Because <dpct/dpl_utils.hpp> includes <oneapi/dpl/execution> and
@@ -1225,15 +1238,26 @@ std::string DpctGlobalInfo::getDefaultQueue(const Stmt *S) {
12251238

12261239
return buildString(RegexPrefix, 'Q', Idx, RegexSuffix);
12271240
}
1228-
const std::string &DpctGlobalInfo::getDeviceQueueName() {
1229-
static const std::string DeviceQueue = [&]() {
1241+
const std::string &DpctGlobalInfo::getDefaultQueueFreeFuncCall() {
1242+
static const std::string DefaultQueueFreeFuncCall = [&]() {
1243+
if (auto Iter = MapNames::CustomHelperFunctionMap.find(getQueueKind());
1244+
Iter != MapNames::CustomHelperFunctionMap.end()) {
1245+
return Iter->second;
1246+
}
1247+
return MapNames::getDpctNamespace() + "get_" +
1248+
getDefaultQueueMemFuncName() + "()";
1249+
}();
1250+
return DefaultQueueFreeFuncCall;
1251+
}
1252+
const std::string &DpctGlobalInfo::getDefaultQueueMemFuncName() {
1253+
static const std::string DefaultQueueMemFuncName = [&]() {
12301254
if (DpctGlobalInfo::useSYCLCompat())
12311255
return "default_queue";
12321256
if (DpctGlobalInfo::getUsmLevel() == UsmLevel::UL_None)
12331257
return "out_of_order_queue";
12341258
return "in_order_queue";
12351259
}();
1236-
return DeviceQueue;
1260+
return DefaultQueueMemFuncName;
12371261
}
12381262
void DpctGlobalInfo::setContext(ASTContext &C) {
12391263
Context = &C;
@@ -1588,7 +1612,8 @@ void DpctGlobalInfo::buildReplacements() {
15881612
QDecl << "&q_ct1 = ";
15891613
if (DpctGlobalInfo::useSYCLCompat())
15901614
QDecl << '*';
1591-
QDecl << "dev_ct1." << DpctGlobalInfo::getDeviceQueueName() << "();";
1615+
QDecl << "dev_ct1." << DpctGlobalInfo::getDefaultQueueMemFuncName()
1616+
<< "();";
15921617
} else {
15931618
DevDecl << MapNames::getClNamespace() + "device dev_ct1;";
15941619
// Now the UsmLevel must not be UL_None here.
@@ -2237,6 +2262,7 @@ void DpctGlobalInfo::resetInfo() {
22372262
SpellingLocToDFIsMapForAssumeNDRange.clear();
22382263
DFIToSpellingLocsMapForAssumeNDRange.clear();
22392264
FreeQueriesInfo::reset();
2265+
CustomHelperFunctionAddtionalIncludes.clear();
22402266
}
22412267
void DpctGlobalInfo::updateSpellingLocDFIMaps(
22422268
SourceLocation SL, std::shared_ptr<DeviceFunctionInfo> DFI) {
@@ -2454,6 +2480,8 @@ std::vector<std::pair<std::string, std::vector<std::string>>>
24542480
std::vector<std::pair<std::string, std::vector<std::string>>>
24552481
DpctGlobalInfo::CodePinDumpFuncDepsVec;
24562482
std::unordered_set<std::string> DpctGlobalInfo::NeedParenAPISet = {};
2483+
std::unordered_set<std::string>
2484+
DpctGlobalInfo::CustomHelperFunctionAddtionalIncludes = {};
24572485
///// class DpctNameGenerator /////
24582486
void DpctNameGenerator::printName(const FunctionDecl *FD,
24592487
llvm::raw_ostream &OS) {
@@ -6072,6 +6100,7 @@ void KernelCallExpr::removeExtraIndent() {
60726100
getFilePath(), getOffset() - LocInfo.Indent.length(),
60736101
LocInfo.Indent.length(), "", nullptr));
60746102
}
6103+
60756104
void KernelCallExpr::addDevCapCheckStmt() {
60766105
llvm::SmallVector<std::string> AspectList;
60776106
if (getVarMap().hasBF64()) {
@@ -6081,17 +6110,28 @@ void KernelCallExpr::addDevCapCheckStmt() {
60816110
AspectList.push_back(MapNames::getClNamespace() + "aspect::fp16");
60826111
}
60836112
if (!AspectList.empty()) {
6084-
requestFeature(HelperFeatureEnum::device_ext);
60856113
std::string Str;
60866114
llvm::raw_string_ostream OS(Str);
6087-
OS << MapNames::getDpctNamespace() << "get_device(";
6088-
OS << MapNames::getDpctNamespace() << "get_device_id(";
6089-
printStreamBase(OS);
6090-
OS << "get_device())).has_capability_or_fail({" << AspectList.front();
6091-
for (size_t i = 1; i < AspectList.size(); ++i) {
6092-
OS << ", " << AspectList[i];
6093-
}
6094-
OS << "});";
6115+
if (auto Iter = MapNames::CustomHelperFunctionMap.find(getQueueKind());
6116+
Iter != MapNames::CustomHelperFunctionMap.end()) {
6117+
OS << MapNames::getDpctNamespace() << "has_capability_or_fail(";
6118+
OS << Iter->second << ".get_device(), ";
6119+
OS << "{" << AspectList.front();
6120+
for (size_t i = 1; i < AspectList.size(); ++i) {
6121+
OS << ", " << AspectList[i];
6122+
}
6123+
OS << "});";
6124+
} else {
6125+
requestFeature(HelperFeatureEnum::device_ext);
6126+
OS << MapNames::getDpctNamespace() << "get_device(";
6127+
OS << MapNames::getDpctNamespace() << "get_device_id(";
6128+
printStreamBase(OS);
6129+
OS << "get_device())).has_capability_or_fail({" << AspectList.front();
6130+
for (size_t i = 1; i < AspectList.size(); ++i) {
6131+
OS << ", " << AspectList[i];
6132+
}
6133+
OS << "});";
6134+
}
60956135
OuterStmts.OthersList.emplace_back(OS.str());
60966136
}
60976137
}
@@ -6141,8 +6181,7 @@ void KernelCallExpr::addStreamDecl() {
61416181
buildString(MapNames::getClNamespace() + "stream ",
61426182
DpctGlobalInfo::getStreamName(), "(64 * 1024, 80, cgh);"));
61436183
if (getVarMap().hasSync()) {
6144-
auto DefaultQueue = buildString(MapNames::getDpctNamespace(), "get_",
6145-
DpctGlobalInfo::getDeviceQueueName(), "()");
6184+
auto DefaultQueue = DpctGlobalInfo::getDefaultQueueFreeFuncCall();
61466185
if (DpctGlobalInfo::getUsmLevel() == UsmLevel::UL_None) {
61476186
OuterStmts.OthersList.emplace_back(
61486187
buildString(MapNames::getDpctNamespace(), "global_memory<",

clang/lib/DPCT/AnalysisInfo.h

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -663,16 +663,14 @@ class DpctGlobalInfo {
663663
: DefaultQueueCounter(DefaultQueueCounter),
664664
CurrentDeviceCounter(CurrentDeviceCounter),
665665
PlaceholderStr{
666-
"",
667-
buildString(MapNames::getDpctNamespace(), "get_",
668-
DpctGlobalInfo::getDeviceQueueName(), "()"),
666+
"", DpctGlobalInfo::getDefaultQueueFreeFuncCall(),
669667
MapNames::getDpctNamespace() + "get_current_device()",
670668
(DpctGlobalInfo::useSYCLCompat()
671669
? buildString(MapNames::getDpctNamespace() +
672670
"get_current_device().default_queue()")
673-
: buildString("&" + MapNames::getDpctNamespace() + "get_" +
674-
DpctGlobalInfo::getDeviceQueueName() +
675-
"()"))} {}
671+
: buildString(
672+
"&", DpctGlobalInfo::getDefaultQueueFreeFuncCall()))} {
673+
}
676674
int DefaultQueueCounter = 0;
677675
int CurrentDeviceCounter = 0;
678676
std::string PlaceholderStr[4];
@@ -750,7 +748,8 @@ class DpctGlobalInfo {
750748
static std::string getSubGroup(const Stmt *,
751749
const FunctionDecl *FD = nullptr);
752750
static std::string getDefaultQueue(const Stmt *);
753-
static const std::string &getDeviceQueueName();
751+
static const std::string &getDefaultQueueFreeFuncCall();
752+
static const std::string &getDefaultQueueMemFuncName();
754753
static const std::string &getStreamName() {
755754
const static std::string StreamName = "stream" + getCTFixedSuffix();
756755
return StreamName;
@@ -1329,6 +1328,10 @@ class DpctGlobalInfo {
13291328
static bool useBFloat16() {
13301329
return getUsingExtensionDE(DPCPPExtensionsDefaultEnabled::ExtDE_BFloat16);
13311330
}
1331+
static std::unordered_set<std::string> &
1332+
getCustomHelperFunctionAddtionalIncludes() {
1333+
return CustomHelperFunctionAddtionalIncludes;
1334+
}
13321335
std::shared_ptr<DpctFileInfo>
13331336
insertFile(const clang::tooling::UnifiedPath &FilePath) {
13341337
return insertObject(FileMap, FilePath);
@@ -1644,6 +1647,7 @@ class DpctGlobalInfo {
16441647
static std::vector<std::pair<std::string, std::vector<std::string>>>
16451648
CodePinDumpFuncDepsVec;
16461649
static std::unordered_set<std::string> NeedParenAPISet;
1650+
static std::unordered_set<std::string> CustomHelperFunctionAddtionalIncludes;
16471651
};
16481652

16491653
/// Generate mangle name of FunctionDecl as key of DeviceFunctionInfo.

clang/lib/DPCT/MapNames.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4600,6 +4600,8 @@ MapNames::MapTy TextureRule::ResourceTypeNames{{"devPtr", "data_ptr"},
46004600
{"numChannels", "channel_num"}};
46014601

46024602
std::vector<MetaRuleObject::PatternRewriter> MapNames::PatternRewriters;
4603+
std::map<clang::dpct::HelperFuncCatalog, std::string>
4604+
MapNames::CustomHelperFunctionMap;
46034605

46044606
const MapNames::MapTy MemoryDataTypeRule::PitchMemberNames{
46054607
{"pitch", "pitch"}, {"ptr", "data_ptr"}, {"xsize", "x"}, {"ysize", "y"}};

clang/lib/DPCT/MapNames.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ namespace dpct {
2020
enum class KernelArgType;
2121
enum class HelperFileEnum : unsigned int;
2222
struct HelperFunc;
23+
enum class HelperFuncCatalog {
24+
GetDefaultQueue,
25+
GetOutOfOrderQueue,
26+
GetInOrderQueue,
27+
};
2328
} // namespace dpct
2429
} // namespace clang
2530

@@ -420,6 +425,9 @@ class MapNames {
420425
/// {Original API, {ToType, FromType}}
421426
static std::unordered_map<std::string, std::pair<std::string, std::string>>
422427
MathTypeCastingMap;
428+
429+
static std::map<clang::dpct::HelperFuncCatalog, std::string>
430+
CustomHelperFunctionMap;
423431
};
424432

425433
class MigrationStatistics {

clang/lib/DPCT/Rules.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,32 @@ void registerPatternRewriterRule(MetaRuleObject &R) {
277277
R.BuildScriptSyntax, R.Priority));
278278
}
279279

280+
void registerHelperFunctionRule(MetaRuleObject &R) {
281+
static const std::unordered_map<std::string, dpct::HelperFuncCatalog>
282+
String2HelperFuncCatalogMap{
283+
{"get_default_queue", dpct::HelperFuncCatalog::GetDefaultQueue},
284+
{"get_in_order_queue", dpct::HelperFuncCatalog::GetInOrderQueue},
285+
{"get_out_of_order_queue",
286+
dpct::HelperFuncCatalog::GetOutOfOrderQueue}};
287+
if (R.Priority == RulePriority::Takeover) {
288+
if (auto Iter = String2HelperFuncCatalogMap.find(R.In);
289+
Iter != String2HelperFuncCatalogMap.end()) {
290+
// This map is inited here.
291+
// It saves the customized string which used for each kind of helper
292+
// function call in the migrated code.
293+
MapNames::CustomHelperFunctionMap.insert({Iter->second, R.Out});
294+
dpct::DpctGlobalInfo::setUsingDRYPattern(false);
295+
dpct::DpctGlobalInfo::getCustomHelperFunctionAddtionalIncludes().insert(
296+
R.Includes.begin(), R.Includes.end());
297+
} else {
298+
llvm::outs()
299+
<< "Warning: The rule named " << R.RuleId
300+
<< " (Kind: HelperFunction) is ignored, as the API specified "
301+
"in the \"In\" field is not supported for customization.\n";
302+
}
303+
}
304+
}
305+
280306
MetaRuleObject::PatternRewriter &MetaRuleObject::PatternRewriter::operator=(
281307
const MetaRuleObject::PatternRewriter &PR) {
282308
if (this != &PR) {
@@ -366,6 +392,9 @@ void importRules(std::vector<clang::tooling::UnifiedPath> &RuleFiles) {
366392
case (RuleKind::CMakeRule):
367393
registerCmakeMigrationRule(*r);
368394
break;
395+
case (RuleKind::HelperFunction):
396+
registerHelperFunctionRule(*r);
397+
break;
369398
case (RuleKind::PythonRule):
370399
registerPythonMigrationRule(*r);
371400
break;

clang/lib/DPCT/Rules.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ enum RuleKind {
2525
DisableAPIMigration,
2626
PatternRewriter,
2727
CMakeRule,
28+
HelperFunction,
2829
PythonRule
2930
};
3031

@@ -210,6 +211,7 @@ template <> struct llvm::yaml::ScalarEnumerationTraits<RuleKind> {
210211
Io.enumCase(Value, "DisableAPIMigration", RuleKind::DisableAPIMigration);
211212
Io.enumCase(Value, "PatternRewriter", RuleKind::PatternRewriter);
212213
Io.enumCase(Value, "CMakeRule", RuleKind::CMakeRule);
214+
Io.enumCase(Value, "HelperFunction", RuleKind::HelperFunction);
213215
Io.enumCase(Value, "PythonRule", RuleKind::PythonRule);
214216
}
215217
};

0 commit comments

Comments
 (0)