Skip to content

Commit c614a29

Browse files
Improve template instantiate
Match parameter types to argument types in cases of when the parameter type itself is templated
1 parent 294da5b commit c614a29

File tree

3 files changed

+175
-0
lines changed

3 files changed

+175
-0
lines changed

include/clang/Interpreter/CppInterOp.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,15 @@ namespace Cpp {
676676
const char* m_IntegralValue;
677677
TemplateArgInfo(TCppScope_t type, const char* integral_value = nullptr)
678678
: m_Type(type), m_IntegralValue(integral_value) {}
679+
friend bool operator==(const TemplateArgInfo& lhs,
680+
const TemplateArgInfo& rhs) {
681+
return (lhs.m_Type == rhs.m_Type &&
682+
lhs.m_IntegralValue == rhs.m_IntegralValue);
683+
}
684+
friend bool operator!=(const TemplateArgInfo& lhs,
685+
const TemplateArgInfo& rhs) {
686+
return !(lhs == rhs);
687+
}
679688
};
680689
/// Builds a template instantiation for a given templated declaration.
681690
/// Offers a single interface for instantiation of class, function and

lib/Interpreter/CppInterOp.cpp

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "clang/AST/CXXInheritance.h"
1515
#include "clang/AST/Decl.h"
1616
#include "clang/AST/DeclCXX.h"
17+
#include "clang/AST/DeclTemplate.h"
1718
#include "clang/AST/GlobalDecl.h"
1819
#include "clang/AST/Mangle.h"
1920
#include "clang/AST/QualTypeNames.h"
@@ -42,6 +43,7 @@
4243
#include <set>
4344
#include <sstream>
4445
#include <string>
46+
#include <vector>
4547

4648
// Stream redirect.
4749
#ifdef _WIN32
@@ -1026,11 +1028,103 @@ namespace Cpp {
10261028
funcs.push_back(Found);
10271029
}
10281030

1031+
namespace {
1032+
inline void
1033+
collectUniqueTemplateArgs(const std::vector<TemplateArgInfo>& templ_types,
1034+
std::vector<TemplateArgInfo>& result) {
1035+
std::unique_copy(templ_types.begin(), templ_types.end(),
1036+
std::back_inserter(result));
1037+
}
1038+
bool
1039+
IsTemplateFunctionGoodMatch(const FunctionTemplateDecl* FTD,
1040+
const std::vector<TemplateArgInfo>& arg_types,
1041+
std::vector<TemplateArgInfo>& templ_types) {
1042+
const FunctionDecl* F = FTD->getTemplatedDecl();
1043+
clang::TemplateParameterList* tpl = FTD->getTemplateParameters();
1044+
1045+
if (arg_types.size() != F->getNumParams())
1046+
return false;
1047+
1048+
for (size_t i = 0; i < arg_types.size(); i++) {
1049+
QualType fn_arg_type = F->getParamDecl(i)->getType();
1050+
QualType arg_type = QualType::getFromOpaquePtr(arg_types[i].m_Type);
1051+
1052+
// dereference
1053+
if (fn_arg_type->isReferenceType())
1054+
fn_arg_type = fn_arg_type.getNonReferenceType();
1055+
if (arg_type->isReferenceType())
1056+
arg_type = arg_type.getNonReferenceType();
1057+
1058+
fn_arg_type = fn_arg_type.getCanonicalType();
1059+
arg_type = arg_type.getCanonicalType();
1060+
1061+
// matching parameter and argument types
1062+
// resolving parameter
1063+
const auto* fn_TST =
1064+
fn_arg_type->getAs<clang::TemplateSpecializationType>();
1065+
const TemplateDecl* fn_TD = nullptr;
1066+
if (fn_TST)
1067+
fn_TD = fn_TST->getTemplateName().getAsTemplateDecl();
1068+
1069+
// resolving argument
1070+
const auto* arg_RT = arg_type->getAs<clang::RecordType>();
1071+
ClassTemplateSpecializationDecl* arg_CTSD = nullptr;
1072+
if (arg_RT)
1073+
arg_CTSD = llvm::dyn_cast<clang::ClassTemplateSpecializationDecl>(
1074+
arg_RT->getDecl());
1075+
1076+
if ((!arg_CTSD || !fn_TD) && (arg_CTSD || fn_TD))
1077+
return false;
1078+
1079+
// check if types match
1080+
if (arg_CTSD) {
1081+
auto* arg_D = arg_CTSD->getSpecializedTemplate()->getCanonicalDecl();
1082+
if (arg_D != fn_TD->getCanonicalDecl())
1083+
return false;
1084+
if (templ_types.size() < tpl->size()) {
1085+
Cpp::GetClassTemplateInstantiationArgs(arg_CTSD, templ_types);
1086+
break;
1087+
}
1088+
} else if (templ_types.size() < tpl->size()) {
1089+
templ_types.push_back(arg_types[i]);
1090+
}
1091+
}
1092+
return true;
1093+
}
1094+
} // namespace
1095+
10291096
TCppFunction_t
10301097
BestTemplateFunctionMatch(const std::vector<TCppFunction_t>& candidates,
10311098
const std::vector<TemplateArgInfo>& explicit_types,
10321099
const std::vector<TemplateArgInfo>& arg_types) {
10331100

1101+
/*
1102+
Try matching function with templated class as arguments first
1103+
Example:
1104+
1105+
template<typename T>
1106+
struct A { T value; };
1107+
1108+
template<typename T>
1109+
void somefunc(A<T> arg); // overload 1
1110+
1111+
template<typename T>
1112+
void somefunc(T arg); // overload 2
1113+
1114+
somefunc(A<int>()); // should call overload 1; resolve this first
1115+
somefunc(3); // should call overload 2
1116+
*/
1117+
for (const auto& candidate : candidates) {
1118+
std::vector<TemplateArgInfo> templ_types;
1119+
auto* TFD = static_cast<FunctionTemplateDecl*>(candidate);
1120+
if (IsTemplateFunctionGoodMatch(TFD, arg_types, templ_types)) {
1121+
TCppFunction_t instantiated = InstantiateTemplate(
1122+
candidate, templ_types.data(), templ_types.size());
1123+
if (instantiated)
1124+
return instantiated;
1125+
}
1126+
}
1127+
10341128
for (const auto& candidate : candidates) {
10351129
auto* TFD = (FunctionTemplateDecl*)candidate;
10361130
clang::TemplateParameterList* tpl = TFD->getTemplateParameters();
@@ -1060,9 +1154,19 @@ namespace Cpp {
10601154
if (instantiated)
10611155
return instantiated;
10621156

1157+
std::vector<TemplateArgInfo> unique_arg_types;
1158+
collectUniqueTemplateArgs(arg_types, unique_arg_types);
1159+
instantiated = InstantiateTemplate(candidate, unique_arg_types.data(),
1160+
unique_arg_types.size());
1161+
if (instantiated)
1162+
return instantiated;
1163+
10631164
// Force the instantiation with template params in case of no args
10641165
// maybe steer instantiation better with arg set returned from
10651166
// TemplateProxy?
1167+
if (explicit_types.empty())
1168+
continue;
1169+
10661170
instantiated = InstantiateTemplate(candidate, explicit_types.data(),
10671171
explicit_types.size());
10681172
if (instantiated)

unittests/CppInterOp/FunctionReflectionTest.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,68 @@ TEST(FunctionReflectionTest, BestTemplateFunctionMatch) {
653653
"template<> long MyTemplatedMethodClass::get_size<float>(float &)");
654654
}
655655

656+
TEST(FunctionReflectionTest, BestTemplateFunctionMatch2) {
657+
std::vector<Decl*> Decls;
658+
std::string code = R"(
659+
template<typename T>
660+
struct A { T value; };
661+
662+
A<int> a;
663+
664+
template<typename T>
665+
void somefunc(A<T> arg) {}
666+
667+
template<typename T>
668+
void somefunc(T arg) {}
669+
670+
template<typename T>
671+
void somefunc(A<T> arg1, A<T> arg2) {}
672+
673+
template<typename T>
674+
void somefunc(T arg1, T arg2) {}
675+
)";
676+
677+
GetAllTopLevelDecls(code, Decls);
678+
std::vector<Cpp::TCppFunction_t> candidates;
679+
680+
for (auto decl : Decls)
681+
if (Cpp::IsTemplatedFunction(decl))
682+
candidates.push_back((Cpp::TCppFunction_t)decl);
683+
684+
EXPECT_EQ(candidates.size(), 4);
685+
686+
ASTContext& C = Interp->getCI()->getASTContext();
687+
688+
std::vector<Cpp::TemplateArgInfo> args1 = {C.IntTy.getAsOpaquePtr()};
689+
std::vector<Cpp::TemplateArgInfo> args2 = {
690+
Cpp::GetVariableType(Cpp::GetNamed("a"))};
691+
std::vector<Cpp::TemplateArgInfo> args3 = {C.IntTy.getAsOpaquePtr(),
692+
C.IntTy.getAsOpaquePtr()};
693+
std::vector<Cpp::TemplateArgInfo> args4 = {
694+
Cpp::GetVariableType(Cpp::GetNamed("a")),
695+
Cpp::GetVariableType(Cpp::GetNamed("a"))};
696+
697+
std::vector<Cpp::TemplateArgInfo> explicit_args;
698+
699+
Cpp::TCppFunction_t func1 =
700+
Cpp::BestTemplateFunctionMatch(candidates, explicit_args, args1);
701+
Cpp::TCppFunction_t func2 =
702+
Cpp::BestTemplateFunctionMatch(candidates, explicit_args, args2);
703+
Cpp::TCppFunction_t func3 =
704+
Cpp::BestTemplateFunctionMatch(candidates, explicit_args, args3);
705+
Cpp::TCppFunction_t func4 =
706+
Cpp::BestTemplateFunctionMatch(candidates, explicit_args, args4);
707+
708+
EXPECT_EQ(Cpp::GetFunctionSignature(func1),
709+
"template<> void somefunc<int>(int arg)");
710+
EXPECT_EQ(Cpp::GetFunctionSignature(func2),
711+
"template<> void somefunc<int>(A<int> arg)");
712+
EXPECT_EQ(Cpp::GetFunctionSignature(func3),
713+
"template<> void somefunc<int>(int arg1, int arg2)");
714+
EXPECT_EQ(Cpp::GetFunctionSignature(func4),
715+
"template<> void somefunc<int>(A<int> arg1, A<int> arg2)");
716+
}
717+
656718
TEST(FunctionReflectionTest, IsPublicMethod) {
657719
std::vector<Decl *> Decls, SubDecls;
658720
std::string code = R"(

0 commit comments

Comments
 (0)