diff --git a/src/main/java/org/openrewrite/java/migrate/lang/SwitchCaseReturnsToSwitchExpression.java b/src/main/java/org/openrewrite/java/migrate/lang/SwitchCaseReturnsToSwitchExpression.java new file mode 100644 index 0000000000..cbff93f8f3 --- /dev/null +++ b/src/main/java/org/openrewrite/java/migrate/lang/SwitchCaseReturnsToSwitchExpression.java @@ -0,0 +1,209 @@ +/* + * Copyright 2025 the original author or authors. + *

+ * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://docs.moderne.io/licensing/moderne-source-available-license + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.java.migrate.lang; + +import lombok.EqualsAndHashCode; +import lombok.Value; +import org.jspecify.annotations.Nullable; +import org.openrewrite.ExecutionContext; +import org.openrewrite.Preconditions; +import org.openrewrite.Recipe; +import org.openrewrite.TreeVisitor; +import org.openrewrite.internal.ListUtils; +import org.openrewrite.java.JavaIsoVisitor; +import org.openrewrite.java.search.UsesJavaVersion; +import org.openrewrite.java.tree.*; +import org.openrewrite.marker.Markers; +import org.openrewrite.staticanalysis.groovy.GroovyFileChecker; +import org.openrewrite.staticanalysis.kotlin.KotlinFileChecker; + +import java.util.List; + +import static org.openrewrite.Tree.randomId; + +@Value +@EqualsAndHashCode(callSuper = false) +public class SwitchCaseReturnsToSwitchExpression extends Recipe { + @Override + public String getDisplayName() { + return "Convert switch cases where every case returns into a returned switch expression"; + } + + @Override + public String getDescription() { + return "Switch statements where each case returns a value can be converted to a switch expression that returns the value directly."; + } + + @Override + public TreeVisitor getVisitor() { + TreeVisitor preconditions = Preconditions.and( + new UsesJavaVersion<>(14), + Preconditions.not(new KotlinFileChecker<>()), + Preconditions.not(new GroovyFileChecker<>()) + ); + return Preconditions.check(preconditions, new JavaIsoVisitor() { + @Override + public J.Block visitBlock(J.Block block, ExecutionContext ctx) { + J.Block b = super.visitBlock(block, ctx); + return b.withStatements(ListUtils.map(b.getStatements(), statement -> { + if (statement instanceof J.Switch) { + J.Switch sw = (J.Switch) statement; + if (canConvertToSwitchExpression(sw)) { + J.SwitchExpression switchExpression = convertToSwitchExpression(sw); + return new J.Return( + randomId(), + sw.getPrefix(), + Markers.EMPTY, + switchExpression + ); + } + } + return statement; + })); + } + + private boolean canConvertToSwitchExpression(J.Switch switchStatement) { + for (Statement statement : switchStatement.getCases().getStatements()) { + if (!(statement instanceof J.Case)) { + return false; + } + + J.Case caseStatement = (J.Case) statement; + if (caseStatement.getBody() != null) { + // Arrow case + J body = caseStatement.getBody(); + if (body instanceof J.Block) { + if (!isReturnCase(((J.Block) body).getStatements())) { + return false; + } + } else if (!(body instanceof J.Return)) { + return false; + } + } else { + // Colon case + if (!isReturnCase(caseStatement.getStatements())) { + return false; + } + } + } + + // We need either a default case or the switch to cover all possible values + return SwitchUtils.coversAllPossibleValues(switchStatement); + } + + private boolean isReturnCase(List statements) { + if (statements.size() != 1) { + return false; + } + // Handle block containing a single return + if (statements.get(0) instanceof J.Block) { + return isReturnCase(((J.Block) statements.get(0)).getStatements()); + } + // Direct return statement + return statements.get(0) instanceof J.Return; + } + + private J.SwitchExpression convertToSwitchExpression(J.Switch switchStatement) { + JavaType returnType = extractReturnType(switchStatement); + + List convertedCases = ListUtils.map(switchStatement.getCases().getStatements(), statement -> { + J.Case caseStatement = (J.Case) statement; + if (caseStatement.getBody() != null) { + // Arrow case + J body = caseStatement.getBody(); + if (body instanceof J.Block && ((J.Block) body).getStatements().size() == 1) { + body = ((J.Block) body).getStatements().get(0); + } + if (body instanceof J.Return) { + J.Return ret = (J.Return) body; + if (ret.getExpression() != null) { + return caseStatement.withBody(ret.getExpression()); + } + } + } else { + // Colon case - convert to arrow case + Expression returnExpression = extractReturnExpression(caseStatement.getStatements()); + if (returnExpression != null) { + // When converting from colon to arrow syntax, we need to ensure proper spacing + JContainer caseLabels = caseStatement.getPadding().getCaseLabels(); + JContainer updatedLabels = caseLabels.getPadding().withElements( + ListUtils.mapLast(caseLabels.getPadding().getElements(), + elem -> elem.withAfter(Space.SINGLE_SPACE))); + return caseStatement + .withStatements(null) + .withBody(returnExpression.withPrefix(Space.SINGLE_SPACE)) + .withType(J.Case.Type.Rule) + .getPadding() + .withCaseLabels(updatedLabels); + } + } + return caseStatement; + }); + return new J.SwitchExpression( + randomId(), + Space.SINGLE_SPACE, + Markers.EMPTY, + switchStatement.getSelector(), + switchStatement.getCases().withStatements(convertedCases), + returnType + ); + } + + private @Nullable JavaType extractReturnType(J.Switch switchStatement) { + for (Statement statement : switchStatement.getCases().getStatements()) { + J.Case caseStatement = (J.Case) statement; + if (caseStatement.getBody() != null) { + J body = caseStatement.getBody(); + if (body instanceof J.Block && ((J.Block) body).getStatements().size() == 1) { + body = ((J.Block) body).getStatements().get(0); + } + if (body instanceof J.Return) { + J.Return ret = (J.Return) body; + if (ret.getExpression() != null && ret.getExpression().getType() != null) { + return ret.getExpression().getType(); + } + } + } else { + Expression returnExpression = extractReturnExpression(caseStatement.getStatements()); + if (returnExpression != null && returnExpression.getType() != null) { + return returnExpression.getType(); + } + } + } + return null; + } + + + private @Nullable Expression extractReturnExpression(List statements) { + if (statements.size() != 1) { + return null; + } + // Handle block containing a single return + if (statements.get(0) instanceof J.Block) { + J.Block block = (J.Block) statements.get(0); + if (block.getStatements().size() == 1 && block.getStatements().get(0) instanceof J.Return) { + return ((J.Return) block.getStatements().get(0)).getExpression(); + } + } + // Direct return statement + if (statements.get(0) instanceof J.Return) { + return ((J.Return) statements.get(0)).getExpression(); + } + return null; + } + }); + } +} diff --git a/src/main/resources/META-INF/rewrite/java-version-17.yml b/src/main/resources/META-INF/rewrite/java-version-17.yml index 770a980df2..a7630276c6 100644 --- a/src/main/resources/META-INF/rewrite/java-version-17.yml +++ b/src/main/resources/META-INF/rewrite/java-version-17.yml @@ -60,6 +60,7 @@ recipeList: newVersion: 1.17.x - org.openrewrite.java.migrate.AddLombokMapstructBinding - org.openrewrite.java.migrate.lang.SwitchCaseAssignmentsToSwitchExpression + - org.openrewrite.java.migrate.lang.SwitchCaseReturnsToSwitchExpression - org.openrewrite.java.migrate.lang.SwitchExpressionYieldToArrow --- diff --git a/src/test/java/org/openrewrite/java/migrate/lang/SwitchCaseReturnsToSwitchExpressionTest.java b/src/test/java/org/openrewrite/java/migrate/lang/SwitchCaseReturnsToSwitchExpressionTest.java new file mode 100644 index 0000000000..d644db1a15 --- /dev/null +++ b/src/test/java/org/openrewrite/java/migrate/lang/SwitchCaseReturnsToSwitchExpressionTest.java @@ -0,0 +1,304 @@ +/* + * Copyright 2025 the original author or authors. + *

+ * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://docs.moderne.io/licensing/moderne-source-available-license + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.java.migrate.lang; + +import org.junit.jupiter.api.Test; +import org.openrewrite.DocumentExample; +import org.openrewrite.Issue; +import org.openrewrite.test.RecipeSpec; +import org.openrewrite.test.RewriteTest; + +import static org.openrewrite.java.Assertions.java; +import static org.openrewrite.java.Assertions.javaVersion; + +class SwitchCaseReturnsToSwitchExpressionTest implements RewriteTest { + + @Override + public void defaults(RecipeSpec spec) { + spec.recipe(new SwitchCaseReturnsToSwitchExpression()) + .allSources(s -> s.markers(javaVersion(17))); + } + + @DocumentExample + @Issue("https://github.com/openrewrite/rewrite-migrate-java/issues/800") + @Test + void convertSimpleSwitchWithReturns() { + rewriteRun( + //language=java + java( + """ + class Test { + String doFormat(String str) { + switch (str) { + case "foo": return "Foo"; + case "bar": return "Bar"; + case null, default: return "Other"; + } + } + } + """, + """ + class Test { + String doFormat(String str) { + return switch (str) { + case "foo" -> "Foo"; + case "bar" -> "Bar"; + case null, default -> "Other"; + }; + } + } + """ + ) + ); + } + + @Test + void convertSimpleSwitchWithReturnsAfterOtherStatements() { + rewriteRun( + //language=java + java( + """ + class Test { + String doFormat(String str) { + System.out.println("Formatting: " + str); + switch (str) { + case "foo": return "Foo"; + case "bar": return "Bar"; + case null, default: return "Other"; + } + } + } + """, + """ + class Test { + String doFormat(String str) { + System.out.println("Formatting: " + str); + return switch (str) { + case "foo" -> "Foo"; + case "bar" -> "Bar"; + case null, default -> "Other"; + }; + } + } + """ + ) + ); + } + + @Test + void convertSwitchWithColonCases() { + rewriteRun( + //language=java + java( + """ + class Test { + int getValue(String str) { + switch (str) { + case "one": + return 1; + case "two": + return 2; + default: + return 0; + } + } + } + """, + """ + class Test { + int getValue(String str) { + return switch (str) { + case "one" -> 1; + case "two" -> 2; + default -> 0; + }; + } + } + """ + ) + ); + } + + @Test + void convertSwitchWithBlocksContainingReturns() { + rewriteRun( + //language=java + java( + """ + class Test { + String process(int value) { + switch (value) { + case 1: { + return "One"; + } + case 2: { + return "Two"; + } + default: { + return "Many"; + } + } + } + } + """, + """ + class Test { + String process(int value) { + return switch (value) { + case 1 -> "One"; + case 2 -> "Two"; + default -> "Many"; + }; + } + } + """ + ) + ); + } + + @Test + void convertSwitchWithArrowCases() { + rewriteRun( + //language=java + java( + """ + class Test { + String format(String str) { + switch (str) { + case "foo" -> { return "Foo"; } + case "bar" -> { return "Bar"; } + default -> { return "Other"; } + } + } + } + """, + """ + class Test { + String format(String str) { + return switch (str) { + case "foo" -> "Foo"; + case "bar" -> "Bar"; + default -> "Other"; + }; + } + } + """ + ) + ); + } + + @Test + void convertEnumSwitchThatIsExhaustive() { + rewriteRun( + //language=java + java( + """ + class Test { + enum Color { RED, GREEN, BLUE } + + String colorName(Color color) { + switch (color) { + case RED: return "Red"; + case GREEN: return "Green"; + case BLUE: return "Blue"; + } + } + } + """, + """ + class Test { + enum Color { RED, GREEN, BLUE } + + String colorName(Color color) { + return switch (color) { + case RED -> "Red"; + case GREEN -> "Green"; + case BLUE -> "Blue"; + }; + } + } + """ + ) + ); + } + + @Test + void doNotConvertWhenNotAllCasesReturn() { + rewriteRun( + //language=java + java( + """ + class Test { + String process(String str) { + switch (str) { + case "foo": + return "Foo"; + case "bar": + System.out.println("Bar case"); + break; + default: + return "Other"; + } + return "End"; + } + } + """ + ) + ); + } + + @Test + void doNotConvertWhenNoDefaultAndNotExhaustive() { + rewriteRun( + //language=java + java( + """ + class Test { + String format(String str) { + switch (str) { + case "foo": return "Foo"; + case "bar": return "Bar"; + } + return "Not found"; + } + } + """ + ) + ); + } + + @Test + void doNotConvertIfNotOnlyStatementInBlock() { + rewriteRun( + //language=java + java( + """ + class Test { + String process(String str) { + switch (str) { + case "foo": + System.out.println("Processing: " + str); + return "Foo"; + case "bar": return "Bar"; + default: return "Other"; + } + } + } + """ + ) + ); + } +}