1818#include " llvm/Support/FormatVariadic.h"
1919
2020namespace clang ::tidy::llvm_check {
21- namespace {
2221
2322using namespace ::clang::ast_matchers;
2423using namespace ::clang::transformer;
2524
26- EditGenerator rewrite (RangeSelector Call, RangeSelector Builder,
27- RangeSelector CallArgs) {
25+ static EditGenerator rewrite (RangeSelector Call, RangeSelector Builder) {
2826 // This is using an EditGenerator rather than ASTEdit as we want to warn even
2927 // if in macro.
30- return [Call = std::move (Call), Builder = std::move (Builder),
31- CallArgs =
32- std::move (CallArgs)](const MatchFinder::MatchResult &Result)
28+ return [Call = std::move (Call),
29+ Builder = std::move (Builder)](const MatchFinder::MatchResult &Result)
3330 -> Expected<SmallVector<transformer::Edit, 1 >> {
3431 Expected<CharSourceRange> CallRange = Call (Result);
3532 if (!CallRange)
@@ -54,7 +51,7 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,
5451 auto NextToken = [&](std::optional<Token> CurrentToken) {
5552 if (!CurrentToken)
5653 return CurrentToken;
57- if (CurrentToken->getEndLoc () >= CallRange-> getEnd ( ))
54+ if (CurrentToken->is (clang::tok::eof ))
5855 return std::optional<Token>();
5956 return clang::Lexer::findNextToken (CurrentToken->getLocation (), SM,
6057 LangOpts);
@@ -68,9 +65,10 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,
6865 return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
6966 " missing '<' token" );
7067 }
68+
7169 std::optional<Token> EndToken = NextToken (LessToken);
72- for ( std::optional<Token> GreaterToken = NextToken (EndToken);
73- GreaterToken && GreaterToken->getKind () != clang::tok::greater;
70+ std::optional<Token> GreaterToken = NextToken (EndToken);
71+ for (; GreaterToken && GreaterToken->getKind () != clang::tok::greater;
7472 GreaterToken = NextToken (GreaterToken)) {
7573 EndToken = GreaterToken;
7674 }
@@ -79,12 +77,21 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,
7977 " missing '>' token" );
8078 }
8179
80+ std::optional<Token> ArgStart = NextToken (GreaterToken);
81+ if (!ArgStart || ArgStart->getKind () != clang::tok::l_paren) {
82+ return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
83+ " missing '(' token" );
84+ }
85+ std::optional<Token> Arg = NextToken (ArgStart);
86+ if (!Arg) {
87+ return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
88+ " unexpected end of file" );
89+ }
90+ const bool HasArgs = Arg->getKind () != clang::tok::r_paren;
91+
8292 Expected<CharSourceRange> BuilderRange = Builder (Result);
8393 if (!BuilderRange)
8494 return BuilderRange.takeError ();
85- Expected<CharSourceRange> CallArgsRange = CallArgs (Result);
86- if (!CallArgsRange)
87- return CallArgsRange.takeError ();
8895
8996 // Helper for concatting below.
9097 auto GetText = [&](const CharSourceRange &Range) {
@@ -93,43 +100,42 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,
93100
94101 Edit Replace;
95102 Replace.Kind = EditKind::Range;
96- Replace.Range = *CallRange;
97- std::string CallArgsStr;
98- // Only emit args if there are any.
99- if (auto CallArgsText = GetText (*CallArgsRange).ltrim ();
100- !CallArgsText.rtrim ().empty ()) {
101- CallArgsStr = llvm::formatv (" , {}" , CallArgsText);
103+ Replace.Range .setBegin (CallRange->getBegin ());
104+ Replace.Range .setEnd (ArgStart->getEndLoc ());
105+ const Expr *BuilderExpr = Result.Nodes .getNodeAs <Expr>(" builder" );
106+ std::string BuilderText = GetText (*BuilderRange).str ();
107+ if (BuilderExpr->getType ()->isPointerType ()) {
108+ BuilderText = BuilderExpr->isImplicitCXXThis ()
109+ ? " *this"
110+ : llvm::formatv (" *{}" , BuilderText).str ();
102111 }
103- Replace.Replacement =
104- llvm::formatv (" {}::create({}{})" ,
105- GetText (CharSourceRange::getTokenRange (
106- LessToken->getEndLoc (), EndToken->getLastLoc ())),
107- GetText (*BuilderRange), CallArgsStr);
112+ const StringRef OpType = GetText (CharSourceRange::getTokenRange (
113+ LessToken->getEndLoc (), EndToken->getLastLoc ()));
114+ Replace.Replacement = llvm::formatv (" {}::create({}{}" , OpType, BuilderText,
115+ HasArgs ? " , " : " " );
108116
109117 return SmallVector<Edit, 1 >({Replace});
110118 };
111119}
112120
113- RewriteRuleWith<std::string> useNewMlirOpBuilderCheckRule () {
121+ static RewriteRuleWith<std::string> useNewMlirOpBuilderCheckRule () {
114122 Stencil Message = cat (" use 'OpType::create(builder, ...)' instead of "
115123 " 'builder.create<OpType>(...)'" );
116124 // Match a create call on an OpBuilder.
125+ auto BuilderType = cxxRecordDecl (isSameOrDerivedFrom (" ::mlir::OpBuilder" ));
117126 ast_matchers::internal::Matcher<Stmt> Base =
118127 cxxMemberCallExpr (
119- on (expr (hasType (
120- cxxRecordDecl (isSameOrDerivedFrom (" ::mlir::OpBuilder" ))))
128+ on (expr (anyOf (hasType (BuilderType), hasType (pointsTo (BuilderType))))
121129 .bind (" builder" )),
122- callee (cxxMethodDecl (hasTemplateArgument (0 , templateArgument ()))) ,
123- callee ( cxxMethodDecl ( hasName (" create" ))))
130+ callee (cxxMethodDecl (hasTemplateArgument (0 , templateArgument ()),
131+ hasName (" create" ))))
124132 .bind (" call" );
125133 return applyFirst (
126134 // Attempt rewrite given an lvalue builder, else just warn.
127135 {makeRule (cxxMemberCallExpr (unless (on (cxxTemporaryObjectExpr ())), Base),
128- rewrite (node (" call" ), node (" builder" ), callArgs (" call" )),
129- Message),
136+ rewrite (node (" call" ), node (" builder" )), Message),
130137 makeRule (Base, noopEdit (node (" call" )), Message)});
131138}
132- } // namespace
133139
134140UseNewMlirOpBuilderCheck::UseNewMlirOpBuilderCheck (StringRef Name,
135141 ClangTidyContext *Context)
0 commit comments