|
| 1 | +/* |
| 2 | + * Copyright 2024 the original author or authors. |
| 3 | + * <p> |
| 4 | + * Licensed under the Moderne Source Available License (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * <p> |
| 8 | + * https://docs.moderne.io/licensing/moderne-source-available-license |
| 9 | + * <p> |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | +package org.openrewrite.java.testing.mockito; |
| 17 | + |
| 18 | +import lombok.RequiredArgsConstructor; |
| 19 | +import org.jspecify.annotations.Nullable; |
| 20 | +import org.openrewrite.*; |
| 21 | +import org.openrewrite.internal.ListUtils; |
| 22 | +import org.openrewrite.java.*; |
| 23 | +import org.openrewrite.java.search.UsesMethod; |
| 24 | +import org.openrewrite.java.tree.*; |
| 25 | +import org.openrewrite.marker.Markers; |
| 26 | + |
| 27 | +import java.util.ArrayList; |
| 28 | +import java.util.List; |
| 29 | +import java.util.concurrent.atomic.AtomicBoolean; |
| 30 | + |
| 31 | +import static java.util.Collections.emptyList; |
| 32 | +import static org.openrewrite.Tree.randomId; |
| 33 | +import static org.openrewrite.java.VariableNameUtils.GenerationStrategy.INCREMENT_NUMBER; |
| 34 | +import static org.openrewrite.java.VariableNameUtils.findNamesInScope; |
| 35 | +import static org.openrewrite.java.VariableNameUtils.generateVariableName; |
| 36 | +import static org.openrewrite.java.trait.Traits.annotated; |
| 37 | + |
| 38 | +/** |
| 39 | + * Ensures that all mockStatic calls are properly closed. |
| 40 | + * If mockStatic is in lifecycle methods like @BeforeEach or @BeforeAll, |
| 41 | + * creates a class variable and closes it in @AfterEach or @AfterAll. |
| 42 | + * If mockStatic is inside a test method, wraps it in a try-with-resources block. |
| 43 | + */ |
| 44 | +public class CloseUnclosedStaticMocks extends Recipe { |
| 45 | + |
| 46 | + private static final String LIFECYCLE_METHOD = "lifecycle_method"; |
| 47 | + private static final MethodMatcher MOCK_STATIC_MATCHER = new MethodMatcher("org.mockito.Mockito mockStatic(..)"); |
| 48 | + private static final AnnotationMatcher BEFORE_MATCHER = new AnnotationMatcher("@org.junit.jupiter.api.Before*"); |
| 49 | + private static final AnnotationMatcher AFTER_EACH_MATCHER = new AnnotationMatcher("@org.junit.jupiter.api.AfterEach"); |
| 50 | + private static final AnnotationMatcher AFTER_ALL_MATCHER = new AnnotationMatcher("@org.junit.jupiter.api.AfterAll"); |
| 51 | + |
| 52 | + @Override |
| 53 | + public String getDisplayName() { |
| 54 | + return "Close unclosed static mocks"; |
| 55 | + } |
| 56 | + |
| 57 | + @Override |
| 58 | + public String getDescription() { |
| 59 | + return "Ensures that all `mockStatic` calls are properly closed. " + |
| 60 | + "If `mockStatic` is in lifecycle methods like `@BeforeEach` or `@BeforeAll`, " + |
| 61 | + "creates a class variable and closes it in `@AfterEach` or `@AfterAll`. " + |
| 62 | + "If `mockStatic` is inside a test method, wraps it in a try-with-resources block."; |
| 63 | + } |
| 64 | + |
| 65 | + @Override |
| 66 | + public TreeVisitor<?, ExecutionContext> getVisitor() { |
| 67 | + return Preconditions.check(new UsesMethod<>(MOCK_STATIC_MATCHER), new JavaVisitor<ExecutionContext>() { |
| 68 | + @Override |
| 69 | + public J visitCompilationUnit(J.CompilationUnit cu, ExecutionContext ctx) { |
| 70 | + J j = super.visitCompilationUnit(cu, ctx); |
| 71 | + maybeAddImport("org.mockito.MockedStatic"); |
| 72 | + maybeAddImport("org.junit.jupiter.api.AfterEach"); |
| 73 | + maybeAddImport("org.junit.jupiter.api.AfterAll"); |
| 74 | + return j; |
| 75 | + } |
| 76 | + |
| 77 | + @Override |
| 78 | + public J visitTryResource(J.Try.Resource tryResource, ExecutionContext ctx) { |
| 79 | + return tryResource; // skip try resource |
| 80 | + } |
| 81 | + |
| 82 | + @Override |
| 83 | + public J visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext ctx) { |
| 84 | + Cursor cursor = getCursor(); |
| 85 | + annotated(BEFORE_MATCHER).asVisitor(a -> { |
| 86 | + cursor.putMessage(LIFECYCLE_METHOD, a.getTree().getSimpleName()); |
| 87 | + return a.getTree(); |
| 88 | + }).visit(method, ctx); |
| 89 | + return super.visitMethodDeclaration(method, ctx); |
| 90 | + } |
| 91 | + |
| 92 | + @Override |
| 93 | + public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { |
| 94 | + J.MethodInvocation mi = (J.MethodInvocation) super.visitMethodInvocation(method, ctx); |
| 95 | + String lifeCycleMethod = getCursor().getNearestMessage(LIFECYCLE_METHOD); |
| 96 | + if (!MOCK_STATIC_MATCHER.matches(mi) || lifeCycleMethod == null) { |
| 97 | + return mi; |
| 98 | + } |
| 99 | + if (getCursor().getParentTreeCursor().getValue() instanceof J.Block) { |
| 100 | + String mockedClassName = getMockedClassName(mi); |
| 101 | + if (mockedClassName != null) { |
| 102 | + Cursor classCursor = getCursor().dropParentUntil(J.ClassDeclaration.class::isInstance); |
| 103 | + String varName = generateVariableName("mockedStatic" + mockedClassName, classCursor, INCREMENT_NUMBER); |
| 104 | + J.Assignment assignment = JavaTemplate.builder(varName + " = #{any()}") |
| 105 | + .build() |
| 106 | + .apply(updateCursor(mi), mi.getCoordinates().replace(), mi); |
| 107 | + doAfterVisit(new DeclareMockVarAndClose(varName, mockedClassName, lifeCycleMethod.equals("BeforeAll"))); |
| 108 | + return assignment; |
| 109 | + } |
| 110 | + } |
| 111 | + return mi; |
| 112 | + } |
| 113 | + |
| 114 | + @Override |
| 115 | + public J visitAssignment(J.Assignment assignment, ExecutionContext ctx) { |
| 116 | + if (MOCK_STATIC_MATCHER.matches(assignment.getAssignment())) { |
| 117 | + if (assignment.getVariable() instanceof J.Identifier) { |
| 118 | + JavaType.Variable varType = ((J.Identifier) assignment.getVariable()).getFieldType(); |
| 119 | + if (varType != null && varType.getOwner() instanceof JavaType.Class) { |
| 120 | + doAfterVisit(new DeclareMockVarAndClose(varType.getName(), null, varType.getFlags().contains(Flag.Static))); |
| 121 | + } |
| 122 | + } |
| 123 | + } |
| 124 | + return super.visitAssignment(assignment, ctx); |
| 125 | + } |
| 126 | + |
| 127 | + @Override |
| 128 | + public J visitVariableDeclarations(J.VariableDeclarations variableDeclarations, ExecutionContext ctx) { |
| 129 | + J.VariableDeclarations vd = (J.VariableDeclarations) super.visitVariableDeclarations(variableDeclarations, ctx); |
| 130 | + J.VariableDeclarations.NamedVariable namedVariable = vd.getVariables().get(0); |
| 131 | + String lifeCycleMethod = getCursor().getNearestMessage(LIFECYCLE_METHOD); |
| 132 | + if (!MOCK_STATIC_MATCHER.matches(namedVariable.getInitializer()) || lifeCycleMethod == null) { |
| 133 | + return vd; |
| 134 | + } |
| 135 | + String varName = namedVariable.getSimpleName(); |
| 136 | + String mockedClassName = getMockedClassName((J.MethodInvocation) namedVariable.getInitializer()); |
| 137 | + if (mockedClassName != null) { |
| 138 | + doAfterVisit(new DeclareMockVarAndClose(varName, mockedClassName, lifeCycleMethod.equals("BeforeAll"))); |
| 139 | + return JavaTemplate.builder(varName + " = #{any()}").contextSensitive().build() |
| 140 | + .apply(updateCursor(vd), vd.getCoordinates().replace(), namedVariable.getInitializer()); |
| 141 | + } |
| 142 | + return vd; |
| 143 | + } |
| 144 | + |
| 145 | + @Override |
| 146 | + public J visitBlock(J.Block block, ExecutionContext ctx) { |
| 147 | + J.Block b = (J.Block) super.visitBlock(block, ctx); |
| 148 | + if (getCursor().getNearestMessage(LIFECYCLE_METHOD) != null) { |
| 149 | + return b; |
| 150 | + } |
| 151 | + AtomicBoolean removeStatement = new AtomicBoolean(false); |
| 152 | + J.Block b1 = block.withStatements(ListUtils.map(b.getStatements(), statement -> { |
| 153 | + if (!removeStatement.get() && shouldUseTryWithResources(statement)) { |
| 154 | + J.Try tryWithResource = toTryWithResource(b, statement, ctx); |
| 155 | + if (tryWithResource != null) { |
| 156 | + removeStatement.set(true); |
| 157 | + return (J.Try) super.visitTry(tryWithResource, ctx); |
| 158 | + } |
| 159 | + } |
| 160 | + return removeStatement.get() ? null : statement; |
| 161 | + })); |
| 162 | + return maybeAutoFormat(b, b1, ctx); |
| 163 | + } |
| 164 | + |
| 165 | + private J.@Nullable Try toTryWithResource(J.Block block, Statement statement, ExecutionContext ctx) { |
| 166 | + String code = null; |
| 167 | + if (statement instanceof J.MethodInvocation) { |
| 168 | + String mockedClassName = getMockedClassName((J.MethodInvocation) statement); |
| 169 | + if (mockedClassName != null) { |
| 170 | + String varName = generateVariableName("mockedStatic" + mockedClassName, getCursor(), INCREMENT_NUMBER); |
| 171 | + code = String.format("try(MockedStatic<%s> %s = #{any()}) {}", mockedClassName, varName); |
| 172 | + } |
| 173 | + |
| 174 | + } else if (statement instanceof J.VariableDeclarations || statement instanceof J.Assignment) { |
| 175 | + code = "try(#{any()}) {}"; |
| 176 | + } |
| 177 | + if (code == null) { |
| 178 | + return null; |
| 179 | + } |
| 180 | + J.Try tryWithResources = JavaTemplate.builder(code) |
| 181 | + .contextSensitive() |
| 182 | + .imports("org.mockito.MockedStatic") |
| 183 | + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "mockito-core-5")) |
| 184 | + .build().apply(new Cursor(getCursor(), statement), statement.getCoordinates().replace(), statement); |
| 185 | + return maybeAutoFormat(tryWithResources, tryWithResources.withBody(findSuccessorStatements(statement, block)), ctx); |
| 186 | + } |
| 187 | + |
| 188 | + private @Nullable String getMockedClassName(J.MethodInvocation methodInvocation) { |
| 189 | + JavaType.Parameterized type = TypeUtils.asParameterized(methodInvocation.getType()); |
| 190 | + if (type != null && type.getTypeParameters().size() == 1) { |
| 191 | + JavaType.FullyQualified mockedClass = TypeUtils.asFullyQualified(type.getTypeParameters().get(0)); |
| 192 | + if (mockedClass != null) { |
| 193 | + return mockedClass.getClassName(); |
| 194 | + } |
| 195 | + } |
| 196 | + return null; |
| 197 | + } |
| 198 | + |
| 199 | + private boolean shouldUseTryWithResources(@Nullable Statement statement) { |
| 200 | + if (statement instanceof J.VariableDeclarations) { |
| 201 | + J.VariableDeclarations varDecl = (J.VariableDeclarations) statement; |
| 202 | + return MOCK_STATIC_MATCHER.matches(varDecl.getVariables().get(0).getInitializer()); |
| 203 | + } |
| 204 | + if (statement instanceof J.Assignment) { |
| 205 | + J.Assignment assignment = (J.Assignment) statement; |
| 206 | + return MOCK_STATIC_MATCHER.matches(assignment.getAssignment()); |
| 207 | + } |
| 208 | + if (statement instanceof J.MethodInvocation) { |
| 209 | + return MOCK_STATIC_MATCHER.matches((J.MethodInvocation) statement); |
| 210 | + } |
| 211 | + return false; |
| 212 | + } |
| 213 | + |
| 214 | + private J.Block findSuccessorStatements(Statement statement, J.Block block) { |
| 215 | + List<Statement> successors = new ArrayList<>(); |
| 216 | + boolean found = false; |
| 217 | + for (Statement successor : block.getStatements()) { |
| 218 | + if (found) { |
| 219 | + successors.add(successor); |
| 220 | + } |
| 221 | + found = found || successor == statement; |
| 222 | + } |
| 223 | + return new J.Block(randomId(), Space.EMPTY, Markers.EMPTY, |
| 224 | + new JRightPadded<>(false, Space.EMPTY, Markers.EMPTY), emptyList(), |
| 225 | + Space.format(" ")).withStatements(successors); |
| 226 | + } |
| 227 | + }); |
| 228 | + } |
| 229 | + |
| 230 | + @RequiredArgsConstructor |
| 231 | + private static class DeclareMockVarAndClose extends JavaIsoVisitor<ExecutionContext> { |
| 232 | + private final String varName; |
| 233 | + private final String mockedClassName; |
| 234 | + private final boolean isStatic; |
| 235 | + |
| 236 | + private boolean closed = false; |
| 237 | + |
| 238 | + @Override |
| 239 | + public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration classDecl, ExecutionContext ctx) { |
| 240 | + J.ClassDeclaration cd = classDecl; |
| 241 | + if (!findNamesInScope(getCursor()).contains(varName)) { |
| 242 | + String modifier = isStatic ? "static " : ""; |
| 243 | + String varTemplate = "private " + modifier + "MockedStatic<" + mockedClassName + "> " + varName + ";"; |
| 244 | + cd = JavaTemplate.builder(varTemplate) |
| 245 | + .contextSensitive() |
| 246 | + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "mockito-core-5")) |
| 247 | + .imports("org.mockito.MockedStatic") |
| 248 | + .build().apply(getCursor(), classDecl.getBody().getCoordinates().firstStatement()); |
| 249 | + } |
| 250 | + cd = super.visitClassDeclaration(cd, ctx); |
| 251 | + if (closed) { |
| 252 | + return cd; |
| 253 | + } |
| 254 | + String methodName = tearDownMethodName(cd); |
| 255 | + String methodTemplate = String.format("%s void %s() { %s.close(); }", isStatic ? "@AfterAll public static" : "@AfterEach public", methodName, varName); |
| 256 | + return JavaTemplate.builder(methodTemplate) |
| 257 | + .contextSensitive() |
| 258 | + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "junit-jupiter-api-5")) |
| 259 | + .imports("org.junit.jupiter.api.AfterEach", "org.junit.jupiter.api.AfterAll") |
| 260 | + .build().apply(updateCursor(cd), classDecl.getBody().getCoordinates().lastStatement()); |
| 261 | + } |
| 262 | + |
| 263 | + @Override |
| 264 | + public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration methodDecl, ExecutionContext ctx) { |
| 265 | + J.MethodDeclaration md = super.visitMethodDeclaration(methodDecl, ctx); |
| 266 | + if (closed) { |
| 267 | + return md; |
| 268 | + } |
| 269 | + AnnotationMatcher annotationMatcher = isStatic ? AFTER_ALL_MATCHER : AFTER_EACH_MATCHER; |
| 270 | + boolean matched = annotated(annotationMatcher).<AtomicBoolean>asVisitor((a, found) -> { |
| 271 | + found.set(true); |
| 272 | + return a.getTree(); |
| 273 | + }).reduce(md, new AtomicBoolean()).get(); |
| 274 | + if (!matched) { |
| 275 | + return md; |
| 276 | + } |
| 277 | + closed = true; |
| 278 | + return JavaTemplate.builder(varName + ".close()").contextSensitive().build() |
| 279 | + .apply(updateCursor(md), md.getBody().getCoordinates().lastStatement()); |
| 280 | + } |
| 281 | + |
| 282 | + @Override |
| 283 | + public J.MethodInvocation visitMethodInvocation(J.MethodInvocation methodInvocation, ExecutionContext ctx) { |
| 284 | + if (methodInvocation.getSelect() instanceof J.Identifier) { |
| 285 | + String selector = ((J.Identifier) methodInvocation.getSelect()).getSimpleName(); |
| 286 | + if (selector.equals(varName) && methodInvocation.getSimpleName().equals("close")) { |
| 287 | + closed = true; |
| 288 | + } |
| 289 | + } |
| 290 | + return super.visitMethodInvocation(methodInvocation, ctx); |
| 291 | + } |
| 292 | + |
| 293 | + private String tearDownMethodName(J.ClassDeclaration cd) { |
| 294 | + String methodName = "tearDown"; |
| 295 | + int suffix = 0; |
| 296 | + String updatedMethodName = methodName; |
| 297 | + for (Statement st : cd.getBody().getStatements()) { |
| 298 | + if (st instanceof J.MethodDeclaration && ((J.MethodDeclaration) st).getSimpleName().equals(updatedMethodName)) { |
| 299 | + updatedMethodName = methodName + suffix++; |
| 300 | + } |
| 301 | + } |
| 302 | + return updatedMethodName; |
| 303 | + } |
| 304 | + } |
| 305 | +} |
0 commit comments