Skip to content

Commit b9b66c8

Browse files
committed
TestRuleToTestInfo recipe polish. Fixes #166
1 parent 51f9ea6 commit b9b66c8

File tree

2 files changed

+118
-52
lines changed

2 files changed

+118
-52
lines changed

src/main/java/org/openrewrite/java/testing/junit5/TestRuleToTestInfo.java

Lines changed: 66 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -19,26 +19,27 @@
1919
import org.openrewrite.Parser;
2020
import org.openrewrite.Recipe;
2121
import org.openrewrite.TreeVisitor;
22+
import org.openrewrite.internal.ListUtils;
2223
import org.openrewrite.internal.lang.Nullable;
2324
import org.openrewrite.java.*;
2425
import org.openrewrite.java.search.UsesType;
2526
import org.openrewrite.java.tree.J;
2627
import org.openrewrite.java.tree.Space;
2728
import org.openrewrite.java.tree.TypeUtils;
2829

29-
import java.util.ArrayList;
30-
import java.util.stream.Collectors;
31-
import java.util.stream.Stream;
30+
import java.util.Arrays;
3231

3332
public class TestRuleToTestInfo extends Recipe {
3433

3534
private static final String testNameType = "org.junit.rules.TestName";
3635
private static final MethodMatcher TEST_NAME_GET_NAME = new MethodMatcher(testNameType + " getMethodName()");
36+
private static final AnnotationMatcher RULE_ANNOTATION_MATCHER = new AnnotationMatcher("@org.junit.Rule");
37+
private static final AnnotationMatcher JUNIT_BEFORE_MATCHER = new AnnotationMatcher("@org.junit.Before");
38+
private static final AnnotationMatcher JUPITER_BEFORE_EACH_MATCHER = new AnnotationMatcher("@org.junit.jupiter.api.BeforeEach");
3739

3840
private static final ThreadLocal<JavaParser> TEST_INFO_PARSER = ThreadLocal.withInitial(() ->
3941
JavaParser.fromJavaVersion().dependsOn(
40-
Stream.of(
41-
Parser.Input.fromString(
42+
Arrays.asList(Parser.Input.fromString(
4243
"package org.junit.jupiter.api;\n" +
4344
"import java.lang.reflect.Method;\n" +
4445
"import java.util.Optional;\n" +
@@ -48,7 +49,10 @@ public class TestRuleToTestInfo extends Recipe {
4849
" Set<String> getTags();\n" +
4950
" Optional<Class<?>> getTestClass();\n" +
5051
" Optional<Method> getTestMethod();" +
51-
"}")).collect(Collectors.toList())
52+
"}"),
53+
Parser.Input.fromString(
54+
"package org.junit.jupiter.api; public @interface BeforeEach {}"
55+
))
5256
).build());
5357

5458
@Override
@@ -61,11 +65,6 @@ public String getDescription() {
6165
return "Replace usages of JUnit 4's `@Rule TestName` with JUnit 5's TestInfo.";
6266
}
6367

64-
@Override
65-
protected @Nullable TreeVisitor<?, ExecutionContext> getApplicableTest() {
66-
return new UsesType<>(testNameType);
67-
}
68-
6968
@Override
7069
protected @Nullable TreeVisitor<?, ExecutionContext> getSingleSourceApplicableTest() {
7170
return new UsesType<>(testNameType);
@@ -91,16 +90,21 @@ public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext execu
9190
}
9291
});
9392
doAfterVisit(new ChangeType(testNameType, "String"));
93+
doAfterVisit(new ChangeType("org.junit.Before", "org.junit.jupiter.api.BeforeEach"));
9494
return compilationUnit;
9595
}
9696

9797
@Override
9898
public J.VariableDeclarations visitVariableDeclarations(J.VariableDeclarations multiVariable, ExecutionContext executionContext) {
9999
J.VariableDeclarations varDecls = super.visitVariableDeclarations(multiVariable, executionContext);
100100
if (varDecls.getType() != null && TypeUtils.isOfClassType(varDecls.getType(), testNameType)) {
101-
varDecls = varDecls.withLeadingAnnotations(new ArrayList<>());
102-
//noinspection ConstantConditions
103-
doAfterVisit(new AddBeforeEachMethod(varDecls, getCursor().firstEnclosing(J.ClassDeclaration.class)));
101+
varDecls = varDecls.withLeadingAnnotations(ListUtils.map(varDecls.getLeadingAnnotations(), anno -> {
102+
if (RULE_ANNOTATION_MATCHER.matches(anno)) {
103+
return null;
104+
}
105+
return anno;
106+
}));
107+
getCursor().dropParentUntil(J.ClassDeclaration.class::isInstance).putMessage("has-testName-rule", varDecls);
104108
}
105109
return varDecls;
106110
}
@@ -115,46 +119,56 @@ public J.NewClass visitNewClass(J.NewClass newClass, ExecutionContext executionC
115119
return nc;
116120
}
117121

118-
//FIXME. add TestMethod statements.
119-
// @Override
120-
// public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext executionContext) {
121-
// J.MethodDeclaration md = super.visitMethodDeclaration(method, executionContext);
122-
// return md;
123-
// }
124-
125-
private boolean isBeforeAnnotation(J.Annotation annotation) {
126-
return TypeUtils.isOfClassType(annotation.getType(), "org.junit.Before") || TypeUtils.isOfClassType(annotation.getType(), "org.junit.jupiter.api.BeforeEach");
122+
@Override
123+
public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext executionContext) {
124+
J.MethodDeclaration md = super.visitMethodDeclaration(method, executionContext);
125+
if (md.getLeadingAnnotations().stream().anyMatch(anno -> JUNIT_BEFORE_MATCHER.matches(anno) || JUPITER_BEFORE_EACH_MATCHER.matches(anno))) {
126+
getCursor().dropParentUntil(J.ClassDeclaration.class::isInstance).putMessage("before-method", md);
127+
}
128+
return md;
127129
}
128-
};
129-
}
130-
131-
private static class AddBeforeEachMethod extends JavaIsoVisitor<ExecutionContext> {
132-
private final J.VariableDeclarations varDecls;
133-
private final J.ClassDeclaration enclosingClass;
134-
135-
public AddBeforeEachMethod(J.VariableDeclarations varDecls, J.ClassDeclaration enclosingClass) {
136-
this.varDecls = varDecls;
137-
this.enclosingClass = enclosingClass;
138-
}
139130

140-
@Override
141-
public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration classDecl, ExecutionContext executionContext) {
142-
J.ClassDeclaration cd = super.visitClassDeclaration(classDecl, executionContext);
143-
if (enclosingClass.getId().equals(cd.getId())) {
144-
String t = "@BeforeEach\n" +
145-
"public void setup(TestInfo testInfo) {\n" +
146-
" Optional<Method> testMethod = testInfo.getTestMethod();\n" +
147-
" if (testMethod.isPresent()) {\n" +
148-
" this.#{} = testMethod.get().getName();\n" +
149-
" }\n" +
150-
"}";
151-
cd = cd.withTemplate(JavaTemplate.builder(this::getCursor, t).javaParser(TEST_INFO_PARSER::get)
152-
.imports("org.junit.jupiter.api.TestInfo", "java.util.Optional", "java.lang.reflect.Method")
153-
.build(),
154-
cd.getBody().getCoordinates().lastStatement(),
155-
varDecls.getVariables().get(0).getName().getSimpleName());
131+
@Override
132+
public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration classDecl, ExecutionContext executionContext) {
133+
J.ClassDeclaration cd = super.visitClassDeclaration(classDecl, executionContext);
134+
J.VariableDeclarations varDecls = getCursor().pollMessage("has-testName-rule");
135+
if (varDecls != null) {
136+
String testMethodStatement = "Optional<Method> testMethod = testInfo.getTestMethod();\n" +
137+
"if (testMethod.isPresent()) {\n" +
138+
" this.#{} = testMethod.get().getName();\n" +
139+
"}";
140+
J.MethodDeclaration beforeMethod = getCursor().pollMessage("before-method");
141+
if (beforeMethod == null) {
142+
String t = "@BeforeEach\n" +
143+
"public void setup(TestInfo testInfo) {" + testMethodStatement + "}";
144+
cd = cd.withTemplate(JavaTemplate.builder(this::getCursor, t).javaParser(TEST_INFO_PARSER::get)
145+
.imports("org.junit.jupiter.api.TestInfo", "org.junit.jupiter.api.BeforeEach", "java.util.Optional", "java.lang.reflect.Method")
146+
.build(),
147+
cd.getBody().getCoordinates().lastStatement(),
148+
varDecls.getVariables().get(0).getName().getSimpleName());
149+
} else {
150+
doAfterVisit(new JavaIsoVisitor<ExecutionContext>() {
151+
@Override
152+
public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext executionContext) {
153+
J.MethodDeclaration md = super.visitMethodDeclaration(method, executionContext);
154+
if (md.getId().equals(beforeMethod.getId())) {
155+
md = md.withTemplate(JavaTemplate.builder(this::getCursor, "TestInfo testInfo").javaParser(TEST_INFO_PARSER::get)
156+
.imports("org.junit.jupiter.api.TestInfo", "org.junit.jupiter.api.BeforeEach", "java.util.Optional", "java.lang.reflect.Method")
157+
.build(),
158+
md.getCoordinates().replaceParameters());
159+
//noinspection ConstantConditions
160+
md = maybeAutoFormat(md, md.withTemplate(JavaTemplate.builder(this::getCursor, testMethodStatement).javaParser(TEST_INFO_PARSER::get)
161+
.imports("org.junit.jupiter.api.TestInfo", "java.util.Optional", "java.lang.reflect.Method")
162+
.build(),
163+
md.getBody().getCoordinates().lastStatement(), varDecls.getVariables().get(0).getName().getSimpleName()), executionContext, getCursor().getParent());
164+
}
165+
return md;
166+
}
167+
});
168+
}
169+
}
170+
return cd;
156171
}
157-
return cd;
158-
}
172+
};
159173
}
160174
}

src/test/kotlin/org/openrewrite/java/testing/junit5/TestRuleToTestInfoTest.kt

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class TestRuleToTestInfoTest : JavaRecipeTest {
4343
}
4444
""",
4545
after = """
46+
import org.junit.jupiter.api.BeforeEach;
4647
import org.junit.jupiter.api.TestInfo;
4748
4849
public class SomeTest {
@@ -65,4 +66,55 @@ class TestRuleToTestInfoTest : JavaRecipeTest {
6566
}
6667
"""
6768
)
69+
70+
@Test
71+
fun testRuleHasBeforeMethodToTestInfo() = assertChanged(
72+
before = """
73+
import org.junit.Before;
74+
import org.junit.Rule;
75+
import org.junit.rules.TestName;
76+
77+
public class SomeTest {
78+
protected int count;
79+
@Rule
80+
public TestName name = new TestName();
81+
protected String randomName() {
82+
return name.getMethodName();
83+
}
84+
85+
@Before
86+
public void setup() {
87+
count++;
88+
}
89+
90+
private static class SomeInnerClass {
91+
}
92+
}
93+
""",
94+
after = """
95+
import org.junit.jupiter.api.BeforeEach;
96+
import org.junit.jupiter.api.TestInfo;
97+
98+
public class SomeTest {
99+
protected int count;
100+
101+
public String name;
102+
protected String randomName() {
103+
return name;
104+
}
105+
106+
@BeforeEach
107+
public void setup(TestInfo testInfo) {
108+
count++;
109+
Optional<Method> testMethod = testInfo.getTestMethod();
110+
if (testMethod.isPresent()) {
111+
this.name = testMethod.get().getName();
112+
}
113+
}
114+
115+
private static class SomeInnerClass {
116+
}
117+
}
118+
"""
119+
)
68120
}

0 commit comments

Comments
 (0)