Skip to content

Commit 9169d52

Browse files
authored
[SYCLomatic] Do not add dpct_operator_overloading namespace for user-defined types (#2011)
Signed-off-by: Jiang, Zhiwei <[email protected]>
1 parent 7e4ac49 commit 9169d52

File tree

3 files changed

+33
-3
lines changed

3 files changed

+33
-3
lines changed

clang/lib/DPCT/ASTTraversal.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2214,7 +2214,11 @@ void TypeInDeclRule::runRule(const MatchFinder::MatchResult &Result) {
22142214
return;
22152215
}
22162216
if (auto TL = getNodeAsType<TypeLoc>(Result, "cudaTypeDef")) {
2217-
2217+
if (const auto *ND = getNamedDecl(TL->getTypePtr())) {
2218+
auto Loc = ND->getBeginLoc();
2219+
if (DpctGlobalInfo::isInAnalysisScope(Loc))
2220+
return;
2221+
}
22182222
// if TL is the T in
22192223
// template<typename T> void foo(T a);
22202224
if (TL->getType()->getTypeClass() == clang::Type::SubstTemplateTypeParm ||
@@ -2828,8 +2832,16 @@ AST_MATCHER(FunctionDecl, overloadedVectorOperator) {
28282832
return false;
28292833

28302834
const std::string TypeName = IDInfo->getName().str();
2831-
return (MapNames::SupportedVectorTypes.find(TypeName) !=
2832-
MapNames::SupportedVectorTypes.end());
2835+
if (MapNames::SupportedVectorTypes.find(TypeName) !=
2836+
MapNames::SupportedVectorTypes.end()) {
2837+
if (const auto *ND = getNamedDecl(PD->getType().getTypePtr())) {
2838+
auto Loc = ND->getBeginLoc();
2839+
if (DpctGlobalInfo::isInAnalysisScope(Loc))
2840+
return false;
2841+
}
2842+
return true;
2843+
}
2844+
return false;
28332845
};
28342846

28352847
// As long as one parameter is vector type

clang/lib/DPCT/Utility.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3262,6 +3262,10 @@ const NamedDecl *getNamedDecl(const clang::Type *TypePtr) {
32623262
ND = getNamedDecl(ET->getNamedType().getTypePtr());
32633263
} else if (auto TST = dyn_cast<clang::TemplateSpecializationType>(TypePtr)) {
32643264
ND = TST->getTemplateName().getAsTemplateDecl();
3265+
} else if (auto LVRT = dyn_cast<clang::LValueReferenceType>(TypePtr)) {
3266+
ND = getNamedDecl(LVRT->getPointeeType().getTypePtr());
3267+
} else if (auto RVRT = dyn_cast<clang::RValueReferenceType>(TypePtr)) {
3268+
ND = getNamedDecl(RVRT->getPointeeType().getTypePtr());
32653269
}
32663270
return ND;
32673271
}

clang/test/dpct/overload_operator.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// RUN: dpct --format-range=none --out-root %T/overload_operator %s --cuda-include-path="%cuda-path/include" --extra-arg="-xc++" || true
2+
// RUN: FileCheck %s --match-full-lines --input-file %T/overload_operator/overload_operator.cpp
3+
// RUN: %if build_lit %{icpx -c -fsycl %T/overload_operator/overload_operator.cpp -o %T/overload_operator/overload_operator.o %}
4+
5+
// CHECK: struct half {};
6+
// CHECK-NEXT: half operator+(half &&a, half &&b) { return a; }
7+
// CHECK-NEXT: half operator-(half &a, half &b) { return a; }
8+
// CHECK-NEXT: void foo1() { half() + half(); }
9+
// CHECK-NEXT: void foo2(half &a, half &b) { a - b; }
10+
struct half {};
11+
half operator+(half &&a, half &&b) { return a; }
12+
half operator-(half &a, half &b) { return a; }
13+
void foo1() { half() + half(); }
14+
void foo2(half &a, half &b) { a - b; }

0 commit comments

Comments
 (0)