Skip to content

Commit 779623e

Browse files
committed
Preserve throws clause when checked exceptions exist outside assertThrows.
1 parent efa7b6f commit 779623e

File tree

3 files changed

+147
-2
lines changed

3 files changed

+147
-2
lines changed

src/main/java/org/openrewrite/java/testing/assertj/JUnitTryFailToAssertThatThrownBy.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ public J visitTry(J.Try tryBlock, ExecutionContext ctx) {
111111
return JavaTemplate.builder(template)
112112
.contextSensitive()
113113
.staticImports("org.assertj.core.api.Assertions.assertThatThrownBy")
114-
.javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3"))
114+
.javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "junit-jupiter-api-5", "assertj-core-3"))
115115
.build()
116116
.<J.MethodInvocation>apply(getCursor(), try_.getCoordinates().replace(), lambdaStatements.toArray());
117117
}

src/main/java/org/openrewrite/java/testing/junit5/ExpectedExceptionToAssertThrows.java

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ public TreeVisitor<?, ExecutionContext> getVisitor() {
6363
private static class ExpectedExceptionToAssertThrowsVisitor extends JavaIsoVisitor<ExecutionContext> {
6464

6565
private static final String FIRST_EXPECTED_EXCEPTION_METHOD_INVOCATION = "firstExpectedExceptionMethodInvocation";
66+
private static final String STATEMENTS_BEFORE_EXPECT_EXCEPTION = "statementsBeforeExpectException";
6667
private static final String STATEMENTS_AFTER_EXPECT_EXCEPTION = "statementsAfterExpectException";
6768
private static final String HAS_MATCHER = "hasMatcher";
6869
private static final String EXCEPTION_CLASS = "exceptionClass";
@@ -100,13 +101,73 @@ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, Ex
100101
if (getCursor().pollMessage("hasExpectException") != null) {
101102
List<NameTree> thrown = m.getThrows();
102103
if (thrown != null && !thrown.isEmpty()) {
104+
List<Statement> statementsBeforeExpect = getCursor().pollMessage(STATEMENTS_BEFORE_EXPECT_EXCEPTION);
105+
if (statementsBeforeExpectThrowCheckedException(statementsBeforeExpect)) {
106+
return m;
107+
}
103108
assert m.getBody() != null;
104109
return m.withBody(m.getBody().withPrefix(thrown.get(0).getPrefix())).withThrows(emptyList());
105110
}
106111
}
107112
return m;
108113
}
109114

115+
private boolean statementsBeforeExpectThrowCheckedException(List<Statement> statements) {
116+
return statements.stream().anyMatch(this::statementThrowsCheckedException);
117+
}
118+
119+
private boolean statementThrowsCheckedException(Statement statement) {
120+
AtomicBoolean throwsChecked = new AtomicBoolean(false);
121+
new JavaIsoVisitor<AtomicBoolean>() {
122+
@Override
123+
public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, AtomicBoolean found) {
124+
if (found.get()) {
125+
return method;
126+
}
127+
JavaType.Method methodType = method.getMethodType();
128+
if (methodType == null) {
129+
return super.visitMethodInvocation(method, found);
130+
}
131+
List<JavaType> thrownExceptions = methodType.getThrownExceptions();
132+
for (JavaType thrownException : thrownExceptions) {
133+
if (isCheckedException(thrownException)) {
134+
found.set(true);
135+
return method;
136+
}
137+
}
138+
return super.visitMethodInvocation(method, found);
139+
}
140+
141+
@Override
142+
public J.NewClass visitNewClass(J.NewClass newClass, AtomicBoolean found) {
143+
if (found.get()) {
144+
return newClass;
145+
}
146+
JavaType.Method constructorType = newClass.getConstructorType();
147+
if (constructorType == null) {
148+
return super.visitNewClass(newClass, found);
149+
}
150+
List<JavaType> thrownExceptions = constructorType.getThrownExceptions();
151+
for (JavaType thrownException : thrownExceptions) {
152+
if (isCheckedException(thrownException)) {
153+
found.set(true);
154+
return newClass;
155+
}
156+
}
157+
return super.visitNewClass(newClass, found);
158+
}
159+
}.visit(statement, throwsChecked);
160+
return throwsChecked.get();
161+
}
162+
163+
private boolean isCheckedException(JavaType exceptionType) {
164+
if (exceptionType == null) {
165+
return false;
166+
}
167+
return !TypeUtils.isAssignableTo("java.lang.RuntimeException", exceptionType) &&
168+
!TypeUtils.isAssignableTo("java.lang.Error", exceptionType);
169+
}
170+
110171
@Override
111172
public J.Block visitBlock(J.Block block, ExecutionContext ctx) {
112173
J.Block b = super.visitBlock(block, ctx);
@@ -175,7 +236,13 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
175236
return method;
176237
}
177238
getCursor().dropParentUntil(J.MethodDeclaration.class::isInstance).putMessage("hasExpectException", true);
178-
getCursor().dropParentUntil(J.Block.class::isInstance).computeMessageIfAbsent(FIRST_EXPECTED_EXCEPTION_METHOD_INVOCATION, k -> method);
239+
Cursor blockCursor = getCursor().dropParentUntil(J.Block.class::isInstance);
240+
blockCursor.computeMessageIfAbsent(FIRST_EXPECTED_EXCEPTION_METHOD_INVOCATION, k -> method);
241+
242+
List<Statement> predecessorStatements = findPredecessorStatements(getCursor());
243+
getCursor().dropParentUntil(J.MethodDeclaration.class::isInstance)
244+
.computeMessageIfAbsent(STATEMENTS_BEFORE_EXPECT_EXCEPTION, k -> predecessorStatements);
245+
179246
List<Statement> successorStatements = findSuccessorStatements(getCursor());
180247
getCursor().putMessageOnFirstEnclosing(J.Block.class, STATEMENTS_AFTER_EXPECT_EXCEPTION, successorStatements);
181248
if (EXPECTED_EXCEPTION_CLASS_MATCHER.matches(method)) {
@@ -186,6 +253,25 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
186253
return method;
187254
}
188255

256+
/**
257+
* From the current cursor point find all preceding statements in the method body.
258+
*/
259+
private List<Statement> findPredecessorStatements(Cursor cursor) {
260+
J.MethodDeclaration methodDecl = cursor.firstEnclosing(J.MethodDeclaration.class);
261+
if (methodDecl == null || methodDecl.getBody() == null) {
262+
return emptyList();
263+
}
264+
List<Statement> predecessorStatements = new ArrayList<>();
265+
Statement currentStatement = cursor.firstEnclosing(Statement.class);
266+
for (Statement statement : methodDecl.getBody().getStatements()) {
267+
if (statement == currentStatement) {
268+
break;
269+
}
270+
predecessorStatements.add(statement);
271+
}
272+
return predecessorStatements;
273+
}
274+
189275
/**
190276
* From the current cursor point find all the next statements that can be executed in the current path.
191277
*/

src/test/java/org/openrewrite/java/testing/junit5/ExpectedExceptionToAssertThrowsTest.java

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,65 @@ public void expectExceptionUseCases() {
424424
);
425425
}
426426

427+
@Issue("https://github.com/openrewrite/rewrite-testing-frameworks/issues/55")
428+
@Test
429+
void preserveThrowsWhenCodeBeforeExpectThrowsCheckedException() {
430+
//language=java
431+
rewriteRun(
432+
java(
433+
"""
434+
import org.junit.Rule;
435+
import org.junit.Test;
436+
import org.junit.rules.ExpectedException;
437+
438+
class MyTest {
439+
440+
@Rule
441+
ExpectedException thrown = ExpectedException.none();
442+
443+
@Test
444+
public void testMethod() throws InterruptedException {
445+
setup();
446+
this.thrown.expect(IllegalArgumentException.class);
447+
doSomething();
448+
}
449+
450+
void setup() throws InterruptedException {
451+
Thread.sleep(100);
452+
}
453+
454+
void doSomething() {
455+
throw new IllegalArgumentException();
456+
}
457+
}
458+
""",
459+
"""
460+
import org.junit.Test;
461+
462+
import static org.junit.jupiter.api.Assertions.assertThrows;
463+
464+
class MyTest {
465+
466+
@Test
467+
public void testMethod() throws InterruptedException {
468+
setup();
469+
assertThrows(IllegalArgumentException.class, () ->
470+
doSomething());
471+
}
472+
473+
void setup() throws InterruptedException {
474+
Thread.sleep(100);
475+
}
476+
477+
void doSomething() {
478+
throw new IllegalArgumentException();
479+
}
480+
}
481+
"""
482+
)
483+
);
484+
}
485+
427486
@Issue("https://github.com/openrewrite/rewrite-testing-frameworks/issues/563")
428487
@Test
429488
void expectedCheckedExceptionThrowsRemoved() {

0 commit comments

Comments
 (0)