1717
1818import lombok .AllArgsConstructor ;
1919import lombok .NoArgsConstructor ;
20+ import org .jetbrains .annotations .NotNull ;
2021import org .openrewrite .*;
2122import org .openrewrite .internal .lang .Nullable ;
2223import org .openrewrite .java .JavaIsoVisitor ;
@@ -75,76 +76,64 @@ private class MigrateToAssertJVisitor extends JavaIsoVisitor<ExecutionContext> {
7576 public J .MethodInvocation visitMethodInvocation (J .MethodInvocation method , ExecutionContext ctx ) {
7677 J .MethodInvocation mi = super .visitMethodInvocation (method , ctx );
7778 if (assertThatMatcher .matches (mi )) {
78- if (mi .getArguments ().size () == 2 ) {
79- return handleTwoArgumentCase (mi , ctx );
80- }
81- if (mi .getArguments ().size () == 3 ) {
82- return handleThreeArgumentCase (mi , ctx );
83- }
79+ return replace (mi , ctx );
8480 }
8581 return mi ;
8682 }
8783
88- private J .MethodInvocation handleTwoArgumentCase (J .MethodInvocation mi , ExecutionContext ctx ) {
89- Expression actualArgument = mi .getArguments ().get (0 );
90- Expression matcherArgument = mi .getArguments ().get (1 );
84+ private J .MethodInvocation replace (J .MethodInvocation mi , ExecutionContext ctx ) {
85+ List <Expression > mia = mi .getArguments ();
86+ Expression reasonArgument = mia .size () == 3 ? mia .get (0 ) : null ;
87+ Expression actualArgument = mia .get (mia .size () - 2 );
88+ Expression matcherArgument = mia .get (mia .size () - 1 );
9189 if (!matchersMatcher .matches (matcherArgument ) || subMatcher .matches (matcherArgument )) {
9290 return mi ;
9391 }
9492 String actual = typeToIndicator (actualArgument .getType ());
95- List <Expression > originalArguments = ((J .MethodInvocation ) matcherArgument ).getArguments ().stream ()
96- .filter (a -> !(a instanceof J .Empty ))
97- .collect (Collectors .toList ());
98- String argumentsTemplate = originalArguments .stream ()
99- .map (a -> typeToIndicator (a .getType ()))
100- .collect (Collectors .joining (", " ));
101- argumentsTemplate = applySpecialCases ((J .MethodInvocation ) matcherArgument , argumentsTemplate );
102-
103- JavaTemplate template = JavaTemplate .builder (String .format ("assertThat(%s).%s(%s)" ,
104- actual , assertion , argumentsTemplate ))
93+ J .MethodInvocation matcherArgumentMethod = (J .MethodInvocation ) matcherArgument ;
94+ JavaTemplate template = JavaTemplate .builder (String .format (
95+ "assertThat(%s)" +
96+ (reasonArgument != null ? ".as(#{any(String)})" : "" ) +
97+ ".%s(%s)" ,
98+ actual , assertion , getArgumentsTemplate (matcherArgumentMethod )))
10599 .javaParser (JavaParser .fromJavaVersion ().classpathFromResources (ctx , "assertj-core-3.24" ))
106- .staticImports ("org.assertj.core.api.Assertions.assertThat" , "org.assertj.core.api.Assertions.within" )
100+ .staticImports (
101+ "org.assertj.core.api.Assertions.assertThat" ,
102+ "org.assertj.core.api.Assertions.within" )
107103 .build ();
108104 maybeAddImport ("org.assertj.core.api.Assertions" , "assertThat" );
109105 maybeAddImport ("org.assertj.core.api.Assertions" , "within" );
110106 maybeRemoveImport ("org.hamcrest.Matchers." + matcher );
107+ maybeRemoveImport ("org.hamcrest.MatcherAssert" );
111108 maybeRemoveImport ("org.hamcrest.MatcherAssert.assertThat" );
112109
113110 List <Expression > templateArguments = new ArrayList <>();
114111 templateArguments .add (actualArgument );
115- templateArguments .addAll (originalArguments );
112+ if (reasonArgument != null ) {
113+ templateArguments .add (reasonArgument );
114+ }
115+ for (Expression originalArgument : matcherArgumentMethod .getArguments ()) {
116+ if (!(originalArgument instanceof J .Empty )) {
117+ templateArguments .add (originalArgument );
118+ }
119+ }
116120 return template .apply (getCursor (), mi .getCoordinates ().replace (), templateArguments .toArray ());
117121 }
118122
119- private J .MethodInvocation handleThreeArgumentCase (J .MethodInvocation mi , ExecutionContext ctx ) {
120- Expression reasonArgument = mi .getArguments ().get (0 );
121- Expression actualArgument = mi .getArguments ().get (1 );
122- Expression matcherArgument = mi .getArguments ().get (2 );
123- if (!matchersMatcher .matches (matcherArgument ) || subMatcher .matches (matcherArgument )) {
124- return mi ;
123+ private final MethodMatcher CLOSE_TO_MATCHER = new MethodMatcher ("org.hamcrest.Matchers closeTo(..)" );
124+
125+ @ NotNull
126+ private String getArgumentsTemplate (J .MethodInvocation matcherArgument ) {
127+ List <Expression > methodArguments = matcherArgument .getArguments ();
128+ if (CLOSE_TO_MATCHER .matches (matcherArgument )) {
129+ return String .format ("%s, within(%s)" ,
130+ typeToIndicator (methodArguments .get (0 ).getType ()),
131+ typeToIndicator (methodArguments .get (1 ).getType ()));
125132 }
126- String actual = typeToIndicator (actualArgument .getType ());
127- List <Expression > originalArguments = ((J .MethodInvocation ) matcherArgument ).getArguments ().stream ()
133+ return methodArguments .stream ()
128134 .filter (a -> !(a instanceof J .Empty ))
129- .collect (Collectors .toList ());
130- String argumentsTemplate = originalArguments .stream ()
131135 .map (a -> typeToIndicator (a .getType ()))
132136 .collect (Collectors .joining (", " ));
133- JavaTemplate template = JavaTemplate .builder (String .format ("assertThat(%s).as(#{any(String)}).%s(%s)" ,
134- actual , assertion , argumentsTemplate ))
135- .javaParser (JavaParser .fromJavaVersion ().classpathFromResources (ctx , "assertj-core-3.24" ))
136- .staticImports ("org.assertj.core.api.Assertions.assertThat" )
137- .build ();
138- maybeAddImport ("org.assertj.core.api.Assertions" , "assertThat" );
139- maybeRemoveImport ("org.hamcrest.Matchers." + matcher );
140- maybeRemoveImport ("org.hamcrest.MatcherAssert" );
141- maybeRemoveImport ("org.hamcrest.MatcherAssert.assertThat" );
142-
143- List <Expression > templateArguments = new ArrayList <>();
144- templateArguments .add (actualArgument );
145- templateArguments .add (reasonArgument );
146- templateArguments .addAll (originalArguments );
147- return template .apply (getCursor (), mi .getCoordinates ().replace (), templateArguments .toArray ());
148137 }
149138
150139 private String typeToIndicator (JavaType type ) {
@@ -159,24 +148,5 @@ private String typeToIndicator(JavaType type) {
159148 return String .format ("#{any(%s)}" , str );
160149 }
161150 }
162-
163- private String applySpecialCases (J .MethodInvocation mi , String template ) {
164- final MethodMatcher CLOSE_TO_MATCHER = new MethodMatcher ("org.hamcrest.Matchers closeTo(..)" );
165- String [] splitTemplate = template .split (", " );
166-
167- if (CLOSE_TO_MATCHER .matches (mi )) {
168- List <String > newTemplateArr = new ArrayList <>();
169- for (int i = 0 ; i < splitTemplate .length ; i ++) {
170- // within needs to placed on the second argument of isCloseTo
171- if (i == 1 ) {
172- newTemplateArr .add (String .format ("within(%s)" , splitTemplate [i ]));
173- continue ;
174- }
175- newTemplateArr .add (splitTemplate [i ]);
176- }
177- return String .join (", " , newTemplateArr );
178- }
179- return template ;
180- }
181151 }
182152}
0 commit comments