Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions clang/lib/DPCT/CUBAPIMigration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -558,14 +558,14 @@ void CubDeviceLevelRule::removeRedundantTempVar(const CallExpr *CE) {
void CubRule::registerMatcher(ast_matchers::MatchFinder &MF) {
MF.addMatcher(
typeLoc(loc(qualType(hasDeclaration(namedDecl(hasAnyName(
"WarpScan", "WarpReduce", "BlockScan", "BlockReduce"))))))
"WarpScan", "WarpReduce", "BlockScan", "BlockReduce", "BlockLoad"))))))
.bind("TypeLoc"),
this);

MF.addMatcher(
typedefDecl(
hasType(hasCanonicalType(qualType(hasDeclaration(namedDecl(hasAnyName(
"WarpScan", "WarpReduce", "BlockScan", "BlockReduce")))))))
"WarpScan", "WarpReduce", "BlockScan", "BlockReduce", "BlockLoad")))))))
.bind("TypeDefDecl"),
this);

Expand Down Expand Up @@ -684,7 +684,8 @@ void CubRule::processCubDeclStmt(const DeclStmt *DS) {
ObjTypeStr.find("class cub::WarpReduce") == 0) {
Repl = DpctGlobalInfo::getSubGroup(DRE);
} else if (ObjTypeStr.find("class cub::BlockScan") == 0 ||
ObjTypeStr.find("class cub::BlockReduce") == 0) {
ObjTypeStr.find("class cub::BlockReduce") == 0 ||
ObjTypeStr.find("class cub::BlockLoad") == 0) {
Repl = DpctGlobalInfo::getGroup(DRE);
} else {
continue;
Expand Down Expand Up @@ -749,7 +750,8 @@ void CubRule::processCubTypeDef(const TypedefDecl *TD) {
!(ObjTypeStr.find("class cub::WarpScan") == 0 ||
ObjTypeStr.find("class cub::WarpReduce") == 0 ||
ObjTypeStr.find("class cub::BlockScan") == 0 ||
ObjTypeStr.find("class cub::BlockReduce") == 0)) {
ObjTypeStr.find("class cub::BlockReduce") == 0 ||
ObjTypeStr.find("class cub::BlockLoad") == 0)) {
DeleteFlag = false;
break;
}
Expand Down Expand Up @@ -1304,7 +1306,8 @@ void CubRule::processCubMemberCall(const CXXMemberCallExpr *MC) {
ObjTypeStr.find("class cub::WarpReduce") == 0) {
processWarpLevelMemberCall(MC);
} else if (ObjTypeStr.find("class cub::BlockScan") == 0 ||
ObjTypeStr.find("class cub::BlockReduce") == 0) {
ObjTypeStr.find("class cub::BlockReduce") == 0 ||
ObjTypeStr.find("class cub::BlockLoad") == 0) {
processBlockLevelMemberCall(MC);
} else {
report(MC->getBeginLoc(), Diagnostics::API_NOT_MIGRATED, false, ObjTypeStr);
Expand All @@ -1328,7 +1331,8 @@ void CubRule::processTypeLoc(const TypeLoc *TL) {
MapNames::getClNamespace() + "sub_group",
SM));
} else if (TypeName.find("class cub::BlockScan") == 0 ||
TypeName.find("class cub::BlockReduce") == 0) {
TypeName.find("class cub::BlockReduce") == 0 ||
ObjTypeStr.find("class cub::BlockLoad") == 0) {
auto DeviceFuncDecl = DpctGlobalInfo::findAncestor<FunctionDecl>(TL);
if (DeviceFuncDecl && (DeviceFuncDecl->hasAttr<CUDADeviceAttr>() ||
DeviceFuncDecl->hasAttr<CUDAGlobalAttr>())) {
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/DPCT/ExprAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1170,7 +1170,7 @@ void ExprAnalysis::analyzeType(TypeLoc TL, const Expr *CSCE,
}
}
if (OS.str() != "cub::WarpScan" && OS.str() != "cub::WarpReduce" &&
OS.str() != "cub::BlockReduce" && OS.str() != "cub::BlockScan") {
OS.str() != "cub::BlockReduce" && OS.str() != "cub::BlockScan" && OS.str != "cub::BlockLoad") {
SR.setEnd(TSTL.getTemplateNameLoc());
}
analyzeTemplateSpecializationType(TSTL);
Expand Down