@@ -23,13 +23,11 @@ namespace {
23
23
using namespace ::clang::ast_matchers;
24
24
using namespace ::clang::transformer;
25
25
26
- EditGenerator rewrite (RangeSelector Call, RangeSelector Builder,
27
- RangeSelector CallArgs) {
26
+ EditGenerator rewrite (RangeSelector Call, RangeSelector Builder) {
28
27
// This is using an EditGenerator rather than ASTEdit as we want to warn even
29
28
// if in macro.
30
- return [Call = std::move (Call), Builder = std::move (Builder),
31
- CallArgs =
32
- std::move (CallArgs)](const MatchFinder::MatchResult &Result)
29
+ return [Call = std::move (Call),
30
+ Builder = std::move (Builder)](const MatchFinder::MatchResult &Result)
33
31
-> Expected<SmallVector<transformer::Edit, 1 >> {
34
32
Expected<CharSourceRange> CallRange = Call (Result);
35
33
if (!CallRange)
@@ -54,7 +52,7 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,
54
52
auto NextToken = [&](std::optional<Token> CurrentToken) {
55
53
if (!CurrentToken)
56
54
return CurrentToken;
57
- if (CurrentToken->getEndLoc () >= CallRange-> getEnd ( ))
55
+ if (CurrentToken->is (clang::tok::eof ))
58
56
return std::optional<Token>();
59
57
return clang::Lexer::findNextToken (CurrentToken->getLocation (), SM,
60
58
LangOpts);
@@ -68,9 +66,10 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,
68
66
return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
69
67
" missing '<' token" );
70
68
}
69
+
71
70
std::optional<Token> EndToken = NextToken (LessToken);
72
- for ( std::optional<Token> GreaterToken = NextToken (EndToken);
73
- GreaterToken && GreaterToken->getKind () != clang::tok::greater;
71
+ std::optional<Token> GreaterToken = NextToken (EndToken);
72
+ for (; GreaterToken && GreaterToken->getKind () != clang::tok::greater;
74
73
GreaterToken = NextToken (GreaterToken)) {
75
74
EndToken = GreaterToken;
76
75
}
@@ -79,12 +78,21 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,
79
78
" missing '>' token" );
80
79
}
81
80
81
+ std::optional<Token> ArgStart = NextToken (GreaterToken);
82
+ if (!ArgStart || ArgStart->getKind () != clang::tok::l_paren) {
83
+ return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
84
+ " missing '(' token" );
85
+ }
86
+ std::optional<Token> Arg = NextToken (ArgStart);
87
+ if (!Arg) {
88
+ return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
89
+ " unexpected end of file" );
90
+ }
91
+ bool hasArgs = Arg->getKind () != clang::tok::r_paren;
92
+
82
93
Expected<CharSourceRange> BuilderRange = Builder (Result);
83
94
if (!BuilderRange)
84
95
return BuilderRange.takeError ();
85
- Expected<CharSourceRange> CallArgsRange = CallArgs (Result);
86
- if (!CallArgsRange)
87
- return CallArgsRange.takeError ();
88
96
89
97
// Helper for concatting below.
90
98
auto GetText = [&](const CharSourceRange &Range) {
@@ -93,18 +101,19 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,
93
101
94
102
Edit Replace;
95
103
Replace.Kind = EditKind::Range;
96
- Replace.Range = *CallRange;
97
- std::string CallArgsStr;
98
- // Only emit args if there are any.
99
- if (auto CallArgsText = GetText (*CallArgsRange).ltrim ();
100
- !CallArgsText.rtrim ().empty ()) {
101
- CallArgsStr = llvm::formatv (" , {}" , CallArgsText);
104
+ Replace.Range .setBegin (CallRange->getBegin ());
105
+ Replace.Range .setEnd (ArgStart->getEndLoc ());
106
+ const Expr *BuilderExpr = Result.Nodes .getNodeAs <Expr>(" builder" );
107
+ std::string BuilderText = GetText (*BuilderRange).str ();
108
+ if (BuilderExpr->getType ()->isPointerType ()) {
109
+ BuilderText = BuilderExpr->isImplicitCXXThis ()
110
+ ? " *this"
111
+ : llvm::formatv (" *{}" , BuilderText).str ();
102
112
}
103
- Replace.Replacement =
104
- llvm::formatv (" {}::create({}{})" ,
105
- GetText (CharSourceRange::getTokenRange (
106
- LessToken->getEndLoc (), EndToken->getLastLoc ())),
107
- GetText (*BuilderRange), CallArgsStr);
113
+ StringRef OpType = GetText (CharSourceRange::getTokenRange (
114
+ LessToken->getEndLoc (), EndToken->getLastLoc ()));
115
+ Replace.Replacement = llvm::formatv (" {}::create({}{}" , OpType, BuilderText,
116
+ hasArgs ? " , " : " " );
108
117
109
118
return SmallVector<Edit, 1 >({Replace});
110
119
};
@@ -114,19 +123,17 @@ RewriteRuleWith<std::string> useNewMlirOpBuilderCheckRule() {
114
123
Stencil message = cat (" use 'OpType::create(builder, ...)' instead of "
115
124
" 'builder.create<OpType>(...)'" );
116
125
// Match a create call on an OpBuilder.
117
- ast_matchers::internal::Matcher<Stmt> base =
118
- cxxMemberCallExpr (
119
- on (expr (hasType (
120
- cxxRecordDecl (isSameOrDerivedFrom (" ::mlir::OpBuilder" ))))
121
- .bind (" builder" )),
122
- callee (cxxMethodDecl (hasTemplateArgument (0 , templateArgument ()))),
123
- callee (cxxMethodDecl (hasName (" create" ))))
124
- .bind (" call" );
126
+ auto BuilderType = cxxRecordDecl (isSameOrDerivedFrom (" ::mlir::OpBuilder" ));
127
+ ast_matchers::internal::Matcher<Stmt> base = cxxMemberCallExpr (
128
+ on (expr (anyOf (hasType (BuilderType), hasType (pointsTo (BuilderType))))
129
+ .bind (" builder" )),
130
+ callee (expr ().bind (" call" )),
131
+ callee (cxxMethodDecl (hasTemplateArgument (0 , templateArgument ()))),
132
+ callee (cxxMethodDecl (hasName (" create" ))));
125
133
return applyFirst (
126
134
// Attempt rewrite given an lvalue builder, else just warn.
127
135
{makeRule (cxxMemberCallExpr (unless (on (cxxTemporaryObjectExpr ())), base),
128
- rewrite (node (" call" ), node (" builder" ), callArgs (" call" )),
129
- message),
136
+ rewrite (node (" call" ), node (" builder" )), message),
130
137
makeRule (base, noopEdit (node (" call" )), message)});
131
138
}
132
139
} // namespace
0 commit comments