Skip to content

Commit 018ae02

Browse files
authored
[clang] fix transform for constant template parameter type subst node (#162587)
This fixes the transform to use the correct parameter type for an AssociatedDecl which has been fully specialized. Instead of using the type for the parameter of the specialized template, this uses the type of the argument it has been specialized with. This fixes a regression reported here: #161029 (comment) Since this regression was never released, there are no release notes.
1 parent 822446d commit 018ae02

File tree

7 files changed

+93
-63
lines changed

7 files changed

+93
-63
lines changed

clang/include/clang/AST/DeclTemplate.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3395,9 +3395,10 @@ inline UnsignedOrNone getExpandedPackSize(const NamedDecl *Param) {
33953395
return std::nullopt;
33963396
}
33973397

3398-
/// Internal helper used by Subst* nodes to retrieve the parameter list
3399-
/// for their AssociatedDecl.
3400-
TemplateParameterList *getReplacedTemplateParameterList(const Decl *D);
3398+
/// Internal helper used by Subst* nodes to retrieve a parameter from the
3399+
/// AssociatedDecl, and the template argument substituted into it, if any.
3400+
std::tuple<NamedDecl *, TemplateArgument>
3401+
getReplacedTemplateParameter(Decl *D, unsigned Index);
34013402

34023403
/// If we have a 'templated' declaration for a template, adjust 'D' to
34033404
/// refer to the actual template.

clang/lib/AST/DeclTemplate.cpp

Lines changed: 44 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1653,57 +1653,65 @@ void TemplateParamObjectDecl::printAsInit(llvm::raw_ostream &OS,
16531653
getValue().printPretty(OS, Policy, getType(), &getASTContext());
16541654
}
16551655

1656-
TemplateParameterList *clang::getReplacedTemplateParameterList(const Decl *D) {
1656+
std::tuple<NamedDecl *, TemplateArgument>
1657+
clang::getReplacedTemplateParameter(Decl *D, unsigned Index) {
16571658
switch (D->getKind()) {
1658-
case Decl::Kind::CXXRecord:
1659-
return cast<CXXRecordDecl>(D)
1660-
->getDescribedTemplate()
1661-
->getTemplateParameters();
1659+
case Decl::Kind::BuiltinTemplate:
16621660
case Decl::Kind::ClassTemplate:
1663-
return cast<ClassTemplateDecl>(D)->getTemplateParameters();
1661+
case Decl::Kind::Concept:
1662+
case Decl::Kind::FunctionTemplate:
1663+
case Decl::Kind::TemplateTemplateParm:
1664+
case Decl::Kind::TypeAliasTemplate:
1665+
case Decl::Kind::VarTemplate:
1666+
return {cast<TemplateDecl>(D)->getTemplateParameters()->getParam(Index),
1667+
{}};
16641668
case Decl::Kind::ClassTemplateSpecialization: {
16651669
const auto *CTSD = cast<ClassTemplateSpecializationDecl>(D);
16661670
auto P = CTSD->getSpecializedTemplateOrPartial();
1671+
TemplateParameterList *TPL;
16671672
if (const auto *CTPSD =
16681673
dyn_cast<ClassTemplatePartialSpecializationDecl *>(P))
1669-
return CTPSD->getTemplateParameters();
1670-
return cast<ClassTemplateDecl *>(P)->getTemplateParameters();
1674+
TPL = CTPSD->getTemplateParameters();
1675+
else
1676+
TPL = cast<ClassTemplateDecl *>(P)->getTemplateParameters();
1677+
return {TPL->getParam(Index), CTSD->getTemplateArgs()[Index]};
1678+
}
1679+
case Decl::Kind::VarTemplateSpecialization: {
1680+
const auto *VTSD = cast<VarTemplateSpecializationDecl>(D);
1681+
auto P = VTSD->getSpecializedTemplateOrPartial();
1682+
TemplateParameterList *TPL;
1683+
if (const auto *VTPSD = dyn_cast<VarTemplatePartialSpecializationDecl *>(P))
1684+
TPL = VTPSD->getTemplateParameters();
1685+
else
1686+
TPL = cast<VarTemplateDecl *>(P)->getTemplateParameters();
1687+
return {TPL->getParam(Index), VTSD->getTemplateArgs()[Index]};
16711688
}
16721689
case Decl::Kind::ClassTemplatePartialSpecialization:
1673-
return cast<ClassTemplatePartialSpecializationDecl>(D)
1674-
->getTemplateParameters();
1675-
case Decl::Kind::TypeAliasTemplate:
1676-
return cast<TypeAliasTemplateDecl>(D)->getTemplateParameters();
1677-
case Decl::Kind::BuiltinTemplate:
1678-
return cast<BuiltinTemplateDecl>(D)->getTemplateParameters();
1690+
return {cast<ClassTemplatePartialSpecializationDecl>(D)
1691+
->getTemplateParameters()
1692+
->getParam(Index),
1693+
{}};
1694+
case Decl::Kind::VarTemplatePartialSpecialization:
1695+
return {cast<VarTemplatePartialSpecializationDecl>(D)
1696+
->getTemplateParameters()
1697+
->getParam(Index),
1698+
{}};
1699+
// This is used as the AssociatedDecl for placeholder type deduction.
1700+
case Decl::TemplateTypeParm:
1701+
return {cast<NamedDecl>(D), {}};
1702+
// FIXME: Always use the template decl as the AssociatedDecl.
1703+
case Decl::Kind::CXXRecord:
1704+
return getReplacedTemplateParameter(
1705+
cast<CXXRecordDecl>(D)->getDescribedClassTemplate(), Index);
16791706
case Decl::Kind::CXXDeductionGuide:
16801707
case Decl::Kind::CXXConversion:
16811708
case Decl::Kind::CXXConstructor:
16821709
case Decl::Kind::CXXDestructor:
16831710
case Decl::Kind::CXXMethod:
16841711
case Decl::Kind::Function:
1685-
return cast<FunctionDecl>(D)
1686-
->getTemplateSpecializationInfo()
1687-
->getTemplate()
1688-
->getTemplateParameters();
1689-
case Decl::Kind::FunctionTemplate:
1690-
return cast<FunctionTemplateDecl>(D)->getTemplateParameters();
1691-
case Decl::Kind::VarTemplate:
1692-
return cast<VarTemplateDecl>(D)->getTemplateParameters();
1693-
case Decl::Kind::VarTemplateSpecialization: {
1694-
const auto *VTSD = cast<VarTemplateSpecializationDecl>(D);
1695-
auto P = VTSD->getSpecializedTemplateOrPartial();
1696-
if (const auto *VTPSD = dyn_cast<VarTemplatePartialSpecializationDecl *>(P))
1697-
return VTPSD->getTemplateParameters();
1698-
return cast<VarTemplateDecl *>(P)->getTemplateParameters();
1699-
}
1700-
case Decl::Kind::VarTemplatePartialSpecialization:
1701-
return cast<VarTemplatePartialSpecializationDecl>(D)
1702-
->getTemplateParameters();
1703-
case Decl::Kind::TemplateTemplateParm:
1704-
return cast<TemplateTemplateParmDecl>(D)->getTemplateParameters();
1705-
case Decl::Kind::Concept:
1706-
return cast<ConceptDecl>(D)->getTemplateParameters();
1712+
return getReplacedTemplateParameter(
1713+
cast<FunctionDecl>(D)->getTemplateSpecializationInfo()->getTemplate(),
1714+
Index);
17071715
default:
17081716
llvm_unreachable("Unhandled templated declaration kind");
17091717
}

clang/lib/AST/ExprCXX.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1727,7 +1727,7 @@ SizeOfPackExpr *SizeOfPackExpr::CreateDeserialized(ASTContext &Context,
17271727

17281728
NonTypeTemplateParmDecl *SubstNonTypeTemplateParmExpr::getParameter() const {
17291729
return cast<NonTypeTemplateParmDecl>(
1730-
getReplacedTemplateParameterList(getAssociatedDecl())->asArray()[Index]);
1730+
std::get<0>(getReplacedTemplateParameter(getAssociatedDecl(), Index)));
17311731
}
17321732

17331733
PackIndexingExpr *PackIndexingExpr::Create(
@@ -1793,7 +1793,7 @@ SubstNonTypeTemplateParmPackExpr::SubstNonTypeTemplateParmPackExpr(
17931793
NonTypeTemplateParmDecl *
17941794
SubstNonTypeTemplateParmPackExpr::getParameterPack() const {
17951795
return cast<NonTypeTemplateParmDecl>(
1796-
getReplacedTemplateParameterList(getAssociatedDecl())->asArray()[Index]);
1796+
std::get<0>(getReplacedTemplateParameter(getAssociatedDecl(), Index)));
17971797
}
17981798

17991799
TemplateArgument SubstNonTypeTemplateParmPackExpr::getArgumentPack() const {

clang/lib/AST/TemplateName.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,16 +64,14 @@ SubstTemplateTemplateParmPackStorage::getArgumentPack() const {
6464

6565
TemplateTemplateParmDecl *
6666
SubstTemplateTemplateParmPackStorage::getParameterPack() const {
67-
return cast<TemplateTemplateParmDecl>(
68-
getReplacedTemplateParameterList(getAssociatedDecl())
69-
->asArray()[Bits.Index]);
67+
return cast<TemplateTemplateParmDecl>(std::get<0>(
68+
getReplacedTemplateParameter(getAssociatedDecl(), Bits.Index)));
7069
}
7170

7271
TemplateTemplateParmDecl *
7372
SubstTemplateTemplateParmStorage::getParameter() const {
74-
return cast<TemplateTemplateParmDecl>(
75-
getReplacedTemplateParameterList(getAssociatedDecl())
76-
->asArray()[Bits.Index]);
73+
return cast<TemplateTemplateParmDecl>(std::get<0>(
74+
getReplacedTemplateParameter(getAssociatedDecl(), Bits.Index)));
7775
}
7876

7977
void SubstTemplateTemplateParmStorage::Profile(llvm::FoldingSetNodeID &ID) {

clang/lib/AST/Type.cpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4436,14 +4436,6 @@ IdentifierInfo *TemplateTypeParmType::getIdentifier() const {
44364436
return isCanonicalUnqualified() ? nullptr : getDecl()->getIdentifier();
44374437
}
44384438

4439-
static const TemplateTypeParmDecl *getReplacedParameter(Decl *D,
4440-
unsigned Index) {
4441-
if (const auto *TTP = dyn_cast<TemplateTypeParmDecl>(D))
4442-
return TTP;
4443-
return cast<TemplateTypeParmDecl>(
4444-
getReplacedTemplateParameterList(D)->getParam(Index));
4445-
}
4446-
44474439
SubstTemplateTypeParmType::SubstTemplateTypeParmType(QualType Replacement,
44484440
Decl *AssociatedDecl,
44494441
unsigned Index,
@@ -4466,7 +4458,8 @@ SubstTemplateTypeParmType::SubstTemplateTypeParmType(QualType Replacement,
44664458

44674459
const TemplateTypeParmDecl *
44684460
SubstTemplateTypeParmType::getReplacedParameter() const {
4469-
return ::getReplacedParameter(getAssociatedDecl(), getIndex());
4461+
return cast<TemplateTypeParmDecl>(std::get<0>(
4462+
getReplacedTemplateParameter(getAssociatedDecl(), getIndex())));
44704463
}
44714464

44724465
void SubstTemplateTypeParmType::Profile(llvm::FoldingSetNodeID &ID,
@@ -4532,7 +4525,8 @@ bool SubstTemplateTypeParmPackType::getFinal() const {
45324525

45334526
const TemplateTypeParmDecl *
45344527
SubstTemplateTypeParmPackType::getReplacedParameter() const {
4535-
return ::getReplacedParameter(getAssociatedDecl(), getIndex());
4528+
return cast<TemplateTypeParmDecl>(std::get<0>(
4529+
getReplacedTemplateParameter(getAssociatedDecl(), getIndex())));
45364530
}
45374531

45384532
IdentifierInfo *SubstTemplateTypeParmPackType::getIdentifier() const {

clang/lib/Sema/TreeTransform.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16364,16 +16364,21 @@ ExprResult TreeTransform<Derived>::TransformSubstNonTypeTemplateParmExpr(
1636416364
AssociatedDecl == E->getAssociatedDecl())
1636516365
return E;
1636616366

16367+
auto getParamAndType = [Index = E->getIndex()](Decl *AssociatedDecl)
16368+
-> std::tuple<NonTypeTemplateParmDecl *, QualType> {
16369+
auto [PDecl, Arg] = getReplacedTemplateParameter(AssociatedDecl, Index);
16370+
auto *Param = cast<NonTypeTemplateParmDecl>(PDecl);
16371+
return {Param, Arg.isNull() ? Param->getType()
16372+
: Arg.getNonTypeTemplateArgumentType()};
16373+
};
16374+
1636716375
// If the replacement expression did not change, and the parameter type
1636816376
// did not change, we can skip the semantic action because it would
1636916377
// produce the same result anyway.
16370-
auto *Param = cast<NonTypeTemplateParmDecl>(
16371-
getReplacedTemplateParameterList(AssociatedDecl)
16372-
->asArray()[E->getIndex()]);
16373-
if (QualType ParamType = Param->getType();
16374-
!SemaRef.Context.hasSameType(ParamType, E->getParameter()->getType()) ||
16378+
if (auto [Param, ParamType] = getParamAndType(AssociatedDecl);
16379+
!SemaRef.Context.hasSameType(
16380+
ParamType, std::get<1>(getParamAndType(E->getAssociatedDecl()))) ||
1637516381
Replacement.get() != OrigReplacement) {
16376-
1637716382
// When transforming the replacement expression previously, all Sema
1637816383
// specific annotations, such as implicit casts, are discarded. Calling the
1637916384
// corresponding sema action is necessary to recover those. Otherwise,
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// RUN: %clang_cc1 %s -O0 -disable-llvm-passes -triple=x86_64 -std=c++20 -emit-llvm -o - | FileCheck %s
2+
3+
namespace GH161029_regression1 {
4+
template <class _Fp> auto f(int) { _Fp{}(0); }
5+
template <class _Fp, int... _Js> void g() {
6+
(..., f<_Fp>(_Js));
7+
}
8+
enum E { k };
9+
template <int, E> struct ElementAt;
10+
template <E First> struct ElementAt<0, First> {
11+
static int value;
12+
};
13+
template <typename T, T Item> struct TagSet {
14+
template <int Index> using Tag = ElementAt<Index, Item>;
15+
};
16+
template <typename TagSet> struct S {
17+
void U() { (void)TagSet::template Tag<0>::value; }
18+
};
19+
S<TagSet<E, k>> s;
20+
void h() {
21+
g<decltype([](auto) -> void { s.U(); }), 0>();
22+
}
23+
// CHECK: call void @_ZN20GH161029_regression11SINS_6TagSetINS_1EELS2_0EEEE1UEv
24+
}

0 commit comments

Comments
 (0)