Skip to content

Commit c62e629

Browse files
committed
Preserve else-if chains in UnwrapElseAfterReturn recipe
- Preserve `else if` structure when unwrapping else blocks after return/throw - Only unwrap the final `else` block in an else-if chain, not intermediate `else if` branches - Add `findInnermostIfWithElse()` helper to locate the last if in a chain - Add `removeInnermostElse()` helper to recursively remove only the innermost else - Update tests to verify else-if preservation while final else is still unwrapped
1 parent 6d3fb00 commit c62e629

File tree

2 files changed

+89
-31
lines changed

2 files changed

+89
-31
lines changed

src/main/java/org/openrewrite/staticanalysis/UnwrapElseAfterReturn.java

Lines changed: 81 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
package org.openrewrite.staticanalysis;
1717

1818
import lombok.Getter;
19+
import org.jspecify.annotations.Nullable;
1920
import org.openrewrite.ExecutionContext;
2021
import org.openrewrite.Recipe;
2122
import org.openrewrite.Repeat;
@@ -50,41 +51,103 @@ public TreeVisitor<?, ExecutionContext> getVisitor() {
5051
@Override
5152
public J.Block visitBlock(J.Block block, ExecutionContext ctx) {
5253
J.Block b = visitAndCast(block, ctx, super::visitBlock);
53-
AtomicReference<Space> endWhitespace = new AtomicReference<>(null);
54+
AtomicReference<@Nullable Space> endWhitespace = new AtomicReference<>(null);
5455
J.Block alteredBlock = b.withStatements(ListUtils.flatMap(b.getStatements(), statement -> {
5556
if (statement instanceof J.If) {
5657
J.If ifStatement = (J.If) statement;
5758
if (ifStatement.getElsePart() != null && endsWithReturnOrThrow(ifStatement.getThenPart())) {
58-
J.If newIf = ifStatement.withElsePart(null);
5959
Statement elsePart = ifStatement.getElsePart().getBody();
60-
if (elsePart instanceof J.Block) {
61-
J.Block elseBlock = (J.Block) elsePart;
62-
endWhitespace.set(elseBlock.getEnd());
63-
return ListUtils.concat(newIf, ListUtils.mapFirst(elseBlock.getStatements(), elseStmt -> {
64-
// Combine comments from the else block itself and the first statement
65-
List<Comment> elseComments = elseBlock.getPrefix().getComments();
66-
List<Comment> stmtComments = elseStmt.getPrefix().getComments();
67-
if (!elseComments.isEmpty() || !stmtComments.isEmpty()) {
68-
return elseStmt.withComments(ListUtils.concatAll(elseComments, stmtComments));
60+
if (elsePart instanceof J.If) {
61+
// Else-if chain: find and unwrap the innermost else
62+
J.If innermost = findInnermostIfWithElse((J.If) elsePart);
63+
if (innermost != null &&
64+
innermost.getElsePart() != null &&
65+
endsWithReturnOrThrow(innermost.getThenPart()) &&
66+
!(innermost.getElsePart().getBody() instanceof J.If)) {
67+
// Unwrap the innermost else
68+
J.If modifiedChain = removeInnermostElse(ifStatement);
69+
Statement innermostElseBody = innermost.getElsePart().getBody();
70+
if (innermostElseBody instanceof J.Block) {
71+
J.Block elseBlock = (J.Block) innermostElseBody;
72+
endWhitespace.set(elseBlock.getEnd());
73+
return ListUtils.concat(modifiedChain, ListUtils.mapFirst(elseBlock.getStatements(), elseStmt -> {
74+
List<Comment> elseComments = elseBlock.getPrefix().getComments();
75+
List<Comment> stmtComments = elseStmt.getPrefix().getComments();
76+
if (!elseComments.isEmpty() || !stmtComments.isEmpty()) {
77+
return elseStmt.withComments(ListUtils.concatAll(elseComments, stmtComments));
78+
}
79+
String whitespace = innermost.getElsePart().getPrefix().getWhitespace();
80+
return elseStmt.withPrefix(elseStmt.getPrefix().withWhitespace(whitespace));
81+
}));
6982
}
70-
String whitespace = ifStatement.getElsePart().getPrefix().getWhitespace();
71-
return elseStmt.withPrefix(elseStmt.getPrefix().withWhitespace(whitespace));
72-
}));
83+
return Arrays.asList(modifiedChain, innermostElseBody.<Statement>withPrefix(innermost.getElsePart().getPrefix()));
84+
}
85+
} else {
86+
// Plain else block: unwrap directly
87+
J.If newIf = ifStatement.withElsePart(null);
88+
if (elsePart instanceof J.Block) {
89+
J.Block elseBlock = (J.Block) elsePart;
90+
endWhitespace.set(elseBlock.getEnd());
91+
return ListUtils.concat(newIf, ListUtils.mapFirst(elseBlock.getStatements(), elseStmt -> {
92+
List<Comment> elseComments = elseBlock.getPrefix().getComments();
93+
List<Comment> stmtComments = elseStmt.getPrefix().getComments();
94+
if (!elseComments.isEmpty() || !stmtComments.isEmpty()) {
95+
return elseStmt.withComments(ListUtils.concatAll(elseComments, stmtComments));
96+
}
97+
String whitespace = ifStatement.getElsePart().getPrefix().getWhitespace();
98+
return elseStmt.withPrefix(elseStmt.getPrefix().withWhitespace(whitespace));
99+
}));
100+
}
101+
return Arrays.asList(newIf, elsePart.<Statement>withPrefix(ifStatement.getElsePart().getPrefix()));
73102
}
74-
return Arrays.asList(newIf, elsePart.<Statement>withPrefix(ifStatement.getElsePart().getPrefix()));
75103
}
76104
}
77105
return statement;
78106
}));
79107

80-
if (endWhitespace.get() != null) {
81-
List<Comment> mergedComments = ListUtils.concatAll(endWhitespace.get().getComments(), b.getEnd().getComments());
82-
alteredBlock = alteredBlock.withEnd(b.getEnd().withComments(mergedComments).withWhitespace(endWhitespace.get().getWhitespace()));
108+
Space end = endWhitespace.get();
109+
if (end != null) {
110+
List<Comment> mergedComments = ListUtils.concatAll(end.getComments(), b.getEnd().getComments());
111+
alteredBlock = alteredBlock.withEnd(b.getEnd().withComments(mergedComments).withWhitespace(end.getWhitespace()));
83112
}
84113

85114
return maybeAutoFormat(b, alteredBlock, ctx);
86115
}
87116

117+
private J.@Nullable If findInnermostIfWithElse(J.If ifStatement) {
118+
if (ifStatement.getElsePart() == null) {
119+
return null;
120+
}
121+
Statement elseBody = ifStatement.getElsePart().getBody();
122+
if (elseBody instanceof J.If) {
123+
J.If result = findInnermostIfWithElse((J.If) elseBody);
124+
return result != null ? result : ifStatement;
125+
}
126+
return ifStatement;
127+
}
128+
129+
private J.If removeInnermostElse(J.If ifStatement) {
130+
if (ifStatement.getElsePart() == null) {
131+
return ifStatement;
132+
}
133+
Statement elseBody = ifStatement.getElsePart().getBody();
134+
if (elseBody instanceof J.If) {
135+
J.If innerIf = (J.If) elseBody;
136+
if (innerIf.getElsePart() != null && !(innerIf.getElsePart().getBody() instanceof J.If)) {
137+
// This is the innermost if with a non-if else, remove its else
138+
return ifStatement.withElsePart(
139+
ifStatement.getElsePart().withBody(innerIf.withElsePart(null))
140+
);
141+
}
142+
// Recurse deeper into the chain
143+
return ifStatement.withElsePart(
144+
ifStatement.getElsePart().withBody(removeInnermostElse(innerIf))
145+
);
146+
}
147+
// Direct else (not else-if), remove it
148+
return ifStatement.withElsePart(null);
149+
}
150+
88151
private boolean endsWithReturnOrThrow(Statement statement) {
89152
if (statement instanceof J.Return || statement instanceof J.Throw) {
90153
return true;

src/test/java/org/openrewrite/staticanalysis/UnwrapElseAfterReturnTest.java

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ int foo(boolean condition) {
334334
}
335335

336336
@Test
337-
void chainedIfElseIfElse() {
337+
void preserveElseIfButUnwrapFinalElse() {
338338
rewriteRun(
339339
java(
340340
"""
@@ -357,11 +357,9 @@ class Test {
357357
int foo(String str) {
358358
if ("one".equals(str)) {
359359
return 1;
360-
}
361-
if ("two".equals(str)) {
360+
} else if ("two".equals(str)) {
362361
return 2;
363-
}
364-
if ("three".equals(str)) {
362+
} else if ("three".equals(str)) {
365363
return 3;
366364
}
367365
return Integer.MAX_VALUE;
@@ -373,7 +371,7 @@ int foo(String str) {
373371
}
374372

375373
@Test
376-
void chainedIfElseIfElseWithMissingReturn() {
374+
void preserveElseIfButUnwrapFinalElseWithMissingReturn() {
377375
rewriteRun(
378376
java(
379377
"""
@@ -396,14 +394,12 @@ class Test {
396394
int foo(String str) {
397395
if ("one".equals(str)) {
398396
return 1;
399-
}
400-
if ("two".equals(str)) {
397+
} else if ("two".equals(str)) {
401398
System.out.println("two");
402399
} else if ("three".equals(str)) {
403400
return 3;
404-
} else {
405-
return Integer.MAX_VALUE;
406401
}
402+
return Integer.MAX_VALUE;
407403
}
408404
}
409405
"""
@@ -498,7 +494,7 @@ void validateInput(String input) {
498494
}
499495

500496
@Test
501-
void mixedReturnAndThrow() {
497+
void preserveElseIfButUnwrapFinalElseWithMixedReturnAndThrow() {
502498
rewriteRun(
503499
java(
504500
"""
@@ -519,8 +515,7 @@ class Test {
519515
String process(int value) {
520516
if (value < 0) {
521517
throw new IllegalArgumentException("Negative value");
522-
}
523-
if (value == 0) {
518+
} else if (value == 0) {
524519
return "zero";
525520
}
526521
return "positive";

0 commit comments

Comments
 (0)