1919import org .openrewrite .Parser ;
2020import org .openrewrite .Recipe ;
2121import org .openrewrite .TreeVisitor ;
22+ import org .openrewrite .internal .ListUtils ;
2223import org .openrewrite .internal .lang .Nullable ;
2324import org .openrewrite .java .*;
2425import org .openrewrite .java .search .UsesType ;
2526import org .openrewrite .java .tree .J ;
2627import org .openrewrite .java .tree .Space ;
2728import org .openrewrite .java .tree .TypeUtils ;
2829
29- import java .util .ArrayList ;
30- import java .util .stream .Collectors ;
31- import java .util .stream .Stream ;
30+ import java .util .Arrays ;
3231
3332public class TestRuleToTestInfo extends Recipe {
3433
3534 private static final String testNameType = "org.junit.rules.TestName" ;
3635 private static final MethodMatcher TEST_NAME_GET_NAME = new MethodMatcher (testNameType + " getMethodName()" );
36+ private static final AnnotationMatcher RULE_ANNOTATION_MATCHER = new AnnotationMatcher ("@org.junit.Rule" );
37+ private static final AnnotationMatcher JUNIT_BEFORE_MATCHER = new AnnotationMatcher ("@org.junit.Before" );
38+ private static final AnnotationMatcher JUPITER_BEFORE_EACH_MATCHER = new AnnotationMatcher ("@org.junit.jupiter.api.BeforeEach" );
3739
3840 private static final ThreadLocal <JavaParser > TEST_INFO_PARSER = ThreadLocal .withInitial (() ->
3941 JavaParser .fromJavaVersion ().dependsOn (
40- Stream .of (
41- Parser .Input .fromString (
42+ Arrays .asList (Parser .Input .fromString (
4243 "package org.junit.jupiter.api;\n " +
4344 "import java.lang.reflect.Method;\n " +
4445 "import java.util.Optional;\n " +
@@ -48,7 +49,10 @@ public class TestRuleToTestInfo extends Recipe {
4849 " Set<String> getTags();\n " +
4950 " Optional<Class<?>> getTestClass();\n " +
5051 " Optional<Method> getTestMethod();" +
51- "}" )).collect (Collectors .toList ())
52+ "}" ),
53+ Parser .Input .fromString (
54+ "package org.junit.jupiter.api; public @interface BeforeEach {}"
55+ ))
5256 ).build ());
5357
5458 @ Override
@@ -61,11 +65,6 @@ public String getDescription() {
6165 return "Replace usages of JUnit 4's `@Rule TestName` with JUnit 5's TestInfo." ;
6266 }
6367
64- @ Override
65- protected @ Nullable TreeVisitor <?, ExecutionContext > getApplicableTest () {
66- return new UsesType <>(testNameType );
67- }
68-
6968 @ Override
7069 protected @ Nullable TreeVisitor <?, ExecutionContext > getSingleSourceApplicableTest () {
7170 return new UsesType <>(testNameType );
@@ -91,16 +90,21 @@ public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext execu
9190 }
9291 });
9392 doAfterVisit (new ChangeType (testNameType , "String" ));
93+ doAfterVisit (new ChangeType ("org.junit.Before" , "org.junit.jupiter.api.BeforeEach" ));
9494 return compilationUnit ;
9595 }
9696
9797 @ Override
9898 public J .VariableDeclarations visitVariableDeclarations (J .VariableDeclarations multiVariable , ExecutionContext executionContext ) {
9999 J .VariableDeclarations varDecls = super .visitVariableDeclarations (multiVariable , executionContext );
100100 if (varDecls .getType () != null && TypeUtils .isOfClassType (varDecls .getType (), testNameType )) {
101- varDecls = varDecls .withLeadingAnnotations (new ArrayList <>());
102- //noinspection ConstantConditions
103- doAfterVisit (new AddBeforeEachMethod (varDecls , getCursor ().firstEnclosing (J .ClassDeclaration .class )));
101+ varDecls = varDecls .withLeadingAnnotations (ListUtils .map (varDecls .getLeadingAnnotations (), anno -> {
102+ if (RULE_ANNOTATION_MATCHER .matches (anno )) {
103+ return null ;
104+ }
105+ return anno ;
106+ }));
107+ getCursor ().dropParentUntil (J .ClassDeclaration .class ::isInstance ).putMessage ("has-testName-rule" , varDecls );
104108 }
105109 return varDecls ;
106110 }
@@ -115,46 +119,56 @@ public J.NewClass visitNewClass(J.NewClass newClass, ExecutionContext executionC
115119 return nc ;
116120 }
117121
118- //FIXME. add TestMethod statements.
119- // @Override
120- // public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext executionContext) {
121- // J.MethodDeclaration md = super.visitMethodDeclaration(method, executionContext);
122- // return md;
123- // }
124-
125- private boolean isBeforeAnnotation (J .Annotation annotation ) {
126- return TypeUtils .isOfClassType (annotation .getType (), "org.junit.Before" ) || TypeUtils .isOfClassType (annotation .getType (), "org.junit.jupiter.api.BeforeEach" );
122+ @ Override
123+ public J .MethodDeclaration visitMethodDeclaration (J .MethodDeclaration method , ExecutionContext executionContext ) {
124+ J .MethodDeclaration md = super .visitMethodDeclaration (method , executionContext );
125+ if (md .getLeadingAnnotations ().stream ().anyMatch (anno -> JUNIT_BEFORE_MATCHER .matches (anno ) || JUPITER_BEFORE_EACH_MATCHER .matches (anno ))) {
126+ getCursor ().dropParentUntil (J .ClassDeclaration .class ::isInstance ).putMessage ("before-method" , md );
127+ }
128+ return md ;
127129 }
128- };
129- }
130-
131- private static class AddBeforeEachMethod extends JavaIsoVisitor <ExecutionContext > {
132- private final J .VariableDeclarations varDecls ;
133- private final J .ClassDeclaration enclosingClass ;
134-
135- public AddBeforeEachMethod (J .VariableDeclarations varDecls , J .ClassDeclaration enclosingClass ) {
136- this .varDecls = varDecls ;
137- this .enclosingClass = enclosingClass ;
138- }
139130
140- @ Override
141- public J .ClassDeclaration visitClassDeclaration (J .ClassDeclaration classDecl , ExecutionContext executionContext ) {
142- J .ClassDeclaration cd = super .visitClassDeclaration (classDecl , executionContext );
143- if (enclosingClass .getId ().equals (cd .getId ())) {
144- String t = "@BeforeEach\n " +
145- "public void setup(TestInfo testInfo) {\n " +
146- " Optional<Method> testMethod = testInfo.getTestMethod();\n " +
147- " if (testMethod.isPresent()) {\n " +
148- " this.#{} = testMethod.get().getName();\n " +
149- " }\n " +
150- "}" ;
151- cd = cd .withTemplate (JavaTemplate .builder (this ::getCursor , t ).javaParser (TEST_INFO_PARSER ::get )
152- .imports ("org.junit.jupiter.api.TestInfo" , "java.util.Optional" , "java.lang.reflect.Method" )
153- .build (),
154- cd .getBody ().getCoordinates ().lastStatement (),
155- varDecls .getVariables ().get (0 ).getName ().getSimpleName ());
131+ @ Override
132+ public J .ClassDeclaration visitClassDeclaration (J .ClassDeclaration classDecl , ExecutionContext executionContext ) {
133+ J .ClassDeclaration cd = super .visitClassDeclaration (classDecl , executionContext );
134+ J .VariableDeclarations varDecls = getCursor ().pollMessage ("has-testName-rule" );
135+ if (varDecls != null ) {
136+ String testMethodStatement = "Optional<Method> testMethod = testInfo.getTestMethod();\n " +
137+ "if (testMethod.isPresent()) {\n " +
138+ " this.#{} = testMethod.get().getName();\n " +
139+ "}" ;
140+ J .MethodDeclaration beforeMethod = getCursor ().pollMessage ("before-method" );
141+ if (beforeMethod == null ) {
142+ String t = "@BeforeEach\n " +
143+ "public void setup(TestInfo testInfo) {" + testMethodStatement + "}" ;
144+ cd = cd .withTemplate (JavaTemplate .builder (this ::getCursor , t ).javaParser (TEST_INFO_PARSER ::get )
145+ .imports ("org.junit.jupiter.api.TestInfo" , "org.junit.jupiter.api.BeforeEach" , "java.util.Optional" , "java.lang.reflect.Method" )
146+ .build (),
147+ cd .getBody ().getCoordinates ().lastStatement (),
148+ varDecls .getVariables ().get (0 ).getName ().getSimpleName ());
149+ } else {
150+ doAfterVisit (new JavaIsoVisitor <ExecutionContext >() {
151+ @ Override
152+ public J .MethodDeclaration visitMethodDeclaration (J .MethodDeclaration method , ExecutionContext executionContext ) {
153+ J .MethodDeclaration md = super .visitMethodDeclaration (method , executionContext );
154+ if (md .getId ().equals (beforeMethod .getId ())) {
155+ md = md .withTemplate (JavaTemplate .builder (this ::getCursor , "TestInfo testInfo" ).javaParser (TEST_INFO_PARSER ::get )
156+ .imports ("org.junit.jupiter.api.TestInfo" , "org.junit.jupiter.api.BeforeEach" , "java.util.Optional" , "java.lang.reflect.Method" )
157+ .build (),
158+ md .getCoordinates ().replaceParameters ());
159+ //noinspection ConstantConditions
160+ md = maybeAutoFormat (md , md .withTemplate (JavaTemplate .builder (this ::getCursor , testMethodStatement ).javaParser (TEST_INFO_PARSER ::get )
161+ .imports ("org.junit.jupiter.api.TestInfo" , "java.util.Optional" , "java.lang.reflect.Method" )
162+ .build (),
163+ md .getBody ().getCoordinates ().lastStatement (), varDecls .getVariables ().get (0 ).getName ().getSimpleName ()), executionContext , getCursor ().getParent ());
164+ }
165+ return md ;
166+ }
167+ });
168+ }
169+ }
170+ return cd ;
156171 }
157- return cd ;
158- }
172+ };
159173 }
160174}
0 commit comments