From eece5ed1b4c61e387cf1c633f8d5b37434966dc8 Mon Sep 17 00:00:00 2001 From: Greg Oledzki Date: Wed, 23 Jul 2025 09:27:55 +0200 Subject: [PATCH 1/9] PreferEarlyReturn stub --- .../staticanalysis/PreferEarlyReturn.java | 240 ++++++++++ .../staticanalysis/PreferEarlyReturnTest.java | 440 ++++++++++++++++++ 2 files changed, 680 insertions(+) create mode 100644 src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java create mode 100644 src/test/java/org/openrewrite/staticanalysis/PreferEarlyReturnTest.java diff --git a/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java b/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java new file mode 100644 index 0000000000..698249c8b6 --- /dev/null +++ b/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java @@ -0,0 +1,240 @@ +/* + * Copyright 2025 the original author or authors. + *

+ * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://docs.moderne.io/licensing/moderne-source-available-license + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.staticanalysis; + +import org.openrewrite.ExecutionContext; +import org.openrewrite.Recipe; +import org.openrewrite.Tree; +import org.openrewrite.TreeVisitor; +import org.openrewrite.internal.ListUtils; +import org.openrewrite.java.JavaVisitor; +import org.openrewrite.java.tree.Expression; +import org.openrewrite.java.tree.J; +import org.openrewrite.java.tree.JLeftPadded; +import org.openrewrite.java.tree.JavaType; +import org.openrewrite.java.tree.Space; +import org.openrewrite.java.tree.Statement; +import org.openrewrite.marker.Markers; + +import java.time.Duration; +import java.util.Collections; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.openrewrite.Tree.randomId; + +public class PreferEarlyReturn extends Recipe { + + @Override + public String getDisplayName() { + return "Prefer early returns"; + } + + @Override + public String getDescription() { + return "Refactors methods to use early returns for error/edge cases, reducing nesting and improving readability. " + + "The recipe identifies if-else statements where the if block contains the main logic (≥5 statements) and the " + + "else block contains a simple return (≤2 statements). It then inverts the condition and moves the else block " + + "to the beginning of the method with an early return, allowing the main logic to be un-indented."; + } + + @Override + public Set getTags() { + return Collections.emptySet(); + } + + @Override + public Duration getEstimatedEffortPerOccurrence() { + return Duration.ofMinutes(2); + } + + @Override + public TreeVisitor getVisitor() { + return new PreferEarlyReturnVisitor(); + } + + private static class PreferEarlyReturnVisitor extends JavaVisitor { + + @Override + public J visitIf(J.If ifStatement, ExecutionContext ctx) { + J.If if_ = (J.If) super.visitIf(ifStatement, ctx); + + // TODO: Implement the logic to: + // 1. Check if this if-else statement is eligible for early return refactoring + // 2. Count statements in if block (must be >= 5) + // 3. Count statements in else block (must be <= 2) + // 4. Check that else block contains a return statement + // 5. Invert the condition (handle De Morgan's laws) + // 6. Move else block content to before the if statement + // 7. Unwrap the if block content + // 8. Remove the now-empty if-else structure + + return if_; + } + + private boolean isEligibleForEarlyReturn(J.If ifStatement) { + // Must have an else block + if (ifStatement.getElsePart() == null) { + return false; + } + + // The then part must be a block + if (!(ifStatement.getThenPart() instanceof J.Block)) { + return false; + } + + // The else part must be a block (not another if) + if (!(ifStatement.getElsePart().getBody() instanceof J.Block)) { + return false; + } + + J.Block thenBlock = (J.Block) ifStatement.getThenPart(); + J.Block elseBlock = (J.Block) ifStatement.getElsePart().getBody(); + + // Count statements + int thenStatements = countStatements(thenBlock); + int elseStatements = countStatements(elseBlock); + + // Check heuristics: then block >= 5 statements, else block <= 2 statements + if (thenStatements < 5 || elseStatements > 2) { + return false; + } + + // Else block must contain a return statement + return hasReturnStatement(elseBlock); + } + + private int countStatements(J.Block block) { + if (block == null || block.getStatements() == null) { + return 0; + } + + AtomicInteger count = new AtomicInteger(0); + new JavaVisitor() { + @Override + public J visitBlock(J.Block block, AtomicInteger counter) { + // Don't visit nested blocks + return block; + } + + @Override + public J visitStatement(Statement statement, AtomicInteger counter) { + // Count each statement, but don't count block statements themselves + if (!(statement instanceof J.Block)) { + counter.incrementAndGet(); + } + return super.visitStatement(statement, counter); + } + }.visit(block, count); + + return count.get(); + } + + private boolean hasReturnStatement(J.Block block) { + if (block == null || block.getStatements() == null) { + return false; + } + + AtomicBoolean hasReturn = new AtomicBoolean(false); + new JavaVisitor() { + @Override + public J visitReturn(J.Return return_, AtomicBoolean hasReturnFlag) { + hasReturnFlag.set(true); + return return_; + } + }.visit(block, hasReturn); + + return hasReturn.get(); + } + + private J.ControlParentheses invertCondition(J.ControlParentheses condition) { + if (condition == null || !(condition.getTree() instanceof Expression)) { + return condition; + } + + Expression expr = (Expression) condition.getTree(); + Expression inverted = invertExpression(expr); + + return condition.withTree(inverted); + } + + private Expression invertExpression(Expression expr) { + if (expr instanceof J.Binary) { + J.Binary binary = (J.Binary) expr; + + // Handle AND/OR with De Morgan's laws + if (binary.getOperator() == J.Binary.Type.And) { + // A && B becomes !A || !B + Expression leftInverted = invertExpression(binary.getLeft()); + Expression rightInverted = invertExpression(binary.getRight()); + return binary.withOperator(J.Binary.Type.Or) + .withLeft(leftInverted) + .withRight(rightInverted.withPrefix(Space.SINGLE_SPACE)); + } else if (binary.getOperator() == J.Binary.Type.Or) { + // A || B becomes !A && !B + Expression leftInverted = invertExpression(binary.getLeft()); + Expression rightInverted = invertExpression(binary.getRight()); + return binary.withOperator(J.Binary.Type.And) + .withLeft(leftInverted) + .withRight(rightInverted.withPrefix(Space.SINGLE_SPACE)); + } else if (binary.getOperator() == J.Binary.Type.Equal) { + // == becomes != + return binary.withOperator(J.Binary.Type.NotEqual); + } else if (binary.getOperator() == J.Binary.Type.NotEqual) { + // != becomes == + return binary.withOperator(J.Binary.Type.Equal); + } else if (binary.getOperator() == J.Binary.Type.LessThan) { + // < becomes >= + return binary.withOperator(J.Binary.Type.GreaterThanOrEqual); + } else if (binary.getOperator() == J.Binary.Type.LessThanOrEqual) { + // <= becomes > + return binary.withOperator(J.Binary.Type.GreaterThan); + } else if (binary.getOperator() == J.Binary.Type.GreaterThan) { + // > becomes <= + return binary.withOperator(J.Binary.Type.LessThanOrEqual); + } else if (binary.getOperator() == J.Binary.Type.GreaterThanOrEqual) { + // >= becomes < + return binary.withOperator(J.Binary.Type.LessThan); + } + } else if (expr instanceof J.Unary) { + J.Unary unary = (J.Unary) expr; + if (unary.getOperator() == J.Unary.Type.Not) { + // Double negation: !!expr becomes expr + return unary.getExpression(); + } + } else if (expr instanceof J.Parentheses) { + @SuppressWarnings("unchecked") + J.Parentheses parens = (J.Parentheses) expr; + // Invert the expression inside parentheses + if (parens.getTree() instanceof Expression) { + Expression innerInverted = invertExpression(parens.getTree()); + return parens.withTree(innerInverted); + } + } + + // For all other expressions, add a NOT operator + return new J.Unary( + randomId(), + expr.getPrefix(), + Markers.EMPTY, + new JLeftPadded<>(Space.EMPTY, J.Unary.Type.Not, Markers.EMPTY), + expr.withPrefix(Space.EMPTY), + JavaType.Primitive.Boolean + ); + } + } +} \ No newline at end of file diff --git a/src/test/java/org/openrewrite/staticanalysis/PreferEarlyReturnTest.java b/src/test/java/org/openrewrite/staticanalysis/PreferEarlyReturnTest.java new file mode 100644 index 0000000000..8aaee4b726 --- /dev/null +++ b/src/test/java/org/openrewrite/staticanalysis/PreferEarlyReturnTest.java @@ -0,0 +1,440 @@ +/* + * Copyright 2025 the original author or authors. + *

+ * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://docs.moderne.io/licensing/moderne-source-available-license + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.staticanalysis; + +import org.junit.jupiter.api.Test; +import org.openrewrite.DocumentExample; +import org.openrewrite.test.RecipeSpec; +import org.openrewrite.test.RewriteTest; + +import static org.openrewrite.java.Assertions.java; + +@SuppressWarnings("ConstantConditions") +class PreferEarlyReturnTest implements RewriteTest { + @Override + public void defaults(RecipeSpec spec) { + spec.recipe(new PreferEarlyReturn()); + } + + @DocumentExample + @Test + void basicIfElseWithEarlyReturn() { + rewriteRun( + //language=java + java( + """ + class Test { + void processOrder(Order order) { + if (order != null && order.isValid()) { + // Process the order + order.validate(); + order.calculateTax(); + order.applyDiscount(); + order.processPayment(); + order.sendConfirmation(); + } else { + logError("Invalid order"); + return; + } + } + + void logError(String message) {} + + class Order { + boolean isValid() { return true; } + void validate() {} + void calculateTax() {} + void applyDiscount() {} + void processPayment() {} + void sendConfirmation() {} + } + } + """, + """ + class Test { + void processOrder(Order order) { + if (order == null || !order.isValid()) { + logError("Invalid order"); + return; + } + // Process the order + order.validate(); + order.calculateTax(); + order.applyDiscount(); + order.processPayment(); + order.sendConfirmation(); + } + + void logError(String message) {} + + class Order { + boolean isValid() { return true; } + void validate() {} + void calculateTax() {} + void applyDiscount() {} + void processPayment() {} + void sendConfirmation() {} + } + } + """ + ) + ); + } + + @Test + void multipleConditionsWithAndOperator() { + rewriteRun( + //language=java + java( + """ + class Test { + void processUser(User user) { + if (user != null && user.isActive() && !user.isSuspended()) { + // Main processing logic + user.updateLastLogin(); + user.incrementLoginCount(); + user.loadPreferences(); + user.initializeSession(); + user.logActivity(); + } else { + return; + } + } + + class User { + boolean isActive() { return true; } + boolean isSuspended() { return false; } + void updateLastLogin() {} + void incrementLoginCount() {} + void loadPreferences() {} + void initializeSession() {} + void logActivity() {} + } + } + """, + """ + class Test { + void processUser(User user) { + if (user == null || !user.isActive() || user.isSuspended()) { + return; + } + // Main processing logic + user.updateLastLogin(); + user.incrementLoginCount(); + user.loadPreferences(); + user.initializeSession(); + user.logActivity(); + } + + class User { + boolean isActive() { return true; } + boolean isSuspended() { return false; } + void updateLastLogin() {} + void incrementLoginCount() {} + void loadPreferences() {} + void initializeSession() {} + void logActivity() {} + } + } + """ + ) + ); + } + + @Test + void methodWithReturnValue() { + rewriteRun( + //language=java + java( + """ + class Test { + String processData(Data data) { + if (data != null && data.isValid()) { + // Process the data + String result = data.transform(); + result = result.trim(); + result = result.toUpperCase(); + data.log(result); + return result; + } else { + return null; + } + } + + class Data { + boolean isValid() { return true; } + String transform() { return "test"; } + void log(String s) {} + } + } + """, + """ + class Test { + String processData(Data data) { + if (data == null || !data.isValid()) { + return null; + } + // Process the data + String result = data.transform(); + result = result.trim(); + result = result.toUpperCase(); + data.log(result); + return result; + } + + class Data { + boolean isValid() { return true; } + String transform() { return "test"; } + void log(String s) {} + } + } + """ + ) + ); + } + + @Test + void noChangeWhenIfBlockTooSmall() { + rewriteRun( + //language=java + java( + """ + class Test { + void processItem(Item item) { + if (item != null) { + // Too few statements (less than 5) + item.process(); + item.save(); + } else { + return; + } + } + + class Item { + void process() {} + void save() {} + } + } + """ + ) + ); + } + + @Test + void noChangeWhenElseBlockTooLarge() { + rewriteRun( + //language=java + java( + """ + class Test { + void processRequest(Request request) { + if (request != null && request.isValid()) { + // Process the request + request.validate(); + request.authorize(); + request.execute(); + request.logSuccess(); + request.notifyClients(); + } else { + // Too many statements in else block (more than 2) + logError("Invalid request"); + notifyAdmin(); + incrementErrorCounter(); + return; + } + } + + void logError(String message) {} + void notifyAdmin() {} + void incrementErrorCounter() {} + + class Request { + boolean isValid() { return true; } + void validate() {} + void authorize() {} + void execute() {} + void logSuccess() {} + void notifyClients() {} + } + } + """ + ) + ); + } + + @Test + void noChangeWhenNoElseBlock() { + rewriteRun( + //language=java + java( + """ + class Test { + void processEvent(Event event) { + if (event != null && event.isActive()) { + // Process the event + event.handle(); + event.dispatch(); + event.complete(); + event.cleanup(); + event.logCompletion(); + } + // No else block, so no early return to add + } + + class Event { + boolean isActive() { return true; } + void handle() {} + void dispatch() {} + void complete() {} + void cleanup() {} + void logCompletion() {} + } + } + """ + ) + ); + } + + @Test + void preserveCommentsAndFormatting() { + rewriteRun( + //language=java + java( + """ + class Test { + void processPayment(Payment payment) { + // Check if payment is valid + if (payment != null && payment.isAuthorized()) { + // Process the payment + payment.validate(); // Validate payment details + payment.checkFraud(); // Check for fraud + payment.deductAmount(); // Deduct from account + payment.recordTransaction(); // Record in database + payment.sendReceipt(); // Send receipt to customer + } else { + // Payment is invalid + logError("Unauthorized payment"); + return; + } + } + + void logError(String message) {} + + class Payment { + boolean isAuthorized() { return true; } + void validate() {} + void checkFraud() {} + void deductAmount() {} + void recordTransaction() {} + void sendReceipt() {} + } + } + """, + """ + class Test { + void processPayment(Payment payment) { + // Check if payment is valid + if (payment == null || !payment.isAuthorized()) { + // Payment is invalid + logError("Unauthorized payment"); + return; + } + // Process the payment + payment.validate(); // Validate payment details + payment.checkFraud(); // Check for fraud + payment.deductAmount(); // Deduct from account + payment.recordTransaction(); // Record in database + payment.sendReceipt(); // Send receipt to customer + } + + void logError(String message) {} + + class Payment { + boolean isAuthorized() { return true; } + void validate() {} + void checkFraud() {} + void deductAmount() {} + void recordTransaction() {} + void sendReceipt() {} + } + } + """ + ) + ); + } + + @Test + void complexConditionWithParentheses() { + rewriteRun( + //language=java + java( + """ + class Test { + void processTransaction(Transaction tx) { + if (tx != null && (tx.isValid() || tx.isPending()) && !tx.isExpired()) { + // Process transaction + tx.authorize(); + tx.validate(); + tx.execute(); + tx.commit(); + tx.notifyParties(); + } else { + return; + } + } + + class Transaction { + boolean isValid() { return true; } + boolean isPending() { return false; } + boolean isExpired() { return false; } + void authorize() {} + void validate() {} + void execute() {} + void commit() {} + void notifyParties() {} + } + } + """, + """ + class Test { + void processTransaction(Transaction tx) { + if (tx == null || (!tx.isValid() && !tx.isPending()) || tx.isExpired()) { + return; + } + // Process transaction + tx.authorize(); + tx.validate(); + tx.execute(); + tx.commit(); + tx.notifyParties(); + } + + class Transaction { + boolean isValid() { return true; } + boolean isPending() { return false; } + boolean isExpired() { return false; } + void authorize() {} + void validate() {} + void execute() {} + void commit() {} + void notifyParties() {} + } + } + """ + ) + ); + } +} \ No newline at end of file From c378cde782a78b7ec9e54343453024b721175613 Mon Sep 17 00:00:00 2001 From: Greg Oledzki Date: Wed, 23 Jul 2025 09:30:21 +0200 Subject: [PATCH 2/9] Simple impl --- .../staticanalysis/PreferEarlyReturn.java | 34 +++++++++++++------ 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java b/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java index 698249c8b6..2d39b88a5e 100644 --- a/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java +++ b/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java @@ -24,6 +24,7 @@ import org.openrewrite.java.tree.Expression; import org.openrewrite.java.tree.J; import org.openrewrite.java.tree.JLeftPadded; +import org.openrewrite.java.tree.JRightPadded; import org.openrewrite.java.tree.JavaType; import org.openrewrite.java.tree.Space; import org.openrewrite.java.tree.Statement; @@ -73,17 +74,28 @@ private static class PreferEarlyReturnVisitor extends JavaVisitor= 5) - // 3. Count statements in else block (must be <= 2) - // 4. Check that else block contains a return statement - // 5. Invert the condition (handle De Morgan's laws) - // 6. Move else block content to before the if statement - // 7. Unwrap the if block content - // 8. Remove the now-empty if-else structure - - return if_; + // Check if this if-else statement is eligible for early return refactoring + if (!isEligibleForEarlyReturn(if_)) { + return if_; + } + + // Invert the condition + J.ControlParentheses invertedCondition = invertCondition(if_.getIfCondition()); + + // Create a new if statement with the inverted condition and the else block content + J.If newIf = if_.withIfCondition(invertedCondition) + .withThenPart(if_.getElsePart().getBody()) + .withElsePart(new J.If.Else( + randomId(), + if_.getElsePart().getPrefix(), + Markers.EMPTY, + JRightPadded.build(if_.getThenPart()) + )); + + // Apply UnwrapElseAfterReturn to handle the unwrapping + newIf = (J.If) new UnwrapElseAfterReturn().getVisitor().visit(newIf, ctx); + + return newIf; } private boolean isEligibleForEarlyReturn(J.If ifStatement) { From cbf9f3a10c2c9402acb7a7465e856fe17cadf8a4 Mon Sep 17 00:00:00 2001 From: Greg Oledzki Date: Wed, 23 Jul 2025 09:58:30 +0200 Subject: [PATCH 3/9] Better count of statements --- .../staticanalysis/PreferEarlyReturn.java | 25 +++---------------- 1 file changed, 4 insertions(+), 21 deletions(-) diff --git a/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java b/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java index 2d39b88a5e..42ba1be9a0 100644 --- a/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java +++ b/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java @@ -92,8 +92,8 @@ public J visitIf(J.If ifStatement, ExecutionContext ctx) { JRightPadded.build(if_.getThenPart()) )); - // Apply UnwrapElseAfterReturn to handle the unwrapping - newIf = (J.If) new UnwrapElseAfterReturn().getVisitor().visit(newIf, ctx); + // Mark that we need to apply UnwrapElseAfterReturn in a second pass + doAfterVisit(new UnwrapElseAfterReturn().getVisitor()); return newIf; } @@ -135,25 +135,8 @@ private int countStatements(J.Block block) { return 0; } - AtomicInteger count = new AtomicInteger(0); - new JavaVisitor() { - @Override - public J visitBlock(J.Block block, AtomicInteger counter) { - // Don't visit nested blocks - return block; - } - - @Override - public J visitStatement(Statement statement, AtomicInteger counter) { - // Count each statement, but don't count block statements themselves - if (!(statement instanceof J.Block)) { - counter.incrementAndGet(); - } - return super.visitStatement(statement, counter); - } - }.visit(block, count); - - return count.get(); + // Simply count the direct statements in the block + return block.getStatements().size(); } private boolean hasReturnStatement(J.Block block) { From 47f42b4a31ef341f34f158a86a33bc9598342273 Mon Sep 17 00:00:00 2001 From: Greg Oledzki Date: Wed, 23 Jul 2025 10:08:59 +0200 Subject: [PATCH 4/9] Testing for non-void methods too --- .../staticanalysis/PreferEarlyReturn.java | 22 ++-- .../staticanalysis/PreferEarlyReturnTest.java | 110 ++++++++++++++---- 2 files changed, 101 insertions(+), 31 deletions(-) diff --git a/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java b/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java index 42ba1be9a0..6d48cdee68 100644 --- a/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java +++ b/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java @@ -126,8 +126,8 @@ private boolean isEligibleForEarlyReturn(J.If ifStatement) { return false; } - // Else block must contain a return statement - return hasReturnStatement(elseBlock); + // Else block must contain a return or throw statement + return hasReturnOrThrowStatement(elseBlock); } private int countStatements(J.Block block) { @@ -139,21 +139,27 @@ private int countStatements(J.Block block) { return block.getStatements().size(); } - private boolean hasReturnStatement(J.Block block) { + private boolean hasReturnOrThrowStatement(J.Block block) { if (block == null || block.getStatements() == null) { return false; } - AtomicBoolean hasReturn = new AtomicBoolean(false); + AtomicBoolean hasReturnOrThrow = new AtomicBoolean(false); new JavaVisitor() { @Override - public J visitReturn(J.Return return_, AtomicBoolean hasReturnFlag) { - hasReturnFlag.set(true); + public J visitReturn(J.Return return_, AtomicBoolean flag) { + flag.set(true); return return_; } - }.visit(block, hasReturn); + + @Override + public J visitThrow(J.Throw thrown, AtomicBoolean flag) { + flag.set(true); + return thrown; + } + }.visit(block, hasReturnOrThrow); - return hasReturn.get(); + return hasReturnOrThrow.get(); } private J.ControlParentheses invertCondition(J.ControlParentheses condition) { diff --git a/src/test/java/org/openrewrite/staticanalysis/PreferEarlyReturnTest.java b/src/test/java/org/openrewrite/staticanalysis/PreferEarlyReturnTest.java index 8aaee4b726..00aa7446ae 100644 --- a/src/test/java/org/openrewrite/staticanalysis/PreferEarlyReturnTest.java +++ b/src/test/java/org/openrewrite/staticanalysis/PreferEarlyReturnTest.java @@ -50,9 +50,9 @@ void processOrder(Order order) { return; } } - + void logError(String message) {} - + class Order { boolean isValid() { return true; } void validate() {} @@ -77,9 +77,9 @@ void processOrder(Order order) { order.processPayment(); order.sendConfirmation(); } - + void logError(String message) {} - + class Order { boolean isValid() { return true; } void validate() {} @@ -101,7 +101,7 @@ void multipleConditionsWithAndOperator() { java( """ class Test { - void processUser(User user) { + boolean processUser(User user) { if (user != null && user.isActive() && !user.isSuspended()) { // Main processing logic user.updateLastLogin(); @@ -109,11 +109,12 @@ void processUser(User user) { user.loadPreferences(); user.initializeSession(); user.logActivity(); + return true; } else { - return; + return false; } } - + class User { boolean isActive() { return true; } boolean isSuspended() { return false; } @@ -127,9 +128,9 @@ void logActivity() {} """, """ class Test { - void processUser(User user) { + boolean processUser(User user) { if (user == null || !user.isActive() || user.isSuspended()) { - return; + return false; } // Main processing logic user.updateLastLogin(); @@ -137,8 +138,9 @@ void processUser(User user) { user.loadPreferences(); user.initializeSession(); user.logActivity(); + return true; } - + class User { boolean isActive() { return true; } boolean isSuspended() { return false; } @@ -173,7 +175,7 @@ String processData(Data data) { return null; } } - + class Data { boolean isValid() { return true; } String transform() { return "test"; } @@ -194,7 +196,7 @@ String processData(Data data) { data.log(result); return result; } - + class Data { boolean isValid() { return true; } String transform() { return "test"; } @@ -222,7 +224,7 @@ void processItem(Item item) { return; } } - + class Item { void process() {} void save() {} @@ -256,11 +258,11 @@ void processRequest(Request request) { return; } } - + void logError(String message) {} void notifyAdmin() {} void incrementErrorCounter() {} - + class Request { boolean isValid() { return true; } void validate() {} @@ -293,7 +295,7 @@ void processEvent(Event event) { } // No else block, so no early return to add } - + class Event { boolean isActive() { return true; } void handle() {} @@ -330,9 +332,9 @@ void processPayment(Payment payment) { return; } } - + void logError(String message) {} - + class Payment { boolean isAuthorized() { return true; } void validate() {} @@ -359,9 +361,9 @@ void processPayment(Payment payment) { payment.recordTransaction(); // Record in database payment.sendReceipt(); // Send receipt to customer } - + void logError(String message) {} - + class Payment { boolean isAuthorized() { return true; } void validate() {} @@ -395,7 +397,7 @@ void processTransaction(Transaction tx) { return; } } - + class Transaction { boolean isValid() { return true; } boolean isPending() { return false; } @@ -421,7 +423,7 @@ void processTransaction(Transaction tx) { tx.commit(); tx.notifyParties(); } - + class Transaction { boolean isValid() { return true; } boolean isPending() { return false; } @@ -437,4 +439,66 @@ void notifyParties() {} ) ); } -} \ No newline at end of file + + @Test + void methodThrowingExceptionInElseBlock() { + rewriteRun( + //language=java + java( + """ + class Test { + String validateAndProcess(Input input) { + if (input != null && input.isValid() && input.hasRequiredFields()) { + // Process the input + String normalized = input.normalize(); + String validated = input.validate(); + String transformed = input.transform(); + String encrypted = input.encrypt(); + String result = input.format(normalized, validated, transformed, encrypted); + return result; + } else { + throw new IllegalArgumentException("Invalid input"); + } + } + + class Input { + boolean isValid() { return true; } + boolean hasRequiredFields() { return true; } + String normalize() { return "normalized"; } + String validate() { return "validated"; } + String transform() { return "transformed"; } + String encrypt() { return "encrypted"; } + String format(String... parts) { return String.join("-", parts); } + } + } + """, + """ + class Test { + String validateAndProcess(Input input) { + if (input == null || !input.isValid() || !input.hasRequiredFields()) { + throw new IllegalArgumentException("Invalid input"); + } + // Process the input + String normalized = input.normalize(); + String validated = input.validate(); + String transformed = input.transform(); + String encrypted = input.encrypt(); + String result = input.format(normalized, validated, transformed, encrypted); + return result; + } + + class Input { + boolean isValid() { return true; } + boolean hasRequiredFields() { return true; } + String normalize() { return "normalized"; } + String validate() { return "validated"; } + String transform() { return "transformed"; } + String encrypt() { return "encrypted"; } + String format(String... parts) { return String.join("-", parts); } + } + } + """ + ) + ); + } +} From 76406dfe343b7a19a6787a1b9bea7a24ff3aa494 Mon Sep 17 00:00:00 2001 From: Greg Oledzki Date: Wed, 23 Jul 2025 15:14:40 +0200 Subject: [PATCH 5/9] Tests to use interfaces, not classes --- .../staticanalysis/PreferEarlyReturnTest.java | 206 +++++++++--------- 1 file changed, 103 insertions(+), 103 deletions(-) diff --git a/src/test/java/org/openrewrite/staticanalysis/PreferEarlyReturnTest.java b/src/test/java/org/openrewrite/staticanalysis/PreferEarlyReturnTest.java index 00aa7446ae..0962525a87 100644 --- a/src/test/java/org/openrewrite/staticanalysis/PreferEarlyReturnTest.java +++ b/src/test/java/org/openrewrite/staticanalysis/PreferEarlyReturnTest.java @@ -53,13 +53,13 @@ void processOrder(Order order) { void logError(String message) {} - class Order { - boolean isValid() { return true; } - void validate() {} - void calculateTax() {} - void applyDiscount() {} - void processPayment() {} - void sendConfirmation() {} + interface Order { + boolean isValid(); + void validate(); + void calculateTax(); + void applyDiscount(); + void processPayment(); + void sendConfirmation(); } } """, @@ -80,13 +80,13 @@ void processOrder(Order order) { void logError(String message) {} - class Order { - boolean isValid() { return true; } - void validate() {} - void calculateTax() {} - void applyDiscount() {} - void processPayment() {} - void sendConfirmation() {} + interface Order { + boolean isValid(); + void validate(); + void calculateTax(); + void applyDiscount(); + void processPayment(); + void sendConfirmation(); } } """ @@ -115,14 +115,14 @@ boolean processUser(User user) { } } - class User { - boolean isActive() { return true; } - boolean isSuspended() { return false; } - void updateLastLogin() {} - void incrementLoginCount() {} - void loadPreferences() {} - void initializeSession() {} - void logActivity() {} + interface User { + boolean isActive(); + boolean isSuspended(); + void updateLastLogin(); + void incrementLoginCount(); + void loadPreferences(); + void initializeSession(); + void logActivity(); } } """, @@ -141,14 +141,14 @@ boolean processUser(User user) { return true; } - class User { - boolean isActive() { return true; } - boolean isSuspended() { return false; } - void updateLastLogin() {} - void incrementLoginCount() {} - void loadPreferences() {} - void initializeSession() {} - void logActivity() {} + interface User { + boolean isActive(); + boolean isSuspended(); + void updateLastLogin(); + void incrementLoginCount(); + void loadPreferences(); + void initializeSession(); + void logActivity(); } } """ @@ -176,10 +176,10 @@ String processData(Data data) { } } - class Data { - boolean isValid() { return true; } - String transform() { return "test"; } - void log(String s) {} + interface Data { + boolean isValid(); + String transform(); + void log(String s); } } """, @@ -197,10 +197,10 @@ String processData(Data data) { return result; } - class Data { - boolean isValid() { return true; } - String transform() { return "test"; } - void log(String s) {} + interface Data { + boolean isValid(); + String transform(); + void log(String s); } } """ @@ -225,9 +225,9 @@ void processItem(Item item) { } } - class Item { - void process() {} - void save() {} + interface Item { + void process(); + void save(); } } """ @@ -263,13 +263,13 @@ void logError(String message) {} void notifyAdmin() {} void incrementErrorCounter() {} - class Request { - boolean isValid() { return true; } - void validate() {} - void authorize() {} - void execute() {} - void logSuccess() {} - void notifyClients() {} + interface Request { + boolean isValid(); + void validate(); + void authorize(); + void execute(); + void logSuccess(); + void notifyClients(); } } """ @@ -296,13 +296,13 @@ void processEvent(Event event) { // No else block, so no early return to add } - class Event { - boolean isActive() { return true; } - void handle() {} - void dispatch() {} - void complete() {} - void cleanup() {} - void logCompletion() {} + interface Event { + boolean isActive(); + void handle(); + void dispatch(); + void complete(); + void cleanup(); + void logCompletion(); } } """ @@ -335,13 +335,13 @@ void processPayment(Payment payment) { void logError(String message) {} - class Payment { - boolean isAuthorized() { return true; } - void validate() {} - void checkFraud() {} - void deductAmount() {} - void recordTransaction() {} - void sendReceipt() {} + interface Payment { + boolean isAuthorized(); + void validate(); + void checkFraud(); + void deductAmount(); + void recordTransaction(); + void sendReceipt(); } } """, @@ -364,13 +364,13 @@ void processPayment(Payment payment) { void logError(String message) {} - class Payment { - boolean isAuthorized() { return true; } - void validate() {} - void checkFraud() {} - void deductAmount() {} - void recordTransaction() {} - void sendReceipt() {} + interface Payment { + boolean isAuthorized(); + void validate(); + void checkFraud(); + void deductAmount(); + void recordTransaction(); + void sendReceipt(); } } """ @@ -398,15 +398,15 @@ void processTransaction(Transaction tx) { } } - class Transaction { - boolean isValid() { return true; } - boolean isPending() { return false; } - boolean isExpired() { return false; } - void authorize() {} - void validate() {} - void execute() {} - void commit() {} - void notifyParties() {} + interface Transaction { + boolean isValid(); + boolean isPending(); + boolean isExpired(); + void authorize(); + void validate(); + void execute(); + void commit(); + void notifyParties(); } } """, @@ -424,15 +424,15 @@ void processTransaction(Transaction tx) { tx.notifyParties(); } - class Transaction { - boolean isValid() { return true; } - boolean isPending() { return false; } - boolean isExpired() { return false; } - void authorize() {} - void validate() {} - void execute() {} - void commit() {} - void notifyParties() {} + interface Transaction { + boolean isValid(); + boolean isPending(); + boolean isExpired(); + void authorize(); + void validate(); + void execute(); + void commit(); + void notifyParties(); } } """ @@ -461,14 +461,14 @@ String validateAndProcess(Input input) { } } - class Input { - boolean isValid() { return true; } - boolean hasRequiredFields() { return true; } - String normalize() { return "normalized"; } - String validate() { return "validated"; } - String transform() { return "transformed"; } - String encrypt() { return "encrypted"; } - String format(String... parts) { return String.join("-", parts); } + interface Input { + boolean isValid(); + boolean hasRequiredFields(); + String normalize(); + String validate(); + String transform(); + String encrypt(); + String format(String... parts); } } """, @@ -487,14 +487,14 @@ String validateAndProcess(Input input) { return result; } - class Input { - boolean isValid() { return true; } - boolean hasRequiredFields() { return true; } - String normalize() { return "normalized"; } - String validate() { return "validated"; } - String transform() { return "transformed"; } - String encrypt() { return "encrypted"; } - String format(String... parts) { return String.join("-", parts); } + interface Input { + boolean isValid(); + boolean hasRequiredFields(); + String normalize(); + String validate(); + String transform(); + String encrypt(); + String format(String... parts); } } """ From 3fd529b165427bee1974214add403168985f44bc Mon Sep 17 00:00:00 2001 From: Greg Oledzki Date: Wed, 23 Jul 2025 15:21:57 +0200 Subject: [PATCH 6/9] Remove comments --- .../staticanalysis/PreferEarlyReturn.java | 24 ------------------- 1 file changed, 24 deletions(-) diff --git a/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java b/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java index 6d48cdee68..822762b488 100644 --- a/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java +++ b/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java @@ -74,15 +74,11 @@ private static class PreferEarlyReturnVisitor extends JavaVisitor= 5 statements, else block <= 2 statements if (thenStatements < 5 || elseStatements > 2) { return false; } - // Else block must contain a return or throw statement return hasReturnOrThrowStatement(elseBlock); } @@ -135,7 +124,6 @@ private int countStatements(J.Block block) { return 0; } - // Simply count the direct statements in the block return block.getStatements().size(); } @@ -177,57 +165,45 @@ private Expression invertExpression(Expression expr) { if (expr instanceof J.Binary) { J.Binary binary = (J.Binary) expr; - // Handle AND/OR with De Morgan's laws if (binary.getOperator() == J.Binary.Type.And) { - // A && B becomes !A || !B Expression leftInverted = invertExpression(binary.getLeft()); Expression rightInverted = invertExpression(binary.getRight()); return binary.withOperator(J.Binary.Type.Or) .withLeft(leftInverted) .withRight(rightInverted.withPrefix(Space.SINGLE_SPACE)); } else if (binary.getOperator() == J.Binary.Type.Or) { - // A || B becomes !A && !B Expression leftInverted = invertExpression(binary.getLeft()); Expression rightInverted = invertExpression(binary.getRight()); return binary.withOperator(J.Binary.Type.And) .withLeft(leftInverted) .withRight(rightInverted.withPrefix(Space.SINGLE_SPACE)); } else if (binary.getOperator() == J.Binary.Type.Equal) { - // == becomes != return binary.withOperator(J.Binary.Type.NotEqual); } else if (binary.getOperator() == J.Binary.Type.NotEqual) { - // != becomes == return binary.withOperator(J.Binary.Type.Equal); } else if (binary.getOperator() == J.Binary.Type.LessThan) { - // < becomes >= return binary.withOperator(J.Binary.Type.GreaterThanOrEqual); } else if (binary.getOperator() == J.Binary.Type.LessThanOrEqual) { - // <= becomes > return binary.withOperator(J.Binary.Type.GreaterThan); } else if (binary.getOperator() == J.Binary.Type.GreaterThan) { - // > becomes <= return binary.withOperator(J.Binary.Type.LessThanOrEqual); } else if (binary.getOperator() == J.Binary.Type.GreaterThanOrEqual) { - // >= becomes < return binary.withOperator(J.Binary.Type.LessThan); } } else if (expr instanceof J.Unary) { J.Unary unary = (J.Unary) expr; if (unary.getOperator() == J.Unary.Type.Not) { - // Double negation: !!expr becomes expr return unary.getExpression(); } } else if (expr instanceof J.Parentheses) { @SuppressWarnings("unchecked") J.Parentheses parens = (J.Parentheses) expr; - // Invert the expression inside parentheses if (parens.getTree() instanceof Expression) { Expression innerInverted = invertExpression(parens.getTree()); return parens.withTree(innerInverted); } } - // For all other expressions, add a NOT operator return new J.Unary( randomId(), expr.getPrefix(), From 00a00eeee80377d648036934b494bae97ec2b02c Mon Sep 17 00:00:00 2001 From: Greg Oledzki Date: Wed, 23 Jul 2025 15:28:58 +0200 Subject: [PATCH 7/9] Simplify by removing special handling of negations --- .../staticanalysis/PreferEarlyReturn.java | 101 ++++++------------ .../staticanalysis/PreferEarlyReturnTest.java | 12 +-- 2 files changed, 40 insertions(+), 73 deletions(-) diff --git a/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java b/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java index 822762b488..c4a1137145 100644 --- a/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java +++ b/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java @@ -69,15 +69,15 @@ public TreeVisitor getVisitor() { } private static class PreferEarlyReturnVisitor extends JavaVisitor { - + @Override public J visitIf(J.If ifStatement, ExecutionContext ctx) { J.If if_ = (J.If) super.visitIf(ifStatement, ctx); - + if (!isEligibleForEarlyReturn(if_)) { return if_; } - + J.ControlParentheses invertedCondition = invertCondition(if_.getIfCondition()); J.If newIf = if_.withIfCondition(invertedCondition) .withThenPart(if_.getElsePart().getBody()) @@ -87,51 +87,51 @@ public J visitIf(J.If ifStatement, ExecutionContext ctx) { Markers.EMPTY, JRightPadded.build(if_.getThenPart()) )); - + doAfterVisit(new UnwrapElseAfterReturn().getVisitor()); - + return newIf; } - + private boolean isEligibleForEarlyReturn(J.If ifStatement) { if (ifStatement.getElsePart() == null) { return false; } - + if (!(ifStatement.getThenPart() instanceof J.Block)) { return false; } - + if (!(ifStatement.getElsePart().getBody() instanceof J.Block)) { return false; } - + J.Block thenBlock = (J.Block) ifStatement.getThenPart(); J.Block elseBlock = (J.Block) ifStatement.getElsePart().getBody(); - + int thenStatements = countStatements(thenBlock); int elseStatements = countStatements(elseBlock); - + if (thenStatements < 5 || elseStatements > 2) { return false; } - + return hasReturnOrThrowStatement(elseBlock); } - + private int countStatements(J.Block block) { if (block == null || block.getStatements() == null) { return 0; } - + return block.getStatements().size(); } - + private boolean hasReturnOrThrowStatement(J.Block block) { if (block == null || block.getStatements() == null) { return false; } - + AtomicBoolean hasReturnOrThrow = new AtomicBoolean(false); new JavaVisitor() { @Override @@ -139,79 +139,46 @@ public J visitReturn(J.Return return_, AtomicBoolean flag) { flag.set(true); return return_; } - + @Override public J visitThrow(J.Throw thrown, AtomicBoolean flag) { flag.set(true); return thrown; } }.visit(block, hasReturnOrThrow); - + return hasReturnOrThrow.get(); } - + private J.ControlParentheses invertCondition(J.ControlParentheses condition) { if (condition == null || !(condition.getTree() instanceof Expression)) { return condition; } - - Expression expr = (Expression) condition.getTree(); - Expression inverted = invertExpression(expr); - + + Expression inverted = invertExpression((Expression) condition.getTree()); + return condition.withTree(inverted); } - + private Expression invertExpression(Expression expr) { + Expression toNegate = expr; if (expr instanceof J.Binary) { - J.Binary binary = (J.Binary) expr; - - if (binary.getOperator() == J.Binary.Type.And) { - Expression leftInverted = invertExpression(binary.getLeft()); - Expression rightInverted = invertExpression(binary.getRight()); - return binary.withOperator(J.Binary.Type.Or) - .withLeft(leftInverted) - .withRight(rightInverted.withPrefix(Space.SINGLE_SPACE)); - } else if (binary.getOperator() == J.Binary.Type.Or) { - Expression leftInverted = invertExpression(binary.getLeft()); - Expression rightInverted = invertExpression(binary.getRight()); - return binary.withOperator(J.Binary.Type.And) - .withLeft(leftInverted) - .withRight(rightInverted.withPrefix(Space.SINGLE_SPACE)); - } else if (binary.getOperator() == J.Binary.Type.Equal) { - return binary.withOperator(J.Binary.Type.NotEqual); - } else if (binary.getOperator() == J.Binary.Type.NotEqual) { - return binary.withOperator(J.Binary.Type.Equal); - } else if (binary.getOperator() == J.Binary.Type.LessThan) { - return binary.withOperator(J.Binary.Type.GreaterThanOrEqual); - } else if (binary.getOperator() == J.Binary.Type.LessThanOrEqual) { - return binary.withOperator(J.Binary.Type.GreaterThan); - } else if (binary.getOperator() == J.Binary.Type.GreaterThan) { - return binary.withOperator(J.Binary.Type.LessThanOrEqual); - } else if (binary.getOperator() == J.Binary.Type.GreaterThanOrEqual) { - return binary.withOperator(J.Binary.Type.LessThan); - } - } else if (expr instanceof J.Unary) { - J.Unary unary = (J.Unary) expr; - if (unary.getOperator() == J.Unary.Type.Not) { - return unary.getExpression(); - } - } else if (expr instanceof J.Parentheses) { - @SuppressWarnings("unchecked") - J.Parentheses parens = (J.Parentheses) expr; - if (parens.getTree() instanceof Expression) { - Expression innerInverted = invertExpression(parens.getTree()); - return parens.withTree(innerInverted); - } + toNegate = new J.Parentheses<>( + randomId(), + expr.getPrefix(), + Markers.EMPTY, + JRightPadded.build(expr.withPrefix(Space.EMPTY)) + ); } - + return new J.Unary( randomId(), - expr.getPrefix(), + toNegate.getPrefix(), Markers.EMPTY, new JLeftPadded<>(Space.EMPTY, J.Unary.Type.Not, Markers.EMPTY), - expr.withPrefix(Space.EMPTY), + toNegate.withPrefix(Space.EMPTY), JavaType.Primitive.Boolean ); } } -} \ No newline at end of file +} diff --git a/src/test/java/org/openrewrite/staticanalysis/PreferEarlyReturnTest.java b/src/test/java/org/openrewrite/staticanalysis/PreferEarlyReturnTest.java index 0962525a87..72860024dc 100644 --- a/src/test/java/org/openrewrite/staticanalysis/PreferEarlyReturnTest.java +++ b/src/test/java/org/openrewrite/staticanalysis/PreferEarlyReturnTest.java @@ -66,7 +66,7 @@ interface Order { """ class Test { void processOrder(Order order) { - if (order == null || !order.isValid()) { + if (!(order != null && order.isValid())) { logError("Invalid order"); return; } @@ -129,7 +129,7 @@ interface User { """ class Test { boolean processUser(User user) { - if (user == null || !user.isActive() || user.isSuspended()) { + if (!(user != null && user.isActive() && !user.isSuspended())) { return false; } // Main processing logic @@ -186,7 +186,7 @@ interface Data { """ class Test { String processData(Data data) { - if (data == null || !data.isValid()) { + if (!(data != null && data.isValid())) { return null; } // Process the data @@ -349,7 +349,7 @@ interface Payment { class Test { void processPayment(Payment payment) { // Check if payment is valid - if (payment == null || !payment.isAuthorized()) { + if (!(payment != null && payment.isAuthorized())) { // Payment is invalid logError("Unauthorized payment"); return; @@ -413,7 +413,7 @@ interface Transaction { """ class Test { void processTransaction(Transaction tx) { - if (tx == null || (!tx.isValid() && !tx.isPending()) || tx.isExpired()) { + if (!(tx != null && (tx.isValid() || tx.isPending()) && !tx.isExpired())) { return; } // Process transaction @@ -475,7 +475,7 @@ interface Input { """ class Test { String validateAndProcess(Input input) { - if (input == null || !input.isValid() || !input.hasRequiredFields()) { + if (!(input != null && input.isValid() && input.hasRequiredFields())) { throw new IllegalArgumentException("Invalid input"); } // Process the input From 28fdc550bbbd8c4205acc087512ade7cb50624e6 Mon Sep 17 00:00:00 2001 From: Greg Oledzki Date: Wed, 23 Jul 2025 15:38:14 +0200 Subject: [PATCH 8/9] Simplify code by inlining some methods --- .../staticanalysis/PreferEarlyReturn.java | 54 ++++--------------- 1 file changed, 11 insertions(+), 43 deletions(-) diff --git a/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java b/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java index c4a1137145..20121d05eb 100644 --- a/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java +++ b/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java @@ -17,24 +17,15 @@ import org.openrewrite.ExecutionContext; import org.openrewrite.Recipe; -import org.openrewrite.Tree; import org.openrewrite.TreeVisitor; -import org.openrewrite.internal.ListUtils; import org.openrewrite.java.JavaVisitor; -import org.openrewrite.java.tree.Expression; -import org.openrewrite.java.tree.J; -import org.openrewrite.java.tree.JLeftPadded; -import org.openrewrite.java.tree.JRightPadded; -import org.openrewrite.java.tree.JavaType; -import org.openrewrite.java.tree.Space; -import org.openrewrite.java.tree.Statement; +import org.openrewrite.java.tree.*; import org.openrewrite.marker.Markers; import java.time.Duration; import java.util.Collections; import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; import static org.openrewrite.Tree.randomId; @@ -48,8 +39,8 @@ public String getDisplayName() { @Override public String getDescription() { return "Refactors methods to use early returns for error/edge cases, reducing nesting and improving readability. " + - "The recipe identifies if-else statements where the if block contains the main logic (≥5 statements) and the " + - "else block contains a simple return (≤2 statements). It then inverts the condition and moves the else block " + + "The recipe heuristically identifies if-else statements where the if block contains the main logic and the " + + "else block contains a simple return. It then inverts the condition and moves the else block " + "to the beginning of the method with an early return, allowing the main logic to be un-indented."; } @@ -78,7 +69,7 @@ public J visitIf(J.If ifStatement, ExecutionContext ctx) { return if_; } - J.ControlParentheses invertedCondition = invertCondition(if_.getIfCondition()); + J.ControlParentheses invertedCondition = if_.getIfCondition().withTree(invertExpression(if_.getIfCondition().getTree())); J.If newIf = if_.withIfCondition(invertedCondition) .withThenPart(if_.getElsePart().getBody()) .withElsePart(new J.If.Else( @@ -94,39 +85,26 @@ public J visitIf(J.If ifStatement, ExecutionContext ctx) { } private boolean isEligibleForEarlyReturn(J.If ifStatement) { - if (ifStatement.getElsePart() == null) { - return false; - } - - if (!(ifStatement.getThenPart() instanceof J.Block)) { - return false; - } - - if (!(ifStatement.getElsePart().getBody() instanceof J.Block)) { + if (ifStatement.getElsePart() == null || + !(ifStatement.getThenPart() instanceof J.Block) || + !(ifStatement.getElsePart().getBody() instanceof J.Block)) { return false; } J.Block thenBlock = (J.Block) ifStatement.getThenPart(); J.Block elseBlock = (J.Block) ifStatement.getElsePart().getBody(); - int thenStatements = countStatements(thenBlock); - int elseStatements = countStatements(elseBlock); + int thenStatements = (thenBlock == null || thenBlock.getStatements() == null) ? 0 : thenBlock.getStatements().size(); + int elseStatements = (elseBlock == null || elseBlock.getStatements() == null) ? 0 : elseBlock.getStatements().size(); - if (thenStatements < 5 || elseStatements > 2) { + if (thenStatements < 5 || (thenStatements - elseStatements) < 2) { + // heuristics for determining if the then block is the "main flow" over the else block return false; } return hasReturnOrThrowStatement(elseBlock); } - private int countStatements(J.Block block) { - if (block == null || block.getStatements() == null) { - return 0; - } - - return block.getStatements().size(); - } - private boolean hasReturnOrThrowStatement(J.Block block) { if (block == null || block.getStatements() == null) { return false; @@ -150,16 +128,6 @@ public J visitThrow(J.Throw thrown, AtomicBoolean flag) { return hasReturnOrThrow.get(); } - private J.ControlParentheses invertCondition(J.ControlParentheses condition) { - if (condition == null || !(condition.getTree() instanceof Expression)) { - return condition; - } - - Expression inverted = invertExpression((Expression) condition.getTree()); - - return condition.withTree(inverted); - } - private Expression invertExpression(Expression expr) { Expression toNegate = expr; if (expr instanceof J.Binary) { From 31818ec5c3f5e5c9e51850fd8391502e036b0b79 Mon Sep 17 00:00:00 2001 From: Greg Oledzki Date: Thu, 24 Jul 2025 12:38:11 +0200 Subject: [PATCH 9/9] Autoformat + UnwrapElseAfterReturn in postVisit --- .../staticanalysis/PreferEarlyReturn.java | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java b/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java index 20121d05eb..fae34064a7 100644 --- a/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java +++ b/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java @@ -15,6 +15,8 @@ */ package org.openrewrite.staticanalysis; +import org.jspecify.annotations.NonNull; +import org.jspecify.annotations.Nullable; import org.openrewrite.ExecutionContext; import org.openrewrite.Recipe; import org.openrewrite.TreeVisitor; @@ -61,6 +63,15 @@ public TreeVisitor getVisitor() { private static class PreferEarlyReturnVisitor extends JavaVisitor { + @Override + public @Nullable J postVisit(@NonNull J tree, ExecutionContext executionContext) { + J ret = super.postVisit(tree, executionContext); + if (getCursor().pollMessage("PREFER_EARLY_RETURN") != null) { + ret = (J) new UnwrapElseAfterReturn().getVisitor().visit(ret, executionContext, getCursor().getParent()); + } + return ret; + } + @Override public J visitIf(J.If ifStatement, ExecutionContext ctx) { J.If if_ = (J.If) super.visitIf(ifStatement, ctx); @@ -79,8 +90,8 @@ public J visitIf(J.If ifStatement, ExecutionContext ctx) { JRightPadded.build(if_.getThenPart()) )); - doAfterVisit(new UnwrapElseAfterReturn().getVisitor()); - + newIf = maybeAutoFormat(if_, newIf, ctx); + getCursor().dropParentUntil(J.Block.class::isInstance).putMessage("PREFER_EARLY_RETURN", "unwrap"); return newIf; }