Skip to content

Commit 1e97f4b

Browse files
committed
Update SpringRunnerToSpringExtension rule to also cover SpringRunnerToSpringExtension
1 parent 6b7afe3 commit 1e97f4b

File tree

2 files changed

+64
-10
lines changed

2 files changed

+64
-10
lines changed

src/main/java/org/openrewrite/java/testing/junit5/SpringRunnerToSpringExtension.java

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,8 @@ public class SpringRunnerToSpringExtension extends JavaIsoRefactorVisitor {
4747
extendWithType,
4848
EMPTY
4949
);
50-
private static final JavaType.Class springExtensionType = JavaType.Class.build("org.springframework.test.context.junit.jupiter.SpringExtension");
51-
52-
50+
private static final JavaType.Class springExtensionType =
51+
JavaType.Class.build("org.springframework.test.context.junit.jupiter.SpringExtension");
5352
// Reference @RunWith(SpringRunner.class) annotation for semantically equal to compare against
5453
private static final J.Annotation runWithSpringRunnerAnnotation = new J.Annotation(
5554
randomId(),
@@ -75,6 +74,33 @@ public class SpringRunnerToSpringExtension extends JavaIsoRefactorVisitor {
7574
EMPTY
7675
);
7776

77+
private static final JavaType.Class springJUnit4ClassRunnerType =
78+
JavaType.Class.build("org.springframework.test.context.junit4.SpringJUnit4ClassRunner");
79+
// Reference @RunWith(SpringJUnit4ClassRunner.class) annotation for semantically equal to compare against
80+
private static final J.Annotation runWithSpringJUnit4ClassRunnerAnnotation = new J.Annotation(
81+
randomId(),
82+
runWithIdent,
83+
new J.Annotation.Arguments(
84+
randomId(),
85+
Collections.singletonList(
86+
new J.FieldAccess(
87+
randomId(),
88+
J.Ident.build(
89+
randomId(),
90+
"SpringJUnit4ClassRunner",
91+
springJUnit4ClassRunnerType,
92+
EMPTY
93+
),
94+
J.Ident.build(randomId(), "class", null, EMPTY),
95+
JavaType.Class.build("java.lang.Class"),
96+
EMPTY
97+
)
98+
),
99+
EMPTY
100+
),
101+
EMPTY
102+
);
103+
78104
private static final J.Annotation extendWithSpringExtensionAnnotation = new J.Annotation(
79105
randomId(),
80106
extendWithIdent,
@@ -105,23 +131,26 @@ public SpringRunnerToSpringExtension() {
105131

106132
@Override
107133
public J.ClassDecl visitClassDecl(J.ClassDecl cd) {
108-
List<J.Annotation> annotations = cd.getAnnotations().stream()
109-
.map(this::springRunnerToSpringExtension)
110-
.collect(Collectors.toList());
134+
if(cd.getAnnotations().stream().filter(this::shouldReplaceAnnotation).findAny().isPresent()) {
135+
List<J.Annotation> annotations = cd.getAnnotations().stream()
136+
.map(this::springRunnerToSpringExtension)
137+
.collect(Collectors.toList());
111138

112-
return cd.withAnnotations(annotations);
139+
return cd.withAnnotations(annotations);
140+
}
141+
return cd;
113142
}
114143

115144
/**
116-
* Converts annotations like @RunWith(SpringRunner.class) into @ExtendWith(SpringExtension.class)
145+
* Converts annotations like @RunWith(SpringRunner.class) and @RunWith(SpringJUnit4ClassRunner.class) into @ExtendWith(SpringExtension.class)
117146
* Leaves other annotations untouched and returns as-is.
118147
*
119148
* NOT a pure function. Side effects include:
120149
* Adding imports for ExtendWith and SpringExtension
121150
* Removing imports for RunWith and SpringRunner
122151
*/
123152
private J.Annotation springRunnerToSpringExtension(J.Annotation maybeSpringRunner) {
124-
if(!new SemanticallyEqual(runWithSpringRunnerAnnotation).visit(maybeSpringRunner)) {
153+
if(!(new SemanticallyEqual(runWithSpringRunnerAnnotation).visit(maybeSpringRunner) || new SemanticallyEqual(runWithSpringJUnit4ClassRunnerAnnotation).visit(maybeSpringRunner))) {
125154
return maybeSpringRunner;
126155
}
127156
Formatting originalFormatting = maybeSpringRunner.getFormatting();
@@ -131,8 +160,14 @@ private J.Annotation springRunnerToSpringExtension(J.Annotation maybeSpringRunne
131160
maybeAddImport(extendWithType);
132161
maybeAddImport(springExtensionType);
133162
maybeRemoveImport(springRunnerType);
163+
maybeRemoveImport(springJUnit4ClassRunnerType);
134164
maybeRemoveImport(runWithType);
135165

136166
return extendWithSpringExtension;
137167
}
168+
169+
private boolean shouldReplaceAnnotation(J.Annotation maybeSpringRunner) {
170+
return new SemanticallyEqual(runWithSpringRunnerAnnotation).visit(maybeSpringRunner)
171+
|| new SemanticallyEqual(runWithSpringJUnit4ClassRunnerAnnotation).visit(maybeSpringRunner);
172+
}
138173
}

src/test/kotlin/org/openrewrite/java/testing/junit5/SpringRunnerToSpringExtensionTest.kt

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class SpringRunnerToSpringExtensionTest : RefactorVisitorTestForParser<J.Compila
2828
override val visitors = listOf(SpringRunnerToSpringExtension())
2929

3030
@Test
31-
fun basicRunnerToExtension() = assertRefactored(
31+
fun springRunnerToExtension() = assertRefactored(
3232
before = """
3333
import org.junit.runner.RunWith;
3434
import org.springframework.test.context.junit4.SpringRunner;
@@ -45,6 +45,25 @@ class SpringRunnerToSpringExtensionTest : RefactorVisitorTestForParser<J.Compila
4545
"""
4646
)
4747

48+
@Test
49+
fun springJUnit4ClassRunnerRunnerToExtension() = assertRefactored(
50+
before = """
51+
import org.junit.runner.RunWith;
52+
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
53+
54+
@RunWith(SpringJUnit4ClassRunner.class)
55+
class A {}
56+
""",
57+
after = """
58+
import org.junit.jupiter.api.extension.ExtendWith;
59+
import org.springframework.test.context.junit.jupiter.SpringExtension;
60+
61+
@ExtendWith(SpringExtension.class)
62+
class A {}
63+
"""
64+
)
65+
66+
4867
@Test
4968
fun leavesOtherRunnersAlone() = assertUnchanged(
5069
before = """

0 commit comments

Comments
 (0)