diff --git a/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java b/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java new file mode 100644 index 0000000000..fae34064a7 --- /dev/null +++ b/src/main/java/org/openrewrite/staticanalysis/PreferEarlyReturn.java @@ -0,0 +1,163 @@ +/* + * Copyright 2025 the original author or authors. + *

+ * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://docs.moderne.io/licensing/moderne-source-available-license + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.staticanalysis; + +import org.jspecify.annotations.NonNull; +import org.jspecify.annotations.Nullable; +import org.openrewrite.ExecutionContext; +import org.openrewrite.Recipe; +import org.openrewrite.TreeVisitor; +import org.openrewrite.java.JavaVisitor; +import org.openrewrite.java.tree.*; +import org.openrewrite.marker.Markers; + +import java.time.Duration; +import java.util.Collections; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.openrewrite.Tree.randomId; + +public class PreferEarlyReturn extends Recipe { + + @Override + public String getDisplayName() { + return "Prefer early returns"; + } + + @Override + public String getDescription() { + return "Refactors methods to use early returns for error/edge cases, reducing nesting and improving readability. " + + "The recipe heuristically identifies if-else statements where the if block contains the main logic and the " + + "else block contains a simple return. It then inverts the condition and moves the else block " + + "to the beginning of the method with an early return, allowing the main logic to be un-indented."; + } + + @Override + public Set getTags() { + return Collections.emptySet(); + } + + @Override + public Duration getEstimatedEffortPerOccurrence() { + return Duration.ofMinutes(2); + } + + @Override + public TreeVisitor getVisitor() { + return new PreferEarlyReturnVisitor(); + } + + private static class PreferEarlyReturnVisitor extends JavaVisitor { + + @Override + public @Nullable J postVisit(@NonNull J tree, ExecutionContext executionContext) { + J ret = super.postVisit(tree, executionContext); + if (getCursor().pollMessage("PREFER_EARLY_RETURN") != null) { + ret = (J) new UnwrapElseAfterReturn().getVisitor().visit(ret, executionContext, getCursor().getParent()); + } + return ret; + } + + @Override + public J visitIf(J.If ifStatement, ExecutionContext ctx) { + J.If if_ = (J.If) super.visitIf(ifStatement, ctx); + + if (!isEligibleForEarlyReturn(if_)) { + return if_; + } + + J.ControlParentheses invertedCondition = if_.getIfCondition().withTree(invertExpression(if_.getIfCondition().getTree())); + J.If newIf = if_.withIfCondition(invertedCondition) + .withThenPart(if_.getElsePart().getBody()) + .withElsePart(new J.If.Else( + randomId(), + if_.getElsePart().getPrefix(), + Markers.EMPTY, + JRightPadded.build(if_.getThenPart()) + )); + + newIf = maybeAutoFormat(if_, newIf, ctx); + getCursor().dropParentUntil(J.Block.class::isInstance).putMessage("PREFER_EARLY_RETURN", "unwrap"); + return newIf; + } + + private boolean isEligibleForEarlyReturn(J.If ifStatement) { + if (ifStatement.getElsePart() == null || + !(ifStatement.getThenPart() instanceof J.Block) || + !(ifStatement.getElsePart().getBody() instanceof J.Block)) { + return false; + } + + J.Block thenBlock = (J.Block) ifStatement.getThenPart(); + J.Block elseBlock = (J.Block) ifStatement.getElsePart().getBody(); + + int thenStatements = (thenBlock == null || thenBlock.getStatements() == null) ? 0 : thenBlock.getStatements().size(); + int elseStatements = (elseBlock == null || elseBlock.getStatements() == null) ? 0 : elseBlock.getStatements().size(); + + if (thenStatements < 5 || (thenStatements - elseStatements) < 2) { + // heuristics for determining if the then block is the "main flow" over the else block + return false; + } + + return hasReturnOrThrowStatement(elseBlock); + } + + private boolean hasReturnOrThrowStatement(J.Block block) { + if (block == null || block.getStatements() == null) { + return false; + } + + AtomicBoolean hasReturnOrThrow = new AtomicBoolean(false); + new JavaVisitor() { + @Override + public J visitReturn(J.Return return_, AtomicBoolean flag) { + flag.set(true); + return return_; + } + + @Override + public J visitThrow(J.Throw thrown, AtomicBoolean flag) { + flag.set(true); + return thrown; + } + }.visit(block, hasReturnOrThrow); + + return hasReturnOrThrow.get(); + } + + private Expression invertExpression(Expression expr) { + Expression toNegate = expr; + if (expr instanceof J.Binary) { + toNegate = new J.Parentheses<>( + randomId(), + expr.getPrefix(), + Markers.EMPTY, + JRightPadded.build(expr.withPrefix(Space.EMPTY)) + ); + } + + return new J.Unary( + randomId(), + toNegate.getPrefix(), + Markers.EMPTY, + new JLeftPadded<>(Space.EMPTY, J.Unary.Type.Not, Markers.EMPTY), + toNegate.withPrefix(Space.EMPTY), + JavaType.Primitive.Boolean + ); + } + } +} diff --git a/src/test/java/org/openrewrite/staticanalysis/PreferEarlyReturnTest.java b/src/test/java/org/openrewrite/staticanalysis/PreferEarlyReturnTest.java new file mode 100644 index 0000000000..72860024dc --- /dev/null +++ b/src/test/java/org/openrewrite/staticanalysis/PreferEarlyReturnTest.java @@ -0,0 +1,504 @@ +/* + * Copyright 2025 the original author or authors. + *

+ * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://docs.moderne.io/licensing/moderne-source-available-license + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.staticanalysis; + +import org.junit.jupiter.api.Test; +import org.openrewrite.DocumentExample; +import org.openrewrite.test.RecipeSpec; +import org.openrewrite.test.RewriteTest; + +import static org.openrewrite.java.Assertions.java; + +@SuppressWarnings("ConstantConditions") +class PreferEarlyReturnTest implements RewriteTest { + @Override + public void defaults(RecipeSpec spec) { + spec.recipe(new PreferEarlyReturn()); + } + + @DocumentExample + @Test + void basicIfElseWithEarlyReturn() { + rewriteRun( + //language=java + java( + """ + class Test { + void processOrder(Order order) { + if (order != null && order.isValid()) { + // Process the order + order.validate(); + order.calculateTax(); + order.applyDiscount(); + order.processPayment(); + order.sendConfirmation(); + } else { + logError("Invalid order"); + return; + } + } + + void logError(String message) {} + + interface Order { + boolean isValid(); + void validate(); + void calculateTax(); + void applyDiscount(); + void processPayment(); + void sendConfirmation(); + } + } + """, + """ + class Test { + void processOrder(Order order) { + if (!(order != null && order.isValid())) { + logError("Invalid order"); + return; + } + // Process the order + order.validate(); + order.calculateTax(); + order.applyDiscount(); + order.processPayment(); + order.sendConfirmation(); + } + + void logError(String message) {} + + interface Order { + boolean isValid(); + void validate(); + void calculateTax(); + void applyDiscount(); + void processPayment(); + void sendConfirmation(); + } + } + """ + ) + ); + } + + @Test + void multipleConditionsWithAndOperator() { + rewriteRun( + //language=java + java( + """ + class Test { + boolean processUser(User user) { + if (user != null && user.isActive() && !user.isSuspended()) { + // Main processing logic + user.updateLastLogin(); + user.incrementLoginCount(); + user.loadPreferences(); + user.initializeSession(); + user.logActivity(); + return true; + } else { + return false; + } + } + + interface User { + boolean isActive(); + boolean isSuspended(); + void updateLastLogin(); + void incrementLoginCount(); + void loadPreferences(); + void initializeSession(); + void logActivity(); + } + } + """, + """ + class Test { + boolean processUser(User user) { + if (!(user != null && user.isActive() && !user.isSuspended())) { + return false; + } + // Main processing logic + user.updateLastLogin(); + user.incrementLoginCount(); + user.loadPreferences(); + user.initializeSession(); + user.logActivity(); + return true; + } + + interface User { + boolean isActive(); + boolean isSuspended(); + void updateLastLogin(); + void incrementLoginCount(); + void loadPreferences(); + void initializeSession(); + void logActivity(); + } + } + """ + ) + ); + } + + @Test + void methodWithReturnValue() { + rewriteRun( + //language=java + java( + """ + class Test { + String processData(Data data) { + if (data != null && data.isValid()) { + // Process the data + String result = data.transform(); + result = result.trim(); + result = result.toUpperCase(); + data.log(result); + return result; + } else { + return null; + } + } + + interface Data { + boolean isValid(); + String transform(); + void log(String s); + } + } + """, + """ + class Test { + String processData(Data data) { + if (!(data != null && data.isValid())) { + return null; + } + // Process the data + String result = data.transform(); + result = result.trim(); + result = result.toUpperCase(); + data.log(result); + return result; + } + + interface Data { + boolean isValid(); + String transform(); + void log(String s); + } + } + """ + ) + ); + } + + @Test + void noChangeWhenIfBlockTooSmall() { + rewriteRun( + //language=java + java( + """ + class Test { + void processItem(Item item) { + if (item != null) { + // Too few statements (less than 5) + item.process(); + item.save(); + } else { + return; + } + } + + interface Item { + void process(); + void save(); + } + } + """ + ) + ); + } + + @Test + void noChangeWhenElseBlockTooLarge() { + rewriteRun( + //language=java + java( + """ + class Test { + void processRequest(Request request) { + if (request != null && request.isValid()) { + // Process the request + request.validate(); + request.authorize(); + request.execute(); + request.logSuccess(); + request.notifyClients(); + } else { + // Too many statements in else block (more than 2) + logError("Invalid request"); + notifyAdmin(); + incrementErrorCounter(); + return; + } + } + + void logError(String message) {} + void notifyAdmin() {} + void incrementErrorCounter() {} + + interface Request { + boolean isValid(); + void validate(); + void authorize(); + void execute(); + void logSuccess(); + void notifyClients(); + } + } + """ + ) + ); + } + + @Test + void noChangeWhenNoElseBlock() { + rewriteRun( + //language=java + java( + """ + class Test { + void processEvent(Event event) { + if (event != null && event.isActive()) { + // Process the event + event.handle(); + event.dispatch(); + event.complete(); + event.cleanup(); + event.logCompletion(); + } + // No else block, so no early return to add + } + + interface Event { + boolean isActive(); + void handle(); + void dispatch(); + void complete(); + void cleanup(); + void logCompletion(); + } + } + """ + ) + ); + } + + @Test + void preserveCommentsAndFormatting() { + rewriteRun( + //language=java + java( + """ + class Test { + void processPayment(Payment payment) { + // Check if payment is valid + if (payment != null && payment.isAuthorized()) { + // Process the payment + payment.validate(); // Validate payment details + payment.checkFraud(); // Check for fraud + payment.deductAmount(); // Deduct from account + payment.recordTransaction(); // Record in database + payment.sendReceipt(); // Send receipt to customer + } else { + // Payment is invalid + logError("Unauthorized payment"); + return; + } + } + + void logError(String message) {} + + interface Payment { + boolean isAuthorized(); + void validate(); + void checkFraud(); + void deductAmount(); + void recordTransaction(); + void sendReceipt(); + } + } + """, + """ + class Test { + void processPayment(Payment payment) { + // Check if payment is valid + if (!(payment != null && payment.isAuthorized())) { + // Payment is invalid + logError("Unauthorized payment"); + return; + } + // Process the payment + payment.validate(); // Validate payment details + payment.checkFraud(); // Check for fraud + payment.deductAmount(); // Deduct from account + payment.recordTransaction(); // Record in database + payment.sendReceipt(); // Send receipt to customer + } + + void logError(String message) {} + + interface Payment { + boolean isAuthorized(); + void validate(); + void checkFraud(); + void deductAmount(); + void recordTransaction(); + void sendReceipt(); + } + } + """ + ) + ); + } + + @Test + void complexConditionWithParentheses() { + rewriteRun( + //language=java + java( + """ + class Test { + void processTransaction(Transaction tx) { + if (tx != null && (tx.isValid() || tx.isPending()) && !tx.isExpired()) { + // Process transaction + tx.authorize(); + tx.validate(); + tx.execute(); + tx.commit(); + tx.notifyParties(); + } else { + return; + } + } + + interface Transaction { + boolean isValid(); + boolean isPending(); + boolean isExpired(); + void authorize(); + void validate(); + void execute(); + void commit(); + void notifyParties(); + } + } + """, + """ + class Test { + void processTransaction(Transaction tx) { + if (!(tx != null && (tx.isValid() || tx.isPending()) && !tx.isExpired())) { + return; + } + // Process transaction + tx.authorize(); + tx.validate(); + tx.execute(); + tx.commit(); + tx.notifyParties(); + } + + interface Transaction { + boolean isValid(); + boolean isPending(); + boolean isExpired(); + void authorize(); + void validate(); + void execute(); + void commit(); + void notifyParties(); + } + } + """ + ) + ); + } + + @Test + void methodThrowingExceptionInElseBlock() { + rewriteRun( + //language=java + java( + """ + class Test { + String validateAndProcess(Input input) { + if (input != null && input.isValid() && input.hasRequiredFields()) { + // Process the input + String normalized = input.normalize(); + String validated = input.validate(); + String transformed = input.transform(); + String encrypted = input.encrypt(); + String result = input.format(normalized, validated, transformed, encrypted); + return result; + } else { + throw new IllegalArgumentException("Invalid input"); + } + } + + interface Input { + boolean isValid(); + boolean hasRequiredFields(); + String normalize(); + String validate(); + String transform(); + String encrypt(); + String format(String... parts); + } + } + """, + """ + class Test { + String validateAndProcess(Input input) { + if (!(input != null && input.isValid() && input.hasRequiredFields())) { + throw new IllegalArgumentException("Invalid input"); + } + // Process the input + String normalized = input.normalize(); + String validated = input.validate(); + String transformed = input.transform(); + String encrypted = input.encrypt(); + String result = input.format(normalized, validated, transformed, encrypted); + return result; + } + + interface Input { + boolean isValid(); + boolean hasRequiredFields(); + String normalize(); + String validate(); + String transform(); + String encrypt(); + String format(String... parts); + } + } + """ + ) + ); + } +}