1515 */
1616package org .openrewrite .java .testing .mockito ;
1717
18+ import lombok .AllArgsConstructor ;
1819import org .jspecify .annotations .Nullable ;
1920import org .openrewrite .*;
2021import org .openrewrite .internal .ListUtils ;
2829
2930import static java .util .Collections .emptyList ;
3031import static java .util .Objects .requireNonNull ;
31- import static java .util .stream .Collectors .toList ;
3232import static org .openrewrite .java .VariableNameUtils .GenerationStrategy .INCREMENT_NUMBER ;
3333import static org .openrewrite .java .VariableNameUtils .generateVariableName ;
3434import static org .openrewrite .java .tree .Flag .Static ;
@@ -81,9 +81,9 @@ private List<Statement> maybeStatementsToMockedStatic(J.Block m, List<Statement>
8181 for (Statement statement : statements ) {
8282 J .MethodInvocation whenArg = getWhenArg (statement );
8383 if (whenArg != null ) {
84- String className = getClassNameFromInvocation (whenArg );
85- if (className != null ) {
86- list .addAll (mockedStatic (m , (J .MethodInvocation ) statement , className , whenArg , ctx ));
84+ JavaType . @ Nullable Class invokedType = getTypeFromInvocation (whenArg );
85+ if (invokedType != null ) {
86+ list .addAll (mockedStatic (m , (J .MethodInvocation ) statement , invokedType . getClassName () , whenArg , ctx ));
8787 }
8888 } else {
8989 list .add (statement );
@@ -102,18 +102,18 @@ private List<Statement> maybeWrapStatementsInTryWithResourcesMockedStatic(J.Bloc
102102
103103 J .MethodInvocation whenArg = getWhenArg (statement );
104104 if (whenArg != null ) {
105- String className = getClassNameFromInvocation (whenArg );
106- if (className != null ) {
107- Optional <String > nameOfWrappingMockedStatic = tryGetMatchedWrappingResourceName (getCursor (), className );
105+ JavaType . @ Nullable Class invokedType = getTypeFromInvocation (whenArg );
106+ if (invokedType != null ) {
107+ Optional <String > nameOfWrappingMockedStatic = tryGetMatchedWrappingResourceName (getCursor (), invokedType );
108108 if (nameOfWrappingMockedStatic .isPresent ()) {
109109 return reuseMockedStatic (block , (J .MethodInvocation ) statement , nameOfWrappingMockedStatic .get (), whenArg , ctx );
110110 }
111- J .Identifier staticMockedVariable = findMockedStaticVariable (getCursor (), className );
111+ J .Identifier staticMockedVariable = findMockedStaticVariable (getCursor (), invokedType );
112112 if (staticMockedVariable != null ) {
113113 return reuseMockedStatic (block , (J .MethodInvocation ) statement , staticMockedVariable , whenArg , ctx );
114114 }
115115 restInTry .set (true );
116- return tryWithMockedStatic (block , statements , index , (J .MethodInvocation ) statement , className , whenArg , ctx );
116+ return tryWithMockedStatic (block , statements , index , (J .MethodInvocation ) statement , invokedType . getClassName () , whenArg , ctx );
117117 }
118118 }
119119 return statement ;
@@ -133,15 +133,15 @@ private List<Statement> maybeWrapStatementsInTryWithResourcesMockedStatic(J.Bloc
133133 return null ;
134134 }
135135
136- private @ Nullable String getClassNameFromInvocation (J .MethodInvocation whenArg ) {
136+ private JavaType . @ Nullable Class getTypeFromInvocation (J .MethodInvocation whenArg ) {
137137 J .Identifier clazz = null ;
138138 // Having a fieldType implies that something is a field rather than a class itself
139139 if (whenArg .getSelect () instanceof J .Identifier && ((J .Identifier ) whenArg .getSelect ()).getFieldType () == null ) {
140140 clazz = (J .Identifier ) whenArg .getSelect ();
141141 } else if (whenArg .getSelect () instanceof J .FieldAccess && ((J .FieldAccess ) whenArg .getSelect ()).getTarget () instanceof J .Identifier ) {
142142 clazz = (J .Identifier ) ((J .FieldAccess ) whenArg .getSelect ()).getTarget ();
143143 }
144- return clazz != null && clazz .getType () != null ? clazz .getSimpleName () : null ;
144+ return clazz != null && clazz .getType () != null ? ( JavaType . Class ) clazz .getType () : null ;
145145 }
146146
147147 private J .Try tryWithMockedStatic (J .Block block , List <Statement > statements , Integer index ,
@@ -265,17 +265,22 @@ private JavaTemplate javaTemplateMockStatic(String code, ExecutionContext ctx) {
265265 });
266266 }
267267
268- private static List <J .Try .Resource > getMatchingFilteredResources (@ Nullable List <J .Try .Resource > resources , String className ) {
269- if (resources != null ) {
270- return resources .stream ().filter (res -> {
271- J .VariableDeclarations vds = (J .VariableDeclarations ) res .getVariableDeclarations ();
272- return TypeUtils .isAssignableTo ("org.mockito.MockedStatic<" + className + ">" , vds .getTypeAsFullyQualified ());
273- }).collect (toList ());
268+ private static List <J .Try .Resource > getMatchingFilteredResources (@ Nullable List <J .Try .Resource > resources , JavaType className ) {
269+ if (resources == null ) {
270+ return emptyList ();
274271 }
275- return emptyList ( );
272+ return ListUtils . filter ( resources , res -> isMockedStaticOfType ( className , (( J . VariableDeclarations ) res . getVariableDeclarations ()). getTypeAsFullyQualified ()) );
276273 }
277274
278- private static Optional <String > tryGetMatchedWrappingResourceName (Cursor cursor , String className ) {
275+ private static boolean isMockedStaticOfType (JavaType mockedType , @ Nullable JavaType comparisonType ) {
276+ if (comparisonType != null && MOCKED_STATIC .matches (comparisonType ) && comparisonType instanceof JavaType .Parameterized ) {
277+ JavaType .Parameterized parameterizedType = requireNonNull (TypeUtils .asParameterized (comparisonType ));
278+ return parameterizedType .getTypeParameters ().size () == 1 && TypeUtils .isAssignableTo (mockedType , parameterizedType .getTypeParameters ().get (0 ));
279+ }
280+ return false ;
281+ }
282+
283+ private static Optional <String > tryGetMatchedWrappingResourceName (Cursor cursor , JavaType className ) {
279284 try {
280285 Cursor foundParentCursor = cursor .dropParentUntil (val -> {
281286 if (val instanceof J .Try ) {
@@ -334,7 +339,7 @@ private static String getSafeAfterMethodName(String baseName, List<Statement> ex
334339 .orElse (baseName );
335340 }
336341
337- private static J .@ Nullable Identifier findMockedStaticVariable (Cursor scope , String className ) {
342+ private static J .@ Nullable Identifier findMockedStaticVariable (Cursor scope , JavaType className ) {
338343 JavaSourceFile compilationUnit = scope .firstEnclosing (JavaSourceFile .class );
339344 if (compilationUnit == null ) {
340345 return null ;
@@ -352,11 +357,8 @@ public J.Block visitBlock(J.Block block, AtomicReference<J.Identifier> mockedSta
352357 @ Override
353358 public J .VariableDeclarations .NamedVariable visitVariable (J .VariableDeclarations .NamedVariable variable , AtomicReference <J .Identifier > mockedStaticVar ) {
354359 J .Identifier identifier = variable .getName ();
355- if (MOCKED_STATIC .matches (identifier ) && identifier .getType () instanceof JavaType .Parameterized ) {
356- JavaType .Parameterized parameterizedType = (JavaType .Parameterized ) identifier .getType ();
357- if (parameterizedType .getTypeParameters ().size () == 1 && TypeUtils .isAssignableTo (className , parameterizedType .getTypeParameters ().get (0 ))) {
358- mockedStaticVar .set (identifier );
359- }
360+ if (isMockedStaticOfType (className , identifier .getType ())) {
361+ mockedStaticVar .set (identifier );
360362 }
361363
362364 return super .visitVariable (variable , mockedStaticVar );
0 commit comments