Skip to content

Commit 05f0bb1

Browse files
[clang][Sema] Implement additional heuristics from SemaCodeComplete's getApproximateType() in HeuristicResolver
After this change, HeuristicResolver should be able to do everything that SemaCodeComplete's getApproximateType() can do (and more).
1 parent 64b9896 commit 05f0bb1

File tree

3 files changed

+167
-30
lines changed

3 files changed

+167
-30
lines changed

clang/include/clang/Sema/HeuristicResolver.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@ class HeuristicResolver {
5454
std::vector<const NamedDecl *>
5555
resolveDeclRefExpr(const DependentScopeDeclRefExpr *RE) const;
5656
std::vector<const NamedDecl *>
57-
resolveTypeOfCallExpr(const CallExpr *CE) const;
58-
std::vector<const NamedDecl *>
5957
resolveCalleeOfCallExpr(const CallExpr *CE) const;
6058
std::vector<const NamedDecl *>
6159
resolveUsingValueDecl(const UnresolvedUsingValueDecl *UUVD) const;
@@ -93,6 +91,10 @@ class HeuristicResolver {
9391
// during simplification, and the operation fails if no pointer type is found.
9492
QualType simplifyType(QualType Type, const Expr *E, bool UnwrapPointer);
9593

94+
// Try to heuristically resolve the type of a possibly-dependent expression
95+
// `E`.
96+
QualType resolveExprToType(const Expr *E) const;
97+
9698
// Given an expression `Fn` representing the callee in a function call,
9799
// if the call is through a function pointer, try to find the declaration of
98100
// the corresponding function pointer type, so that we can recover argument

clang/lib/Sema/HeuristicResolver.cpp

Lines changed: 66 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ class HeuristicResolverImpl {
3636
resolveMemberExpr(const CXXDependentScopeMemberExpr *ME);
3737
std::vector<const NamedDecl *>
3838
resolveDeclRefExpr(const DependentScopeDeclRefExpr *RE);
39-
std::vector<const NamedDecl *> resolveTypeOfCallExpr(const CallExpr *CE);
4039
std::vector<const NamedDecl *> resolveCalleeOfCallExpr(const CallExpr *CE);
4140
std::vector<const NamedDecl *>
4241
resolveUsingValueDecl(const UnresolvedUsingValueDecl *UUVD);
@@ -51,6 +50,7 @@ class HeuristicResolverImpl {
5150
llvm::function_ref<bool(const NamedDecl *ND)> Filter);
5251
TagDecl *resolveTypeToTagDecl(QualType T);
5352
QualType simplifyType(QualType Type, const Expr *E, bool UnwrapPointer);
53+
QualType resolveExprToType(const Expr *E);
5454
FunctionProtoTypeLoc getFunctionProtoTypeLoc(const Expr *Fn);
5555

5656
private:
@@ -72,10 +72,8 @@ class HeuristicResolverImpl {
7272
resolveDependentMember(QualType T, DeclarationName Name,
7373
llvm::function_ref<bool(const NamedDecl *ND)> Filter);
7474

75-
// Try to heuristically resolve the type of a possibly-dependent expression
76-
// `E`.
77-
QualType resolveExprToType(const Expr *E);
7875
std::vector<const NamedDecl *> resolveExprToDecls(const Expr *E);
76+
QualType resolveTypeOfCallExpr(const CallExpr *CE);
7977

8078
bool findOrdinaryMemberInDependentClasses(const CXXBaseSpecifier *Specifier,
8179
CXXBasePath &Path,
@@ -97,18 +95,25 @@ const auto TemplateFilter = [](const NamedDecl *D) {
9795
return isa<TemplateDecl>(D);
9896
};
9997

100-
QualType resolveDeclsToType(const std::vector<const NamedDecl *> &Decls,
101-
ASTContext &Ctx) {
102-
if (Decls.size() != 1) // Names an overload set -- just bail.
103-
return QualType();
104-
if (const auto *TD = dyn_cast<TypeDecl>(Decls[0]))
98+
QualType resolveDeclToType(const NamedDecl *D, ASTContext &Ctx) {
99+
if (const auto *TempD = dyn_cast<TemplateDecl>(D)) {
100+
D = TempD->getTemplatedDecl();
101+
}
102+
if (const auto *TD = dyn_cast<TypeDecl>(D))
105103
return Ctx.getCanonicalTypeDeclType(TD);
106-
if (const auto *VD = dyn_cast<ValueDecl>(Decls[0])) {
104+
if (const auto *VD = dyn_cast<ValueDecl>(D)) {
107105
return VD->getType();
108106
}
109107
return QualType();
110108
}
111109

110+
QualType resolveDeclsToType(const std::vector<const NamedDecl *> &Decls,
111+
ASTContext &Ctx) {
112+
if (Decls.size() != 1) // Names an overload set -- just bail.
113+
return QualType();
114+
return resolveDeclToType(Decls[0], Ctx);
115+
}
116+
112117
TemplateName getReferencedTemplateName(const Type *T) {
113118
if (const auto *TST = T->getAs<TemplateSpecializationType>()) {
114119
return TST->getTemplateName();
@@ -330,19 +335,29 @@ HeuristicResolverImpl::resolveDeclRefExpr(const DependentScopeDeclRefExpr *RE) {
330335
return resolveDependentMember(Qualifier, RE->getDeclName(), StaticFilter);
331336
}
332337

333-
std::vector<const NamedDecl *>
334-
HeuristicResolverImpl::resolveTypeOfCallExpr(const CallExpr *CE) {
335-
QualType CalleeType = resolveExprToType(CE->getCallee());
336-
if (CalleeType.isNull())
337-
return {};
338-
if (const auto *FnTypePtr = CalleeType->getAs<PointerType>())
339-
CalleeType = FnTypePtr->getPointeeType();
340-
if (const FunctionType *FnType = CalleeType->getAs<FunctionType>()) {
341-
if (const auto *D = resolveTypeToTagDecl(FnType->getReturnType())) {
342-
return {D};
338+
QualType HeuristicResolverImpl::resolveTypeOfCallExpr(const CallExpr *CE) {
339+
// resolveExprToType(CE->getCallee()) would bail in the case of multiple
340+
// overloads, as it can't produce a single type for them. We can be more
341+
// permissive here, and allow multiple overloads with a common return type.
342+
std::vector<const NamedDecl *> CalleeDecls =
343+
resolveExprToDecls(CE->getCallee());
344+
QualType CommonReturnType;
345+
for (const NamedDecl *CalleeDecl : CalleeDecls) {
346+
QualType CalleeType = resolveDeclToType(CalleeDecl, Ctx);
347+
if (CalleeType.isNull())
348+
continue;
349+
if (const auto *FnTypePtr = CalleeType->getAs<PointerType>())
350+
CalleeType = FnTypePtr->getPointeeType();
351+
if (const FunctionType *FnType = CalleeType->getAs<FunctionType>()) {
352+
QualType ReturnType =
353+
simplifyType(FnType->getReturnType(), nullptr, false);
354+
if (!CommonReturnType.isNull() && CommonReturnType != ReturnType) {
355+
return {}; // conflicting return types
356+
}
357+
CommonReturnType = ReturnType;
343358
}
344359
}
345-
return {};
360+
return CommonReturnType;
346361
}
347362

348363
std::vector<const NamedDecl *>
@@ -393,15 +408,41 @@ HeuristicResolverImpl::resolveExprToDecls(const Expr *E) {
393408
return {OE->decls_begin(), OE->decls_end()};
394409
}
395410
if (const auto *CE = dyn_cast<CallExpr>(E)) {
396-
return resolveTypeOfCallExpr(CE);
411+
QualType T = resolveTypeOfCallExpr(CE);
412+
if (const auto *D = resolveTypeToTagDecl(T)) {
413+
return {D};
414+
}
415+
return {};
397416
}
398417
if (const auto *ME = dyn_cast<MemberExpr>(E))
399418
return {ME->getMemberDecl()};
419+
if (const auto *DRE = dyn_cast<DeclRefExpr>(E))
420+
return {DRE->getDecl()};
400421

401422
return {};
402423
}
403424

404425
QualType HeuristicResolverImpl::resolveExprToType(const Expr *E) {
426+
// resolveExprToDecls on a CallExpr only succeeds if the return type is
427+
// a TagDecl, but we may want the type of a call in other cases as well.
428+
// (FIXME: There are probably other cases where we can do something more
429+
// flexible than resoveExprToDecls + resolveDeclsToType, e.g. in the case
430+
// of OverloadExpr we can probably accept overloads with a common type).
431+
if (const auto *CE = dyn_cast<CallExpr>(E)) {
432+
if (QualType Resolved = resolveTypeOfCallExpr(CE); !Resolved.isNull())
433+
return Resolved;
434+
}
435+
// Similarly, unwrapping a unary dereference operation does not work via
436+
// resolveExprToDecls.
437+
if (const auto *UO = dyn_cast<UnaryOperator>(E->IgnoreParenCasts())) {
438+
if (UO->getOpcode() == UnaryOperatorKind::UO_Deref) {
439+
if (auto Pointee = getPointeeType(resolveExprToType(UO->getSubExpr()));
440+
!Pointee.isNull()) {
441+
return Pointee;
442+
}
443+
}
444+
}
445+
405446
std::vector<const NamedDecl *> Decls = resolveExprToDecls(E);
406447
if (!Decls.empty())
407448
return resolveDeclsToType(Decls, Ctx);
@@ -580,10 +621,6 @@ std::vector<const NamedDecl *> HeuristicResolver::resolveDeclRefExpr(
580621
return HeuristicResolverImpl(Ctx).resolveDeclRefExpr(RE);
581622
}
582623
std::vector<const NamedDecl *>
583-
HeuristicResolver::resolveTypeOfCallExpr(const CallExpr *CE) const {
584-
return HeuristicResolverImpl(Ctx).resolveTypeOfCallExpr(CE);
585-
}
586-
std::vector<const NamedDecl *>
587624
HeuristicResolver::resolveCalleeOfCallExpr(const CallExpr *CE) const {
588625
return HeuristicResolverImpl(Ctx).resolveCalleeOfCallExpr(CE);
589626
}
@@ -619,7 +656,9 @@ QualType HeuristicResolver::simplifyType(QualType Type, const Expr *E,
619656
bool UnwrapPointer) {
620657
return HeuristicResolverImpl(Ctx).simplifyType(Type, E, UnwrapPointer);
621658
}
622-
659+
QualType HeuristicResolver::resolveExprToType(const Expr *E) const {
660+
return HeuristicResolverImpl(Ctx).resolveExprToType(E);
661+
}
623662
FunctionProtoTypeLoc
624663
HeuristicResolver::getFunctionProtoTypeLoc(const Expr *Fn) const {
625664
return HeuristicResolverImpl(Ctx).getFunctionProtoTypeLoc(Fn);

clang/unittests/Sema/HeuristicResolverTest.cpp

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,6 @@ TEST(HeuristicResolver, MemberExpr_AutoTypeDeduction2) {
203203
struct B {
204204
int waldo;
205205
};
206-
207206
template <typename T>
208207
struct A {
209208
B b;
@@ -238,6 +237,103 @@ TEST(HeuristicResolver, MemberExpr_Chained) {
238237
cxxMethodDecl(hasName("foo")).bind("output"));
239238
}
240239

240+
TEST(HeuristicResolver, MemberExpr_Chained_ReferenceType) {
241+
std::string Code = R"cpp(
242+
struct B {
243+
int waldo;
244+
};
245+
template <typename T>
246+
struct A {
247+
B &foo();
248+
};
249+
template <typename T>
250+
void bar(A<T> a) {
251+
a.foo().waldo;
252+
}
253+
)cpp";
254+
// Test resolution of "waldo" in "a.foo().waldo"
255+
expectResolution(
256+
Code, &HeuristicResolver::resolveMemberExpr,
257+
cxxDependentScopeMemberExpr(hasMemberName("waldo")).bind("input"),
258+
fieldDecl(hasName("waldo")).bind("output"));
259+
}
260+
261+
TEST(HeuristicResolver, MemberExpr_Chained_PointerArrow) {
262+
std::string Code = R"cpp(
263+
struct B {
264+
int waldo;
265+
};
266+
template <typename T>
267+
B* foo(T);
268+
template <class T>
269+
void bar(T t) {
270+
foo(t)->waldo;
271+
}
272+
)cpp";
273+
// Test resolution of "waldo" in "foo(t)->waldo"
274+
expectResolution(
275+
Code, &HeuristicResolver::resolveMemberExpr,
276+
cxxDependentScopeMemberExpr(hasMemberName("waldo")).bind("input"),
277+
fieldDecl(hasName("waldo")).bind("output"));
278+
}
279+
280+
TEST(HeuristicResolver, MemberExpr_Chained_PointerDeref) {
281+
std::string Code = R"cpp(
282+
struct B {
283+
int waldo;
284+
};
285+
template <typename T>
286+
B* foo(T);
287+
template <class T>
288+
void bar(T t) {
289+
(*foo(t)).waldo;
290+
}
291+
)cpp";
292+
// Test resolution of "waldo" in "foo(t)->waldo"
293+
expectResolution(
294+
Code, &HeuristicResolver::resolveMemberExpr,
295+
cxxDependentScopeMemberExpr(hasMemberName("waldo")).bind("input"),
296+
fieldDecl(hasName("waldo")).bind("output"));
297+
}
298+
299+
TEST(HeuristicResolver, MemberExpr_Chained_Overload) {
300+
std::string Code = R"cpp(
301+
struct B {
302+
int waldo;
303+
};
304+
B overloaded(int);
305+
B overloaded(double);
306+
template <typename T>
307+
void foo(T t) {
308+
overloaded(t).waldo;
309+
}
310+
)cpp";
311+
// Test resolution of "waldo" in "overloaded(t).waldo"
312+
expectResolution(
313+
Code, &HeuristicResolver::resolveMemberExpr,
314+
cxxDependentScopeMemberExpr(hasMemberName("waldo")).bind("input"),
315+
fieldDecl(hasName("waldo")).bind("output"));
316+
}
317+
318+
TEST(HeuristicResolver, MemberExpr_CallToFunctionTemplate) {
319+
std::string Code = R"cpp(
320+
struct B {
321+
int waldo;
322+
};
323+
template <typename T>
324+
B bar(T);
325+
template <typename T>
326+
void foo(T t) {
327+
bar(t).waldo;
328+
}
329+
)cpp";
330+
// Test resolution of "waldo" in "bar(t).waldo"
331+
expectResolution(
332+
Code, &HeuristicResolver::resolveMemberExpr,
333+
cxxDependentScopeMemberExpr(hasMemberName("waldo")).bind("input"),
334+
fieldDecl(hasName("waldo")).bind("output"));
335+
}
336+
241337
TEST(HeuristicResolver, MemberExpr_ReferenceType) {
242338
std::string Code = R"cpp(
243339
struct B {

0 commit comments

Comments
 (0)