Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions clang/lib/DPCT/APINamesMemory.inc
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY(
"cuMemPrefetchAsync",
MEMBER_CALL(CALL(MapNames::getDpctNamespace() + "get_device",
ARG(2)),
false, DpctGlobalInfo::getDeviceQueueName()),
false, DpctGlobalInfo::getDefaultQueueMemFuncName()),
DpctGlobalInfo::useSYCLCompat(), "prefetch", ARG(0), ARG(1)))))

ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY(
Expand All @@ -280,7 +280,7 @@ ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY(
"cuMemAdvise",
MEMBER_CALL(CALL(MapNames::getDpctNamespace() +
"cpu_device"),
false, DpctGlobalInfo::getDeviceQueueName()),
false, DpctGlobalInfo::getDefaultQueueMemFuncName()),
DpctGlobalInfo::useSYCLCompat(), "mem_advise", ARG(0), ARG(1), ARG("0"))),
Diagnostics::DEFAULT_MEM_ADVICE, ARG(" and was set to 0")),

Expand All @@ -289,7 +289,7 @@ ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY(
MEMBER_CALL_FACTORY_ENTRY(
"cuMemAdvise",
MEMBER_CALL(CALL(MapNames::getDpctNamespace() + "cpu_device"),
false, DpctGlobalInfo::getDeviceQueueName()),
false, DpctGlobalInfo::getDefaultQueueMemFuncName()),
DpctGlobalInfo::useSYCLCompat(), "mem_advise", ARG(0), ARG(1), ARG(2)))),

CONDITIONAL_FACTORY_ENTRY(
Expand All @@ -303,7 +303,7 @@ ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY(
MEMBER_CALL(CALL(MapNames::getDpctNamespace() +
"get_device",
ARG(3)),
false, DpctGlobalInfo::getDeviceQueueName()),
false, DpctGlobalInfo::getDefaultQueueMemFuncName()),
DpctGlobalInfo::useSYCLCompat(), "mem_advise", ARG(0), ARG(1), ARG("0"))),
Diagnostics::DEFAULT_MEM_ADVICE, ARG(" and was set to 0")),
FEATURE_REQUEST_FACTORY(
Expand All @@ -312,7 +312,7 @@ ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY(
"cuMemAdvise",
MEMBER_CALL(CALL(MapNames::getDpctNamespace() + "get_device",
ARG(3)),
false, DpctGlobalInfo::getDeviceQueueName()),
false, DpctGlobalInfo::getDefaultQueueMemFuncName()),
DpctGlobalInfo::useSYCLCompat(), "mem_advise", ARG(0), ARG(1), ARG(2))))))

ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY(
Expand Down
7 changes: 4 additions & 3 deletions clang/lib/DPCT/ASTTraversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10584,7 +10584,8 @@ void MemoryMigrationRule::prefetchMigration(
? "cpu_device()"
: "get_device(" + StmtStrArg2 + ")");
requestFeature(HelperFeatureEnum::device_ext);
Replacement = Prefix + "." + DpctGlobalInfo::getDeviceQueueName() + "()" +
Replacement = Prefix + "." +
DpctGlobalInfo::getDefaultQueueMemFuncName() + "()" +
(DpctGlobalInfo::useSYCLCompat() ? "->" : ".") +
"prefetch(" + StmtStrArg0 + "," + StmtStrArg1 + ")";
} else {
Expand Down Expand Up @@ -10786,15 +10787,15 @@ void MemoryMigrationRule::cudaMemAdvise(const MatchFinder::MatchResult &Result,
std::ostringstream OS;
if (getStmtSpelling(C->getArg(3)) == "cudaCpuDeviceId") {
OS << MapNames::getDpctNamespace() + "cpu_device()." +
DpctGlobalInfo::getDeviceQueueName() + "()";
DpctGlobalInfo::getDefaultQueueMemFuncName() + "()";
OS << (DpctGlobalInfo::useSYCLCompat() ? "->" : ".") << "mem_advise("
<< Arg0Str << ", " << Arg1Str << ", " << Arg2Str << ")";
emplaceTransformation(new ReplaceStmt(C, OS.str()));
requestFeature(HelperFeatureEnum::device_ext);
return;
}
OS << MapNames::getDpctNamespace() + "get_device(" << Arg3Str
<< ")." + DpctGlobalInfo::getDeviceQueueName() + "()";
<< ")." + DpctGlobalInfo::getDefaultQueueMemFuncName() + "()";
OS << (DpctGlobalInfo::useSYCLCompat() ? "->" : ".") << "mem_advise("
<< Arg0Str << ", " << Arg1Str << ", " << Arg2Str << ")";
emplaceTransformation(new ReplaceStmt(C, OS.str()));
Expand Down
76 changes: 57 additions & 19 deletions clang/lib/DPCT/AnalysisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ const std::string &getDefaultString(HelperFuncType HFT) {
const static std::string DefaultQueue =
DpctGlobalInfo::useNoQueueDevice()
? DpctGlobalInfo::getGlobalQueueName()
: buildString(MapNames::getDpctNamespace() + "get_" +
DpctGlobalInfo::getDeviceQueueName() + "()");
: DpctGlobalInfo::getDefaultQueueFreeFuncCall();
return DefaultQueue;
}
case clang::dpct::HelperFuncType::HFT_DefaultQueuePtr: {
Expand All @@ -74,8 +73,8 @@ const std::string &getDefaultString(HelperFuncType HFT) {
: (DpctGlobalInfo::useSYCLCompat()
? buildString(MapNames::getDpctNamespace() +
"get_current_device().default_queue()")
: buildString("&" + MapNames::getDpctNamespace() + "get_" +
DpctGlobalInfo::getDeviceQueueName() + "()"));
: buildString(
"&", DpctGlobalInfo::getDefaultQueueFreeFuncCall()));
return DefaultQueue;
}
case clang::dpct::HelperFuncType::HFT_CurrentDevice: {
Expand Down Expand Up @@ -269,6 +268,15 @@ void processTypeLoc(const TypeLoc &TL, ExprAnalysis &EA,
}
EA.applyAllSubExprRepl();
}
HelperFuncCatalog getQueueKind() {
if (DpctGlobalInfo::useSYCLCompat()) {
return HelperFuncCatalog::GetDefaultQueue;
}
if (DpctGlobalInfo::getUsmLevel() == UsmLevel::UL_Restricted) {
return HelperFuncCatalog::GetInOrderQueue;
}
return HelperFuncCatalog::GetOutOfOrderQueue;
}

///// class FreeQueriesInfo /////
class FreeQueriesInfo {
Expand Down Expand Up @@ -930,6 +938,11 @@ void DpctFileInfo::insertHeader(HeaderType Type, unsigned Offset,
<< CCLVerValue << getNL();
insertHeader(MigratedMacroDefinitionOS.str(), FileBeginOffset,
InsertPosition::IP_AlwaysLeft);
for (const auto &File :
DpctGlobalInfo::getCustomHelperFunctionAddtionalIncludes()) {
insertHeader("#include \"" + File + +"\"" + getNL(), FirstIncludeOffset,
InsertPosition::IP_Right);
}
return;

// Because <dpct/dpl_utils.hpp> includes <oneapi/dpl/execution> and
Expand Down Expand Up @@ -1225,15 +1238,26 @@ std::string DpctGlobalInfo::getDefaultQueue(const Stmt *S) {

return buildString(RegexPrefix, 'Q', Idx, RegexSuffix);
}
const std::string &DpctGlobalInfo::getDeviceQueueName() {
static const std::string DeviceQueue = [&]() {
const std::string &DpctGlobalInfo::getDefaultQueueFreeFuncCall() {
static const std::string DefaultQueueFreeFuncCall = [&]() {
if (auto Iter = MapNames::CustomHelperFunctionMap.find(getQueueKind());
Iter != MapNames::CustomHelperFunctionMap.end()) {
return Iter->second;
}
return MapNames::getDpctNamespace() + "get_" +
getDefaultQueueMemFuncName() + "()";
}();
return DefaultQueueFreeFuncCall;
}
const std::string &DpctGlobalInfo::getDefaultQueueMemFuncName() {
static const std::string DefaultQueueMemFuncName = [&]() {
if (DpctGlobalInfo::useSYCLCompat())
return "default_queue";
if (DpctGlobalInfo::getUsmLevel() == UsmLevel::UL_None)
return "out_of_order_queue";
return "in_order_queue";
}();
return DeviceQueue;
return DefaultQueueMemFuncName;
}
void DpctGlobalInfo::setContext(ASTContext &C) {
Context = &C;
Expand Down Expand Up @@ -1588,7 +1612,8 @@ void DpctGlobalInfo::buildReplacements() {
QDecl << "&q_ct1 = ";
if (DpctGlobalInfo::useSYCLCompat())
QDecl << '*';
QDecl << "dev_ct1." << DpctGlobalInfo::getDeviceQueueName() << "();";
QDecl << "dev_ct1." << DpctGlobalInfo::getDefaultQueueMemFuncName()
<< "();";
} else {
DevDecl << MapNames::getClNamespace() + "device dev_ct1;";
// Now the UsmLevel must not be UL_None here.
Expand Down Expand Up @@ -2454,6 +2479,8 @@ std::vector<std::pair<std::string, std::vector<std::string>>>
std::vector<std::pair<std::string, std::vector<std::string>>>
DpctGlobalInfo::CodePinDumpFuncDepsVec;
std::unordered_set<std::string> DpctGlobalInfo::NeedParenAPISet = {};
std::unordered_set<std::string>
DpctGlobalInfo::CustomHelperFunctionAddtionalIncludes = {};
///// class DpctNameGenerator /////
void DpctNameGenerator::printName(const FunctionDecl *FD,
llvm::raw_ostream &OS) {
Expand Down Expand Up @@ -6070,6 +6097,7 @@ void KernelCallExpr::removeExtraIndent() {
getFilePath(), getOffset() - LocInfo.Indent.length(),
LocInfo.Indent.length(), "", nullptr));
}

void KernelCallExpr::addDevCapCheckStmt() {
llvm::SmallVector<std::string> AspectList;
if (getVarMap().hasBF64()) {
Expand All @@ -6079,17 +6107,28 @@ void KernelCallExpr::addDevCapCheckStmt() {
AspectList.push_back(MapNames::getClNamespace() + "aspect::fp16");
}
if (!AspectList.empty()) {
requestFeature(HelperFeatureEnum::device_ext);
std::string Str;
llvm::raw_string_ostream OS(Str);
OS << MapNames::getDpctNamespace() << "get_device(";
OS << MapNames::getDpctNamespace() << "get_device_id(";
printStreamBase(OS);
OS << "get_device())).has_capability_or_fail({" << AspectList.front();
for (size_t i = 1; i < AspectList.size(); ++i) {
OS << ", " << AspectList[i];
}
OS << "});";
if (auto Iter = MapNames::CustomHelperFunctionMap.find(getQueueKind());
Iter != MapNames::CustomHelperFunctionMap.end()) {
OS << MapNames::getDpctNamespace() << "has_capability_or_fail(";
OS << Iter->second << ".get_device(), ";
OS << "{" << AspectList.front();
for (size_t i = 1; i < AspectList.size(); ++i) {
OS << ", " << AspectList[i];
}
OS << "});";
} else {
requestFeature(HelperFeatureEnum::device_ext);
OS << MapNames::getDpctNamespace() << "get_device(";
OS << MapNames::getDpctNamespace() << "get_device_id(";
printStreamBase(OS);
OS << "get_device())).has_capability_or_fail({" << AspectList.front();
for (size_t i = 1; i < AspectList.size(); ++i) {
OS << ", " << AspectList[i];
}
OS << "});";
}
OuterStmts.OthersList.emplace_back(OS.str());
}
}
Expand Down Expand Up @@ -6139,8 +6178,7 @@ void KernelCallExpr::addStreamDecl() {
buildString(MapNames::getClNamespace() + "stream ",
DpctGlobalInfo::getStreamName(), "(64 * 1024, 80, cgh);"));
if (getVarMap().hasSync()) {
auto DefaultQueue = buildString(MapNames::getDpctNamespace(), "get_",
DpctGlobalInfo::getDeviceQueueName(), "()");
auto DefaultQueue = DpctGlobalInfo::getDefaultQueueFreeFuncCall();
if (DpctGlobalInfo::getUsmLevel() == UsmLevel::UL_None) {
OuterStmts.OthersList.emplace_back(
buildString(MapNames::getDpctNamespace(), "global_memory<",
Expand Down
18 changes: 11 additions & 7 deletions clang/lib/DPCT/AnalysisInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -662,16 +662,14 @@ class DpctGlobalInfo {
: DefaultQueueCounter(DefaultQueueCounter),
CurrentDeviceCounter(CurrentDeviceCounter),
PlaceholderStr{
"",
buildString(MapNames::getDpctNamespace(), "get_",
DpctGlobalInfo::getDeviceQueueName(), "()"),
"", DpctGlobalInfo::getDefaultQueueFreeFuncCall(),
MapNames::getDpctNamespace() + "get_current_device()",
(DpctGlobalInfo::useSYCLCompat()
? buildString(MapNames::getDpctNamespace() +
"get_current_device().default_queue()")
: buildString("&" + MapNames::getDpctNamespace() + "get_" +
DpctGlobalInfo::getDeviceQueueName() +
"()"))} {}
: buildString(
"&", DpctGlobalInfo::getDefaultQueueFreeFuncCall()))} {
}
int DefaultQueueCounter = 0;
int CurrentDeviceCounter = 0;
std::string PlaceholderStr[4];
Expand Down Expand Up @@ -749,7 +747,8 @@ class DpctGlobalInfo {
static std::string getSubGroup(const Stmt *,
const FunctionDecl *FD = nullptr);
static std::string getDefaultQueue(const Stmt *);
static const std::string &getDeviceQueueName();
static const std::string &getDefaultQueueFreeFuncCall();
static const std::string &getDefaultQueueMemFuncName();
static const std::string &getStreamName() {
const static std::string StreamName = "stream" + getCTFixedSuffix();
return StreamName;
Expand Down Expand Up @@ -1328,6 +1327,10 @@ class DpctGlobalInfo {
static bool useBFloat16() {
return getUsingExtensionDE(DPCPPExtensionsDefaultEnabled::ExtDE_BFloat16);
}
static std::unordered_set<std::string> &
getCustomHelperFunctionAddtionalIncludes() {
return CustomHelperFunctionAddtionalIncludes;
}
std::shared_ptr<DpctFileInfo>
insertFile(const clang::tooling::UnifiedPath &FilePath) {
return insertObject(FileMap, FilePath);
Expand Down Expand Up @@ -1643,6 +1646,7 @@ class DpctGlobalInfo {
static std::vector<std::pair<std::string, std::vector<std::string>>>
CodePinDumpFuncDepsVec;
static std::unordered_set<std::string> NeedParenAPISet;
static std::unordered_set<std::string> CustomHelperFunctionAddtionalIncludes;
};

/// Generate mangle name of FunctionDecl as key of DeviceFunctionInfo.
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/DPCT/MapNames.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4597,6 +4597,8 @@ MapNames::MapTy TextureRule::ResourceTypeNames{{"devPtr", "data_ptr"},
{"numChannels", "channel_num"}};

std::vector<MetaRuleObject::PatternRewriter> MapNames::PatternRewriters;
std::map<clang::dpct::HelperFuncCatalog, std::string>
MapNames::CustomHelperFunctionMap;

const MapNames::MapTy MemoryDataTypeRule::PitchMemberNames{
{"pitch", "pitch"}, {"ptr", "data_ptr"}, {"xsize", "x"}, {"ysize", "y"}};
Expand Down
8 changes: 8 additions & 0 deletions clang/lib/DPCT/MapNames.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ namespace dpct {
enum class KernelArgType;
enum class HelperFileEnum : unsigned int;
struct HelperFunc;
enum class HelperFuncCatalog {
GetDefaultQueue,
GetOutOfOrderQueue,
GetInOrderQueue,
};
} // namespace dpct
} // namespace clang

Expand Down Expand Up @@ -420,6 +425,9 @@ class MapNames {
/// {Original API, {ToType, FromType}}
static std::unordered_map<std::string, std::pair<std::string, std::string>>
MathTypeCastingMap;

static std::map<clang::dpct::HelperFuncCatalog, std::string>
CustomHelperFunctionMap;
};

class MigrationStatistics {
Expand Down
21 changes: 21 additions & 0 deletions clang/lib/DPCT/Rules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,25 @@ void registerPatternRewriterRule(MetaRuleObject &R) {
R.BuildScriptSyntax, R.Priority));
}

void registerHelperFunctionRule(MetaRuleObject &R) {
if ((R.In == "get_default_queue" || R.In == "get_in_order_queue" ||
R.In == "get_out_of_order_queue") &&
R.Priority == RulePriority::Takeover) {
if (R.In == "get_default_queue")
MapNames::CustomHelperFunctionMap.insert(
{dpct::HelperFuncCatalog::GetDefaultQueue, R.Out});
else if (R.In == "get_in_order_queue")
MapNames::CustomHelperFunctionMap.insert(
{dpct::HelperFuncCatalog::GetInOrderQueue, R.Out});
else
MapNames::CustomHelperFunctionMap.insert(
{dpct::HelperFuncCatalog::GetOutOfOrderQueue, R.Out});
dpct::DpctGlobalInfo::setUsingDRYPattern(false);
dpct::DpctGlobalInfo::getCustomHelperFunctionAddtionalIncludes().insert(
R.Includes.begin(), R.Includes.end());
}
}

MetaRuleObject::PatternRewriter &MetaRuleObject::PatternRewriter::operator=(
const MetaRuleObject::PatternRewriter &PR) {
if (this != &PR) {
Expand Down Expand Up @@ -366,6 +385,8 @@ void importRules(std::vector<clang::tooling::UnifiedPath> &RuleFiles) {
case (RuleKind::CMakeRule):
registerCmakeMigrationRule(*r);
break;
case (RuleKind::HelperFunction):
registerHelperFunctionRule(*r);
case (RuleKind::PythonRule):
registerPythonMigrationRule(*r);
break;
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/DPCT/Rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ enum RuleKind {
DisableAPIMigration,
PatternRewriter,
CMakeRule,
HelperFunction,
PythonRule
};

Expand Down Expand Up @@ -210,6 +211,7 @@ template <> struct llvm::yaml::ScalarEnumerationTraits<RuleKind> {
Io.enumCase(Value, "DisableAPIMigration", RuleKind::DisableAPIMigration);
Io.enumCase(Value, "PatternRewriter", RuleKind::PatternRewriter);
Io.enumCase(Value, "CMakeRule", RuleKind::CMakeRule);
Io.enumCase(Value, "HelperFunction", RuleKind::HelperFunction);
Io.enumCase(Value, "PythonRule", RuleKind::PythonRule);
}
};
Expand Down
Loading
Loading