2121import org .openrewrite .TreeVisitor ;
2222import org .openrewrite .java .*;
2323import org .openrewrite .java .search .UsesType ;
24- import org .openrewrite .java .tree .*;
24+ import org .openrewrite .java .tree .Expression ;
25+ import org .openrewrite .java .tree .J ;
26+ import org .openrewrite .java .tree .JavaType ;
2527
26- import java .util .ArrayList ;
2728import java .util .Collections ;
2829import java .util .Comparator ;
2930import java .util .List ;
30- import java .util .function .Predicate ;
3131
3232public class UpdateTestAnnotation extends Recipe {
3333
34- private static List <Parser .Input > assertThrowsDependsOn (Expression e ) {
35- List <Parser .Input > dependsOn = new ArrayList <>(3 );
36-
37- dependsOn .add (Parser .Input .fromString ("package org.junit.jupiter.api.function;\n " +
38- "public interface Executable {\n " +
39- " void execute() throws Throwable;\n " +
40- "}" ));
41-
42- dependsOn .add (Parser .Input .fromString ("package org.junit.jupiter.api;\n " +
43- "import org.junit.jupiter.api.function.Executable;\n " +
44- "public class Assertions {\n " +
45- " public static <T extends Throwable> T assertThrows(Class<T> expectedType, Executable executable) {\n " +
46- " return null;\n " +
47- " }\n " +
48- "}" ));
49-
50- if (e instanceof J .FieldAccess ) {
51- JavaType .FullyQualified type = TypeUtils .asFullyQualified (((J .FieldAccess ) e ).getTarget ().getType ());
52- if (type != null ) {
53- String source = (type .getPackageName ().isEmpty () ? "" : "package " + type .getPackageName () + ";\n " ) +
54- "public class " + type .getClassName () + " extends Exception {}" ;
55- dependsOn .add (Parser .Input .fromString (source ));
56- }
57- }
58-
59- return dependsOn ;
60- }
61-
6234 @ Override
6335 public String getDisplayName () {
6436 return "Migrate JUnit 4 `@Test` annotations to JUnit5" ;
@@ -93,7 +65,8 @@ public J.CompilationUnit visitCompilationUnit(J.CompilationUnit cu, ExecutionCon
9365 public J .Annotation visitAnnotation (J .Annotation annotation , ExecutionContext ctx ) {
9466 J .Annotation ann = super .visitAnnotation (annotation , ctx );
9567 if (JUNIT4_TEST .matches (ann )) {
96- getCursor ().dropParentUntil (J .MethodDeclaration .class ::isInstance ).putMessage (JUNIT4_TEST_ANNOTATION_ARGUMENTS , ann .getArguments ());
68+ getCursor ().dropParentUntil (J .MethodDeclaration .class ::isInstance ).putMessage (JUNIT4_TEST_ANNOTATION_ARGUMENTS ,
69+ ann .getArguments ());
9770 ann = ann .withArguments (null );
9871 }
9972 return ann ;
@@ -119,6 +92,26 @@ private static class ChangeTestMethodBodyStep extends JavaIsoVisitor<ExecutionCo
11992 private final J .MethodDeclaration scope ;
12093 private final List <Expression > arguments ;
12194
95+ private final JavaTemplate assertThrows = JavaTemplate .builder (this ::getCursor , "assertThrows(#{any(java.lang.Class)}, #{any(org.junit.jupiter.api.function.Executable)});" )
96+ .javaParser (() -> JavaParser .fromJavaVersion ()
97+ .logCompilationWarningsAndErrors (true )
98+ .dependsOn (
99+ "package org.junit.jupiter.api.function;" +
100+ "public interface Executable {" +
101+ " void execute() throws Throwable;" +
102+ "}" ,
103+ "package org.junit.jupiter.api;" +
104+ "import org.junit.jupiter.api.function.Executable;" +
105+ "public class Assertions {" +
106+ " public static <T extends Throwable> T assertThrows(Class<T> expectedType, Executable executable) {" +
107+ " return null;" +
108+ " }" +
109+ "}"
110+ )
111+ .build ())
112+ .staticImports ("org.junit.jupiter.api.Assertions.assertThrows" )
113+ .build ();
114+
122115 public ChangeTestMethodBodyStep (J .MethodDeclaration scope , List <Expression > arguments ) {
123116 this .scope = scope ;
124117 this .arguments = arguments ;
@@ -137,17 +130,21 @@ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, Ex
137130 if (assignParamName .equals ("expected" )) {
138131 assert e instanceof J .FieldAccess ;
139132
140- m = m .withTemplate (
141- JavaTemplate .builder (this ::getCursor , "assertThrows(#{any()}, () -> #{});" )
142- .javaParser (() -> JavaParser .fromJavaVersion ()
143- .dependsOn (assertThrowsDependsOn (e ))
144- .build ())
145- .staticImports ("org.junit.jupiter.api.Assertions.assertThrows" )
146- .build (),
133+ m = m .withTemplate (JavaTemplate .builder (this ::getCursor , "Object o = () -> #{}" ).build (),
147134 m .getCoordinates ().replaceBody (),
148- e ,
149- m .getBody ()
150- );
135+ m .getBody ());
136+
137+ assert m .getBody () != null ;
138+ J .Lambda lambda = (J .Lambda ) ((J .VariableDeclarations ) m .getBody ().getStatements ().get (0 ))
139+ .getVariables ().get (0 ).getInitializer ();
140+
141+ assert lambda != null ;
142+ lambda = lambda .withType (JavaType .Class .build ("org.junit.jupiter.api.function.Executable" ));
143+
144+ m = m .withTemplate (assertThrows ,
145+ m .getCoordinates ().replaceBody (),
146+ e , lambda );
147+
151148 maybeAddImport ("org.junit.jupiter.api.Assertions" , "assertThrows" );
152149 } else if (assignParamName .equals ("timeout" )) {
153150 doAfterVisit (new AddTimeoutAnnotationStep (m , e ));
0 commit comments