diff --git a/src/main/java/org/openrewrite/java/migrate/lang/MigrateMainMethodToInstanceMain.java b/src/main/java/org/openrewrite/java/migrate/lang/MigrateMainMethodToInstanceMain.java index 1d0b684c03..027c8c73bb 100644 --- a/src/main/java/org/openrewrite/java/migrate/lang/MigrateMainMethodToInstanceMain.java +++ b/src/main/java/org/openrewrite/java/migrate/lang/MigrateMainMethodToInstanceMain.java @@ -20,15 +20,24 @@ import org.openrewrite.Recipe; import org.openrewrite.TreeVisitor; import org.openrewrite.java.JavaIsoVisitor; +import org.openrewrite.java.MethodMatcher; +import org.openrewrite.java.search.DeclaresMethod; import org.openrewrite.java.search.UsesJavaVersion; import org.openrewrite.java.tree.J; import org.openrewrite.java.tree.JavaType; import org.openrewrite.java.tree.TypeUtils; import org.openrewrite.staticanalysis.VariableReferences; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; + import static java.util.Collections.emptyList; +import static java.util.stream.Collectors.toList; public class MigrateMainMethodToInstanceMain extends Recipe { + + private static final MethodMatcher MAIN_METHOD_MATCHER = new MethodMatcher("*..* main(String[])", false); + @Override public String getDisplayName() { return "Migrate `public static void main(String[] args)` to instance `void main()`"; @@ -41,19 +50,23 @@ public String getDescription() { @Override public TreeVisitor getVisitor() { - return Preconditions.check(new UsesJavaVersion<>(25), new JavaIsoVisitor() { + TreeVisitor preconditions = Preconditions.and( + new UsesJavaVersion<>(25), + new DeclaresMethod<>(MAIN_METHOD_MATCHER) + ); + return Preconditions.check(preconditions, new JavaIsoVisitor() { @Override public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext ctx) { + J.ClassDeclaration enclosingClass = getCursor().firstEnclosing(J.ClassDeclaration.class); J.MethodDeclaration md = super.visitMethodDeclaration(method, ctx); // Check if this is a main method: public static void main(String[] args) - if (!"main".equals(md.getSimpleName()) || + if (enclosingClass == null || + !MAIN_METHOD_MATCHER.matches(md, enclosingClass) || md.getReturnTypeExpression() == null || md.getReturnTypeExpression().getType() != JavaType.Primitive.Void || !md.hasModifier(J.Modifier.Type.Public) || !md.hasModifier(J.Modifier.Type.Static) || - md.getParameters().size() != 1 || - !(md.getParameters().get(0) instanceof J.VariableDeclarations) || md.getBody() == null) { return md; } @@ -65,6 +78,13 @@ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, Ex return md; } + // Do not migrate in any of these cases + if (hasSpringBootApplicationAnnotation(enclosingClass) || + !hasNoArgConstructor(enclosingClass) || + isMainMethodReferenced(md)) { + return md; + } + // Remove the parameter if unused J.Identifier variableName = param.getVariables().get(0).getName(); if (VariableReferences.findRhsReferences(md.getBody(), variableName).isEmpty()) { @@ -73,7 +93,50 @@ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, Ex return md.withReturnTypeExpression(md.getReturnTypeExpression().withPrefix(md.getModifiers().get(0).getPrefix())) .withModifiers(emptyList()); } + + private boolean hasSpringBootApplicationAnnotation(J.ClassDeclaration classDecl) { + return classDecl.getLeadingAnnotations().stream() + .anyMatch(ann -> TypeUtils.isOfClassType(ann.getType(), "org.springframework.boot.autoconfigure.SpringBootApplication")); + } + + private boolean hasNoArgConstructor(J.ClassDeclaration classDecl) { + List constructors = classDecl.getBody().getStatements().stream() + .filter(stmt -> stmt instanceof J.MethodDeclaration) + .map(stmt -> (J.MethodDeclaration) stmt) + .filter(J.MethodDeclaration::isConstructor) + .collect(toList()); + + // If no constructors are declared, the class has an implicit no-arg constructor + if (constructors.isEmpty()) { + return true; + } + + // Check if any explicit constructor is a no-arg constructor + return constructors.stream() + .anyMatch(ctor -> ctor.getParameters().isEmpty() || + (ctor.getParameters().size() == 1 && ctor.getParameters().get(0) instanceof J.Empty)); + } + + private boolean isMainMethodReferenced(J.MethodDeclaration mainMethod) { + J.CompilationUnit cu = getCursor().firstEnclosing(J.CompilationUnit.class); + if (cu == null) { + return false; + } + + // XXX Only picks up references in the same compilation unit; convert to scanning recipe if needed + return new JavaIsoVisitor() { + @Override + public J.MemberReference visitMemberReference(J.MemberReference memberRef, AtomicBoolean referenced) { + // Check if this is a reference to the main method + if ("main".equals(memberRef.getReference().getSimpleName()) && + memberRef.getMethodType() != null && + TypeUtils.isOfType(memberRef.getMethodType(), mainMethod.getMethodType())) { + referenced.set(true); + } + return super.visitMemberReference(memberRef, referenced); + } + }.reduce(cu, new AtomicBoolean()).get(); + } }); } - } diff --git a/src/test/java/org/openrewrite/java/migrate/lang/MigrateMainMethodToInstanceMainTest.java b/src/test/java/org/openrewrite/java/migrate/lang/MigrateMainMethodToInstanceMainTest.java index 9ceb82ab0c..c17854c9e0 100644 --- a/src/test/java/org/openrewrite/java/migrate/lang/MigrateMainMethodToInstanceMainTest.java +++ b/src/test/java/org/openrewrite/java/migrate/lang/MigrateMainMethodToInstanceMainTest.java @@ -17,6 +17,7 @@ import org.junit.jupiter.api.Test; import org.openrewrite.DocumentExample; +import org.openrewrite.java.JavaParser; import org.openrewrite.test.RecipeSpec; import org.openrewrite.test.RewriteTest; @@ -288,4 +289,88 @@ void main() { ) ); } + + @Test + void doNotMigrateMainUsedAsMethodReference() { + //language=java + rewriteRun( + java( + """ + interface MainMethod { + void run(String[] args); + } + """ + ), + java( + """ + class Application { + public static void main(String[] args) { + System.out.println("Hello from main"); + } + } + + class Runner { + void executeMain() { + MainMethod foo = Application::main; + foo.run(null); + } + } + """ + ) + ); + } + + @Test + void doNotMigrateMainWithNonDefaultConstructor() { + //language=java + rewriteRun( + java( + """ + class Application { + public static void main(String[] args) { + System.out.println("Hello!"); + } + + public Application(String config) { + // Non-default constructor + } + } + """ + ) + ); + } + + @Test + void doNotMigrateMainInSpringBootApplication() { + //language=java + rewriteRun( + spec -> spec.parser(JavaParser.fromJavaVersion().dependsOn( + """ + package org.springframework.boot.autoconfigure; + public @interface SpringBootApplication {} + """, + """ + package org.springframework.boot; + public class SpringApplication { + public static void run(Class primarySource, String... args) {} + } + """ + )), + java( + """ + package com.example.demo; + + import org.springframework.boot.SpringApplication; + import org.springframework.boot.autoconfigure.SpringBootApplication; + + @SpringBootApplication + class DemoApplication { + public static void main(String[] args) { + SpringApplication.run(DemoApplication.class, args); + } + } + """ + ) + ); + } }