Skip to content

Commit 8b2ec71

Browse files
perform type inference on template arguments
transform template arguments from {MyClass<T1, T2>} to {T1, T2} where templated function is defined as template<class T1, class T2> void fn(MyClass<T1, T2> arg); before instantiation
1 parent e4adda3 commit 8b2ec71

File tree

2 files changed

+132
-0
lines changed

2 files changed

+132
-0
lines changed

lib/Interpreter/CppInterOp.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,6 +1028,77 @@ namespace Cpp {
10281028
if (instantiated)
10291029
return instantiated;
10301030

1031+
if (func->getMinRequiredArguments() == 1 && arg_types.size() == 1) {
1032+
/*
1033+
reconstruct TemplateArgs for cases such as
1034+
template<class T1, class T2>
1035+
void fn(MyClass<T1, T2> arg);
1036+
by decomposing the argument's template
1037+
1038+
transforms
1039+
template_args = {MyClass<T1, T2>}
1040+
to
1041+
template_args = {T1, T2}
1042+
then performs the instantiation
1043+
*/
1044+
QualType fn_arg_type = func->getParamDecl(0)->getType();
1045+
QualType arg_type = QualType::getFromOpaquePtr(arg_types[0].m_Type);
1046+
1047+
// dereference
1048+
if (fn_arg_type->isReferenceType())
1049+
fn_arg_type = fn_arg_type.getNonReferenceType();
1050+
if (arg_type->isReferenceType())
1051+
arg_type = arg_type.getNonReferenceType();
1052+
1053+
// matching parameter and argument types
1054+
// resolving parameter
1055+
if (const auto* ET = fn_arg_type->getAs<clang::ElaboratedType>()) {
1056+
if (const auto* TST =
1057+
ET->getNamedType()
1058+
->getAs<clang::TemplateSpecializationType>()) {
1059+
if (const auto* TD = TST->getTemplateName().getAsTemplateDecl()) {
1060+
// resolving argument
1061+
if (const auto* RT = arg_type->getAs<clang::RecordType>()) {
1062+
if (auto* CTSD =
1063+
llvm::dyn_cast<clang::ClassTemplateSpecializationDecl>(
1064+
RT->getDecl())) {
1065+
if (CTSD->getSpecializedTemplate()->getCanonicalDecl() ==
1066+
TD->getCanonicalDecl()) {
1067+
// parameter type matches argument type
1068+
std::vector<TemplateArgInfo> total_arg_set;
1069+
1070+
const TemplateArgumentList& TAL = CTSD->getTemplateArgs();
1071+
1072+
total_arg_set.insert(total_arg_set.end(),
1073+
explicit_types.begin(),
1074+
explicit_types.end());
1075+
1076+
for (size_t i = 0; i < TAL.size(); i++) {
1077+
// FIXME: handle the case where TemplateArgument is
1078+
// Integral value
1079+
if (TAL[i].getKind() == clang::TemplateArgument::Pack) {
1080+
for (auto i : TAL[i].pack_elements()) {
1081+
total_arg_set.emplace_back(
1082+
i.getAsType().getAsOpaquePtr());
1083+
}
1084+
} else if (TAL[i].getKind() ==
1085+
clang::TemplateArgument::Type) {
1086+
QualType TA = TAL[i].getAsType();
1087+
total_arg_set.emplace_back(TA.getAsOpaquePtr());
1088+
}
1089+
}
1090+
instantiated = InstantiateTemplate(
1091+
candidate, total_arg_set.data(), total_arg_set.size());
1092+
if (instantiated)
1093+
return instantiated;
1094+
}
1095+
}
1096+
}
1097+
}
1098+
}
1099+
}
1100+
}
1101+
10311102
// join explicit and arg_types
10321103
std::vector<TemplateArgInfo> total_arg_set;
10331104
total_arg_set.reserve(explicit_types.size() + arg_types.size());

unittests/CppInterOp/FunctionReflectionTest.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "gtest/gtest.h"
99

1010
#include <string>
11+
#include <tuple>
1112

1213
using namespace TestUtils;
1314
using namespace llvm;
@@ -1154,3 +1155,63 @@ TEST(FunctionReflectionTest, Destruct) {
11541155
output = testing::internal::GetCapturedStdout();
11551156
EXPECT_EQ(output, "Destructor Executed");
11561157
}
1158+
1159+
TEST(FunctionReflectionTest, NestedTemplate) {
1160+
if (llvm::sys::RunningOnValgrind())
1161+
GTEST_SKIP() << "XFAIL due to Valgrind report";
1162+
1163+
Cpp::CreateInterpreter();
1164+
1165+
Interp->declare(R"(
1166+
#include <tuple>
1167+
#include <string>
1168+
)");
1169+
1170+
ASTContext& C = Interp->getCI()->getASTContext();
1171+
1172+
std::vector<Cpp::TCppFunction_t> make_tuple_candidate_methods;
1173+
Cpp::GetClassTemplatedMethods("make_tuple", Cpp::GetScope("std"),
1174+
make_tuple_candidate_methods);
1175+
EXPECT_GE(make_tuple_candidate_methods.size(), 1);
1176+
1177+
std::vector<Cpp::TemplateArgInfo> make_tuple_arg_types = {
1178+
{C.IntTy.getAsOpaquePtr()},
1179+
{C.DoubleTy.getAsOpaquePtr()},
1180+
};
1181+
std::vector<Cpp::TemplateArgInfo> make_tuple_templ_params = {};
1182+
Cpp::TCppFunction_t make_tuple_scope = Cpp::BestTemplateFunctionMatch(
1183+
make_tuple_candidate_methods, make_tuple_templ_params,
1184+
make_tuple_arg_types);
1185+
EXPECT_TRUE(make_tuple_scope);
1186+
1187+
auto make_tuple = Cpp::MakeFunctionCallable(make_tuple_scope);
1188+
EXPECT_TRUE(make_tuple);
1189+
1190+
int x = 2;
1191+
double y = 4.0;
1192+
void* args0[2] = {(void*)&x, (void*)&y};
1193+
void* tuple = new std::tuple<int, double>;
1194+
make_tuple.Invoke(tuple, {args0, 2});
1195+
1196+
std::vector<Cpp::TCppFunction_t> get_candidate_methods;
1197+
Cpp::GetClassTemplatedMethods("get", Cpp::GetScope("std"),
1198+
get_candidate_methods);
1199+
EXPECT_GE(get_candidate_methods.size(), 1);
1200+
1201+
std::vector<Cpp::TemplateArgInfo> get_arg_types = {
1202+
{Cpp::GetFunctionReturnType(make_tuple_scope)},
1203+
};
1204+
std::vector<Cpp::TemplateArgInfo> get_templ_params = {
1205+
{C.IntTy.getAsOpaquePtr(), "0"}};
1206+
Cpp::TCppFunction_t get_scope = Cpp::BestTemplateFunctionMatch(
1207+
get_candidate_methods, get_templ_params, get_arg_types);
1208+
EXPECT_TRUE(get_scope);
1209+
1210+
auto get = Cpp::MakeFunctionCallable(get_scope);
1211+
EXPECT_TRUE(get);
1212+
1213+
// int get0_result = 0;
1214+
// void *args1[1] = { (void *) &tuple };
1215+
// get.Invoke(&get0_result, {args1, 1});
1216+
// EXPECT_EQ(get0_result, 2);
1217+
}

0 commit comments

Comments
 (0)