1616package org .openrewrite .java .testing .jmockit ;
1717
1818import java .util .*;
19- import java .util .regex .Pattern ;
2019
2120import lombok .EqualsAndHashCode ;
2221import lombok .Value ;
@@ -69,6 +68,14 @@ private static class RewriteExpectationsVisitor extends JavaIsoVisitor<Execution
6968 JMOCKIT_ARGUMENT_MATCHERS .add ("anyShort" );
7069 JMOCKIT_ARGUMENT_MATCHERS .add ("any" );
7170 }
71+ private static final Map <String , String > MOCKITO_COLLECTION_MATCHERS = new HashMap <>();
72+ static {
73+ MOCKITO_COLLECTION_MATCHERS .put ("java.util.List" , "anyList" );
74+ MOCKITO_COLLECTION_MATCHERS .put ("java.util.Set" , "anySet" );
75+ MOCKITO_COLLECTION_MATCHERS .put ("java.util.Collection" , "anyCollection" );
76+ MOCKITO_COLLECTION_MATCHERS .put ("java.util.Iterable" , "anyIterable" );
77+ MOCKITO_COLLECTION_MATCHERS .put ("java.util.Map" , "anyMap" );
78+ }
7279
7380 @ Override
7481 public J .MethodDeclaration visitMethodDeclaration (J .MethodDeclaration methodDeclaration , ExecutionContext ctx ) {
@@ -81,77 +88,86 @@ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration methodDecl
8188 J .Block newBody = md .getBody ();
8289 List <Statement > statements = md .getBody ().getStatements ();
8390
84- // iterate over each statement in the method body, find Expectations blocks and rewrite them
85- for (int bodyStatementIndex = 0 ; bodyStatementIndex < statements .size (); bodyStatementIndex ++) {
86- Statement s = statements .get (bodyStatementIndex );
87- if (!(s instanceof J .NewClass )) {
88- continue ;
89- }
90- J .NewClass nc = (J .NewClass ) s ;
91- if (!(nc .getClazz () instanceof J .Identifier )) {
92- continue ;
93- }
94- J .Identifier clazz = (J .Identifier ) nc .getClazz ();
95- if (!TypeUtils .isAssignableTo ("mockit.Expectations" , clazz .getType ())) {
96- continue ;
97- }
98- // empty Expectations block is considered invalid
99- assert nc .getBody () != null && !nc .getBody ().getStatements ().isEmpty () : "Expectations block is empty" ;
100- // Expectations block should be composed of a block within another block
101- assert nc .getBody ().getStatements ().size () == 1 : "Expectations block is malformed" ;
102-
103- // we have a valid Expectations block, update imports and rewrite with Mockito statements
104- maybeRemoveImport ("mockit.Expectations" );
105-
106- // the first coordinates are the coordinates of the Expectations block, replacing it
107- JavaCoordinates coordinates = nc .getCoordinates ().replace ();
108- J .Block expectationsBlock = (J .Block ) nc .getBody ().getStatements ().get (0 );
109- List <Object > templateParams = new ArrayList <>();
110-
111- // iterate over the expectations statements and rebuild the method body
112- int mockitoStatementIndex = 0 ;
113- for (Statement expectationStatement : expectationsBlock .getStatements ()) {
114- // TODO: handle additional jmockit expectations features
115-
116- if (expectationStatement instanceof J .MethodInvocation ) {
117- if (!templateParams .isEmpty ()) {
118- // apply template to build new method body
119- newBody = applyTemplate (ctx , templateParams , cursorLocation , coordinates );
120-
121- // next statement coordinates are immediately after the statement just added
122- int newStatementIndex = bodyStatementIndex + mockitoStatementIndex ;
123- coordinates = newBody .getStatements ().get (newStatementIndex ).getCoordinates ().after ();
124-
125- // cursor location is now the new body
126- cursorLocation = newBody ;
127-
128- // reset template params for next expectation
129- templateParams = new ArrayList <>();
130- mockitoStatementIndex += 1 ;
91+ try {
92+ // iterate over each statement in the method body, find Expectations blocks and rewrite them
93+ for (int bodyStatementIndex = 0 ; bodyStatementIndex < statements .size (); bodyStatementIndex ++) {
94+ Statement s = statements .get (bodyStatementIndex );
95+ if (!(s instanceof J .NewClass )) {
96+ continue ;
97+ }
98+ J .NewClass nc = (J .NewClass ) s ;
99+ if (!(nc .getClazz () instanceof J .Identifier )) {
100+ continue ;
101+ }
102+ J .Identifier clazz = (J .Identifier ) nc .getClazz ();
103+ if (!TypeUtils .isAssignableTo ("mockit.Expectations" , clazz .getType ())) {
104+ continue ;
105+ }
106+ // empty Expectations block is considered invalid
107+ assert nc .getBody () != null && !nc .getBody ().getStatements ().isEmpty () : "Expectations block is empty" ;
108+ // Expectations block should be composed of a block within another block
109+ assert nc .getBody ().getStatements ().size () == 1 : "Expectations block is malformed" ;
110+
111+ // we have a valid Expectations block, update imports and rewrite with Mockito statements
112+ maybeRemoveImport ("mockit.Expectations" );
113+
114+ // the first coordinates are the coordinates of the Expectations block, replacing it
115+ JavaCoordinates coordinates = nc .getCoordinates ().replace ();
116+ J .Block expectationsBlock = (J .Block ) nc .getBody ().getStatements ().get (0 );
117+ List <Object > templateParams = new ArrayList <>();
118+
119+ // iterate over the expectations statements and rebuild the method body
120+ int mockitoStatementIndex = 0 ;
121+ for (Statement expectationStatement : expectationsBlock .getStatements ()) {
122+ // TODO: handle additional jmockit expectations features
123+
124+ if (expectationStatement instanceof J .MethodInvocation ) {
125+ if (!templateParams .isEmpty ()) {
126+ // apply template to build new method body
127+ newBody = rewriteMethodBody (ctx , templateParams , cursorLocation , coordinates );
128+
129+ // next statement coordinates are immediately after the statement just added
130+ int newStatementIndex = bodyStatementIndex + mockitoStatementIndex ;
131+ coordinates = newBody .getStatements ().get (newStatementIndex ).getCoordinates ().after ();
132+
133+ // cursor location is now the new body
134+ cursorLocation = newBody ;
135+
136+ // reset template params for next expectation
137+ templateParams = new ArrayList <>();
138+ mockitoStatementIndex += 1 ;
139+ }
140+ templateParams .add (expectationStatement );
141+ } else {
142+ // assignment
143+ templateParams .add (((J .Assignment ) expectationStatement ).getAssignment ());
131144 }
132- templateParams .add (expectationStatement );
133- } else {
134- // assignment
135- templateParams .add (((J .Assignment ) expectationStatement ).getAssignment ());
136145 }
137- }
138146
139- // handle the last statement
140- if (!templateParams .isEmpty ()) {
141- newBody = applyTemplate (ctx , templateParams , cursorLocation , coordinates );
147+ // handle the last statement
148+ if (!templateParams .isEmpty ()) {
149+ newBody = rewriteMethodBody (ctx , templateParams , cursorLocation , coordinates );
150+ }
142151 }
152+ } catch (Exception e ) {
153+ // if anything goes wrong, just return the original method declaration
154+ return md ;
143155 }
144156
145157 return md .withBody (newBody );
146158 }
147159
148- private J .Block applyTemplate (ExecutionContext ctx , List <Object > templateParams , Object cursorLocation ,
149- JavaCoordinates coordinates ) {
160+ private J .Block rewriteMethodBody (ExecutionContext ctx , List <Object > templateParams , Object cursorLocation ,
161+ JavaCoordinates coordinates ) {
150162 Expression result = null ;
151- String methodName = "doNothing" ;
152- if (templateParams .size () > 1 ) {
163+ String methodName ;
164+ if (templateParams .size () == 1 ) {
165+ methodName = "doNothing" ;
166+ } else if (templateParams .size () == 2 ) {
153167 methodName = "when" ;
154168 result = (Expression ) templateParams .get (1 );
169+ } else {
170+ throw new IllegalStateException ("Unexpected number of template params: " + templateParams .size ());
155171 }
156172 maybeAddImport ("org.mockito.Mockito" , methodName );
157173 rewriteArgumentMatchers (ctx , templateParams );
@@ -166,29 +182,78 @@ private J.Block applyTemplate(ExecutionContext ctx, List<Object> templateParams,
166182 );
167183 }
168184
169- private void rewriteArgumentMatchers (ExecutionContext ctx , List <Object > templateParams ) {
170- J .MethodInvocation invocation = (J .MethodInvocation ) templateParams .get (0 );
185+ private void rewriteArgumentMatchers (ExecutionContext ctx , List <Object > bodyTemplateParams ) {
186+ J .MethodInvocation invocation = (J .MethodInvocation ) bodyTemplateParams .get (0 );
171187 List <Expression > newArguments = new ArrayList <>(invocation .getArguments ().size ());
172188 for (Expression methodArgument : invocation .getArguments ()) {
173189 if (!isArgumentMatcher (methodArgument )) {
174190 newArguments .add (methodArgument );
175191 continue ;
176192 }
177- String argumentMatcher = ((J .Identifier ) methodArgument ).getSimpleName ();
178- maybeAddImport ("org.mockito.Mockito" , argumentMatcher );
179- newArguments .add (JavaTemplate .builder (argumentMatcher + "()" )
180- .javaParser (JavaParser .fromJavaVersion ().classpathFromResources (ctx , "mockito-core-3.12" ))
181- .staticImports ("org.mockito.Mockito." + argumentMatcher )
182- .build ()
183- .apply (
184- new Cursor (getCursor (), methodArgument ),
185- methodArgument .getCoordinates ().replace ()
186- ));
193+ String argumentMatcher , template ;
194+ List <Object > argumentTemplateParams = new ArrayList <>();
195+ if (!(methodArgument instanceof J .TypeCast )) {
196+ argumentMatcher = ((J .Identifier ) methodArgument ).getSimpleName ();
197+ template = argumentMatcher + "()" ;
198+ newArguments .add (rewriteMethodArgument (ctx , argumentMatcher , template , methodArgument ,
199+ argumentTemplateParams ));
200+ continue ;
201+ }
202+ J .TypeCast tc = (J .TypeCast ) methodArgument ;
203+ argumentMatcher = ((J .Identifier ) tc .getExpression ()).getSimpleName ();
204+ String className , fqn ;
205+ JavaType typeCastType = tc .getType ();
206+ if (typeCastType instanceof JavaType .Parameterized ) {
207+ // strip the raw type from the parameterized type
208+ className = ((JavaType .Parameterized ) typeCastType ).getType ().getClassName ();
209+ fqn = ((JavaType .Parameterized ) typeCastType ).getType ().getFullyQualifiedName ();
210+ } else if (typeCastType instanceof JavaType .FullyQualified ) {
211+ className = ((JavaType .FullyQualified ) typeCastType ).getClassName ();
212+ fqn = ((JavaType .FullyQualified ) typeCastType ).getFullyQualifiedName ();
213+ } else {
214+ throw new IllegalStateException ("Unexpected J.TypeCast type: " + typeCastType );
215+ }
216+ if (MOCKITO_COLLECTION_MATCHERS .containsKey (fqn )) {
217+ // mockito has specific argument matchers for collections
218+ argumentMatcher = MOCKITO_COLLECTION_MATCHERS .get (fqn );
219+ template = argumentMatcher + "()" ;
220+ } else {
221+ // rewrite parameter from ((<type>) any) to <type>.class
222+ argumentTemplateParams .add (JavaTemplate .builder ("#{}.class" )
223+ .javaParser (JavaParser .fromJavaVersion ())
224+ .imports (fqn )
225+ .build ()
226+ .apply (
227+ new Cursor (getCursor (), tc ),
228+ tc .getCoordinates ().replace (),
229+ className
230+ ));
231+ template = argumentMatcher + "(#{any(java.lang.Class)})" ;
232+ }
233+ newArguments .add (rewriteMethodArgument (ctx , argumentMatcher , template , methodArgument ,
234+ argumentTemplateParams ));
187235 }
188- templateParams .set (0 , invocation .withArguments (newArguments ));
236+ bodyTemplateParams .set (0 , invocation .withArguments (newArguments ));
237+ }
238+
239+ private Expression rewriteMethodArgument (ExecutionContext ctx , String argumentMatcher , String template ,
240+ Expression methodArgument , List <Object > templateParams ) {
241+ maybeAddImport ("org.mockito.Mockito" , argumentMatcher );
242+ return JavaTemplate .builder (template )
243+ .javaParser (JavaParser .fromJavaVersion ().classpathFromResources (ctx , "mockito-core-3.12" ))
244+ .staticImports ("org.mockito.Mockito." + argumentMatcher )
245+ .build ()
246+ .apply (
247+ new Cursor (getCursor (), methodArgument ),
248+ methodArgument .getCoordinates ().replace (),
249+ templateParams .toArray ()
250+ );
189251 }
190252
191253 private static boolean isArgumentMatcher (Expression expression ) {
254+ if (expression instanceof J .TypeCast ) {
255+ expression = ((J .TypeCast ) expression ).getExpression ();
256+ }
192257 if (!(expression instanceof J .Identifier )) {
193258 return false ;
194259 }
@@ -209,7 +274,7 @@ private static String getMockitoStatementTemplate(Expression result) {
209274 ? THROWABLE_RESULT_TEMPLATE
210275 : OBJECT_RESULT_TEMPLATE ;
211276 } else {
212- throw new IllegalStateException ("Unexpected value : " + result .getType ());
277+ throw new IllegalStateException ("Unexpected expression type for template : " + result .getType ());
213278 }
214279 return template ;
215280 }
0 commit comments