diff --git a/clang-tools-extra/clang-tidy/llvm/UseNewMLIROpBuilderCheck.cpp b/clang-tools-extra/clang-tidy/llvm/UseNewMLIROpBuilderCheck.cpp index 0d81b9a9e38ca..0e868b704a3f3 100644 --- a/clang-tools-extra/clang-tidy/llvm/UseNewMLIROpBuilderCheck.cpp +++ b/clang-tools-extra/clang-tidy/llvm/UseNewMLIROpBuilderCheck.cpp @@ -23,13 +23,11 @@ namespace { using namespace ::clang::ast_matchers; using namespace ::clang::transformer; -EditGenerator rewrite(RangeSelector Call, RangeSelector Builder, - RangeSelector CallArgs) { +EditGenerator rewrite(RangeSelector Call, RangeSelector Builder) { // This is using an EditGenerator rather than ASTEdit as we want to warn even // if in macro. - return [Call = std::move(Call), Builder = std::move(Builder), - CallArgs = - std::move(CallArgs)](const MatchFinder::MatchResult &Result) + return [Call = std::move(Call), + Builder = std::move(Builder)](const MatchFinder::MatchResult &Result) -> Expected> { Expected CallRange = Call(Result); if (!CallRange) @@ -54,7 +52,7 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder, auto NextToken = [&](std::optional CurrentToken) { if (!CurrentToken) return CurrentToken; - if (CurrentToken->getEndLoc() >= CallRange->getEnd()) + if (CurrentToken->is(clang::tok::eof)) return std::optional(); return clang::Lexer::findNextToken(CurrentToken->getLocation(), SM, LangOpts); @@ -68,9 +66,10 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder, return llvm::make_error(llvm::errc::invalid_argument, "missing '<' token"); } + std::optional EndToken = NextToken(LessToken); - for (std::optional GreaterToken = NextToken(EndToken); - GreaterToken && GreaterToken->getKind() != clang::tok::greater; + std::optional GreaterToken = NextToken(EndToken); + for (; GreaterToken && GreaterToken->getKind() != clang::tok::greater; GreaterToken = NextToken(GreaterToken)) { EndToken = GreaterToken; } @@ -79,12 +78,21 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder, "missing '>' token"); } + std::optional ArgStart = NextToken(GreaterToken); + if (!ArgStart || ArgStart->getKind() != clang::tok::l_paren) { + return llvm::make_error(llvm::errc::invalid_argument, + "missing '(' token"); + } + std::optional Arg = NextToken(ArgStart); + if (!Arg) { + return llvm::make_error(llvm::errc::invalid_argument, + "unexpected end of file"); + } + const bool HasArgs = Arg->getKind() != clang::tok::r_paren; + Expected BuilderRange = Builder(Result); if (!BuilderRange) return BuilderRange.takeError(); - Expected CallArgsRange = CallArgs(Result); - if (!CallArgsRange) - return CallArgsRange.takeError(); // Helper for concatting below. auto GetText = [&](const CharSourceRange &Range) { @@ -93,18 +101,19 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder, Edit Replace; Replace.Kind = EditKind::Range; - Replace.Range = *CallRange; - std::string CallArgsStr; - // Only emit args if there are any. - if (auto CallArgsText = GetText(*CallArgsRange).ltrim(); - !CallArgsText.rtrim().empty()) { - CallArgsStr = llvm::formatv(", {}", CallArgsText); + Replace.Range.setBegin(CallRange->getBegin()); + Replace.Range.setEnd(ArgStart->getEndLoc()); + const Expr *BuilderExpr = Result.Nodes.getNodeAs("builder"); + std::string BuilderText = GetText(*BuilderRange).str(); + if (BuilderExpr->getType()->isPointerType()) { + BuilderText = BuilderExpr->isImplicitCXXThis() + ? "*this" + : llvm::formatv("*{}", BuilderText).str(); } - Replace.Replacement = - llvm::formatv("{}::create({}{})", - GetText(CharSourceRange::getTokenRange( - LessToken->getEndLoc(), EndToken->getLastLoc())), - GetText(*BuilderRange), CallArgsStr); + const StringRef OpType = GetText(CharSourceRange::getTokenRange( + LessToken->getEndLoc(), EndToken->getLastLoc())); + Replace.Replacement = llvm::formatv("{}::create({}{}", OpType, BuilderText, + HasArgs ? ", " : ""); return SmallVector({Replace}); }; @@ -114,20 +123,19 @@ RewriteRuleWith useNewMlirOpBuilderCheckRule() { Stencil message = cat("use 'OpType::create(builder, ...)' instead of " "'builder.create(...)'"); // Match a create call on an OpBuilder. - ast_matchers::internal::Matcher base = + auto BuilderType = cxxRecordDecl(isSameOrDerivedFrom("::mlir::OpBuilder")); + ast_matchers::internal::Matcher Base = cxxMemberCallExpr( - on(expr(hasType( - cxxRecordDecl(isSameOrDerivedFrom("::mlir::OpBuilder")))) + on(expr(anyOf(hasType(BuilderType), hasType(pointsTo(BuilderType)))) .bind("builder")), - callee(cxxMethodDecl(hasTemplateArgument(0, templateArgument()))), - callee(cxxMethodDecl(hasName("create")))) + callee(cxxMethodDecl(hasTemplateArgument(0, templateArgument()), + hasName("create")))) .bind("call"); return applyFirst( // Attempt rewrite given an lvalue builder, else just warn. - {makeRule(cxxMemberCallExpr(unless(on(cxxTemporaryObjectExpr())), base), - rewrite(node("call"), node("builder"), callArgs("call")), - message), - makeRule(base, noopEdit(node("call")), message)}); + {makeRule(cxxMemberCallExpr(unless(on(cxxTemporaryObjectExpr())), Base), + rewrite(node("call"), node("builder")), message), + makeRule(Base, noopEdit(node("call")), message)}); } } // namespace diff --git a/clang-tools-extra/test/clang-tidy/checkers/llvm/use-new-mlir-op-builder.cpp b/clang-tools-extra/test/clang-tidy/checkers/llvm/use-new-mlir-op-builder.cpp index 0971a1611e3cb..ea58a6c93e324 100644 --- a/clang-tools-extra/test/clang-tidy/checkers/llvm/use-new-mlir-op-builder.cpp +++ b/clang-tools-extra/test/clang-tidy/checkers/llvm/use-new-mlir-op-builder.cpp @@ -2,6 +2,7 @@ namespace mlir { class Location {}; +class Value {}; class OpBuilder { public: template @@ -28,6 +29,13 @@ struct NamedOp { static NamedOp create(OpBuilder &builder, Location location, const char* name) { return NamedOp(name); } + Value getResult() { return Value(); } +}; +struct OperandOp { + OperandOp(Value val) {} + static OperandOp create(OpBuilder &builder, Location location, Value val) { + return OperandOp(val); + } }; } // namespace mlir @@ -40,6 +48,15 @@ void g(mlir::OpBuilder &b) { b.create(b.getUnknownLoc(), "gaz"); } +class CustomBuilder : public mlir::ImplicitLocOpBuilder { +public: + mlir::NamedOp f(const char *name) { + // CHECK-MESSAGES: :[[@LINE+2]]:12: warning: use 'OpType::create(builder, ...)' + // CHECK-FIXES: mlir::NamedOp::create(*this, name); + return create(name); + } +}; + void f() { mlir::OpBuilder builder; // CHECK-MESSAGES: :[[@LINE+2]]:3: warning: use 'OpType::create(builder, ...)' instead of 'builder.create(...)' [llvm-use-new-mlir-op-builder] @@ -47,6 +64,8 @@ void f() { builder.create(builder.getUnknownLoc()); using mlir::NamedOp; + using mlir::OperandOp; + // CHECK-MESSAGES: :[[@LINE+2]]:3: warning: use 'OpType::create(builder, ...)' instead of 'builder.create(...)' [llvm-use-new-mlir-op-builder] // CHECK-FIXES: NamedOp::create(builder, builder.getUnknownLoc(), "baz") builder.create(builder.getUnknownLoc(), "baz"); @@ -56,7 +75,7 @@ void f() { // CHECK-FIXES: builder.getUnknownLoc(), // CHECK-FIXES: "caz") builder. - create( + create ( builder.getUnknownLoc(), "caz"); @@ -67,10 +86,25 @@ void f() { mlir::ImplicitLocOpBuilder ib; // CHECK-MESSAGES: :[[@LINE+2]]:3: warning: use 'OpType::create(builder, ...)' instead of 'builder.create(...)' [llvm-use-new-mlir-op-builder] - // CHECK-FIXES: mlir::ModuleOp::create(ib) + // CHECK-FIXES: mlir::ModuleOp::create(ib ) ib.create( ); // CHECK-MESSAGES: :[[@LINE+2]]:3: warning: use 'OpType::create(builder, ...)' instead of 'builder.create(...)' [llvm-use-new-mlir-op-builder] // CHECK-FIXES: mlir::OpBuilder().create(builder.getUnknownLoc()); mlir::OpBuilder().create(builder.getUnknownLoc()); + + auto *p = &builder; + // CHECK-MESSAGES: :[[@LINE+2]]:3: warning: use 'OpType::create(builder, ...)' + // CHECK-FIXES: NamedOp::create(*p, builder.getUnknownLoc(), "eaz") + p->create(builder.getUnknownLoc(), "eaz"); + + CustomBuilder cb; + cb.f("faz"); + + // CHECK-MESSAGES: :[[@LINE+4]]:3: warning: use 'OpType::create(builder, ...)' instead of 'builder.create(...)' [llvm-use-new-mlir-op-builder] + // CHECK-FIXES: OperandOp::create(builder, builder.getUnknownLoc(), + // CHECK-MESSAGES: :[[@LINE+3]]:5: warning: use 'OpType::create(builder, ...)' instead of 'builder.create(...)' [llvm-use-new-mlir-op-builder] + // CHECK-FIXES: NamedOp::create(builder, + builder.create(builder.getUnknownLoc(), + builder.create(builder.getUnknownLoc(), "gaz").getResult()); }