Skip to content

Commit ca93760

Browse files
authored
JUnit 4 AssertArrayEquals/AssertNotEquals to AssertJ (#15)
1 parent 9e6d5b6 commit ca93760

File tree

10 files changed

+1164
-69
lines changed

10 files changed

+1164
-69
lines changed
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
/*
2+
* Copyright 2020 the original author or authors.
3+
* <p>
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
* <p>
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
* <p>
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.openrewrite.java.testing.junitassertj;
17+
18+
import org.openrewrite.AutoConfigure;
19+
import org.openrewrite.java.AutoFormat;
20+
import org.openrewrite.java.JavaIsoRefactorVisitor;
21+
import org.openrewrite.java.MethodMatcher;
22+
import org.openrewrite.java.tree.*;
23+
24+
import java.util.List;
25+
26+
import static org.openrewrite.java.tree.MethodTypeBuilder.newMethodType;
27+
28+
/**
29+
* This is a refactoring visitor that will convert JUnit-style assertArrayEquals() to assertJ's assertThat().containsExactly().
30+
*
31+
* This visitor will handle the following JUnit 5 method signatures:
32+
*
33+
* <PRE>
34+
* Two parameter variants:
35+
*
36+
* assertArrayEquals(expected,actual) -> assertThat(actual).containsExactly(expected)
37+
*
38+
* Three parameter variant where the third argument is a String:
39+
*
40+
* assertArrayEquals(expected, actual, "message") -> assertThat(actual).as("message").containsExactly(expected)
41+
*
42+
* Three parameter variant where the third argument is a String supplier:
43+
*
44+
* assertArrayEquals(expected, actual, () -> "message") -> assertThat(actual).withFailureMessage("message").containsExactly(expected)
45+
*
46+
* Three parameter variant where args are all floating point numbers.
47+
*
48+
* assertArrayEquals(expected, actual, delta) -> assertThat(actual).containsExactly(expected, within(delta));
49+
*
50+
* Four parameter variant when comparing floating point numbers with a delta and a message:
51+
*
52+
* assertArrayEquals(expected, actual, delta, "message") -> assertThat(actual).withFailureMessage("message").containsExactly(expected, within(delta));
53+
*
54+
* </PRE>
55+
*/
56+
@AutoConfigure
57+
public class AssertArrayEqualsToAssertThat extends JavaIsoRefactorVisitor {
58+
59+
private static final String JUNIT_QUALIFIED_ASSERTIONS_CLASS_NAME = "org.junit.jupiter.api.Assertions";
60+
61+
private static final String ASSERTJ_QUALIFIED_ASSERTIONS_CLASS_NAME = "org.assertj.core.api.Assertions";
62+
private static final String ASSERTJ_ASSERT_THAT_METHOD_NAME = "assertThat";
63+
private static final String ASSERTJ_WITHIN_METHOD_NAME = "within";
64+
65+
/**
66+
* This matcher finds the junit methods that will be migrated by this visitor.
67+
*/
68+
private static final MethodMatcher JUNIT_ASSERT_EQUALS_MATCHER = new MethodMatcher(
69+
JUNIT_QUALIFIED_ASSERTIONS_CLASS_NAME + " assertArrayEquals(..)"
70+
);
71+
72+
private static final JavaType ASSERTJ_ASSERTIONS_WILDCARD_STATIC_IMPORT = newMethodType()
73+
.declaringClass("org.assertj.core.api.Assertions")
74+
.flags(Flag.Public, Flag.Static)
75+
.name("*")
76+
.build();
77+
78+
public AssertArrayEqualsToAssertThat() {
79+
setCursoringOn();
80+
}
81+
82+
@Override
83+
public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method) {
84+
85+
J.MethodInvocation original = super.visitMethodInvocation(method);
86+
87+
if (!JUNIT_ASSERT_EQUALS_MATCHER.matches(method)) {
88+
return original;
89+
}
90+
91+
List<Expression> originalArgs = original.getArgs().getArgs();
92+
93+
Expression expected = originalArgs.get(0);
94+
Expression actual = originalArgs.get(1);
95+
96+
J.MethodInvocation replacement;
97+
if (originalArgs.size() == 2) {
98+
//assertThat(actual).isEqualTo(expected)
99+
replacement = assertSimple(actual, expected);
100+
} else if (originalArgs.size() == 3 && !isFloatingPointType(originalArgs.get(2))) {
101+
//assertThat(actual).as(message).isEqualTo(expected)
102+
replacement = assertWithMessage(actual, expected, originalArgs.get(2));
103+
} else if (originalArgs.size() == 3) {
104+
//assert is using floating points with a delta and no message.
105+
replacement = assertFloatingPointDelta(actual, expected, originalArgs.get(2));
106+
maybeAddImport(ASSERTJ_QUALIFIED_ASSERTIONS_CLASS_NAME, ASSERTJ_WITHIN_METHOD_NAME);
107+
108+
} else {
109+
//The assertEquals is using a floating point with a delta argument and a message.
110+
replacement = assertFloatingPointDeltaWithMessage(actual, expected, originalArgs.get(2), originalArgs.get(3));
111+
maybeAddImport(ASSERTJ_QUALIFIED_ASSERTIONS_CLASS_NAME, ASSERTJ_WITHIN_METHOD_NAME);
112+
}
113+
114+
//Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat"
115+
maybeAddImport(ASSERTJ_QUALIFIED_ASSERTIONS_CLASS_NAME, ASSERTJ_ASSERT_THAT_METHOD_NAME);
116+
//And if there are no longer references to the JUnit assertions class, we can remove the import.
117+
maybeRemoveImport(JUNIT_QUALIFIED_ASSERTIONS_CLASS_NAME);
118+
119+
//Format the replacement method invocation in the context of where it is called.
120+
andThen(new AutoFormat(replacement));
121+
return replacement;
122+
}
123+
124+
private J.MethodInvocation assertSimple(Expression actual, Expression expected) {
125+
126+
List<J.MethodInvocation> statements = treeBuilder.buildSnippet(getCursor(),
127+
String.format("assertThat(%s).containsExactly(%s);", actual.printTrimmed(), expected.printTrimmed()),
128+
ASSERTJ_ASSERTIONS_WILDCARD_STATIC_IMPORT
129+
);
130+
return statements.get(0);
131+
}
132+
133+
private J.MethodInvocation assertWithMessage(Expression actual, Expression expected, Expression message) {
134+
135+
// In assertJ the "as" method has a more informative error message, but doesn't accept String suppliers
136+
// so we're using "as" if the message is a string and "withFailMessage" if it is a supplier.
137+
String messageAs = TypeUtils.isString(message.getType())?"as":"withFailMessage";
138+
139+
List<J.MethodInvocation> statements = treeBuilder.buildSnippet(getCursor(),
140+
String.format("assertThat(%s).%s(%s).containsExactly(%s);",
141+
actual.printTrimmed(), messageAs, message.printTrimmed(), expected.printTrimmed()),
142+
ASSERTJ_ASSERTIONS_WILDCARD_STATIC_IMPORT
143+
);
144+
return statements.get(0);
145+
}
146+
147+
private J.MethodInvocation assertFloatingPointDelta(Expression actual, Expression expected, Expression delta) {
148+
List<J.MethodInvocation> statements = treeBuilder.buildSnippet(getCursor(),
149+
String.format("assertThat(%s).containsExactly(%s, within(%s));",
150+
actual.printTrimmed(), expected.printTrimmed(), delta.printTrimmed()),
151+
ASSERTJ_ASSERTIONS_WILDCARD_STATIC_IMPORT
152+
);
153+
return statements.get(0);
154+
}
155+
156+
private J.MethodInvocation assertFloatingPointDeltaWithMessage(Expression actual, Expression expected,
157+
Expression delta, Expression message) {
158+
159+
//If the message is a string use "as", if it is a supplier use "withFailMessage"
160+
String messageAs = TypeUtils.isString(message.getType())?"as":"withFailMessage";
161+
162+
List<J.MethodInvocation> statements = treeBuilder.buildSnippet(getCursor(),
163+
String.format("assertThat(%s).%s(%s).containsExactly(%s, within(%s));",
164+
actual.printTrimmed(), messageAs, message.printTrimmed(), expected.printTrimmed(), delta.printTrimmed()),
165+
ASSERTJ_ASSERTIONS_WILDCARD_STATIC_IMPORT
166+
);
167+
return statements.get(0);
168+
}
169+
170+
/**
171+
* Returns true if the expression's type is either a primitive float/double or their object forms Float/Double
172+
*
173+
* @param expression The expression parsed from the original AST.
174+
* @return true if the type is a floating point number.
175+
*/
176+
private boolean isFloatingPointType(Expression expression) {
177+
178+
JavaType.FullyQualified fullyQualified = TypeUtils.asFullyQualified(expression.getType());
179+
if (fullyQualified != null) {
180+
String typeName = fullyQualified.getFullyQualifiedName();
181+
return (typeName.equals("java.lang.Double") || typeName.equals("java.lang.Float"));
182+
}
183+
184+
JavaType.Primitive parameterType = TypeUtils.asPrimitive(expression.getType());
185+
return parameterType == JavaType.Primitive.Double || parameterType == JavaType.Primitive.Float;
186+
}
187+
}

src/main/java/org/openrewrite/java/testing/junitassertj/AssertEqualsToAssertThat.java

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,21 @@
3535
*
3636
* assertEquals(expected,actual) -> assertThat(actual).isEqualTo(expected)
3737
*
38-
* Three parameter variant where the third argument is either a String or String Supplier:
38+
* Three parameter variant where the third argument is a String:
3939
*
40-
* assertEquals(expected, actual, "message") -> assertThat(actual).withFailureMessage("message").isEqualTo(expected)
40+
* assertEquals(expected, actual, "message") -> assertThat(actual).as("message").isEqualTo(expected)
41+
*
42+
* Three parameter variant where the third argument is a String supplier:
43+
*
44+
* assertEquals(expected, actual, () -> "message") -> assertThat(actual).withFailureMessage("message").isEqualTo(expected)
4145
*
4246
* Three parameter variant where args are all floating point numbers.
4347
*
4448
* assertEquals(expected, actual, delta) -> assertThat(actual).isCloseTo(expected, within(delta));
4549
*
4650
* Four parameter variant when comparing floating point numbers with a delta and a message:
4751
*
48-
* assertEquals(expected, actual, delta) -> assertThat(actual).withFailureMessage("message").isCloseTo(expected, within(delta));
52+
* assertEquals(expected, actual, delta, "message") -> assertThat(actual).withFailureMessage("message").isCloseTo(expected, within(delta));
4953
*
5054
* </PRE>
5155
*/
@@ -59,7 +63,7 @@ public class AssertEqualsToAssertThat extends JavaIsoRefactorVisitor {
5963
private static final String ASSERTJ_WITHIN_METHOD_NAME = "within";
6064

6165
/**
62-
* This matcher uses a pointcut expression to find the matching junit methods that will be migrated by this visitor
66+
* This matcher finds the junit methods that will be migrated by this visitor.
6367
*/
6468
private static final MethodMatcher JUNIT_ASSERT_EQUALS_MATCHER = new MethodMatcher(
6569
JUNIT_QUALIFIED_ASSERTIONS_CLASS_NAME + " assertEquals(..)"
@@ -75,12 +79,6 @@ public AssertEqualsToAssertThat() {
7579
setCursoringOn();
7680
}
7781

78-
@Override
79-
public J.CompilationUnit visitCompilationUnit(J.CompilationUnit cu) {
80-
maybeRemoveImport(JUNIT_QUALIFIED_ASSERTIONS_CLASS_NAME);
81-
return super.visitCompilationUnit(cu);
82-
}
83-
8482
@Override
8583
public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method) {
8684

@@ -107,14 +105,15 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method) {
107105
maybeAddImport(ASSERTJ_QUALIFIED_ASSERTIONS_CLASS_NAME, ASSERTJ_WITHIN_METHOD_NAME);
108106

109107
} else {
110-
//The assertEquals is using a primitive floating point with a delta argument. (There may be an optional)
111-
//fourth argument that contains the message.
108+
//The assertEquals is using a floating point with a delta argument and a message.
112109
replacement = assertFloatingPointDeltaWithMessage(actual, expected, originalArgs.get(2), originalArgs.get(3));
113110
maybeAddImport(ASSERTJ_QUALIFIED_ASSERTIONS_CLASS_NAME, ASSERTJ_WITHIN_METHOD_NAME);
114111
}
115112

116113
//Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat"
117114
maybeAddImport(ASSERTJ_QUALIFIED_ASSERTIONS_CLASS_NAME, ASSERTJ_ASSERT_THAT_METHOD_NAME);
115+
//And if there are no longer references to the JUnit assertions class, we can remove the import.
116+
maybeRemoveImport(JUNIT_QUALIFIED_ASSERTIONS_CLASS_NAME);
118117

119118
//Format the replacement method invocation in the context of where it is called.
120119
andThen(new AutoFormat(replacement));
@@ -131,9 +130,14 @@ private J.MethodInvocation assertSimple(Expression actual, Expression expected)
131130
}
132131

133132
private J.MethodInvocation assertWithMessage(Expression actual, Expression expected, Expression message) {
133+
134+
// In assertJ the "as" method has a more informative error message, but doesn't accept String suppliers
135+
// so we're using "as" if the message is a string and "withFailMessage" if it is a supplier.
136+
String messageAs = TypeUtils.isString(message.getType())?"as":"withFailMessage";
137+
134138
List<J.MethodInvocation> statements = treeBuilder.buildSnippet(getCursor(),
135-
String.format("assertThat(%s).withFailMessage(%s).isEqualTo(%s);",
136-
actual.printTrimmed(), message.printTrimmed(), expected.printTrimmed()),
139+
String.format("assertThat(%s).%s(%s).isEqualTo(%s);",
140+
actual.printTrimmed(), messageAs, message.printTrimmed(), expected.printTrimmed()),
137141
ASSERTJ_ASSERTIONS_WILDCARD_STATIC_IMPORT
138142
);
139143
return statements.get(0);
@@ -150,21 +154,32 @@ private J.MethodInvocation assertFloatingPointDelta(Expression actual, Expressio
150154

151155
private J.MethodInvocation assertFloatingPointDeltaWithMessage(Expression actual, Expression expected,
152156
Expression delta, Expression message) {
157+
158+
//If the message is a string use "as", if it is a supplier use "withFailMessage"
159+
String messageAs = TypeUtils.isString(message.getType())?"as":"withFailMessage";
160+
153161
List<J.MethodInvocation> statements = treeBuilder.buildSnippet(getCursor(),
154-
String.format("assertThat(%s).withFailMessage(%s).isCloseTo(%s, within(%s));",
155-
actual.printTrimmed(), message.printTrimmed(), expected.printTrimmed(), delta.printTrimmed()),
162+
String.format("assertThat(%s).%s(%s).isCloseTo(%s, within(%s));",
163+
actual.printTrimmed(), messageAs, message.printTrimmed(), expected.printTrimmed(), delta.printTrimmed()),
156164
ASSERTJ_ASSERTIONS_WILDCARD_STATIC_IMPORT
157165
);
158166
return statements.get(0);
159167
}
160168

161169
/**
162-
* Returns true if the expression's type is either a primitive float or double.
170+
* Returns true if the expression's type is either a primitive float/double or their object forms Float/Double
163171
*
164172
* @param expression The expression parsed from the original AST.
165173
* @return true if the type is a floating point number.
166174
*/
167175
private boolean isFloatingPointType(Expression expression) {
176+
177+
JavaType.FullyQualified fullyQualified = TypeUtils.asFullyQualified(expression.getType());
178+
if (fullyQualified != null) {
179+
String typeName = fullyQualified.getFullyQualifiedName();
180+
return (typeName.equals("java.lang.Double") || typeName.equals("java.lang.Float"));
181+
}
182+
168183
JavaType.Primitive parameterType = TypeUtils.asPrimitive(expression.getType());
169184
return parameterType == JavaType.Primitive.Double || parameterType == JavaType.Primitive.Float;
170185
}

0 commit comments

Comments
 (0)