Skip to content

Commit cce239a

Browse files
jpienaarEugeneZelenkolocalspook
authored andcommitted
[clang-tidy][mlir] Expand to cover pointer of builder (llvm#159423)
Previously this only checked for OpBuilder usage, but it could also be invoked via pointer. Also change how call range is calculated to avoid false overlaps which limits rewriting builder calls inside arguments of builder calls. --------- Co-authored-by: EugeneZelenko <[email protected]> Co-authored-by: Victor Chernyakin <[email protected]>
1 parent 86093d8 commit cce239a

File tree

2 files changed

+85
-36
lines changed

2 files changed

+85
-36
lines changed

clang-tools-extra/clang-tidy/llvm/UseNewMLIROpBuilderCheck.cpp

Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,15 @@
1818
#include "llvm/Support/FormatVariadic.h"
1919

2020
namespace clang::tidy::llvm_check {
21-
namespace {
2221

2322
using namespace ::clang::ast_matchers;
2423
using namespace ::clang::transformer;
2524

26-
EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,
27-
RangeSelector CallArgs) {
25+
static EditGenerator rewrite(RangeSelector Call, RangeSelector Builder) {
2826
// This is using an EditGenerator rather than ASTEdit as we want to warn even
2927
// if in macro.
30-
return [Call = std::move(Call), Builder = std::move(Builder),
31-
CallArgs =
32-
std::move(CallArgs)](const MatchFinder::MatchResult &Result)
28+
return [Call = std::move(Call),
29+
Builder = std::move(Builder)](const MatchFinder::MatchResult &Result)
3330
-> Expected<SmallVector<transformer::Edit, 1>> {
3431
Expected<CharSourceRange> CallRange = Call(Result);
3532
if (!CallRange)
@@ -54,7 +51,7 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,
5451
auto NextToken = [&](std::optional<Token> CurrentToken) {
5552
if (!CurrentToken)
5653
return CurrentToken;
57-
if (CurrentToken->getEndLoc() >= CallRange->getEnd())
54+
if (CurrentToken->is(clang::tok::eof))
5855
return std::optional<Token>();
5956
return clang::Lexer::findNextToken(CurrentToken->getLocation(), SM,
6057
LangOpts);
@@ -68,9 +65,10 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,
6865
return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
6966
"missing '<' token");
7067
}
68+
7169
std::optional<Token> EndToken = NextToken(LessToken);
72-
for (std::optional<Token> GreaterToken = NextToken(EndToken);
73-
GreaterToken && GreaterToken->getKind() != clang::tok::greater;
70+
std::optional<Token> GreaterToken = NextToken(EndToken);
71+
for (; GreaterToken && GreaterToken->getKind() != clang::tok::greater;
7472
GreaterToken = NextToken(GreaterToken)) {
7573
EndToken = GreaterToken;
7674
}
@@ -79,12 +77,21 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,
7977
"missing '>' token");
8078
}
8179

80+
std::optional<Token> ArgStart = NextToken(GreaterToken);
81+
if (!ArgStart || ArgStart->getKind() != clang::tok::l_paren) {
82+
return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
83+
"missing '(' token");
84+
}
85+
std::optional<Token> Arg = NextToken(ArgStart);
86+
if (!Arg) {
87+
return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
88+
"unexpected end of file");
89+
}
90+
const bool HasArgs = Arg->getKind() != clang::tok::r_paren;
91+
8292
Expected<CharSourceRange> BuilderRange = Builder(Result);
8393
if (!BuilderRange)
8494
return BuilderRange.takeError();
85-
Expected<CharSourceRange> CallArgsRange = CallArgs(Result);
86-
if (!CallArgsRange)
87-
return CallArgsRange.takeError();
8895

8996
// Helper for concatting below.
9097
auto GetText = [&](const CharSourceRange &Range) {
@@ -93,43 +100,42 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,
93100

94101
Edit Replace;
95102
Replace.Kind = EditKind::Range;
96-
Replace.Range = *CallRange;
97-
std::string CallArgsStr;
98-
// Only emit args if there are any.
99-
if (auto CallArgsText = GetText(*CallArgsRange).ltrim();
100-
!CallArgsText.rtrim().empty()) {
101-
CallArgsStr = llvm::formatv(", {}", CallArgsText);
103+
Replace.Range.setBegin(CallRange->getBegin());
104+
Replace.Range.setEnd(ArgStart->getEndLoc());
105+
const Expr *BuilderExpr = Result.Nodes.getNodeAs<Expr>("builder");
106+
std::string BuilderText = GetText(*BuilderRange).str();
107+
if (BuilderExpr->getType()->isPointerType()) {
108+
BuilderText = BuilderExpr->isImplicitCXXThis()
109+
? "*this"
110+
: llvm::formatv("*{}", BuilderText).str();
102111
}
103-
Replace.Replacement =
104-
llvm::formatv("{}::create({}{})",
105-
GetText(CharSourceRange::getTokenRange(
106-
LessToken->getEndLoc(), EndToken->getLastLoc())),
107-
GetText(*BuilderRange), CallArgsStr);
112+
const StringRef OpType = GetText(CharSourceRange::getTokenRange(
113+
LessToken->getEndLoc(), EndToken->getLastLoc()));
114+
Replace.Replacement = llvm::formatv("{}::create({}{}", OpType, BuilderText,
115+
HasArgs ? ", " : "");
108116

109117
return SmallVector<Edit, 1>({Replace});
110118
};
111119
}
112120

113-
RewriteRuleWith<std::string> useNewMlirOpBuilderCheckRule() {
121+
static RewriteRuleWith<std::string> useNewMlirOpBuilderCheckRule() {
114122
Stencil Message = cat("use 'OpType::create(builder, ...)' instead of "
115123
"'builder.create<OpType>(...)'");
116124
// Match a create call on an OpBuilder.
125+
auto BuilderType = cxxRecordDecl(isSameOrDerivedFrom("::mlir::OpBuilder"));
117126
ast_matchers::internal::Matcher<Stmt> Base =
118127
cxxMemberCallExpr(
119-
on(expr(hasType(
120-
cxxRecordDecl(isSameOrDerivedFrom("::mlir::OpBuilder"))))
128+
on(expr(anyOf(hasType(BuilderType), hasType(pointsTo(BuilderType))))
121129
.bind("builder")),
122-
callee(cxxMethodDecl(hasTemplateArgument(0, templateArgument()))),
123-
callee(cxxMethodDecl(hasName("create"))))
130+
callee(cxxMethodDecl(hasTemplateArgument(0, templateArgument()),
131+
hasName("create"))))
124132
.bind("call");
125133
return applyFirst(
126134
// Attempt rewrite given an lvalue builder, else just warn.
127135
{makeRule(cxxMemberCallExpr(unless(on(cxxTemporaryObjectExpr())), Base),
128-
rewrite(node("call"), node("builder"), callArgs("call")),
129-
Message),
136+
rewrite(node("call"), node("builder")), Message),
130137
makeRule(Base, noopEdit(node("call")), Message)});
131138
}
132-
} // namespace
133139

134140
UseNewMlirOpBuilderCheck::UseNewMlirOpBuilderCheck(StringRef Name,
135141
ClangTidyContext *Context)

clang-tools-extra/test/clang-tidy/checkers/llvm/use-new-mlir-op-builder.cpp

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
namespace mlir {
44
class Location {};
5+
class Value {};
56
class OpBuilder {
67
public:
78
template <typename OpTy, typename... Args>
@@ -28,6 +29,13 @@ struct NamedOp {
2829
static NamedOp create(OpBuilder &builder, Location location, const char* name) {
2930
return NamedOp(name);
3031
}
32+
Value getResult() { return Value(); }
33+
};
34+
struct OperandOp {
35+
OperandOp(Value val) {}
36+
static OperandOp create(OpBuilder &builder, Location location, Value val) {
37+
return OperandOp(val);
38+
}
3139
};
3240
} // namespace mlir
3341

@@ -40,22 +48,41 @@ void g(mlir::OpBuilder &b) {
4048
b.create<T>(b.getUnknownLoc(), "gaz");
4149
}
4250

51+
class CustomBuilder : public mlir::ImplicitLocOpBuilder {
52+
public:
53+
mlir::NamedOp f(const char *name) {
54+
// CHECK-MESSAGES: :[[@LINE+2]]:12: warning: use 'OpType::create(builder, ...)'
55+
// CHECK-FIXES: return mlir::NamedOp::create(*this, name);
56+
return create<mlir::NamedOp>(name);
57+
}
58+
59+
mlir::NamedOp g(const char *name) {
60+
using mlir::NamedOp;
61+
// CHECK-MESSAGES: :[[@LINE+2]]:12: warning: use 'OpType::create(builder, ...)'
62+
// CHECK-FIXES: return NamedOp::create(*this, name);
63+
return create<NamedOp>(name);
64+
}
65+
};
66+
4367
void f() {
4468
mlir::OpBuilder builder;
4569
// CHECK-MESSAGES: :[[@LINE+2]]:3: warning: use 'OpType::create(builder, ...)' instead of 'builder.create<OpType>(...)' [llvm-use-new-mlir-op-builder]
4670
// CHECK-FIXES: mlir:: ModuleOp::create(builder, builder.getUnknownLoc());
4771
builder.create<mlir:: ModuleOp>(builder.getUnknownLoc());
4872

4973
using mlir::NamedOp;
74+
using mlir::OperandOp;
75+
5076
// CHECK-MESSAGES: :[[@LINE+2]]:3: warning: use 'OpType::create(builder, ...)' instead of 'builder.create<OpType>(...)' [llvm-use-new-mlir-op-builder]
5177
// CHECK-FIXES: NamedOp::create(builder, builder.getUnknownLoc(), "baz");
5278
builder.create<NamedOp>(builder.getUnknownLoc(), "baz");
5379

54-
// CHECK-MESSAGES: :[[@LINE+3]]:3: warning: use 'OpType::create(builder, ...)' instead of 'builder.create<OpType>(...)' [llvm-use-new-mlir-op-builder]
55-
// CHECK-FIXES: NamedOp::create(builder, builder.getUnknownLoc(),
56-
// CHECK-FIXES: "caz");
80+
// CHECK-MESSAGES: :[[@LINE+4]]:3: warning: use 'OpType::create(builder, ...)' instead of 'builder.create<OpType>(...)' [llvm-use-new-mlir-op-builder]
81+
// CHECK-FIXES: NamedOp::create(builder,
82+
// CHECK-FIXES: builder.getUnknownLoc(),
83+
// CHECK-FIXES: "caz");
5784
builder.
58-
create<NamedOp>(
85+
create<NamedOp> (
5986
builder.getUnknownLoc(),
6087
"caz");
6188

@@ -66,10 +93,26 @@ void f() {
6693

6794
mlir::ImplicitLocOpBuilder ib;
6895
// CHECK-MESSAGES: :[[@LINE+2]]:3: warning: use 'OpType::create(builder, ...)' instead of 'builder.create<OpType>(...)' [llvm-use-new-mlir-op-builder]
69-
// CHECK-FIXES: mlir::ModuleOp::create(ib);
96+
// CHECK-FIXES: mlir::ModuleOp::create(ib );
7097
ib.create<mlir::ModuleOp>( );
7198

7299
// CHECK-MESSAGES: :[[@LINE+2]]:3: warning: use 'OpType::create(builder, ...)' instead of 'builder.create<OpType>(...)' [llvm-use-new-mlir-op-builder]
73100
// CHECK-FIXES: mlir::OpBuilder().create<mlir::ModuleOp>(builder.getUnknownLoc());
74101
mlir::OpBuilder().create<mlir::ModuleOp>(builder.getUnknownLoc());
102+
103+
auto *p = &builder;
104+
// CHECK-MESSAGES: :[[@LINE+2]]:3: warning: use 'OpType::create(builder, ...)'
105+
// CHECK-FIXES: NamedOp::create(*p, builder.getUnknownLoc(), "eaz");
106+
p->create<NamedOp>(builder.getUnknownLoc(), "eaz");
107+
108+
CustomBuilder cb;
109+
cb.f("faz");
110+
cb.g("gaz");
111+
112+
// CHECK-FIXES: OperandOp::create(builder, builder.getUnknownLoc(),
113+
// CHECK-FIXES-NEXT: NamedOp::create(builder, builder.getUnknownLoc(), "haz").getResult());
114+
// CHECK-MESSAGES: :[[@LINE+2]]:3: warning: use 'OpType::create(builder, ...)' instead of 'builder.create<OpType>(...)' [llvm-use-new-mlir-op-builder]
115+
// CHECK-MESSAGES: :[[@LINE+2]]:5: warning: use 'OpType::create(builder, ...)' instead of 'builder.create<OpType>(...)' [llvm-use-new-mlir-op-builder]
116+
builder.create<OperandOp>(builder.getUnknownLoc(),
117+
builder.create<NamedOp>(builder.getUnknownLoc(), "haz").getResult());
75118
}

0 commit comments

Comments
 (0)