Skip to content

Commit 4a32f94

Browse files
jeremystuckiVenyla
andcommitted
[clangd] Add tweak to abbreviate function templates
Co-authored-by: Vina Zahnd <[email protected]>
1 parent 44b928e commit 4a32f94

File tree

5 files changed

+433
-0
lines changed

5 files changed

+433
-0
lines changed
Lines changed: 352 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,352 @@
1+
//===-- AbbreviateFunctionTemplate.cpp ---------------------------*- C++-*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#include "FindTarget.h"
9+
#include "SourceCode.h"
10+
#include "XRefs.h"
11+
#include "refactor/Tweak.h"
12+
#include "support/Logger.h"
13+
#include "clang/AST/ASTContext.h"
14+
#include "clang/AST/ExprConcepts.h"
15+
#include "clang/Tooling/Core/Replacement.h"
16+
#include "llvm/ADT/StringRef.h"
17+
#include "llvm/Support/Casting.h"
18+
#include "llvm/Support/Error.h"
19+
#include <numeric>
20+
21+
namespace clang {
22+
namespace clangd {
23+
namespace {
24+
/// Converts a function template to its abbreviated form using auto parameters.
25+
/// Before:
26+
/// template <std::integral T>
27+
/// auto foo(T param) { }
28+
/// ^^^^^^^^^^^
29+
/// After:
30+
/// auto foo(std::integral auto param) { }
31+
class AbbreviateFunctionTemplate : public Tweak {
32+
public:
33+
const char *id() const final;
34+
35+
auto prepare(const Selection &Inputs) -> bool override;
36+
auto apply(const Selection &Inputs) -> Expected<Effect> override;
37+
38+
auto title() const -> std::string override {
39+
return llvm::formatv("Abbreviate function template");
40+
}
41+
42+
auto kind() const -> llvm::StringLiteral override {
43+
return CodeAction::REFACTOR_KIND;
44+
}
45+
46+
private:
47+
static const char *AutoKeywordSpelling;
48+
const FunctionTemplateDecl *FunctionTemplateDeclaration;
49+
50+
struct TemplateParameterInfo {
51+
const TypeConstraint *Constraint;
52+
unsigned int FunctionParameterIndex;
53+
std::vector<tok::TokenKind> FunctionParameterQualifiers;
54+
std::vector<tok::TokenKind> FunctionParameterTypeQualifiers;
55+
};
56+
57+
std::vector<TemplateParameterInfo> TemplateParameterInfoList;
58+
59+
auto traverseFunctionParameters(size_t NumberOfTemplateParameters) -> bool;
60+
61+
auto generateFunctionParameterReplacements(const ASTContext &Context)
62+
-> llvm::Expected<tooling::Replacements>;
63+
64+
auto generateFunctionParameterReplacement(
65+
const TemplateParameterInfo &TemplateParameterInfo,
66+
const ASTContext &Context) -> llvm::Expected<tooling::Replacement>;
67+
68+
auto generateTemplateDeclarationReplacement(const ASTContext &Context)
69+
-> llvm::Expected<tooling::Replacement>;
70+
71+
static auto deconstructType(QualType Type)
72+
-> std::tuple<QualType, std::vector<tok::TokenKind>,
73+
std::vector<tok::TokenKind>>;
74+
};
75+
76+
REGISTER_TWEAK(AbbreviateFunctionTemplate)
77+
78+
const char *AbbreviateFunctionTemplate::AutoKeywordSpelling =
79+
getKeywordSpelling(tok::kw_auto);
80+
81+
template <typename T>
82+
auto findDeclaration(const SelectionTree::Node &Root) -> const T * {
83+
for (const auto *Node = &Root; Node; Node = Node->Parent) {
84+
if (const T *Result = dyn_cast_or_null<T>(Node->ASTNode.get<Decl>()))
85+
return Result;
86+
}
87+
88+
return nullptr;
89+
}
90+
91+
auto getSpellingForQualifier(tok::TokenKind const &Qualifier) -> const char * {
92+
if (const auto *Spelling = getKeywordSpelling(Qualifier))
93+
return Spelling;
94+
95+
if (const auto *Spelling = getPunctuatorSpelling(Qualifier))
96+
return Spelling;
97+
98+
return nullptr;
99+
}
100+
101+
bool AbbreviateFunctionTemplate::prepare(const Selection &Inputs) {
102+
const auto *CommonAncestor = Inputs.ASTSelection.commonAncestor();
103+
if (!CommonAncestor)
104+
return false;
105+
106+
FunctionTemplateDeclaration =
107+
findDeclaration<FunctionTemplateDecl>(*CommonAncestor);
108+
109+
if (!FunctionTemplateDeclaration)
110+
return false;
111+
112+
auto *TemplateParameters =
113+
FunctionTemplateDeclaration->getTemplateParameters();
114+
115+
auto NumberOfTemplateParameters = TemplateParameters->size();
116+
TemplateParameterInfoList =
117+
std::vector<TemplateParameterInfo>(NumberOfTemplateParameters);
118+
119+
// Check how many times each template parameter is referenced.
120+
// Depending on the number of references it can be checked
121+
// if the refactoring is possible:
122+
// - exactly one: The template parameter was declared but never used, which
123+
// means we know for sure it doesn't appear as a parameter.
124+
// - exactly two: The template parameter was used exactly once, either as a
125+
// parameter or somewhere else. This is the case we are
126+
// interested in.
127+
// - more than two: The template parameter was either used for multiple
128+
// parameters or somewhere else in the function.
129+
for (unsigned TemplateParameterIndex = 0;
130+
TemplateParameterIndex < NumberOfTemplateParameters;
131+
TemplateParameterIndex++) {
132+
auto *TemplateParameter =
133+
TemplateParameters->getParam(TemplateParameterIndex);
134+
auto *TemplateParameterInfo =
135+
&TemplateParameterInfoList[TemplateParameterIndex];
136+
137+
auto *TemplateParameterDeclaration =
138+
dyn_cast_or_null<TemplateTypeParmDecl>(TemplateParameter);
139+
if (!TemplateParameterDeclaration)
140+
return false;
141+
142+
TemplateParameterInfo->Constraint =
143+
TemplateParameterDeclaration->getTypeConstraint();
144+
145+
auto TemplateParameterPosition = sourceLocToPosition(
146+
Inputs.AST->getSourceManager(), TemplateParameter->getEndLoc());
147+
148+
auto FindReferencesLimit = 3;
149+
auto ReferencesResult =
150+
findReferences(*Inputs.AST, TemplateParameterPosition,
151+
FindReferencesLimit, Inputs.Index);
152+
153+
if (ReferencesResult.References.size() != 2)
154+
return false;
155+
}
156+
157+
return traverseFunctionParameters(NumberOfTemplateParameters);
158+
}
159+
160+
auto AbbreviateFunctionTemplate::apply(const Selection &Inputs)
161+
-> Expected<Tweak::Effect> {
162+
auto &Context = Inputs.AST->getASTContext();
163+
auto FunctionParameterReplacements =
164+
generateFunctionParameterReplacements(Context);
165+
166+
if (auto Err = FunctionParameterReplacements.takeError())
167+
return Err;
168+
169+
auto Replacements = *FunctionParameterReplacements;
170+
auto TemplateDeclarationReplacement =
171+
generateTemplateDeclarationReplacement(Context);
172+
173+
if (auto Err = TemplateDeclarationReplacement.takeError())
174+
return Err;
175+
176+
if (auto Err = Replacements.add(*TemplateDeclarationReplacement))
177+
return Err;
178+
179+
return Effect::mainFileEdit(Context.getSourceManager(), Replacements);
180+
}
181+
182+
auto AbbreviateFunctionTemplate::traverseFunctionParameters(
183+
size_t NumberOfTemplateParameters) -> bool {
184+
auto CurrentTemplateParameterBeingChecked = 0u;
185+
auto FunctionParameters =
186+
FunctionTemplateDeclaration->getAsFunction()->parameters();
187+
188+
for (auto ParameterIndex = 0u; ParameterIndex < FunctionParameters.size();
189+
ParameterIndex++) {
190+
auto [RawType, ParameterTypeQualifiers, ParameterQualifiers] =
191+
deconstructType(FunctionParameters[ParameterIndex]->getOriginalType());
192+
193+
if (!RawType->isTemplateTypeParmType())
194+
continue;
195+
196+
auto TemplateParameterIndex =
197+
dyn_cast<TemplateTypeParmType>(RawType)->getIndex();
198+
199+
if (TemplateParameterIndex != CurrentTemplateParameterBeingChecked)
200+
return false;
201+
202+
auto *TemplateParameterInfo =
203+
&TemplateParameterInfoList[TemplateParameterIndex];
204+
TemplateParameterInfo->FunctionParameterIndex = ParameterIndex;
205+
TemplateParameterInfo->FunctionParameterTypeQualifiers =
206+
ParameterTypeQualifiers;
207+
TemplateParameterInfo->FunctionParameterQualifiers = ParameterQualifiers;
208+
209+
CurrentTemplateParameterBeingChecked++;
210+
}
211+
212+
// All defined template parameters need to be used as function parameters
213+
return CurrentTemplateParameterBeingChecked == NumberOfTemplateParameters;
214+
}
215+
216+
auto AbbreviateFunctionTemplate::generateFunctionParameterReplacements(
217+
const ASTContext &Context) -> llvm::Expected<tooling::Replacements> {
218+
tooling::Replacements Replacements;
219+
for (const auto &TemplateParameterInfo : TemplateParameterInfoList) {
220+
auto FunctionParameterReplacement =
221+
generateFunctionParameterReplacement(TemplateParameterInfo, Context);
222+
223+
if (auto Err = FunctionParameterReplacement.takeError())
224+
return Err;
225+
226+
if (auto Err = Replacements.add(*FunctionParameterReplacement))
227+
return Err;
228+
}
229+
230+
return Replacements;
231+
}
232+
233+
auto AbbreviateFunctionTemplate::generateFunctionParameterReplacement(
234+
const TemplateParameterInfo &TemplateParameterInfo,
235+
const ASTContext &Context) -> llvm::Expected<tooling::Replacement> {
236+
auto &SourceManager = Context.getSourceManager();
237+
238+
const auto *Function = FunctionTemplateDeclaration->getAsFunction();
239+
auto *Parameter =
240+
Function->getParamDecl(TemplateParameterInfo.FunctionParameterIndex);
241+
auto ParameterName = Parameter->getDeclName().getAsString();
242+
243+
std::vector<std::string> ParameterTokens{};
244+
245+
if (const auto *TypeConstraint = TemplateParameterInfo.Constraint) {
246+
auto *ConceptReference = TypeConstraint->getConceptReference();
247+
auto *NamedConcept = ConceptReference->getNamedConcept();
248+
249+
ParameterTokens.push_back(NamedConcept->getQualifiedNameAsString());
250+
251+
if (const auto *TemplateArgs = TypeConstraint->getTemplateArgsAsWritten()) {
252+
auto TemplateArgsRange = SourceRange(TemplateArgs->getLAngleLoc(),
253+
TemplateArgs->getRAngleLoc());
254+
auto TemplateArgsSource = toSourceCode(SourceManager, TemplateArgsRange);
255+
ParameterTokens.push_back(TemplateArgsSource.str() + '>');
256+
}
257+
}
258+
259+
ParameterTokens.push_back(AutoKeywordSpelling);
260+
261+
for (const auto &Qualifier :
262+
TemplateParameterInfo.FunctionParameterTypeQualifiers) {
263+
ParameterTokens.push_back(getSpellingForQualifier(Qualifier));
264+
}
265+
266+
ParameterTokens.push_back(ParameterName);
267+
268+
for (const auto &Qualifier :
269+
TemplateParameterInfo.FunctionParameterQualifiers) {
270+
ParameterTokens.push_back(getSpellingForQualifier(Qualifier));
271+
}
272+
273+
auto FunctionTypeReplacementText = std::accumulate(
274+
ParameterTokens.begin(), ParameterTokens.end(), std::string{},
275+
[](auto Result, auto Token) { return std::move(Result) + " " + Token; });
276+
277+
auto FunctionParameterRange = toHalfOpenFileRange(
278+
SourceManager, Context.getLangOpts(), Parameter->getSourceRange());
279+
280+
if (!FunctionParameterRange)
281+
return error("Could not obtain range of the template parameter. Macros?");
282+
283+
return tooling::Replacement(
284+
SourceManager, CharSourceRange::getCharRange(*FunctionParameterRange),
285+
FunctionTypeReplacementText);
286+
}
287+
288+
auto AbbreviateFunctionTemplate::generateTemplateDeclarationReplacement(
289+
const ASTContext &Context) -> llvm::Expected<tooling::Replacement> {
290+
auto &SourceManager = Context.getSourceManager();
291+
auto *TemplateParameters =
292+
FunctionTemplateDeclaration->getTemplateParameters();
293+
294+
auto TemplateDeclarationRange =
295+
toHalfOpenFileRange(SourceManager, Context.getLangOpts(),
296+
TemplateParameters->getSourceRange());
297+
298+
if (!TemplateDeclarationRange)
299+
return error("Could not obtain range of the template parameter. Macros?");
300+
301+
auto CharRange = CharSourceRange::getCharRange(*TemplateDeclarationRange);
302+
return tooling::Replacement(SourceManager, CharRange, "");
303+
}
304+
305+
auto AbbreviateFunctionTemplate::deconstructType(QualType Type)
306+
-> std::tuple<QualType, std::vector<tok::TokenKind>,
307+
std::vector<tok::TokenKind>> {
308+
std::vector<tok::TokenKind> ParameterTypeQualifiers{};
309+
std::vector<tok::TokenKind> ParameterQualifiers{};
310+
311+
if (Type->isIncompleteArrayType()) {
312+
ParameterQualifiers.push_back(tok::l_square);
313+
ParameterQualifiers.push_back(tok::r_square);
314+
Type = Type->castAsArrayTypeUnsafe()->getElementType();
315+
}
316+
317+
if (isa<PackExpansionType>(Type))
318+
ParameterTypeQualifiers.push_back(tok::ellipsis);
319+
320+
Type = Type.getNonPackExpansionType();
321+
322+
if (Type->isRValueReferenceType()) {
323+
ParameterTypeQualifiers.push_back(tok::ampamp);
324+
Type = Type.getNonReferenceType();
325+
}
326+
327+
if (Type->isLValueReferenceType()) {
328+
ParameterTypeQualifiers.push_back(tok::amp);
329+
Type = Type.getNonReferenceType();
330+
}
331+
332+
if (Type.isConstQualified()) {
333+
ParameterTypeQualifiers.push_back(tok::kw_const);
334+
}
335+
336+
while (Type->isPointerType()) {
337+
ParameterTypeQualifiers.push_back(tok::star);
338+
Type = Type->getPointeeType();
339+
340+
if (Type.isConstQualified()) {
341+
ParameterTypeQualifiers.push_back(tok::kw_const);
342+
}
343+
}
344+
345+
std::reverse(ParameterTypeQualifiers.begin(), ParameterTypeQualifiers.end());
346+
347+
return {Type, ParameterTypeQualifiers, ParameterQualifiers};
348+
}
349+
350+
} // namespace
351+
} // namespace clangd
352+
} // namespace clang

clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ set(LLVM_LINK_COMPONENTS
1212
# $<TARGET_OBJECTS:obj.clangDaemonTweaks> to a list of sources, see
1313
# clangd/tool/CMakeLists.txt for an example.
1414
add_clang_library(clangDaemonTweaks OBJECT
15+
AbbreviateFunctionTemplate.cpp
1516
AddUsing.cpp
1617
AnnotateHighlightings.cpp
1718
DumpAST.cpp

clang-tools-extra/clangd/unittests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ add_unittest(ClangdUnitTests ClangdTests
117117
support/ThreadingTests.cpp
118118
support/TraceTests.cpp
119119

120+
tweaks/AbbreviateFunctionTemplateTests.cpp
120121
tweaks/AddUsingTests.cpp
121122
tweaks/AnnotateHighlightingsTests.cpp
122123
tweaks/DefineInlineTests.cpp

0 commit comments

Comments
 (0)