Skip to content

Commit f23c42b

Browse files
committed
[clang-tidy][mlir] Expand to cover pointer of builder
Previously this only checked for OpBuilder usage, but it could also be invoked via pointer. Also change how call range is calculated to avoid false overlaps which limits rewriting builder calls inside arguments of builder calls.
1 parent fe8e703 commit f23c42b

File tree

2 files changed

+75
-34
lines changed

2 files changed

+75
-34
lines changed

clang-tools-extra/clang-tidy/llvm/UseNewMLIROpBuilderCheck.cpp

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,11 @@ namespace {
2323
using namespace ::clang::ast_matchers;
2424
using namespace ::clang::transformer;
2525

26-
EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,
27-
RangeSelector CallArgs) {
26+
EditGenerator rewrite(RangeSelector Call, RangeSelector Builder) {
2827
// This is using an EditGenerator rather than ASTEdit as we want to warn even
2928
// if in macro.
30-
return [Call = std::move(Call), Builder = std::move(Builder),
31-
CallArgs =
32-
std::move(CallArgs)](const MatchFinder::MatchResult &Result)
29+
return [Call = std::move(Call),
30+
Builder = std::move(Builder)](const MatchFinder::MatchResult &Result)
3331
-> Expected<SmallVector<transformer::Edit, 1>> {
3432
Expected<CharSourceRange> CallRange = Call(Result);
3533
if (!CallRange)
@@ -54,7 +52,7 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,
5452
auto NextToken = [&](std::optional<Token> CurrentToken) {
5553
if (!CurrentToken)
5654
return CurrentToken;
57-
if (CurrentToken->getEndLoc() >= CallRange->getEnd())
55+
if (CurrentToken->is(clang::tok::eof))
5856
return std::optional<Token>();
5957
return clang::Lexer::findNextToken(CurrentToken->getLocation(), SM,
6058
LangOpts);
@@ -68,9 +66,10 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,
6866
return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
6967
"missing '<' token");
7068
}
69+
7170
std::optional<Token> EndToken = NextToken(LessToken);
72-
for (std::optional<Token> GreaterToken = NextToken(EndToken);
73-
GreaterToken && GreaterToken->getKind() != clang::tok::greater;
71+
std::optional<Token> GreaterToken = NextToken(EndToken);
72+
for (; GreaterToken && GreaterToken->getKind() != clang::tok::greater;
7473
GreaterToken = NextToken(GreaterToken)) {
7574
EndToken = GreaterToken;
7675
}
@@ -79,12 +78,21 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,
7978
"missing '>' token");
8079
}
8180

81+
std::optional<Token> ArgStart = NextToken(GreaterToken);
82+
if (!ArgStart || ArgStart->getKind() != clang::tok::l_paren) {
83+
return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
84+
"missing '(' token");
85+
}
86+
std::optional<Token> Arg = NextToken(ArgStart);
87+
if (!Arg) {
88+
return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
89+
"unexpected end of file");
90+
}
91+
bool hasArgs = Arg->getKind() != clang::tok::r_paren;
92+
8293
Expected<CharSourceRange> BuilderRange = Builder(Result);
8394
if (!BuilderRange)
8495
return BuilderRange.takeError();
85-
Expected<CharSourceRange> CallArgsRange = CallArgs(Result);
86-
if (!CallArgsRange)
87-
return CallArgsRange.takeError();
8896

8997
// Helper for concatting below.
9098
auto GetText = [&](const CharSourceRange &Range) {
@@ -93,18 +101,19 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,
93101

94102
Edit Replace;
95103
Replace.Kind = EditKind::Range;
96-
Replace.Range = *CallRange;
97-
std::string CallArgsStr;
98-
// Only emit args if there are any.
99-
if (auto CallArgsText = GetText(*CallArgsRange).ltrim();
100-
!CallArgsText.rtrim().empty()) {
101-
CallArgsStr = llvm::formatv(", {}", CallArgsText);
104+
Replace.Range.setBegin(CallRange->getBegin());
105+
Replace.Range.setEnd(ArgStart->getEndLoc());
106+
const Expr *BuilderExpr = Result.Nodes.getNodeAs<Expr>("builder");
107+
std::string BuilderText = GetText(*BuilderRange).str();
108+
if (BuilderExpr->getType()->isPointerType()) {
109+
BuilderText = BuilderExpr->isImplicitCXXThis()
110+
? "*this"
111+
: llvm::formatv("*{}", BuilderText).str();
102112
}
103-
Replace.Replacement =
104-
llvm::formatv("{}::create({}{})",
105-
GetText(CharSourceRange::getTokenRange(
106-
LessToken->getEndLoc(), EndToken->getLastLoc())),
107-
GetText(*BuilderRange), CallArgsStr);
113+
StringRef OpType = GetText(CharSourceRange::getTokenRange(
114+
LessToken->getEndLoc(), EndToken->getLastLoc()));
115+
Replace.Replacement = llvm::formatv("{}::create({}{}", OpType, BuilderText,
116+
hasArgs ? ", " : "");
108117

109118
return SmallVector<Edit, 1>({Replace});
110119
};
@@ -114,19 +123,17 @@ RewriteRuleWith<std::string> useNewMlirOpBuilderCheckRule() {
114123
Stencil message = cat("use 'OpType::create(builder, ...)' instead of "
115124
"'builder.create<OpType>(...)'");
116125
// Match a create call on an OpBuilder.
117-
ast_matchers::internal::Matcher<Stmt> base =
118-
cxxMemberCallExpr(
119-
on(expr(hasType(
120-
cxxRecordDecl(isSameOrDerivedFrom("::mlir::OpBuilder"))))
121-
.bind("builder")),
122-
callee(cxxMethodDecl(hasTemplateArgument(0, templateArgument()))),
123-
callee(cxxMethodDecl(hasName("create"))))
124-
.bind("call");
126+
auto BuilderType = cxxRecordDecl(isSameOrDerivedFrom("::mlir::OpBuilder"));
127+
ast_matchers::internal::Matcher<Stmt> base = cxxMemberCallExpr(
128+
on(expr(anyOf(hasType(BuilderType), hasType(pointsTo(BuilderType))))
129+
.bind("builder")),
130+
callee(expr().bind("call")),
131+
callee(cxxMethodDecl(hasTemplateArgument(0, templateArgument()))),
132+
callee(cxxMethodDecl(hasName("create"))));
125133
return applyFirst(
126134
// Attempt rewrite given an lvalue builder, else just warn.
127135
{makeRule(cxxMemberCallExpr(unless(on(cxxTemporaryObjectExpr())), base),
128-
rewrite(node("call"), node("builder"), callArgs("call")),
129-
message),
136+
rewrite(node("call"), node("builder")), message),
130137
makeRule(base, noopEdit(node("call")), message)});
131138
}
132139
} // namespace

clang-tools-extra/test/clang-tidy/checkers/llvm/use-new-mlir-op-builder.cpp

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
namespace mlir {
44
class Location {};
5+
class Value {};
56
class OpBuilder {
67
public:
78
template <typename OpTy, typename... Args>
@@ -28,6 +29,13 @@ struct NamedOp {
2829
static NamedOp create(OpBuilder &builder, Location location, const char* name) {
2930
return NamedOp(name);
3031
}
32+
Value getResult() { return Value(); }
33+
};
34+
struct OperandOp {
35+
OperandOp(Value val) {}
36+
static OperandOp create(OpBuilder &builder, Location location, Value val) {
37+
return OperandOp(val);
38+
}
3139
};
3240
} // namespace mlir
3341

@@ -40,13 +48,24 @@ void g(mlir::OpBuilder &b) {
4048
b.create<T>(b.getUnknownLoc(), "gaz");
4149
}
4250

51+
class CustomBuilder : public mlir::ImplicitLocOpBuilder {
52+
public:
53+
mlir::NamedOp f(const char *name) {
54+
// CHECK-MESSAGES: :[[@LINE+2]]:12: warning: use 'OpType::create(builder, ...)'
55+
// CHECK-FIXES: NamedOp::create(*this, name);
56+
return create<mlir::NamedOp>(name);
57+
}
58+
};
59+
4360
void f() {
4461
mlir::OpBuilder builder;
4562
// CHECK-MESSAGES: :[[@LINE+2]]:3: warning: use 'OpType::create(builder, ...)' instead of 'builder.create<OpType>(...)' [llvm-use-new-mlir-op-builder]
4663
// CHECK-FIXES: mlir:: ModuleOp::create(builder, builder.getUnknownLoc())
4764
builder.create<mlir:: ModuleOp>(builder.getUnknownLoc());
4865

4966
using mlir::NamedOp;
67+
using mlir::OperandOp;
68+
5069
// CHECK-MESSAGES: :[[@LINE+2]]:3: warning: use 'OpType::create(builder, ...)' instead of 'builder.create<OpType>(...)' [llvm-use-new-mlir-op-builder]
5170
// CHECK-FIXES: NamedOp::create(builder, builder.getUnknownLoc(), "baz")
5271
builder.create<NamedOp>(builder.getUnknownLoc(), "baz");
@@ -56,7 +75,7 @@ void f() {
5675
// CHECK-FIXES: builder.getUnknownLoc(),
5776
// CHECK-FIXES: "caz")
5877
builder.
59-
create<NamedOp>(
78+
create<NamedOp> (
6079
builder.getUnknownLoc(),
6180
"caz");
6281

@@ -67,10 +86,25 @@ void f() {
6786

6887
mlir::ImplicitLocOpBuilder ib;
6988
// CHECK-MESSAGES: :[[@LINE+2]]:3: warning: use 'OpType::create(builder, ...)' instead of 'builder.create<OpType>(...)' [llvm-use-new-mlir-op-builder]
70-
// CHECK-FIXES: mlir::ModuleOp::create(ib)
89+
// CHECK-FIXES: mlir::ModuleOp::create(ib )
7190
ib.create<mlir::ModuleOp>( );
7291

7392
// CHECK-MESSAGES: :[[@LINE+2]]:3: warning: use 'OpType::create(builder, ...)' instead of 'builder.create<OpType>(...)' [llvm-use-new-mlir-op-builder]
7493
// CHECK-FIXES: mlir::OpBuilder().create<mlir::ModuleOp>(builder.getUnknownLoc());
7594
mlir::OpBuilder().create<mlir::ModuleOp>(builder.getUnknownLoc());
95+
96+
auto *p = &builder;
97+
// CHECK-MESSAGES: :[[@LINE+2]]:3: warning: use 'OpType::create(builder, ...)'
98+
// CHECK-FIXES: NamedOp::create(*p, builder.getUnknownLoc(), "eaz")
99+
p->create<NamedOp>(builder.getUnknownLoc(), "eaz");
100+
101+
CustomBuilder cb;
102+
cb.f("faz");
103+
104+
// CHECK-MESSAGES: :[[@LINE+4]]:3: warning: use 'OpType::create(builder, ...)' instead of 'builder.create<OpType>(...)' [llvm-use-new-mlir-op-builder]
105+
// CHECK-FIXES: OperandOp::create(builder, builder.getUnknownLoc(),
106+
// CHECK-MESSAGES: :[[@LINE+3]]:5: warning: use 'OpType::create(builder, ...)' instead of 'builder.create<OpType>(...)' [llvm-use-new-mlir-op-builder]
107+
// CHECK-FIXES: NamedOp::create(builder,
108+
builder.create<OperandOp>(builder.getUnknownLoc(),
109+
builder.create<NamedOp>(builder.getUnknownLoc(), "gaz").getResult());
76110
}

0 commit comments

Comments
 (0)