Skip to content

Commit d58d2c7

Browse files
committed
Fix UpdateTestAnnotation to properly type attribute assertThrows
1 parent 6a831c2 commit d58d2c7

File tree

1 file changed

+39
-42
lines changed

1 file changed

+39
-42
lines changed

src/main/java/org/openrewrite/java/testing/junit5/UpdateTestAnnotation.java

Lines changed: 39 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -21,44 +21,16 @@
2121
import org.openrewrite.TreeVisitor;
2222
import org.openrewrite.java.*;
2323
import org.openrewrite.java.search.UsesType;
24-
import org.openrewrite.java.tree.*;
24+
import org.openrewrite.java.tree.Expression;
25+
import org.openrewrite.java.tree.J;
26+
import org.openrewrite.java.tree.JavaType;
2527

26-
import java.util.ArrayList;
2728
import java.util.Collections;
2829
import java.util.Comparator;
2930
import java.util.List;
30-
import java.util.function.Predicate;
3131

3232
public class UpdateTestAnnotation extends Recipe {
3333

34-
private static List<Parser.Input> assertThrowsDependsOn(Expression e) {
35-
List<Parser.Input> dependsOn = new ArrayList<>(3);
36-
37-
dependsOn.add(Parser.Input.fromString("package org.junit.jupiter.api.function;\n" +
38-
"public interface Executable {\n" +
39-
" void execute() throws Throwable;\n" +
40-
"}"));
41-
42-
dependsOn.add(Parser.Input.fromString("package org.junit.jupiter.api;\n" +
43-
"import org.junit.jupiter.api.function.Executable;\n" +
44-
"public class Assertions {\n" +
45-
" public static <T extends Throwable> T assertThrows(Class<T> expectedType, Executable executable) {\n" +
46-
" return null;\n" +
47-
" }\n" +
48-
"}"));
49-
50-
if (e instanceof J.FieldAccess) {
51-
JavaType.FullyQualified type = TypeUtils.asFullyQualified(((J.FieldAccess) e).getTarget().getType());
52-
if (type != null) {
53-
String source = (type.getPackageName().isEmpty() ? "" : "package " + type.getPackageName() + ";\n") +
54-
"public class " + type.getClassName() + " extends Exception {}";
55-
dependsOn.add(Parser.Input.fromString(source));
56-
}
57-
}
58-
59-
return dependsOn;
60-
}
61-
6234
@Override
6335
public String getDisplayName() {
6436
return "Migrate JUnit 4 `@Test` annotations to JUnit5";
@@ -93,7 +65,8 @@ public J.CompilationUnit visitCompilationUnit(J.CompilationUnit cu, ExecutionCon
9365
public J.Annotation visitAnnotation(J.Annotation annotation, ExecutionContext ctx) {
9466
J.Annotation ann = super.visitAnnotation(annotation, ctx);
9567
if (JUNIT4_TEST.matches(ann)) {
96-
getCursor().dropParentUntil(J.MethodDeclaration.class::isInstance).putMessage(JUNIT4_TEST_ANNOTATION_ARGUMENTS, ann.getArguments());
68+
getCursor().dropParentUntil(J.MethodDeclaration.class::isInstance).putMessage(JUNIT4_TEST_ANNOTATION_ARGUMENTS,
69+
ann.getArguments());
9770
ann = ann.withArguments(null);
9871
}
9972
return ann;
@@ -119,6 +92,26 @@ private static class ChangeTestMethodBodyStep extends JavaIsoVisitor<ExecutionCo
11992
private final J.MethodDeclaration scope;
12093
private final List<Expression> arguments;
12194

95+
private final JavaTemplate assertThrows = JavaTemplate.builder(this::getCursor, "assertThrows(#{any(java.lang.Class)}, #{any(org.junit.jupiter.api.function.Executable)});")
96+
.javaParser(() -> JavaParser.fromJavaVersion()
97+
.logCompilationWarningsAndErrors(true)
98+
.dependsOn(
99+
"package org.junit.jupiter.api.function;" +
100+
"public interface Executable {" +
101+
" void execute() throws Throwable;" +
102+
"}",
103+
"package org.junit.jupiter.api;" +
104+
"import org.junit.jupiter.api.function.Executable;" +
105+
"public class Assertions {" +
106+
" public static <T extends Throwable> T assertThrows(Class<T> expectedType, Executable executable) {" +
107+
" return null;" +
108+
" }" +
109+
"}"
110+
)
111+
.build())
112+
.staticImports("org.junit.jupiter.api.Assertions.assertThrows")
113+
.build();
114+
122115
public ChangeTestMethodBodyStep(J.MethodDeclaration scope, List<Expression> arguments) {
123116
this.scope = scope;
124117
this.arguments = arguments;
@@ -137,17 +130,21 @@ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, Ex
137130
if (assignParamName.equals("expected")) {
138131
assert e instanceof J.FieldAccess;
139132

140-
m = m.withTemplate(
141-
JavaTemplate.builder(this::getCursor, "assertThrows(#{any()}, () -> #{});")
142-
.javaParser(() -> JavaParser.fromJavaVersion()
143-
.dependsOn(assertThrowsDependsOn(e))
144-
.build())
145-
.staticImports("org.junit.jupiter.api.Assertions.assertThrows")
146-
.build(),
133+
m = m.withTemplate(JavaTemplate.builder(this::getCursor, "Object o = () -> #{}").build(),
147134
m.getCoordinates().replaceBody(),
148-
e,
149-
m.getBody()
150-
);
135+
m.getBody());
136+
137+
assert m.getBody() != null;
138+
J.Lambda lambda = (J.Lambda) ((J.VariableDeclarations) m.getBody().getStatements().get(0))
139+
.getVariables().get(0).getInitializer();
140+
141+
assert lambda != null;
142+
lambda = lambda.withType(JavaType.Class.build("org.junit.jupiter.api.function.Executable"));
143+
144+
m = m.withTemplate(assertThrows,
145+
m.getCoordinates().replaceBody(),
146+
e, lambda);
147+
151148
maybeAddImport("org.junit.jupiter.api.Assertions", "assertThrows");
152149
} else if (assignParamName.equals("timeout")) {
153150
doAfterVisit(new AddTimeoutAnnotationStep(m, e));

0 commit comments

Comments
 (0)