Skip to content

Commit a6cecb1

Browse files
authored
Support member references in AssertJ assertThrows (#330)
1 parent f4227ff commit a6cecb1

File tree

2 files changed

+59
-13
lines changed

2 files changed

+59
-13
lines changed

src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertThrowsToAssertExceptionType.java

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -67,23 +67,34 @@ private Supplier<JavaParser> assertionsParser(ExecutionContext ctx) {
6767
}
6868
private static final MethodMatcher ASSERT_THROWS_MATCHER = new MethodMatcher("org.junit.jupiter.api.Assertions assertThrows(..)");
6969

70+
private static final JavaType THROWING_CALLABLE_TYPE = JavaType.buildType("org.assertj.core.api.ThrowableAssert.ThrowingCallable");
71+
7072
@Override
7173
public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) {
7274
J.MethodInvocation mi = super.visitMethodInvocation(method, ctx);
7375
if (ASSERT_THROWS_MATCHER.matches(mi) && mi.getArguments().size() == 2) {
74-
J.Lambda lambdaArg = (J.Lambda) mi.getArguments().get(1);
75-
lambdaArg = lambdaArg.withType(JavaType.buildType("org.assertj.core.api.ThrowableAssert.ThrowingCallable"));
76-
mi = mi.withTemplate(
77-
JavaTemplate
78-
.builder(this::getCursor,
79-
"assertThatExceptionOfType(#{any(java.lang.Class)}).isThrownBy(#{any(org.assertj.core.api.ThrowableAssert.ThrowingCallable)})")
80-
.javaParser(assertionsParser(ctx))
81-
.staticImports("org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType")
82-
.build(),
83-
mi.getCoordinates().replace(),
84-
mi.getArguments().get(0), lambdaArg);
85-
maybeAddImport("org.assertj.core.api.AssertionsForClassTypes", "assertThatExceptionOfType");
86-
maybeRemoveImport("org.junit.jupiter.api.Assertions.assertThrows");
76+
J executable = mi.getArguments().get(1);
77+
if (executable instanceof J.Lambda) {
78+
executable = ((J.Lambda) executable).withType(THROWING_CALLABLE_TYPE);
79+
} else if (executable instanceof J.MemberReference) {
80+
executable = ((J.MemberReference) executable).withType(THROWING_CALLABLE_TYPE);
81+
} else {
82+
executable = null;
83+
}
84+
85+
if (executable != null) {
86+
mi = mi.withTemplate(
87+
JavaTemplate
88+
.builder(this::getCursor,
89+
"assertThatExceptionOfType(#{any(java.lang.Class)}).isThrownBy(#{any(org.assertj.core.api.ThrowableAssert.ThrowingCallable)})")
90+
.javaParser(assertionsParser(ctx))
91+
.staticImports("org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType")
92+
.build(),
93+
mi.getCoordinates().replace(),
94+
mi.getArguments().get(0), executable);
95+
maybeAddImport("org.assertj.core.api.AssertionsForClassTypes", "assertThatExceptionOfType");
96+
maybeRemoveImport("org.junit.jupiter.api.Assertions.assertThrows");
97+
}
8798
}
8899
return mi;
89100
}

src/test/java/org/openrewrite/java/testing/assertj/JUnitAssertThrowsToAssertExceptionTypeTest.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,39 @@ public void throwsExceptionWithSpecificType() {
6363
)
6464
);
6565
}
66+
67+
@Test
68+
void memberReference() {
69+
//language=java
70+
rewriteRun(
71+
java(
72+
"""
73+
import static org.junit.jupiter.api.Assertions.assertThrows;
74+
import java.util.concurrent.CompletableFuture;
75+
import java.util.concurrent.ExecutionException;
76+
77+
public class MemberReferenceTest {
78+
79+
public void throwsWithMemberReference() {
80+
CompletableFuture<Boolean> future = new CompletableFuture<>();
81+
assertThrows(ExecutionException.class, future::get);
82+
}
83+
}
84+
""",
85+
"""
86+
import static org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType;
87+
import java.util.concurrent.CompletableFuture;
88+
import java.util.concurrent.ExecutionException;
89+
90+
public class MemberReferenceTest {
91+
92+
public void throwsWithMemberReference() {
93+
CompletableFuture<Boolean> future = new CompletableFuture<>();
94+
assertThatExceptionOfType(ExecutionException.class).isThrownBy(future::get);
95+
}
96+
}
97+
"""
98+
)
99+
);
100+
}
66101
}

0 commit comments

Comments
 (0)