Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions clang/lib/DPCT/APINamesMemory.inc
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY(
"cuMemPrefetchAsync",
MEMBER_CALL(CALL(MapNames::getDpctNamespace() + "get_device",
ARG(2)),
false, DpctGlobalInfo::getDeviceQueueName()),
false, DpctGlobalInfo::getDefaultQueueMemFuncName()),
DpctGlobalInfo::useSYCLCompat(), "prefetch", ARG(0), ARG(1)))))

ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY(
Expand All @@ -280,7 +280,7 @@ ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY(
"cuMemAdvise",
MEMBER_CALL(CALL(MapNames::getDpctNamespace() +
"cpu_device"),
false, DpctGlobalInfo::getDeviceQueueName()),
false, DpctGlobalInfo::getDefaultQueueMemFuncName()),
DpctGlobalInfo::useSYCLCompat(), "mem_advise", ARG(0), ARG(1), ARG("0"))),
Diagnostics::DEFAULT_MEM_ADVICE, ARG(" and was set to 0")),

Expand All @@ -289,7 +289,7 @@ ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY(
MEMBER_CALL_FACTORY_ENTRY(
"cuMemAdvise",
MEMBER_CALL(CALL(MapNames::getDpctNamespace() + "cpu_device"),
false, DpctGlobalInfo::getDeviceQueueName()),
false, DpctGlobalInfo::getDefaultQueueMemFuncName()),
DpctGlobalInfo::useSYCLCompat(), "mem_advise", ARG(0), ARG(1), ARG(2)))),

CONDITIONAL_FACTORY_ENTRY(
Expand All @@ -303,7 +303,7 @@ ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY(
MEMBER_CALL(CALL(MapNames::getDpctNamespace() +
"get_device",
ARG(3)),
false, DpctGlobalInfo::getDeviceQueueName()),
false, DpctGlobalInfo::getDefaultQueueMemFuncName()),
DpctGlobalInfo::useSYCLCompat(), "mem_advise", ARG(0), ARG(1), ARG("0"))),
Diagnostics::DEFAULT_MEM_ADVICE, ARG(" and was set to 0")),
FEATURE_REQUEST_FACTORY(
Expand All @@ -312,7 +312,7 @@ ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY(
"cuMemAdvise",
MEMBER_CALL(CALL(MapNames::getDpctNamespace() + "get_device",
ARG(3)),
false, DpctGlobalInfo::getDeviceQueueName()),
false, DpctGlobalInfo::getDefaultQueueMemFuncName()),
DpctGlobalInfo::useSYCLCompat(), "mem_advise", ARG(0), ARG(1), ARG(2))))))

ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY(
Expand Down
7 changes: 4 additions & 3 deletions clang/lib/DPCT/ASTTraversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10570,7 +10570,8 @@ void MemoryMigrationRule::prefetchMigration(
? "cpu_device()"
: "get_device(" + StmtStrArg2 + ")");
requestFeature(HelperFeatureEnum::device_ext);
Replacement = Prefix + "." + DpctGlobalInfo::getDeviceQueueName() + "()" +
Replacement = Prefix + "." +
DpctGlobalInfo::getDefaultQueueMemFuncName() + "()" +
(DpctGlobalInfo::useSYCLCompat() ? "->" : ".") +
"prefetch(" + StmtStrArg0 + "," + StmtStrArg1 + ")";
} else {
Expand Down Expand Up @@ -10772,15 +10773,15 @@ void MemoryMigrationRule::cudaMemAdvise(const MatchFinder::MatchResult &Result,
std::ostringstream OS;
if (getStmtSpelling(C->getArg(3)) == "cudaCpuDeviceId") {
OS << MapNames::getDpctNamespace() + "cpu_device()." +
DpctGlobalInfo::getDeviceQueueName() + "()";
DpctGlobalInfo::getDefaultQueueMemFuncName() + "()";
OS << (DpctGlobalInfo::useSYCLCompat() ? "->" : ".") << "mem_advise("
<< Arg0Str << ", " << Arg1Str << ", " << Arg2Str << ")";
emplaceTransformation(new ReplaceStmt(C, OS.str()));
requestFeature(HelperFeatureEnum::device_ext);
return;
}
OS << MapNames::getDpctNamespace() + "get_device(" << Arg3Str
<< ")." + DpctGlobalInfo::getDeviceQueueName() + "()";
<< ")." + DpctGlobalInfo::getDefaultQueueMemFuncName() + "()";
OS << (DpctGlobalInfo::useSYCLCompat() ? "->" : ".") << "mem_advise("
<< Arg0Str << ", " << Arg1Str << ", " << Arg2Str << ")";
emplaceTransformation(new ReplaceStmt(C, OS.str()));
Expand Down
115 changes: 90 additions & 25 deletions clang/lib/DPCT/AnalysisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ const std::string &getDefaultString(HelperFuncType HFT) {
const static std::string DefaultQueue =
DpctGlobalInfo::useNoQueueDevice()
? DpctGlobalInfo::getGlobalQueueName()
: buildString(MapNames::getDpctNamespace() + "get_" +
DpctGlobalInfo::getDeviceQueueName() + "()");
: buildString(MapNames::getDpctNamespace() +
DpctGlobalInfo::getDefaultQueueFreeFuncCall());
return DefaultQueue;
}
case clang::dpct::HelperFuncType::HFT_DefaultQueuePtr: {
Expand All @@ -74,8 +74,9 @@ const std::string &getDefaultString(HelperFuncType HFT) {
: (DpctGlobalInfo::useSYCLCompat()
? buildString(MapNames::getDpctNamespace() +
"get_current_device().default_queue()")
: buildString("&" + MapNames::getDpctNamespace() + "get_" +
DpctGlobalInfo::getDeviceQueueName() + "()"));
: buildString(
"&" + MapNames::getDpctNamespace() +
DpctGlobalInfo::getDefaultQueueFreeFuncCall()));
return DefaultQueue;
}
case clang::dpct::HelperFuncType::HFT_CurrentDevice: {
Expand Down Expand Up @@ -1225,15 +1226,30 @@ std::string DpctGlobalInfo::getDefaultQueue(const Stmt *S) {

return buildString(RegexPrefix, 'Q', Idx, RegexSuffix);
}
const std::string &DpctGlobalInfo::getDeviceQueueName() {
static const std::string DeviceQueue = [&]() {
if (DpctGlobalInfo::useSYCLCompat())
return "default_queue";
if (DpctGlobalInfo::getUsmLevel() == UsmLevel::UL_None)
return "out_of_order_queue";
return "in_order_queue";
}();
return DeviceQueue;
const std::string &DpctGlobalInfo::getDefaultQueueFreeFuncCall() {
static const std::string DefaultQueueFreeFunc = "get_default_queue()";
static const std::string OutOfOrderQueueFreeFunc = "get_out_of_order_queue()";
static const std::string InOrderQueueFreeFunc = "get_in_order_queue()";
auto Iter =
MapNames::CustomHelperFunctionMap.find(HelperFuncCatalog::DefaultQueue);
if (Iter != MapNames::CustomHelperFunctionMap.end()) {
return Iter->second;
}
if (DpctGlobalInfo::useSYCLCompat())
return DefaultQueueFreeFunc;
if (DpctGlobalInfo::getUsmLevel() == UsmLevel::UL_None)
return OutOfOrderQueueFreeFunc;
return InOrderQueueFreeFunc;
}
const std::string &DpctGlobalInfo::getDefaultQueueMemFuncName() {
static const std::string DefaultQueueMemFunc = "default_queue";
static const std::string OutOfOrderQueueMemFunc = "out_of_order_queue";
static const std::string InOrderQueueMemFunc = "in_order_queue";
if (DpctGlobalInfo::useSYCLCompat())
return DefaultQueueMemFunc;
if (DpctGlobalInfo::getUsmLevel() == UsmLevel::UL_None)
return OutOfOrderQueueMemFunc;
return InOrderQueueMemFunc;
}
void DpctGlobalInfo::setContext(ASTContext &C) {
Context = &C;
Expand Down Expand Up @@ -1588,7 +1604,8 @@ void DpctGlobalInfo::buildReplacements() {
QDecl << "&q_ct1 = ";
if (DpctGlobalInfo::useSYCLCompat())
QDecl << '*';
QDecl << "dev_ct1." << DpctGlobalInfo::getDeviceQueueName() << "();";
QDecl << "dev_ct1." << DpctGlobalInfo::getDefaultQueueMemFuncName()
<< "();";
} else {
DevDecl << MapNames::getClNamespace() + "device dev_ct1;";
// Now the UsmLevel must not be UL_None here.
Expand Down Expand Up @@ -6060,6 +6077,38 @@ void KernelCallExpr::removeExtraIndent() {
getFilePath(), getOffset() - LocInfo.Indent.length(),
LocInfo.Indent.length(), "", nullptr));
}

namespace {
void buildHasCapabilityOrFailStr(const std::string &Aspects,
llvm::raw_string_ostream &OS,
const OutputBuilder &OB) {
switch (OB.Kind) {
case (OutputBuilder::Kind::Top):
for (auto &ob : OB.SubBuilders) {
buildHasCapabilityOrFailStr(Aspects, OS, *ob);
}
return;
case (OutputBuilder::Kind::String):
OS << OB.Str;
return;
case (OutputBuilder::Kind::Arg): {
if (OB.ArgIndex > 1) {
OS << "";
return;
}
OS << Aspects;
return;
}
default: {
DpctDebugs() << "[buildHasCapabilityOrFailStr OutputBuilder::Kind] "
"Unexpected value: "
<< OB.Kind << "\n";
assert(0);
}
}
}
} // namespace

void KernelCallExpr::addDevCapCheckStmt() {
llvm::SmallVector<std::string> AspectList;
if (getVarMap().hasBF64()) {
Expand All @@ -6069,17 +6118,32 @@ void KernelCallExpr::addDevCapCheckStmt() {
AspectList.push_back(MapNames::getClNamespace() + "aspect::fp16");
}
if (!AspectList.empty()) {
requestFeature(HelperFeatureEnum::device_ext);
std::string Str;
llvm::raw_string_ostream OS(Str);
OS << MapNames::getDpctNamespace() << "get_device(";
OS << MapNames::getDpctNamespace() << "get_device_id(";
printStreamBase(OS);
OS << "get_device())).has_capability_or_fail({" << AspectList.front();
for (size_t i = 1; i < AspectList.size(); ++i) {
OS << ", " << AspectList[i];
}
OS << "});";
auto Iter = MapNames::CustomHelperFunctionMap.find(
HelperFuncCatalog::HasCapabilityOrFail);
if (Iter != MapNames::CustomHelperFunctionMap.end()) {
OutputBuilder OB;
OB.parse(Iter->second);
OB.Kind = OutputBuilder::Kind::Top;
std::string Aspects = "{" + AspectList.front();
for (size_t i = 1; i < AspectList.size(); ++i) {
Aspects += AspectList[i];
}
Aspects += "}";
buildHasCapabilityOrFailStr(Aspects, OS, OB);
OS << ";";
} else {
requestFeature(HelperFeatureEnum::device_ext);
OS << MapNames::getDpctNamespace() << "get_device(";
OS << MapNames::getDpctNamespace() << "get_device_id(";
printStreamBase(OS);
OS << "get_device())).has_capability_or_fail({" << AspectList.front();
for (size_t i = 1; i < AspectList.size(); ++i) {
OS << ", " << AspectList[i];
}
OS << "});";
}
OuterStmts.OthersList.emplace_back(OS.str());
}
}
Expand Down Expand Up @@ -6129,8 +6193,9 @@ void KernelCallExpr::addStreamDecl() {
buildString(MapNames::getClNamespace() + "stream ",
DpctGlobalInfo::getStreamName(), "(64 * 1024, 80, cgh);"));
if (getVarMap().hasSync()) {
auto DefaultQueue = buildString(MapNames::getDpctNamespace(), "get_",
DpctGlobalInfo::getDeviceQueueName(), "()");
auto DefaultQueue =
buildString(MapNames::getDpctNamespace(),
DpctGlobalInfo::getDefaultQueueFreeFuncCall());
if (DpctGlobalInfo::getUsmLevel() == UsmLevel::UL_None) {
OuterStmts.OthersList.emplace_back(
buildString(MapNames::getDpctNamespace(), "global_memory<",
Expand Down
13 changes: 7 additions & 6 deletions clang/lib/DPCT/AnalysisInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -663,15 +663,15 @@ class DpctGlobalInfo {
CurrentDeviceCounter(CurrentDeviceCounter),
PlaceholderStr{
"",
buildString(MapNames::getDpctNamespace(), "get_",
DpctGlobalInfo::getDeviceQueueName(), "()"),
buildString(MapNames::getDpctNamespace(),
DpctGlobalInfo::getDefaultQueueFreeFuncCall()),
MapNames::getDpctNamespace() + "get_current_device()",
(DpctGlobalInfo::useSYCLCompat()
? buildString(MapNames::getDpctNamespace() +
"get_current_device().default_queue()")
: buildString("&" + MapNames::getDpctNamespace() + "get_" +
DpctGlobalInfo::getDeviceQueueName() +
"()"))} {}
: buildString(
"&" + MapNames::getDpctNamespace() +
DpctGlobalInfo::getDefaultQueueFreeFuncCall()))} {}
int DefaultQueueCounter = 0;
int CurrentDeviceCounter = 0;
std::string PlaceholderStr[4];
Expand Down Expand Up @@ -749,7 +749,8 @@ class DpctGlobalInfo {
static std::string getSubGroup(const Stmt *,
const FunctionDecl *FD = nullptr);
static std::string getDefaultQueue(const Stmt *);
static const std::string &getDeviceQueueName();
static const std::string &getDefaultQueueFreeFuncCall();
static const std::string &getDefaultQueueMemFuncName();
static const std::string &getStreamName() {
const static std::string StreamName = "stream" + getCTFixedSuffix();
return StreamName;
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/DPCT/MapNames.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4560,6 +4560,8 @@ MapNames::MapTy TextureRule::ResourceTypeNames{{"devPtr", "data_ptr"},
{"numChannels", "channel_num"}};

std::vector<MetaRuleObject::PatternRewriter> MapNames::PatternRewriters;
std::map<clang::dpct::HelperFuncCatalog, std::string>
MapNames::CustomHelperFunctionMap;

const MapNames::MapTy MemoryDataTypeRule::PitchMemberNames{
{"pitch", "pitch"}, {"ptr", "data_ptr"}, {"xsize", "x"}, {"ysize", "y"}};
Expand Down
7 changes: 7 additions & 0 deletions clang/lib/DPCT/MapNames.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ namespace dpct {
enum class KernelArgType;
enum class HelperFileEnum : unsigned int;
struct HelperFunc;
enum class HelperFuncCatalog {
DefaultQueue,
HasCapabilityOrFail
};
} // namespace dpct
} // namespace clang

Expand Down Expand Up @@ -420,6 +424,9 @@ class MapNames {
/// {Original API, {ToType, FromType}}
static std::unordered_map<std::string, std::pair<std::string, std::string>>
MathTypeCastingMap;

static std::map<clang::dpct::HelperFuncCatalog, std::string>
CustomHelperFunctionMap;
};

class MigrationStatistics {
Expand Down
16 changes: 16 additions & 0 deletions clang/lib/DPCT/Rules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,19 @@ void registerPatternRewriterRule(MetaRuleObject &R) {
R.Priority));
}

void registerHelperFunctionRule(MetaRuleObject &R) {
if (R.In == "DefaultQueue" && R.Priority == RulePriority::Takeover) {
MapNames::CustomHelperFunctionMap.insert(
{dpct::HelperFuncCatalog::DefaultQueue, R.Out});
dpct::DpctGlobalInfo::setUsingDRYPattern(false);
} else if (R.In == "HasCapabilityOrFail" &&
R.Priority == RulePriority::Takeover) {
MapNames::CustomHelperFunctionMap.insert(
{dpct::HelperFuncCatalog::HasCapabilityOrFail, R.Out});
dpct::DpctGlobalInfo::setUsingDRYPattern(false);
}
}

MetaRuleObject::PatternRewriter &MetaRuleObject::PatternRewriter::operator=(
const MetaRuleObject::PatternRewriter &PR) {
if (this != &PR) {
Expand Down Expand Up @@ -365,6 +378,9 @@ void importRules(std::vector<clang::tooling::UnifiedPath> &RuleFiles) {
case (RuleKind::CMakeRule):
registerCmakeMigrationRule(*r);
break;
case (RuleKind::HelperFunction):
registerHelperFunctionRule(*r);
break;
default:
break;
}
Expand Down
4 changes: 3 additions & 1 deletion clang/lib/DPCT/Rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ enum RuleKind {
Enum,
DisableAPIMigration,
PatternRewriter,
CMakeRule
CMakeRule,
HelperFunction
};

enum RulePriority { Takeover, Default, Fallback };
Expand Down Expand Up @@ -209,6 +210,7 @@ template <> struct llvm::yaml::ScalarEnumerationTraits<RuleKind> {
Io.enumCase(Value, "DisableAPIMigration", RuleKind::DisableAPIMigration);
Io.enumCase(Value, "PatternRewriter", RuleKind::PatternRewriter);
Io.enumCase(Value, "CMakeRule", RuleKind::CMakeRule);
Io.enumCase(Value, "HelperFunction", RuleKind::HelperFunction);
}
};

Expand Down
13 changes: 13 additions & 0 deletions clang/test/dpct/ipex_xpu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
---
- Rule: rule1
Kind: HelperFunction
Priority: Takeover
In: DefaultQueue
Out: static_cast<sycl::queue&>(c10::xpu::getCurrentXPUStream())
Includes: [""]
- Rule: rule1
Kind: HelperFunction
Priority: Takeover
In: HasCapabilityOrFail
Out: dpct::has_capability_or_fail(static_cast<sycl::queue&>(c10::xpu::getCurrentXPUStream()).get_device(), $1)
Includes: [""]
37 changes: 37 additions & 0 deletions clang/test/dpct/user_defined_rule2.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// RUN: dpct --out-root %T/user_defined_rule2 %s --cuda-include-path="%cuda-path/include" --rule-file %S/ipex_xpu.yaml --format-range=none
// RUN: FileCheck --input-file %T/user_defined_rule2/user_defined_rule2.dp.cpp --match-full-lines %s
// RUN: %if build_lit %{icpx -c -fsycl -DBUILD_TEST %T/user_defined_rule2/user_defined_rule2.dp.cpp -o %T/user_defined_rule2/user_defined_rule2.dp.o %}

#ifndef BUILD_TEST

__global__ void foo1_kernel() {}
void foo1() {
// CHECK: dpct::static_cast<sycl::queue&>(c10::xpu::getCurrentXPUStream()).parallel_for(
// CHECK-NEXT: sycl::nd_range<3>(sycl::range<3>(1, 1, 1), sycl::range<3>(1, 1, 1)),
// CHECK-NEXT: [=](sycl::nd_item<3> item_ct1) {
// CHECK-NEXT: foo1_kernel();
// CHECK-NEXT: });
foo1_kernel<<<1, 1>>>();
}

__global__ void foo2_kernel(double *d) {}

void foo2() {
double *d;
// CHECK: d = sycl::malloc_device<double>(1, dpct::static_cast<sycl::queue&>(c10::xpu::getCurrentXPUStream()));
// CHECK-NEXT: {
// CHECK-NEXT: dpct::has_capability_or_fail(static_cast<sycl::queue&>(c10::xpu::getCurrentXPUStream()).get_device(), {sycl::aspect::fp64});
// CHECK-EMPTY:
// CHECK-NEXT: dpct::static_cast<sycl::queue&>(c10::xpu::getCurrentXPUStream()).parallel_for(
// CHECK-NEXT: sycl::nd_range<3>(sycl::range<3>(1, 1, 1), sycl::range<3>(1, 1, 1)),
// CHECK-NEXT: [=](sycl::nd_item<3> item_ct1) {
// CHECK-NEXT: foo2_kernel(d);
// CHECK-NEXT: });
// CHECK-NEXT: }
// CHECK-NEXT: dpct::dpct_free(d, dpct::static_cast<sycl::queue&>(c10::xpu::getCurrentXPUStream()));
cudaMalloc(&d, sizeof(double));
foo2_kernel<<<1, 1>>>(d);
cudaFree(d);
}

#endif
Loading