Skip to content

Commit 97be007

Browse files
amishra-utimtebeek
andauthored
Recipe for closing unclosed static mocks (#739)
* Recipe for closing unclosed static mocks * add msal * Polish ahead of further changes * Include `CloseUnclosedStaticMocks` with `Mockito1to4Migration` & `JUnit4to5Migration` --------- Co-authored-by: Tim te Beek <[email protected]>
1 parent 0da7676 commit 97be007

File tree

3 files changed

+735
-0
lines changed

3 files changed

+735
-0
lines changed
Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
/*
2+
* Copyright 2024 the original author or authors.
3+
* <p>
4+
* Licensed under the Moderne Source Available License (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
* <p>
8+
* https://docs.moderne.io/licensing/moderne-source-available-license
9+
* <p>
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.openrewrite.java.testing.mockito;
17+
18+
import lombok.RequiredArgsConstructor;
19+
import org.jspecify.annotations.Nullable;
20+
import org.openrewrite.*;
21+
import org.openrewrite.internal.ListUtils;
22+
import org.openrewrite.java.*;
23+
import org.openrewrite.java.search.UsesMethod;
24+
import org.openrewrite.java.tree.*;
25+
import org.openrewrite.marker.Markers;
26+
27+
import java.util.ArrayList;
28+
import java.util.List;
29+
import java.util.concurrent.atomic.AtomicBoolean;
30+
31+
import static java.util.Collections.emptyList;
32+
import static org.openrewrite.Tree.randomId;
33+
import static org.openrewrite.java.VariableNameUtils.GenerationStrategy.INCREMENT_NUMBER;
34+
import static org.openrewrite.java.VariableNameUtils.findNamesInScope;
35+
import static org.openrewrite.java.VariableNameUtils.generateVariableName;
36+
import static org.openrewrite.java.trait.Traits.annotated;
37+
38+
/**
39+
* Ensures that all mockStatic calls are properly closed.
40+
* If mockStatic is in lifecycle methods like @BeforeEach or @BeforeAll,
41+
* creates a class variable and closes it in @AfterEach or @AfterAll.
42+
* If mockStatic is inside a test method, wraps it in a try-with-resources block.
43+
*/
44+
public class CloseUnclosedStaticMocks extends Recipe {
45+
46+
private static final String LIFECYCLE_METHOD = "lifecycle_method";
47+
private static final MethodMatcher MOCK_STATIC_MATCHER = new MethodMatcher("org.mockito.Mockito mockStatic(..)");
48+
private static final AnnotationMatcher BEFORE_MATCHER = new AnnotationMatcher("@org.junit.jupiter.api.Before*");
49+
private static final AnnotationMatcher AFTER_EACH_MATCHER = new AnnotationMatcher("@org.junit.jupiter.api.AfterEach");
50+
private static final AnnotationMatcher AFTER_ALL_MATCHER = new AnnotationMatcher("@org.junit.jupiter.api.AfterAll");
51+
52+
@Override
53+
public String getDisplayName() {
54+
return "Close unclosed static mocks";
55+
}
56+
57+
@Override
58+
public String getDescription() {
59+
return "Ensures that all `mockStatic` calls are properly closed. " +
60+
"If `mockStatic` is in lifecycle methods like `@BeforeEach` or `@BeforeAll`, " +
61+
"creates a class variable and closes it in `@AfterEach` or `@AfterAll`. " +
62+
"If `mockStatic` is inside a test method, wraps it in a try-with-resources block.";
63+
}
64+
65+
@Override
66+
public TreeVisitor<?, ExecutionContext> getVisitor() {
67+
return Preconditions.check(new UsesMethod<>(MOCK_STATIC_MATCHER), new JavaVisitor<ExecutionContext>() {
68+
@Override
69+
public J visitCompilationUnit(J.CompilationUnit cu, ExecutionContext ctx) {
70+
J j = super.visitCompilationUnit(cu, ctx);
71+
maybeAddImport("org.mockito.MockedStatic");
72+
maybeAddImport("org.junit.jupiter.api.AfterEach");
73+
maybeAddImport("org.junit.jupiter.api.AfterAll");
74+
return j;
75+
}
76+
77+
@Override
78+
public J visitTryResource(J.Try.Resource tryResource, ExecutionContext ctx) {
79+
return tryResource; // skip try resource
80+
}
81+
82+
@Override
83+
public J visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext ctx) {
84+
Cursor cursor = getCursor();
85+
annotated(BEFORE_MATCHER).asVisitor(a -> {
86+
cursor.putMessage(LIFECYCLE_METHOD, a.getTree().getSimpleName());
87+
return a.getTree();
88+
}).visit(method, ctx);
89+
return super.visitMethodDeclaration(method, ctx);
90+
}
91+
92+
@Override
93+
public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) {
94+
J.MethodInvocation mi = (J.MethodInvocation) super.visitMethodInvocation(method, ctx);
95+
String lifeCycleMethod = getCursor().getNearestMessage(LIFECYCLE_METHOD);
96+
if (!MOCK_STATIC_MATCHER.matches(mi) || lifeCycleMethod == null) {
97+
return mi;
98+
}
99+
if (getCursor().getParentTreeCursor().getValue() instanceof J.Block) {
100+
String mockedClassName = getMockedClassName(mi);
101+
if (mockedClassName != null) {
102+
Cursor classCursor = getCursor().dropParentUntil(J.ClassDeclaration.class::isInstance);
103+
String varName = generateVariableName("mockedStatic" + mockedClassName, classCursor, INCREMENT_NUMBER);
104+
J.Assignment assignment = JavaTemplate.builder(varName + " = #{any()}")
105+
.build()
106+
.apply(updateCursor(mi), mi.getCoordinates().replace(), mi);
107+
doAfterVisit(new DeclareMockVarAndClose(varName, mockedClassName, lifeCycleMethod.equals("BeforeAll")));
108+
return assignment;
109+
}
110+
}
111+
return mi;
112+
}
113+
114+
@Override
115+
public J visitAssignment(J.Assignment assignment, ExecutionContext ctx) {
116+
if (MOCK_STATIC_MATCHER.matches(assignment.getAssignment())) {
117+
if (assignment.getVariable() instanceof J.Identifier) {
118+
JavaType.Variable varType = ((J.Identifier) assignment.getVariable()).getFieldType();
119+
if (varType != null && varType.getOwner() instanceof JavaType.Class) {
120+
doAfterVisit(new DeclareMockVarAndClose(varType.getName(), null, varType.getFlags().contains(Flag.Static)));
121+
}
122+
}
123+
}
124+
return super.visitAssignment(assignment, ctx);
125+
}
126+
127+
@Override
128+
public J visitVariableDeclarations(J.VariableDeclarations variableDeclarations, ExecutionContext ctx) {
129+
J.VariableDeclarations vd = (J.VariableDeclarations) super.visitVariableDeclarations(variableDeclarations, ctx);
130+
J.VariableDeclarations.NamedVariable namedVariable = vd.getVariables().get(0);
131+
String lifeCycleMethod = getCursor().getNearestMessage(LIFECYCLE_METHOD);
132+
if (!MOCK_STATIC_MATCHER.matches(namedVariable.getInitializer()) || lifeCycleMethod == null) {
133+
return vd;
134+
}
135+
String varName = namedVariable.getSimpleName();
136+
String mockedClassName = getMockedClassName((J.MethodInvocation) namedVariable.getInitializer());
137+
if (mockedClassName != null) {
138+
doAfterVisit(new DeclareMockVarAndClose(varName, mockedClassName, lifeCycleMethod.equals("BeforeAll")));
139+
return JavaTemplate.builder(varName + " = #{any()}").contextSensitive().build()
140+
.apply(updateCursor(vd), vd.getCoordinates().replace(), namedVariable.getInitializer());
141+
}
142+
return vd;
143+
}
144+
145+
@Override
146+
public J visitBlock(J.Block block, ExecutionContext ctx) {
147+
J.Block b = (J.Block) super.visitBlock(block, ctx);
148+
if (getCursor().getNearestMessage(LIFECYCLE_METHOD) != null) {
149+
return b;
150+
}
151+
AtomicBoolean removeStatement = new AtomicBoolean(false);
152+
J.Block b1 = block.withStatements(ListUtils.map(b.getStatements(), statement -> {
153+
if (!removeStatement.get() && shouldUseTryWithResources(statement)) {
154+
J.Try tryWithResource = toTryWithResource(b, statement, ctx);
155+
if (tryWithResource != null) {
156+
removeStatement.set(true);
157+
return (J.Try) super.visitTry(tryWithResource, ctx);
158+
}
159+
}
160+
return removeStatement.get() ? null : statement;
161+
}));
162+
return maybeAutoFormat(b, b1, ctx);
163+
}
164+
165+
private J.@Nullable Try toTryWithResource(J.Block block, Statement statement, ExecutionContext ctx) {
166+
String code = null;
167+
if (statement instanceof J.MethodInvocation) {
168+
String mockedClassName = getMockedClassName((J.MethodInvocation) statement);
169+
if (mockedClassName != null) {
170+
String varName = generateVariableName("mockedStatic" + mockedClassName, getCursor(), INCREMENT_NUMBER);
171+
code = String.format("try(MockedStatic<%s> %s = #{any()}) {}", mockedClassName, varName);
172+
}
173+
174+
} else if (statement instanceof J.VariableDeclarations || statement instanceof J.Assignment) {
175+
code = "try(#{any()}) {}";
176+
}
177+
if (code == null) {
178+
return null;
179+
}
180+
J.Try tryWithResources = JavaTemplate.builder(code)
181+
.contextSensitive()
182+
.imports("org.mockito.MockedStatic")
183+
.javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "mockito-core-5"))
184+
.build().apply(new Cursor(getCursor(), statement), statement.getCoordinates().replace(), statement);
185+
return maybeAutoFormat(tryWithResources, tryWithResources.withBody(findSuccessorStatements(statement, block)), ctx);
186+
}
187+
188+
private @Nullable String getMockedClassName(J.MethodInvocation methodInvocation) {
189+
JavaType.Parameterized type = TypeUtils.asParameterized(methodInvocation.getType());
190+
if (type != null && type.getTypeParameters().size() == 1) {
191+
JavaType.FullyQualified mockedClass = TypeUtils.asFullyQualified(type.getTypeParameters().get(0));
192+
if (mockedClass != null) {
193+
return mockedClass.getClassName();
194+
}
195+
}
196+
return null;
197+
}
198+
199+
private boolean shouldUseTryWithResources(@Nullable Statement statement) {
200+
if (statement instanceof J.VariableDeclarations) {
201+
J.VariableDeclarations varDecl = (J.VariableDeclarations) statement;
202+
return MOCK_STATIC_MATCHER.matches(varDecl.getVariables().get(0).getInitializer());
203+
}
204+
if (statement instanceof J.Assignment) {
205+
J.Assignment assignment = (J.Assignment) statement;
206+
return MOCK_STATIC_MATCHER.matches(assignment.getAssignment());
207+
}
208+
if (statement instanceof J.MethodInvocation) {
209+
return MOCK_STATIC_MATCHER.matches((J.MethodInvocation) statement);
210+
}
211+
return false;
212+
}
213+
214+
private J.Block findSuccessorStatements(Statement statement, J.Block block) {
215+
List<Statement> successors = new ArrayList<>();
216+
boolean found = false;
217+
for (Statement successor : block.getStatements()) {
218+
if (found) {
219+
successors.add(successor);
220+
}
221+
found = found || successor == statement;
222+
}
223+
return new J.Block(randomId(), Space.EMPTY, Markers.EMPTY,
224+
new JRightPadded<>(false, Space.EMPTY, Markers.EMPTY), emptyList(),
225+
Space.format(" ")).withStatements(successors);
226+
}
227+
});
228+
}
229+
230+
@RequiredArgsConstructor
231+
private static class DeclareMockVarAndClose extends JavaIsoVisitor<ExecutionContext> {
232+
private final String varName;
233+
private final String mockedClassName;
234+
private final boolean isStatic;
235+
236+
private boolean closed = false;
237+
238+
@Override
239+
public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration classDecl, ExecutionContext ctx) {
240+
J.ClassDeclaration cd = classDecl;
241+
if (!findNamesInScope(getCursor()).contains(varName)) {
242+
String modifier = isStatic ? "static " : "";
243+
String varTemplate = "private " + modifier + "MockedStatic<" + mockedClassName + "> " + varName + ";";
244+
cd = JavaTemplate.builder(varTemplate)
245+
.contextSensitive()
246+
.javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "mockito-core-5"))
247+
.imports("org.mockito.MockedStatic")
248+
.build().apply(getCursor(), classDecl.getBody().getCoordinates().firstStatement());
249+
}
250+
cd = super.visitClassDeclaration(cd, ctx);
251+
if (closed) {
252+
return cd;
253+
}
254+
String methodName = tearDownMethodName(cd);
255+
String methodTemplate = String.format("%s void %s() { %s.close(); }", isStatic ? "@AfterAll public static" : "@AfterEach public", methodName, varName);
256+
return JavaTemplate.builder(methodTemplate)
257+
.contextSensitive()
258+
.javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "junit-jupiter-api-5"))
259+
.imports("org.junit.jupiter.api.AfterEach", "org.junit.jupiter.api.AfterAll")
260+
.build().apply(updateCursor(cd), classDecl.getBody().getCoordinates().lastStatement());
261+
}
262+
263+
@Override
264+
public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration methodDecl, ExecutionContext ctx) {
265+
J.MethodDeclaration md = super.visitMethodDeclaration(methodDecl, ctx);
266+
if (closed) {
267+
return md;
268+
}
269+
AnnotationMatcher annotationMatcher = isStatic ? AFTER_ALL_MATCHER : AFTER_EACH_MATCHER;
270+
boolean matched = annotated(annotationMatcher).<AtomicBoolean>asVisitor((a, found) -> {
271+
found.set(true);
272+
return a.getTree();
273+
}).reduce(md, new AtomicBoolean()).get();
274+
if (!matched) {
275+
return md;
276+
}
277+
closed = true;
278+
return JavaTemplate.builder(varName + ".close()").contextSensitive().build()
279+
.apply(updateCursor(md), md.getBody().getCoordinates().lastStatement());
280+
}
281+
282+
@Override
283+
public J.MethodInvocation visitMethodInvocation(J.MethodInvocation methodInvocation, ExecutionContext ctx) {
284+
if (methodInvocation.getSelect() instanceof J.Identifier) {
285+
String selector = ((J.Identifier) methodInvocation.getSelect()).getSimpleName();
286+
if (selector.equals(varName) && methodInvocation.getSimpleName().equals("close")) {
287+
closed = true;
288+
}
289+
}
290+
return super.visitMethodInvocation(methodInvocation, ctx);
291+
}
292+
293+
private String tearDownMethodName(J.ClassDeclaration cd) {
294+
String methodName = "tearDown";
295+
int suffix = 0;
296+
String updatedMethodName = methodName;
297+
for (Statement st : cd.getBody().getStatements()) {
298+
if (st instanceof J.MethodDeclaration && ((J.MethodDeclaration) st).getSimpleName().equals(updatedMethodName)) {
299+
updatedMethodName = methodName + suffix++;
300+
}
301+
}
302+
return updatedMethodName;
303+
}
304+
}
305+
}

src/main/resources/META-INF/rewrite/mockito.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ recipeList:
8282
artifactId: byte-buddy*
8383
newVersion: 1.12.19
8484
- org.openrewrite.java.testing.mockito.ReplaceInitMockToOpenMock
85+
- org.openrewrite.java.testing.mockito.CloseUnclosedStaticMocks
8586
---
8687
type: specs.openrewrite.org/v1beta/recipe
8788
name: org.openrewrite.java.testing.mockito.Mockito1to3Migration

0 commit comments

Comments
 (0)