1515 */
1616package org .openrewrite .java .testing .junit5 ;
1717
18+ import org .intellij .lang .annotations .Language ;
1819import org .openrewrite .ExecutionContext ;
19- import org .openrewrite .Parser ;
2020import org .openrewrite .Recipe ;
2121import org .openrewrite .TreeVisitor ;
2222import org .openrewrite .internal .lang .Nullable ;
2525import org .openrewrite .java .tree .*;
2626
2727import java .time .Duration ;
28- import java .util .Collections ;
2928import java .util .Comparator ;
3029
3130public class UpdateTestAnnotation extends Recipe {
@@ -52,7 +51,6 @@ protected TreeVisitor<?, ExecutionContext> getVisitor() {
5251
5352 private static class UpdateTestAnnotationVisitor extends JavaIsoVisitor <ExecutionContext > {
5453 private static final AnnotationMatcher JUNIT4_TEST = new AnnotationMatcher ("@org.junit.Test" );
55-
5654 @ Override
5755 public J .MethodDeclaration visitMethodDeclaration (J .MethodDeclaration method , ExecutionContext ctx ) {
5856 ChangeTestAnnotation cta = new ChangeTestAnnotation ();
@@ -72,43 +70,52 @@ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, Ex
7270 .getVariables ().get (0 ).getInitializer ();
7371
7472 assert lambda != null ;
75- lambda = lambda .withType (JavaType .Class .build ("org.junit.jupiter.api.function.Executable" ));
76-
77- m = m .withTemplate (JavaTemplate .builder (this ::getCursor ,
78- "assertThrows(#{any(java.lang.Class)}, #{any(org.junit.jupiter.api.function.Executable)});" )
79- .javaParser (() -> JavaParser .fromJavaVersion ()
80- .dependsOn (
81- "package org.junit.jupiter.api.function;" +
82- "public interface Executable {" +
83- " void execute() throws Throwable;" +
84- "}" ,
85- "package org.junit.jupiter.api;" +
86- "import org.junit.jupiter.api.function.Executable;" +
87- "public class Assertions {" +
88- " public static <T extends Throwable> T assertThrows(Class<T> expectedType, Executable executable) {" +
89- " return null;" +
90- " }" +
91- "}"
92- )
93- .build ())
94- .staticImports ("org.junit.jupiter.api.Assertions.assertThrows" )
95- .build (),
96- m .getCoordinates ().replaceBody (),
97- cta .expectedException , lambda );
98- maybeAddImport ("org.junit.jupiter.api.Assertions" , "assertThrows" );
73+ lambda = lambda .withType (JavaType .ShallowClass .build ("org.junit.jupiter.api.function.Executable" ));
74+
75+ @ Language ("java" ) String [] assertionShims = {
76+ "package org.junit.jupiter.api.function;" +
77+ "public interface Executable {" +
78+ " void execute() throws Throwable;" +
79+ "}" ,
80+ "package org.junit.jupiter.api;" +
81+ "import org.junit.jupiter.api.function.Executable;" +
82+ "public class Assertions {" +
83+ " public static <T extends Throwable> T assertThrows(Class<T> expectedType, Executable executable) {" +
84+ " return null;" +
85+ " }" +
86+ " public static void assertDoesNotThrow(Executable executable) {}" +
87+ "}"
88+ };
89+
90+ if (cta .expectedException instanceof J .FieldAccess
91+ && TypeUtils .isAssignableTo ("org.junit.Test$None" , ((J .FieldAccess ) cta .expectedException ).getTarget ().getType ())) {
92+ m = m .withTemplate (JavaTemplate .builder (this ::getCursor , "assertDoesNotThrow(#{any(org.junit.jupiter.api.function.Executable)});" )
93+ .javaParser (() -> JavaParser .fromJavaVersion ().dependsOn (assertionShims ).build ())
94+ .staticImports ("org.junit.jupiter.api.Assertions.assertDoesNotThrow" )
95+ .build (),
96+ m .getCoordinates ().replaceBody (), lambda );
97+ maybeAddImport ("org.junit.jupiter.api.Assertions" , "assertDoesNotThrow" );
98+ } else {
99+ m = m .withTemplate (JavaTemplate .builder (this ::getCursor , "assertThrows(#{any(java.lang.Class)}, #{any(org.junit.jupiter.api.function.Executable)});" )
100+ .javaParser (() -> JavaParser .fromJavaVersion ().dependsOn (assertionShims ).build ())
101+ .staticImports ("org.junit.jupiter.api.Assertions.assertThrows" )
102+ .build (),
103+ m .getCoordinates ().replaceBody (), cta .expectedException , lambda );
104+ maybeAddImport ("org.junit.jupiter.api.Assertions" , "assertThrows" );
105+ }
99106 }
100107 if (cta .timeout != null ) {
101108 m = m .withTemplate (
102109 JavaTemplate .builder (this ::getCursor , "@Timeout(#{any(long)})" )
103110 .javaParser (() -> JavaParser .fromJavaVersion ()
104- .dependsOn (Collections . singletonList ( Parser . Input . fromString (
105- "package org.junit.jupiter.api;\n " +
106- "import java.util.concurrent.TimeUnit;\n " +
107- "public @interface Timeout {\n " +
108- " long value();\n " +
109- " TimeUnit unit() default TimeUnit.SECONDS;\n " +
110- "} \n "
111- )) )
111+ .dependsOn (new String []{
112+ "package org.junit.jupiter.api;" +
113+ "import java.util.concurrent.TimeUnit;" +
114+ "public @interface Timeout {" +
115+ " long value();" +
116+ " TimeUnit unit() default TimeUnit.SECONDS;" +
117+ "} "
118+ } )
112119 .build ())
113120 .imports ("org.junit.jupiter.api.Timeout" )
114121 .build (),
@@ -155,7 +162,7 @@ public J.Annotation visitAnnotation(J.Annotation a, ExecutionContext context) {
155162 }
156163 }
157164 a = a .withArguments (null )
158- .withType (JavaType .Class .build ("org.junit.jupiter.api.Test" ));
165+ .withType (JavaType .ShallowClass .build ("org.junit.jupiter.api.Test" ));
159166 }
160167 return a ;
161168 }
0 commit comments