|
| 1 | +//===--- OverridePureVirtuals.cpp --------------------------------*- C++-*-===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +// |
| 9 | +// Tweak to automatically generate stubs for pure virtual methods inherited from |
| 10 | +// base classes. |
| 11 | +// |
| 12 | +// Purpose: |
| 13 | +// - Simplifies making a derived class concrete by automating the creation of |
| 14 | +// required method overrides from abstract bases. |
| 15 | +// |
| 16 | +// Tweak Summary: |
| 17 | +// |
| 18 | +// 1. Activation Conditions (prepare): |
| 19 | +// - The tweak activates when the cursor is over a C++ class definition. |
| 20 | +// - The class must be abstract (it, or its base classes, have unimplemented |
| 21 | +// pure virtual functions). |
| 22 | +// - It must also inherit from at least one other abstract class. |
| 23 | +// |
| 24 | +// 2. Identifying Missing Methods: |
| 25 | +// - The tweak scans the inheritance hierarchy of the current class. |
| 26 | +// - It identifies all unique pure virtual methods from base classes |
| 27 | +// that are not yet implemented or overridden. |
| 28 | +// - These missing methods are then grouped by their original access |
| 29 | +// specifier (e.g., public, protected). |
| 30 | +// |
| 31 | +// 3. Code Generation and Insertion: |
| 32 | +// - For each group of missing methods, stubs are inserted. |
| 33 | +// - If an access specifier section (like `public:`) exists, stubs are |
| 34 | +// inserted there; otherwise, a new section is created and appended. |
| 35 | +// - Each generated stub includes the `override` keyword, a `// TODO:` |
| 36 | +// comment, and a `static_assert(false, ...)` to force a compile-time |
| 37 | +// error if the method remains unimplemented. |
| 38 | +// - The base method's signature is adjusted (e.g., `virtual` and `= 0` |
| 39 | +// are removed for the override). |
| 40 | +// |
| 41 | +// 4. Code Action Provided: |
| 42 | +// - A single code action titled "Override pure virtual methods" is offered. |
| 43 | +// - Applying this action results in a single source file modification |
| 44 | +// containing all the generated method stubs. |
| 45 | +// |
| 46 | +// Example: |
| 47 | +// |
| 48 | +// class Base { |
| 49 | +// public: |
| 50 | +// virtual void publicMethod() = 0; |
| 51 | +// protected: |
| 52 | +// virtual auto privateMethod() const -> int = 0; |
| 53 | +// }; |
| 54 | +// |
| 55 | +// Before: |
| 56 | +// // cursor here |
| 57 | +// class Derived : public Base {}^; |
| 58 | +// |
| 59 | +// After: |
| 60 | +// |
| 61 | +// class Derived : public Base { |
| 62 | +// public: |
| 63 | +// void publicMethod() override { |
| 64 | +// // TODO: Implement this pure virtual method. |
| 65 | +// static_assert(false, "Method `publicMethod` is not implemented."); |
| 66 | +// } |
| 67 | +// |
| 68 | +// protected: |
| 69 | +// auto privateMethod() const -> int override { |
| 70 | +// // TODO: Implement this pure virtual method. |
| 71 | +// static_assert(false, "Method `privateMethod` is not implemented."); |
| 72 | +// } |
| 73 | +// }; |
| 74 | +// |
| 75 | +//===----------------------------------------------------------------------===// |
| 76 | + |
| 77 | +#include "refactor/Tweak.h" |
| 78 | +#include "support/Token.h" |
| 79 | + |
| 80 | +#include "clang/AST/ASTContext.h" |
| 81 | +#include "clang/AST/DeclCXX.h" |
| 82 | +#include "clang/AST/Type.h" |
| 83 | +#include "clang/AST/TypeLoc.h" |
| 84 | +#include "clang/Basic/LLVM.h" |
| 85 | +#include "clang/Basic/SourceLocation.h" |
| 86 | +#include "clang/Tooling/Core/Replacement.h" |
| 87 | +#include "llvm/ADT/DenseSet.h" |
| 88 | +#include "llvm/Support/FormatVariadic.h" |
| 89 | +#include <string> |
| 90 | + |
| 91 | +namespace clang { |
| 92 | +namespace clangd { |
| 93 | +namespace { |
| 94 | + |
| 95 | +// This function removes the "virtual" and the "= 0" at the end; |
| 96 | +// e.g.: |
| 97 | +// "virtual void foo(int var = 0) = 0" // input. |
| 98 | +// "void foo(int var = 0)" // output. |
| 99 | +std::string removePureVirtualSyntax(const std::string &MethodDecl, |
| 100 | + const LangOptions &LangOpts) { |
| 101 | + assert(!MethodDecl.empty()); |
| 102 | + |
| 103 | + TokenStream TS = lex(MethodDecl, LangOpts); |
| 104 | + |
| 105 | + std::string DeclString; |
| 106 | + for (const clangd::Token &Tk : TS.tokens()) { |
| 107 | + if (Tk.Kind == clang::tok::raw_identifier && Tk.text() == "virtual") |
| 108 | + continue; |
| 109 | + |
| 110 | + // If the ending two tokens are "= 0", we break here and we already have the |
| 111 | + // method's string without the pure virtual syntax. |
| 112 | + const auto &Next = Tk.next(); |
| 113 | + if (Next.next().Kind == tok::eof && Tk.Kind == clang::tok::equal && |
| 114 | + Next.text() == "0") |
| 115 | + break; |
| 116 | + |
| 117 | + DeclString += Tk.text(); |
| 118 | + if (Tk.Kind != tok::l_paren && Next.Kind != tok::comma && |
| 119 | + Next.Kind != tok::r_paren && Next.Kind != tok::l_paren) |
| 120 | + DeclString += ' '; |
| 121 | + } |
| 122 | + // Trim the last whitespace. |
| 123 | + if (DeclString.back() == ' ') |
| 124 | + DeclString.pop_back(); |
| 125 | + |
| 126 | + return DeclString; |
| 127 | +} |
| 128 | + |
| 129 | +class OverridePureVirtuals final : public Tweak { |
| 130 | +public: |
| 131 | + const char *id() const final; // defined by REGISTER_TWEAK. |
| 132 | + bool prepare(const Selection &Sel) override; |
| 133 | + Expected<Effect> apply(const Selection &Sel) override; |
| 134 | + std::string title() const override { return "Override pure virtual methods"; } |
| 135 | + llvm::StringLiteral kind() const override { |
| 136 | + return CodeAction::QUICKFIX_KIND; |
| 137 | + } |
| 138 | + |
| 139 | +private: |
| 140 | + // Stores the CXXRecordDecl of the class being modified. |
| 141 | + const CXXRecordDecl *CurrentDeclDef = nullptr; |
| 142 | + // Stores pure virtual methods that need overriding, grouped by their original |
| 143 | + // access specifier. |
| 144 | + llvm::MapVector<AccessSpecifier, llvm::SmallVector<const CXXMethodDecl *>> |
| 145 | + MissingMethodsByAccess; |
| 146 | + // Stores the source locations of existing access specifiers in CurrentDecl. |
| 147 | + llvm::MapVector<AccessSpecifier, SourceLocation> AccessSpecifierLocations; |
| 148 | + // Helper function to gather information before applying the tweak. |
| 149 | + void collectMissingPureVirtuals(); |
| 150 | +}; |
| 151 | + |
| 152 | +REGISTER_TWEAK(OverridePureVirtuals) |
| 153 | + |
| 154 | +// Function to get all unique pure virtual methods from the entire |
| 155 | +// base class hierarchy of CurrentDeclDef. |
| 156 | +llvm::SmallVector<const clang::CXXMethodDecl *> |
| 157 | +getAllUniquePureVirtualsFromBaseHierarchy( |
| 158 | + const clang::CXXRecordDecl *CurrentDeclDef) { |
| 159 | + llvm::SmallVector<const clang::CXXMethodDecl *> AllPureVirtualsInHierarchy; |
| 160 | + llvm::DenseSet<const clang::CXXMethodDecl *> CanonicalPureVirtualsSeen; |
| 161 | + |
| 162 | + if (!CurrentDeclDef || !CurrentDeclDef->getDefinition()) |
| 163 | + return AllPureVirtualsInHierarchy; |
| 164 | + |
| 165 | + const clang::CXXRecordDecl *Def = CurrentDeclDef->getDefinition(); |
| 166 | + |
| 167 | + Def->forallBases([&](const clang::CXXRecordDecl *BaseDefinition) { |
| 168 | + for (const clang::CXXMethodDecl *Method : BaseDefinition->methods()) { |
| 169 | + if (Method->isPureVirtual() && |
| 170 | + CanonicalPureVirtualsSeen.insert(Method->getCanonicalDecl()).second) |
| 171 | + AllPureVirtualsInHierarchy.emplace_back(Method); |
| 172 | + } |
| 173 | + // Continue iterating through all bases. |
| 174 | + return true; |
| 175 | + }); |
| 176 | + |
| 177 | + return AllPureVirtualsInHierarchy; |
| 178 | +} |
| 179 | + |
| 180 | +// Gets canonical declarations of methods already overridden or implemented in |
| 181 | +// class D. |
| 182 | +llvm::SetVector<const CXXMethodDecl *> |
| 183 | +getImplementedOrOverriddenCanonicals(const CXXRecordDecl *D) { |
| 184 | + llvm::SetVector<const CXXMethodDecl *> ImplementedSet; |
| 185 | + for (const CXXMethodDecl *M : D->methods()) { |
| 186 | + // If M provides an implementation for any virtual method it overrides. |
| 187 | + // A method is an "implementation" if it's virtual and not pure. |
| 188 | + // Or if it directly overrides a base method. |
| 189 | + for (const CXXMethodDecl *OverriddenM : M->overridden_methods()) |
| 190 | + ImplementedSet.insert(OverriddenM->getCanonicalDecl()); |
| 191 | + } |
| 192 | + return ImplementedSet; |
| 193 | +} |
| 194 | + |
| 195 | +// Get the location of every colon of the `AccessSpecifier`. |
| 196 | +llvm::MapVector<AccessSpecifier, SourceLocation> |
| 197 | +getSpecifierLocations(const CXXRecordDecl *D) { |
| 198 | + llvm::MapVector<AccessSpecifier, SourceLocation> Locs; |
| 199 | + for (auto *DeclNode : D->decls()) { |
| 200 | + if (const auto *ASD = llvm::dyn_cast<AccessSpecDecl>(DeclNode)) |
| 201 | + Locs[ASD->getAccess()] = ASD->getColonLoc(); |
| 202 | + } |
| 203 | + return Locs; |
| 204 | +} |
| 205 | + |
| 206 | +bool hasAbstractBaseAncestor(const clang::CXXRecordDecl *CurrentDecl) { |
| 207 | + assert(CurrentDecl && CurrentDecl->getDefinition()); |
| 208 | + |
| 209 | + return llvm::any_of( |
| 210 | + CurrentDecl->getDefinition()->bases(), [](CXXBaseSpecifier BaseSpec) { |
| 211 | + const auto *D = BaseSpec.getType()->getAsCXXRecordDecl(); |
| 212 | + const auto *Def = D ? D->getDefinition() : nullptr; |
| 213 | + return Def && Def->isAbstract(); |
| 214 | + }); |
| 215 | +} |
| 216 | + |
| 217 | +// The tweak is available if the selection is over an abstract C++ class |
| 218 | +// definition that also inherits from at least one other abstract class. |
| 219 | +bool OverridePureVirtuals::prepare(const Selection &Sel) { |
| 220 | + const SelectionTree::Node *Node = Sel.ASTSelection.commonAncestor(); |
| 221 | + if (!Node) |
| 222 | + return false; |
| 223 | + |
| 224 | + // Make sure we have a definition. |
| 225 | + CurrentDeclDef = Node->ASTNode.get<CXXRecordDecl>(); |
| 226 | + if (!CurrentDeclDef || !CurrentDeclDef->getDefinition()) |
| 227 | + return false; |
| 228 | + |
| 229 | + // From now on, we should work with the definition. |
| 230 | + CurrentDeclDef = CurrentDeclDef->getDefinition(); |
| 231 | + |
| 232 | + // Only offer for abstract classes with abstract bases. |
| 233 | + return CurrentDeclDef->isAbstract() && |
| 234 | + hasAbstractBaseAncestor(CurrentDeclDef); |
| 235 | +} |
| 236 | + |
| 237 | +// Collects all pure virtual methods from base classes that `CurrentDeclDef` has |
| 238 | +// not yet overridden, grouped by their original access specifier. |
| 239 | +// |
| 240 | +// Results are stored in `MissingMethodsByAccess` and `AccessSpecifierLocations` |
| 241 | +// is also populated. |
| 242 | +void OverridePureVirtuals::collectMissingPureVirtuals() { |
| 243 | + if (!CurrentDeclDef) |
| 244 | + return; |
| 245 | + |
| 246 | + AccessSpecifierLocations = getSpecifierLocations(CurrentDeclDef); |
| 247 | + MissingMethodsByAccess.clear(); |
| 248 | + |
| 249 | + // Get all unique pure virtual methods from the entire base class hierarchy. |
| 250 | + llvm::SmallVector<const CXXMethodDecl *> AllPureVirtualsInHierarchy = |
| 251 | + getAllUniquePureVirtualsFromBaseHierarchy(CurrentDeclDef); |
| 252 | + |
| 253 | + // Get methods already implemented or overridden in CurrentDecl. |
| 254 | + const auto ImplementedOrOverriddenSet = |
| 255 | + getImplementedOrOverriddenCanonicals(CurrentDeclDef); |
| 256 | + |
| 257 | + // Filter AllPureVirtualsInHierarchy to find those not in |
| 258 | + // ImplementedOrOverriddenSet, which needs to be overriden. |
| 259 | + for (const CXXMethodDecl *BaseMethod : AllPureVirtualsInHierarchy) { |
| 260 | + bool AlreadyHandled = ImplementedOrOverriddenSet.contains(BaseMethod); |
| 261 | + if (!AlreadyHandled) |
| 262 | + MissingMethodsByAccess[BaseMethod->getAccess()].emplace_back(BaseMethod); |
| 263 | + } |
| 264 | +} |
| 265 | + |
| 266 | +std::string generateOverrideString(const CXXMethodDecl *Method, |
| 267 | + const LangOptions &LangOpts) { |
| 268 | + std::string MethodDecl; |
| 269 | + auto OS = llvm::raw_string_ostream(MethodDecl); |
| 270 | + Method->print(OS); |
| 271 | + |
| 272 | + return llvm::formatv( |
| 273 | + "\n {0} override {{\n" |
| 274 | + " // TODO: Implement this pure virtual method.\n" |
| 275 | + " static_assert(false, \"Method `{1}` is not implemented.\");\n" |
| 276 | + " }", |
| 277 | + removePureVirtualSyntax(MethodDecl, LangOpts), Method->getName()) |
| 278 | + .str(); |
| 279 | +} |
| 280 | + |
| 281 | +// Free function to generate the string for a group of method overrides. |
| 282 | +std::string generateOverridesStringForGroup( |
| 283 | + llvm::SmallVector<const CXXMethodDecl *> Methods, |
| 284 | + const LangOptions &LangOpts) { |
| 285 | + llvm::SmallVector<std::string> MethodsString; |
| 286 | + MethodsString.reserve(Methods.size()); |
| 287 | + |
| 288 | + for (const CXXMethodDecl *Method : Methods) { |
| 289 | + MethodsString.emplace_back(generateOverrideString(Method, LangOpts)); |
| 290 | + } |
| 291 | + |
| 292 | + return llvm::join(MethodsString, "\n") + '\n'; |
| 293 | +} |
| 294 | + |
| 295 | +Expected<Tweak::Effect> OverridePureVirtuals::apply(const Selection &Sel) { |
| 296 | + // The correctness of this tweak heavily relies on the accurate population of |
| 297 | + // these members. |
| 298 | + collectMissingPureVirtuals(); |
| 299 | + // The `prepare` should prevent this. If the prepare identifies an abstract |
| 300 | + // method, then is must have missing methods. |
| 301 | + assert(!MissingMethodsByAccess.empty()); |
| 302 | + |
| 303 | + const auto &SM = Sel.AST->getSourceManager(); |
| 304 | + const auto &LangOpts = Sel.AST->getLangOpts(); |
| 305 | + |
| 306 | + tooling::Replacements EditReplacements; |
| 307 | + // Stores text for new access specifier sections that are not already present |
| 308 | + // in the class. |
| 309 | + // Example: |
| 310 | + // public: // ... |
| 311 | + // protected: // ... |
| 312 | + std::string NewSectionsToAppendText; |
| 313 | + |
| 314 | + for (const auto &[AS, Methods] : MissingMethodsByAccess) { |
| 315 | + assert(!Methods.empty()); |
| 316 | + |
| 317 | + std::string MethodsGroupString = |
| 318 | + generateOverridesStringForGroup(Methods, LangOpts); |
| 319 | + |
| 320 | + auto *ExistingSpecLocIter = AccessSpecifierLocations.find(AS); |
| 321 | + bool ASExists = ExistingSpecLocIter != AccessSpecifierLocations.end(); |
| 322 | + if (ASExists) { |
| 323 | + // Access specifier section already exists in the class. |
| 324 | + // Get location immediately *after* the colon. |
| 325 | + SourceLocation InsertLoc = |
| 326 | + ExistingSpecLocIter->second.getLocWithOffset(1); |
| 327 | + |
| 328 | + // Create a replacement to insert the method declarations. |
| 329 | + // The replacement is at InsertLoc, has length 0 (insertion), and uses |
| 330 | + // InsertionText. |
| 331 | + std::string InsertionText = MethodsGroupString; |
| 332 | + tooling::Replacement Rep(SM, InsertLoc, 0, InsertionText); |
| 333 | + if (auto Err = EditReplacements.add(Rep)) |
| 334 | + return llvm::Expected<Tweak::Effect>(std::move(Err)); |
| 335 | + } else { |
| 336 | + // Access specifier section does not exist in the class. |
| 337 | + // These methods will be grouped into NewSectionsToAppendText and added |
| 338 | + // towards the end of the class definition. |
| 339 | + NewSectionsToAppendText += |
| 340 | + getAccessSpelling(AS).str() + ':' + MethodsGroupString; |
| 341 | + } |
| 342 | + } |
| 343 | + |
| 344 | + // After processing all access specifiers, add any newly created sections |
| 345 | + // (stored in NewSectionsToAppendText) to the end of the class. |
| 346 | + if (!NewSectionsToAppendText.empty()) { |
| 347 | + // AppendLoc is the SourceLocation of the closing brace '}' of the class. |
| 348 | + // The replacement will insert text *before* this closing brace. |
| 349 | + SourceLocation AppendLoc = CurrentDeclDef->getBraceRange().getEnd(); |
| 350 | + std::string FinalAppendText = std::move(NewSectionsToAppendText); |
| 351 | + |
| 352 | + if (!CurrentDeclDef->decls_empty() || !EditReplacements.empty()) { |
| 353 | + FinalAppendText = '\n' + FinalAppendText; |
| 354 | + } |
| 355 | + |
| 356 | + // Create a replacement to append the new sections. |
| 357 | + tooling::Replacement Rep(SM, AppendLoc, 0, FinalAppendText); |
| 358 | + if (auto Err = EditReplacements.add(Rep)) |
| 359 | + return llvm::Expected<Tweak::Effect>(std::move(Err)); |
| 360 | + } |
| 361 | + |
| 362 | + if (EditReplacements.empty()) { |
| 363 | + return llvm::make_error<llvm::StringError>( |
| 364 | + "No changes to apply (internal error or no methods generated).", |
| 365 | + llvm::inconvertibleErrorCode()); |
| 366 | + } |
| 367 | + |
| 368 | + // Return the collected replacements as the effect of this tweak. |
| 369 | + return Effect::mainFileEdit(SM, EditReplacements); |
| 370 | +} |
| 371 | + |
| 372 | +} // namespace |
| 373 | +} // namespace clangd |
| 374 | +} // namespace clang |
0 commit comments