Skip to content

Commit 185e9a6

Browse files
authored
[SYCLomatic] Support migration of cusparse<T>csrgemm and cusparseXcsrgemmNnz (#2065)
Signed-off-by: Jiang, Zhiwei <[email protected]>
1 parent 0748273 commit 185e9a6

File tree

6 files changed

+562
-9
lines changed

6 files changed

+562
-9
lines changed

clang/lib/DPCT/Diagnostics/Diagnostics.inc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,8 @@ DEF_WARNING(UNSUPPORT_SYCLCOMPAT, 1131, MEDIUM_LEVEL, "The migration of \"%0\" i
292292
DEF_COMMENT(UNSUPPORT_SYCLCOMPAT, 1131, MEDIUM_LEVEL, "The migration of \"{0}\" is not currently supported with SYCLcompat. Please adjust the code manually.")
293293
DEF_WARNING(UNSUPPORTED_KERNEL_ATTRIBUTE, 1132, HIGH_LEVEL, "SYCL 2020 does not support accessing the %0 for the kernel. The API is replaced with member variable \"%1\". Please set the appropriate value for \"%1\".")
294294
DEF_COMMENT(UNSUPPORTED_KERNEL_ATTRIBUTE, 1132, HIGH_LEVEL, "SYCL 2020 does not support accessing the {0} for the kernel. The API is replaced with member variable \"{1}\". Please set the appropriate value for \"{1}\".")
295+
DEF_WARNING(SPARSE_NNZ, 1134, MEDIUM_LEVEL, "The tool cannot deduce the consumer API (\"dpct::sparse::csrgemm\") of this API, and this API has 2 arguments depending on the 8th and the 12th parameters of the consumer API. Please replace the 2 arguments tagged as \"dpct_placeholder\" with the corresponding value.")
296+
DEF_COMMENT(SPARSE_NNZ, 1134, MEDIUM_LEVEL, "The tool cannot deduce the consumer API (\"dpct::sparse::csrgemm\") of this API, and this API has 2 arguments depending on the 8th and the 12th parameters of the consumer API. Please replace the 2 arguments tagged as \"dpct_placeholder\" with the corresponding value.")
295297
// clang-format on
296298

297299
#undef DEF_COMMENT

clang/lib/DPCT/RulesLang/RulesLang.cpp

Lines changed: 111 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2844,8 +2844,9 @@ void SPBLASFunctionCallRule::registerMatcher(MatchFinder &MF) {
28442844
"cusparseCsrsv_solveEx",
28452845
/*level 3*/
28462846
"cusparseScsrmm", "cusparseDcsrmm", "cusparseCcsrmm", "cusparseZcsrmm",
2847-
"cusparseScsrmm2", "cusparseDcsrmm2", "cusparseCcsrmm2",
2848-
"cusparseZcsrmm2",
2847+
"cusparseScsrgemm", "cusparseDcsrgemm", "cusparseCcsrgemm",
2848+
"cusparseZcsrgemm", "cusparseXcsrgemmNnz", "cusparseScsrmm2",
2849+
"cusparseDcsrmm2", "cusparseCcsrmm2", "cusparseZcsrmm2",
28492850
/*Generic*/
28502851
"cusparseCreateCsr", "cusparseDestroySpMat", "cusparseCsrGet",
28512852
"cusparseSpMatGetFormat", "cusparseSpMatGetIndexBase",
@@ -2913,6 +2914,114 @@ void SPBLASFunctionCallRule::runRule(const MatchFinder::MatchResult &Result) {
29132914
EA.applyAllSubExprRepl();
29142915
return;
29152916
}
2917+
if (FuncName == "cusparseXcsrgemmNnz") {
2918+
std::vector<std::string> MigratedArgs;
2919+
for (const auto &Arg : CE->arguments()) {
2920+
MigratedArgs.push_back(ExprAnalysis::ref(Arg));
2921+
}
2922+
// We need find the next cusparse<T>csrgemm API call which is using the
2923+
// result of this API call, otherwise a warning will be emitted.
2924+
auto findOuterCS = [](const Stmt *Input) {
2925+
const CompoundStmt *CS = nullptr;
2926+
DpctGlobalInfo::findAncestor<Stmt>(
2927+
Input, [&](const DynTypedNode &Cur) -> bool {
2928+
if (Cur.get<DoStmt>() || Cur.get<ForStmt>() ||
2929+
Cur.get<WhileStmt>() || Cur.get<SwitchStmt>() ||
2930+
Cur.get<IfStmt>())
2931+
return true;
2932+
if (const CompoundStmt *S = Cur.get<CompoundStmt>())
2933+
CS = S;
2934+
return false;
2935+
});
2936+
return CS;
2937+
};
2938+
const CompoundStmt *CS1 = findOuterCS(CE);
2939+
// Find all the cusparse<T>csrgemm calls in this range.
2940+
using namespace clang::ast_matchers;
2941+
auto Matcher =
2942+
findAll(callExpr(callee(functionDecl(hasAnyName(
2943+
"cusparseScsrgemm", "cusparseDcsrgemm",
2944+
"cusparseCcsrgemm", "cusparseZcsrgemm"))))
2945+
.bind("CallExpr"));
2946+
auto CEResults = match(Matcher, *CS1, DpctGlobalInfo::getContext());
2947+
// Find the correct call
2948+
const CallExpr* CorrectCall = nullptr;
2949+
for (auto &Result : CEResults) {
2950+
const CallExpr *MatchedCE = Result.getNodeAs<CallExpr>("CallExpr");
2951+
if (MatchedCE) {
2952+
// 1. The context should be the same
2953+
const CompoundStmt *CS2 = findOuterCS(MatchedCE);
2954+
if (CS1 != CS2)
2955+
continue;
2956+
// 2. The args should be the same
2957+
std::vector<std::string> MatchedCEMigratedArgs;
2958+
for (const auto &Arg : MatchedCE->arguments()) {
2959+
MatchedCEMigratedArgs.push_back(ExprAnalysis::ref(Arg));
2960+
}
2961+
if ([&]() -> bool {
2962+
const static std::map<unsigned /*CE*/, unsigned /*MatchedCE*/>
2963+
IdxMap = {
2964+
{0, 0}, {1, 1}, {2, 2}, {3, 3},
2965+
{4, 4}, {5, 5}, {6, 6}, {7, 7},
2966+
{8, 9}, {9, 10}, {10, 11}, {11, 12},
2967+
{12, 14}, {13, 15}, {14, 16}, {15, 18},
2968+
};
2969+
for (const auto &P : IdxMap) {
2970+
if (MigratedArgs[P.first] != MatchedCEMigratedArgs[P.second]) {
2971+
return false;
2972+
}
2973+
}
2974+
return true;
2975+
}()) {
2976+
CorrectCall = MatchedCE;
2977+
break;
2978+
}
2979+
}
2980+
}
2981+
const constexpr int Placeholder = -1;
2982+
std::map<int /*CE*/, int /*MatchedCE*/> InsertBeforeIdxMap;
2983+
if (CorrectCall) {
2984+
InsertBeforeIdxMap = {
2985+
{8, 8},
2986+
{12, 13},
2987+
};
2988+
} else {
2989+
report(
2990+
DpctGlobalInfo::getSourceManager().getExpansionLoc(CE->getBeginLoc()),
2991+
Diagnostics::SPARSE_NNZ, true);
2992+
InsertBeforeIdxMap = {
2993+
{8, Placeholder},
2994+
{12, Placeholder},
2995+
};
2996+
}
2997+
std::string MigratedCall;
2998+
MigratedCall = MapNames::getDpctNamespace() + "sparse::csrgemm_nnz(";
2999+
for (unsigned i = 0; i < MigratedArgs.size(); i++) {
3000+
if (auto Iter = InsertBeforeIdxMap.find(i);
3001+
Iter != InsertBeforeIdxMap.end()) {
3002+
if (Iter->second == Placeholder) {
3003+
MigratedCall += ("dpct_placeholder, ");
3004+
} else {
3005+
MigratedCall += (ExprAnalysis::ref(
3006+
CorrectCall->getArg(InsertBeforeIdxMap.at(i))) +
3007+
", ");
3008+
}
3009+
}
3010+
MigratedCall += MigratedArgs[i];
3011+
if (i != MigratedArgs.size() - 1)
3012+
MigratedCall += ", ";
3013+
}
3014+
MigratedCall += ")";
3015+
auto DefRange = getDefinitionRange(CE->getBeginLoc(), CE->getEndLoc());
3016+
SourceLocation Begin = DefRange.getBegin();
3017+
SourceLocation End = DefRange.getEnd();
3018+
End = End.getLocWithOffset(
3019+
Lexer::MeasureTokenLength(End, DpctGlobalInfo::getSourceManager(),
3020+
DpctGlobalInfo::getContext().getLangOpts()));
3021+
emplaceTransformation(replaceText(Begin, End, std::move(MigratedCall),
3022+
DpctGlobalInfo::getSourceManager()));
3023+
return;
3024+
}
29163025
}
29173026

29183027

clang/lib/DPCT/RulesMathLib/APINamesCUSPARSE.inc

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,3 +531,28 @@ ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY(
531531
"cusparseSpSM_solve", CALL(MapNames::getLibraryHelperNamespace() + "sparse::spsm",
532532
MEMBER_CALL(ARG(0), true, "get_queue"), ARG(1),
533533
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(7))))
534+
535+
ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY(
536+
"cusparseScsrgemm",
537+
CALL(MapNames::getDpctNamespace() + "sparse::csrgemm", ARG(0), ARG(1),
538+
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(8), ARG(9), ARG(10),
539+
ARG(11), ARG(13), ARG(14), ARG(15), ARG(16), ARG(17), ARG(18),
540+
ARG(19))))
541+
ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY(
542+
"cusparseDcsrgemm",
543+
CALL(MapNames::getDpctNamespace() + "sparse::csrgemm", ARG(0), ARG(1),
544+
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(8), ARG(9), ARG(10),
545+
ARG(11), ARG(13), ARG(14), ARG(15), ARG(16), ARG(17), ARG(18),
546+
ARG(19))))
547+
ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY(
548+
"cusparseCcsrgemm",
549+
CALL(MapNames::getDpctNamespace() + "sparse::csrgemm", ARG(0), ARG(1),
550+
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(8), ARG(9), ARG(10),
551+
ARG(11), ARG(13), ARG(14), ARG(15), ARG(16), ARG(17), ARG(18),
552+
ARG(19))))
553+
ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY(
554+
"cusparseZcsrgemm",
555+
CALL(MapNames::getDpctNamespace() + "sparse::csrgemm", ARG(0), ARG(1),
556+
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(8), ARG(9), ARG(10),
557+
ARG(11), ARG(13), ARG(14), ARG(15), ARG(16), ARG(17), ARG(18),
558+
ARG(19))))

clang/lib/DPCT/SrcAPI/APINames_cuSPARSE.inc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -244,11 +244,11 @@ ENTRY(cusparseScsrgeam2, cusparseScsrgeam2, false, NO_FLAG, P4, "comment")
244244
ENTRY(cusparseDcsrgeam2, cusparseDcsrgeam2, false, NO_FLAG, P4, "comment")
245245
ENTRY(cusparseCcsrgeam2, cusparseCcsrgeam2, false, NO_FLAG, P4, "comment")
246246
ENTRY(cusparseZcsrgeam2, cusparseZcsrgeam2, false, NO_FLAG, P4, "comment")
247-
ENTRY(cusparseXcsrgemmNnz, cusparseXcsrgemmNnz, false, NO_FLAG, P4, "comment")
248-
ENTRY(cusparseScsrgemm, cusparseScsrgemm, false, NO_FLAG, P4, "comment")
249-
ENTRY(cusparseDcsrgemm, cusparseDcsrgemm, false, NO_FLAG, P4, "comment")
250-
ENTRY(cusparseCcsrgemm, cusparseCcsrgemm, false, NO_FLAG, P4, "comment")
251-
ENTRY(cusparseZcsrgemm, cusparseZcsrgemm, false, NO_FLAG, P4, "comment")
247+
ENTRY(cusparseXcsrgemmNnz, cusparseXcsrgemmNnz, true, NO_FLAG, P4, "DPCT1130")
248+
ENTRY(cusparseScsrgemm, cusparseScsrgemm, true, NO_FLAG, P4, "comment")
249+
ENTRY(cusparseDcsrgemm, cusparseDcsrgemm, true, NO_FLAG, P4, "comment")
250+
ENTRY(cusparseCcsrgemm, cusparseCcsrgemm, true, NO_FLAG, P4, "comment")
251+
ENTRY(cusparseZcsrgemm, cusparseZcsrgemm, true, NO_FLAG, P4, "comment")
252252
ENTRY(cusparseScsrgemm2_bufferSizeExt, cusparseScsrgemm2_bufferSizeExt, false, NO_FLAG, P4, "comment")
253253
ENTRY(cusparseDcsrgemm2_bufferSizeExt, cusparseDcsrgemm2_bufferSizeExt, false, NO_FLAG, P4, "comment")
254254
ENTRY(cusparseCcsrgemm2_bufferSizeExt, cusparseCcsrgemm2_bufferSizeExt, false, NO_FLAG, P4, "comment")

0 commit comments

Comments
 (0)