diff --git a/src/main/java/org/openrewrite/java/migrate/lang/NullCheckAsSwitchCase.java b/src/main/java/org/openrewrite/java/migrate/lang/NullCheckAsSwitchCase.java index 9257b6feb6..aa2226fd7d 100644 --- a/src/main/java/org/openrewrite/java/migrate/lang/NullCheckAsSwitchCase.java +++ b/src/main/java/org/openrewrite/java/migrate/lang/NullCheckAsSwitchCase.java @@ -24,19 +24,22 @@ import org.openrewrite.java.JavaVisitor; import org.openrewrite.java.search.SemanticallyEqual; import org.openrewrite.java.search.UsesJavaVersion; -import org.openrewrite.java.tree.*; +import org.openrewrite.java.tree.Expression; +import org.openrewrite.java.tree.J; +import org.openrewrite.java.tree.Space; +import org.openrewrite.java.tree.Statement; import org.openrewrite.staticanalysis.groovy.GroovyFileChecker; import org.openrewrite.staticanalysis.kotlin.KotlinFileChecker; import java.time.Duration; import java.util.ArrayList; import java.util.List; -import java.util.Objects; import java.util.Optional; import java.util.concurrent.atomic.AtomicReference; import static java.util.Objects.requireNonNull; import static org.openrewrite.java.migrate.lang.NullCheck.Matcher.nullCheck; +import static org.openrewrite.java.migrate.lang.SwitchUtils.coversAllPossibleValues; @EqualsAndHashCode(callSuper = false) @Value @@ -184,36 +187,6 @@ private J.Case createCaseStatement(J.Switch aSwitch, Statement whenNull, J.Case return nullCase.withStatements(ListUtils.mapFirst(nullCase.getStatements(), s -> s == null ? null : s.withPrefix(currentFirstCaseIndentation))); } - - private boolean coversAllPossibleValues(J.Switch switch_) { - List labels = new ArrayList<>(); - for (Statement statement : switch_.getCases().getStatements()) { - for (J j : ((J.Case) statement).getCaseLabels()) { - if (j instanceof J.Identifier && "default".equals(((J.Identifier) j).getSimpleName())) { - return true; - } - labels.add(j); - } - } - JavaType javaType = switch_.getSelector().getTree().getType(); - if (javaType instanceof JavaType.Class && ((JavaType.Class) javaType).getKind() == JavaType.FullyQualified.Kind.Enum) { - // Every enum value must be present in the switch - return ((JavaType.Class) javaType).getMembers().stream().allMatch(variable -> - labels.stream().anyMatch(label -> { - if (!(label instanceof TypeTree && TypeUtils.isOfType(((TypeTree) label).getType(), javaType))) { - return false; - } - J.Identifier enumName = null; - if (label instanceof J.Identifier) { - enumName = (J.Identifier) label; - } else if (label instanceof J.FieldAccess) { - enumName = ((J.FieldAccess) label).getName(); - } - return enumName != null && Objects.equals(variable.getName(), enumName.getSimpleName()); - })); - } - return false; - } }); } } diff --git a/src/main/java/org/openrewrite/java/migrate/lang/SwitchCaseAssignmentsToSwitchExpression.java b/src/main/java/org/openrewrite/java/migrate/lang/SwitchCaseAssignmentsToSwitchExpression.java new file mode 100644 index 0000000000..a4f437afd0 --- /dev/null +++ b/src/main/java/org/openrewrite/java/migrate/lang/SwitchCaseAssignmentsToSwitchExpression.java @@ -0,0 +1,283 @@ +/* + * Copyright 2025 the original author or authors. + *

+ * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://docs.moderne.io/licensing/moderne-source-available-license + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.java.migrate.lang; + +import lombok.EqualsAndHashCode; +import lombok.Value; +import org.jspecify.annotations.Nullable; +import org.openrewrite.*; +import org.openrewrite.internal.ListUtils; +import org.openrewrite.java.JavaIsoVisitor; +import org.openrewrite.java.JavaTemplate; +import org.openrewrite.java.search.SemanticallyEqual; +import org.openrewrite.java.search.UsesJavaVersion; +import org.openrewrite.java.tree.*; +import org.openrewrite.marker.Markers; +import org.openrewrite.staticanalysis.InlineVariable; +import org.openrewrite.staticanalysis.groovy.GroovyFileChecker; +import org.openrewrite.staticanalysis.kotlin.KotlinFileChecker; + +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import static java.util.Collections.singletonList; +import static java.util.Objects.requireNonNull; +import static org.openrewrite.Tree.randomId; + +@Value +@EqualsAndHashCode(callSuper = false) +public class SwitchCaseAssignmentsToSwitchExpression extends Recipe { + @Override + public String getDisplayName() { + return "Convert assigning Switch statements to Switch expressions"; + } + + @Override + public String getDescription() { + return "Switch statements for which each case is assigning a value to the same variable can be converted to a switch expression that returns the value of the variable. " + + "This is only applicable for Java 17 and later."; + } + + @Override + public TreeVisitor getVisitor() { + TreeVisitor preconditions = Preconditions.and( + new UsesJavaVersion<>(17), + Preconditions.not(new KotlinFileChecker<>()), + Preconditions.not(new GroovyFileChecker<>()) + ); + return Preconditions.check(preconditions, new JavaIsoVisitor() { + @Override + public J.Block visitBlock(J.Block originalBlock, ExecutionContext ctx) { + J.Block block = super.visitBlock(originalBlock, ctx); + + AtomicReference originalSwitch = new AtomicReference<>(); + + int lastIndex = block.getStatements().size() - 1; + return block.withStatements(ListUtils.map(block.getStatements(), (index, statement) -> { + if (statement == originalSwitch.getAndSet(null)) { + doAfterVisit(new InlineVariable().getVisitor()); + // We've already converted the switch/assignments to an assignment with a switch expression. + return null; + } + + if (index < lastIndex && + statement instanceof J.VariableDeclarations && + ((J.VariableDeclarations) statement).getVariables().size() == 1 && + !canHaveSideEffects(((J.VariableDeclarations) statement).getVariables().get(0).getInitializer()) && + block.getStatements().get(index + 1) instanceof J.Switch) { + J.VariableDeclarations vd = (J.VariableDeclarations) statement; + J.Switch nextStatementSwitch = (J.Switch) block.getStatements().get(index + 1); + + J.VariableDeclarations.NamedVariable originalVariable = vd.getVariables().get(0); + J.SwitchExpression newSwitchExpression = buildNewSwitchExpression(nextStatementSwitch, originalVariable); + if (newSwitchExpression != null) { + originalSwitch.set(nextStatementSwitch); + return vd + .withVariables(singletonList(originalVariable.getPadding().withInitializer( + JLeftPadded.build(newSwitchExpression).withBefore(Space.SINGLE_SPACE)))) + .withComments(ListUtils.concatAll(vd.getComments(), nextStatementSwitch.getComments())); + } + } + return statement; + })); + } + + private J.@Nullable SwitchExpression buildNewSwitchExpression(J.Switch originalSwitch, J.VariableDeclarations.NamedVariable originalVariable) { + J.Identifier originalVariableId = originalVariable.getName(); + AtomicBoolean isQualified = new AtomicBoolean(true); + AtomicBoolean isDefaultCaseAbsent = new AtomicBoolean(true); + AtomicBoolean isUsingArrows = new AtomicBoolean(true); + AtomicBoolean isLastCaseEmpty = new AtomicBoolean(false); + + List updatedCases = ListUtils.map(originalSwitch.getCases().getStatements(), (index, s) -> { + if (!isQualified.get()) { + return null; + } + + J.Case caseItem = (J.Case) s; + if (caseItem.getCaseLabels().get(0) instanceof J.Identifier && + "default".equals(((J.Identifier) caseItem.getCaseLabels().get(0)).getSimpleName())) { + isDefaultCaseAbsent.set(false); + } + + if (caseItem.getBody() != null) { // arrow cases + J caseBody = caseItem.getBody(); + if (caseBody instanceof J.Block && ((J.Block) caseBody).getStatements().size() == 1) { + caseBody = ((J.Block) caseBody).getStatements().get(0); + } + J.Assignment assignment = extractAssignmentOfVariable(caseBody, originalVariableId); + if (assignment != null) { + return caseItem.withBody(assignment.getAssignment()); + } + } else { // colon cases + isUsingArrows.set(false); + boolean isLastCase = index + 1 == originalSwitch.getCases().getStatements().size(); + + List caseStatements = caseItem.getStatements(); + if (caseStatements.isEmpty()) { + if (isLastCase) { + isLastCaseEmpty.set(true); + } + return caseItem; + } + + J.Assignment assignment = extractAssignmentFromColonCase(caseStatements, isLastCase, originalVariableId); + if (assignment != null) { + J.Yield yieldStatement = new J.Yield( + randomId(), + assignment.getPrefix().withWhitespace(" "), + Markers.EMPTY, + false, + assignment.getAssignment() + ); + return caseItem.withStatements(singletonList(yieldStatement)); + } + } + + isQualified.set(false); + return null; + }); + if (!isQualified.get()) { + return null; + } + + boolean shouldAddDefaultCase = isDefaultCaseAbsent.get() && !SwitchUtils.coversAllPossibleValues(originalSwitch); + Expression originalInitializer = originalVariable.getInitializer(); + if ((originalInitializer == null && shouldAddDefaultCase) || + (isLastCaseEmpty.get() && !shouldAddDefaultCase)) { + return null; + } + + if (shouldAddDefaultCase) { + updatedCases.add(createDefaultCase(originalSwitch, originalInitializer.withPrefix(Space.SINGLE_SPACE), isUsingArrows.get())); + } + + return new J.SwitchExpression( + randomId(), + Space.SINGLE_SPACE, + Markers.EMPTY, + originalSwitch.getSelector(), + originalSwitch.getCases().withStatements(updatedCases), + originalVariable.getType()); + } + + private J.@Nullable Assignment extractAssignmentFromColonCase(List caseStatements, boolean isLastCase, J.Identifier variableId) { + if (caseStatements.size() == 1 && caseStatements.get(0) instanceof J.Block) { + caseStatements = ((J.Block) caseStatements.get(0)).getStatements(); + } + if ((caseStatements.size() == 2 && caseStatements.get(1) instanceof J.Break) || (caseStatements.size() == 1 && isLastCase)) { + return extractAssignmentOfVariable(caseStatements.get(0), variableId); + } + return null; + } + + private J.@Nullable Assignment extractAssignmentOfVariable(J maybeAssignment, J.Identifier variableId) { + if (maybeAssignment instanceof J.Assignment) { + J.Assignment assignment = (J.Assignment) maybeAssignment; + if (assignment.getVariable() instanceof J.Identifier) { + J.Identifier variable = (J.Identifier) assignment.getVariable(); + if (SemanticallyEqual.areEqual(variable, variableId) && + !containsIdentifier(variableId, assignment.getAssignment())) { + return assignment; + } + } + } + return null; + } + + private J.Case createDefaultCase(J.Switch originalSwitch, Expression returnedExpression, boolean arrow) { + J.Switch switchStatement = JavaTemplate.apply( + "switch(1) { default" + (arrow ? " ->" : ": yield") + " #{any()}; }", + new Cursor(getCursor(), originalSwitch), + originalSwitch.getCoordinates().replace(), + returnedExpression + ); + return (J.Case) switchStatement.getCases().getStatements().get(0); + } + + private boolean containsIdentifier(J.Identifier identifier, Expression expression) { + return new JavaIsoVisitor() { + @Override + public J.Identifier visitIdentifier(J.Identifier id, AtomicBoolean found) { + if (SemanticallyEqual.areEqual(id, identifier)) { + found.set(true); + return id; + } + return super.visitIdentifier(id, found); + } + }.reduce(expression, new AtomicBoolean()).get(); + } + + // Might the initializer affect the input or output of the switch expression? + private boolean canHaveSideEffects(@Nullable Expression expression) { + if (expression == null) { + return false; + } + + return new JavaIsoVisitor() { + @Override + public J.Assignment visitAssignment(J.Assignment assignment, AtomicBoolean found) { + found.set(true); + return super.visitAssignment(assignment, found); + } + + @Override + public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, AtomicBoolean found) { + found.set(true); + return method; + } + + @Override + public J.NewClass visitNewClass(J.NewClass newClass, AtomicBoolean found) { + found.set(true); + return newClass; + } + + @Override + public J.Unary visitUnary(J.Unary unary, AtomicBoolean found) { + found.set(true); + return super.visitUnary(unary, found); + } + + private boolean isToStringImplicitlyCalled(Expression a, Expression b) { + // Assuming an implicit `.toString()` call could have a side effect, but excluding + // the java.lang.* classes from that rule. + if (TypeUtils.isAssignableTo("java.lang.String", a.getType()) && + TypeUtils.isAssignableTo("java.lang.String", b.getType())) { + return false; + } + + return a.getType() == JavaType.Primitive.String && + (!(b.getType() instanceof JavaType.Primitive || requireNonNull(b.getType()).toString().startsWith("java.lang")) && + !TypeUtils.isAssignableTo("java.lang.String", b.getType())); + } + + @Override + public J.Binary visitBinary(J.Binary binary, AtomicBoolean found) { + if (isToStringImplicitlyCalled(binary.getLeft(), binary.getRight()) || + isToStringImplicitlyCalled(binary.getRight(), binary.getLeft())) { + found.set(true); + return binary; + } + return super.visitBinary(binary, found); + } + }.reduce(expression, new AtomicBoolean()).get(); + } + } + ); + } +} diff --git a/src/main/java/org/openrewrite/java/migrate/lang/SwitchUtils.java b/src/main/java/org/openrewrite/java/migrate/lang/SwitchUtils.java new file mode 100644 index 0000000000..baf2ce77b6 --- /dev/null +++ b/src/main/java/org/openrewrite/java/migrate/lang/SwitchUtils.java @@ -0,0 +1,71 @@ +/* + * Copyright 2025 the original author or authors. + *

+ * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://docs.moderne.io/licensing/moderne-source-available-license + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.java.migrate.lang; + +import org.openrewrite.java.tree.*; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +import static java.util.stream.Collectors.toSet; + +class SwitchUtils { + /** + * Checks if a switch statement covers all possible values of its selector. + * This is typically used to determine if a switch statement is "exhaustive" as per the Java language specification. + *

+ * NOTE: Missing support for sealed classes/interfaces. + * + * @param switch_ the switch statement to check + * @return true if the switch covers all possible values, false otherwise + * @see Switch Expressions in Java 21 + */ + public static boolean coversAllPossibleValues(J.Switch switch_) { + List labels = new ArrayList<>(); + for (Statement statement : switch_.getCases().getStatements()) { + for (J j : ((J.Case) statement).getCaseLabels()) { + if (j instanceof J.Identifier && "default".equals(((J.Identifier) j).getSimpleName())) { + return true; + } + labels.add(j); + } + } + JavaType javaType = switch_.getSelector().getTree().getType(); + if (javaType instanceof JavaType.Class) { + JavaType.Class javaTypeClass = (JavaType.Class) javaType; + if (javaTypeClass.hasFlags(Flag.Enum)) { + Collection labelValues = labels.stream() + .filter(label -> label instanceof J.Identifier || label instanceof J.FieldAccess) + .filter(label -> TypeUtils.isOfType(((TypeTree) label).getType(), javaType)) + .map(label -> label instanceof J.Identifier ? + ((J.Identifier) label).getSimpleName() : + ((J.FieldAccess) label).getName().getSimpleName()) + .collect(toSet()); + if (labelValues.isEmpty()) { + return false; + } + Collection enumValues = javaTypeClass.getMembers().stream() + .filter(member -> member.hasFlags(Flag.Enum)) + .map(JavaType.Variable::getName) + .collect(toSet()); + // Every enum value must be present in the switch + return !enumValues.isEmpty() && labelValues.containsAll(enumValues); + } + } + return false; + } +} diff --git a/src/main/resources/META-INF/rewrite/examples.yml b/src/main/resources/META-INF/rewrite/examples.yml index 1af0177a1d..951bcb4e13 100644 --- a/src/main/resources/META-INF/rewrite/examples.yml +++ b/src/main/resources/META-INF/rewrite/examples.yml @@ -6235,6 +6235,34 @@ examples: language: java --- type: specs.openrewrite.org/v1beta/example +recipeName: org.openrewrite.java.migrate.lang.SwitchCaseAssigningToSwitchExpression +examples: +- description: '' + sources: + - before: | + class Test { + void doFormat(Object obj) { + String formatted = "initialValue"; + switch (obj) { + case Integer i -> formatted = String.format("int %d", i); + case Long l -> formatted = String.format("long %d", l); + default -> formatted = "unknown"; + } + } + } + after: | + class Test { + void doFormat(Object obj) { + String formatted = switch (obj) { + case Integer i -> String.format("int %d", i); + case Long l -> String.format("long %d", l); + default -> "unknown"; + }; + } + } + language: java +--- +type: specs.openrewrite.org/v1beta/example recipeName: org.openrewrite.java.migrate.lang.SwitchCaseEnumGuardToLabel examples: - description: '' diff --git a/src/main/resources/META-INF/rewrite/java-version-17.yml b/src/main/resources/META-INF/rewrite/java-version-17.yml index 35ba237f18..20f9c20bc3 100644 --- a/src/main/resources/META-INF/rewrite/java-version-17.yml +++ b/src/main/resources/META-INF/rewrite/java-version-17.yml @@ -59,6 +59,7 @@ recipeList: artifactId: commons-codec newVersion: 1.17.x - org.openrewrite.java.migrate.AddLombokMapstructBinding + - org.openrewrite.java.migrate.lang.SwitchCaseAssignmentsToSwitchExpression --- type: specs.openrewrite.org/v1beta/recipe diff --git a/src/test/java/org/openrewrite/java/migrate/lang/SwitchCaseAssignmentsToSwitchExpressionTest.java b/src/test/java/org/openrewrite/java/migrate/lang/SwitchCaseAssignmentsToSwitchExpressionTest.java new file mode 100644 index 0000000000..cda324c392 --- /dev/null +++ b/src/test/java/org/openrewrite/java/migrate/lang/SwitchCaseAssignmentsToSwitchExpressionTest.java @@ -0,0 +1,921 @@ +/* + * Copyright 2025 the original author or authors. + *

+ * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://docs.moderne.io/licensing/moderne-source-available-license + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.java.migrate.lang; + +import org.junit.jupiter.api.Test; +import org.openrewrite.DocumentExample; +import org.openrewrite.test.RecipeSpec; +import org.openrewrite.test.RewriteTest; + +import static org.openrewrite.java.Assertions.java; +import static org.openrewrite.java.Assertions.version; + +@SuppressWarnings({"EnhancedSwitchMigration", "RedundantLabeledSwitchRuleCodeBlock", "StringOperationCanBeSimplified", "SwitchStatementWithTooFewBranches", "UnnecessaryReturnStatement", "UnusedAssignment"}) +class SwitchCaseAssignmentsToSwitchExpressionTest implements RewriteTest { + @Override + public void defaults(RecipeSpec spec) { + spec + .recipe(new SwitchCaseAssignmentsToSwitchExpression()) + .allSources(source -> version(source, 17)); + } + + @DocumentExample + @Test + void convertSimpleArrowCasesAssignment() { + rewriteRun( + //language=java + java( + """ + class Test { + void doFormat(Object obj) { + String formatted = "initialValue"; + switch (obj) { + case Integer i -> formatted = String.format("int %d", i); + case Long l -> formatted = String.format("long %d", l); + default -> formatted = "unknown"; + } + } + } + """, + """ + class Test { + void doFormat(Object obj) { + String formatted = switch (obj) { + case Integer i -> String.format("int %d", i); + case Long l -> String.format("long %d", l); + default -> "unknown"; + }; + } + } + """ + ) + ); + } + + @Test + void convertSimpleColonCasesAssignment() { + rewriteRun( + //language=java + java( + """ + class Test { + void doFormat(Object obj) { + String formatted = "initialValue"; + switch (obj) { + case Integer i: formatted = String.format("int %d", i); break; + case Long l: formatted = String.format("long %d", l); break; + default: formatted = "unknown"; break; + } + } + } + """, + """ + class Test { + void doFormat(Object obj) { + String formatted = switch (obj) { + case Integer i: yield String.format("int %d", i); + case Long l: yield String.format("long %d", l); + default: yield "unknown"; + }; + } + } + """ + ) + ); + } + + @Test + void notConvertSimpleColonCasesAssignmentWithExtraCodeInBlock() { + // Only one statement [+break;] per case is currently supported + rewriteRun( + //language=java + java( + """ + class Test { + void doFormat(Object obj) { + String formatted = "initialValue"; + switch (obj) { + case Integer i: formatted = String.format("int %d", i); break; + case Long l: System.out.println("long"); formatted = String.format("long %d", l); break; + default: formatted = "unknown"; break; + } + } + } + """ + ) + ); + } + + @Test + void convertColonCasesSimpleAssignmentInBlockToSingleYield() { + rewriteRun( + //language=java + java( + """ + class Test { + void doFormat(Object obj) { + String formatted = "initialValue"; + switch (obj) { + case Integer i: formatted = String.format("int %d", i); break; + case Long l: { + formatted = String.format("long %d", l); + break; + } + default: formatted = "unknown"; break; + } + } + } + """, + """ + class Test { + void doFormat(Object obj) { + String formatted = switch (obj) { + case Integer i: yield String.format("int %d", i); + case Long l: yield String.format("long %d", l); + default: yield "unknown"; + }; + } + } + """ + ) + ); + } + + @Test + void convertColonCasesSimpleAssignmentInBlockToSingleYieldWithoutFinalCaseBreak() { + rewriteRun( + //language=java + java( + """ + class Test { + void doFormat(Object obj) { + String formatted = "initialValue"; + switch (obj) { + case Integer i: formatted = String.format("int %d", i); break; + case Long l: { + formatted = String.format("long %d", l); + break; + } + default: formatted = "unknown"; + } + } + } + """, + """ + class Test { + void doFormat(Object obj) { + String formatted = switch (obj) { + case Integer i: yield String.format("int %d", i); + case Long l: yield String.format("long %d", l); + default: yield "unknown"; + }; + } + } + """ + ) + ); + } + + @Test + void convertArrowCasesSimpleAssignmentInBlockToSingleValue() { + rewriteRun( + //language=java + java( + """ + class Test { + void doFormat(Object obj) { + String formatted = "initialValue"; + switch (obj) { + case Integer i -> formatted = String.format("int %d", i); + case Long l -> { + formatted = String.format("long %d", l); + } + default -> formatted = "unknown"; + } + } + } + """, + """ + class Test { + void doFormat(Object obj) { + String formatted = switch (obj) { + case Integer i -> String.format("int %d", i); + case Long l -> String.format("long %d", l); + default -> "unknown"; + }; + } + } + """ + ) + ); + } + + @Test + void notConvertCasesWithMissingAssignment() { + rewriteRun( + //language=java + java( + """ + class Test { + void doFormat(Object obj) { + String formatted = "initialValue"; + switch (obj) { + case String s: formatted = String.format("String %s", s); break; + case Integer i: System.out.println("Integer!"); break; + default: formatted = "unknown"; break; + } + } + } + """ + ) + ); + } + + @Test + void notConvertCasesWithAssignmentToDifferentVariables() { + rewriteRun( + //language=java + java( + """ + class Test { + void doFormat(Object obj) { + String formatted = "initialValue"; + String formatted2 = "anotherInitialValue"; + switch (obj) { + case String s: formatted = String.format("String %s", s); break; + case Integer i: formatted2 = String.format("Integer %d", i); break; + default: formatted = "unknown"; break; + } + } + } + """ + ) + ); + } + + @Test + void notConvertCasesWhenColonCaseHasNoStatementsAndNextCaseIsntAssignment() { + rewriteRun( + //language=java + java( + """ + class Test { + void doFormat(Object obj) { + String status = "initialValue"; + switch (obj) { + case null: + default: System.out.println("default"); break; + } + } + } + """ + ) + ); + } + + @Test + void convertCasesWhenColonCaseHasNoStatementsAndNextCaseIsAssignment() { + rewriteRun( + //language=java + java( + """ + class Test { + enum TrafficLight { + RED, GREEN, YELLOW + } + void doFormat(TrafficLight light) { + String status = "initialValue"; + switch (light) { + case RED: + case GREEN: + case YELLOW: status = "unsure"; break; + default: status = "unknown"; break; + } + } + } + """, + """ + class Test { + enum TrafficLight { + RED, GREEN, YELLOW + } + void doFormat(TrafficLight light) { + String status = switch (light) { + case RED: + case GREEN: + case YELLOW: yield "unsure"; + default: yield "unknown"; + }; + } + } + """ + ) + ); + } + + @Test + void convertCasesWithAddedDefault() { + rewriteRun( + //language=java + java( + """ + class Test { + enum TrafficLight { + RED, GREEN, YELLOW + } + void doFormat(TrafficLight light) { + String status = "initialValue"; + switch (light) { + case RED: status = "stop"; break; + case GREEN: status = "go"; break; + } + } + } + """, + """ + class Test { + enum TrafficLight { + RED, GREEN, YELLOW + } + void doFormat(TrafficLight light) { + String status = switch (light) { + case RED: yield "stop"; + case GREEN: yield "go"; + default: yield "initialValue"; + }; + } + } + """ + ) + ); + } + + @Test + void notConvertColonCasesWithMultipleBlocks() { + // More than one block statement per case is not yet supported. + rewriteRun( + //language=java + java( + """ + class Test { + void doFormat(Object obj) { + String status = "initialValue"; + switch (obj) { + case null: { + status = "none"; + } + { + break; + } + default: status = "default status"; break; + } + } + } + """ + ) + ); + } + + @Test + void noDefaultAddedIfAlreadyExhaustive() { + rewriteRun( + //language=java + java( + """ + class Test { + enum TrafficLight { + RED, GREEN, YELLOW + } + void doFormat(TrafficLight light) { + String status = "initialValue"; + switch (light) { + case RED: status = "stop"; break; + case GREEN: status = "go"; break; + case YELLOW: status = "unsure"; break; + } + } + } + """, + """ + class Test { + enum TrafficLight { + RED, GREEN, YELLOW + } + void doFormat(TrafficLight light) { + String status = switch (light) { + case RED: yield "stop"; + case GREEN: yield "go"; + case YELLOW: yield "unsure"; + }; + } + } + """ + ) + ); + } + + @Test + void notConvertWhenOriginalVariableIsUsedInCaseAssignment() { + rewriteRun( + //language=java + java( + """ + class Test { + void doFormat(int i) { + String orig = "initialValue"; + switch (i) { + default: orig = orig.toLowerCase(); break; + } + } + + void doFormat2(int i) { + String orig = "initialValue"; + switch (i) { + default: orig = String.format("%s %s", orig, "foo"); break; + } + } + + void doFormat3(int i) { + String orig = "initialValue"; + switch (i) { + default: orig = "foo" + orig; break; + } + } + } + """ + ) + ); + } + + @Test + void notConvertWhenOriginalVariableAssignmentHasSideEffects() { + rewriteRun( + //language=java + java( + """ + class Test { + void methodInvocation(int i) { + String orig = "initialValue".toLowerCase(); + switch (i) { + default: orig = "hello"; break; + } + } + + void newClass(int i) { + String orig = new String("initialValue"); + switch (i) { + default: orig = "hello"; break; + } + } + + void newClassInBinaryExpression(int i) { + String orig = "initialValue" + new String("more"); + switch (i) { + default: orig = "hello"; break; + } + } + + void implicitToStringInvocation(int i, Test o) { + String orig = "initialValue" + o; + switch (i) { + default: orig = "hello"; break; + } + } + + void incrementOperator(int i) { + int n = i++; + switch (i) { + default: n = 5; break; + } + } + + void assignment(int i) { + int n = ( i = 2 ); + switch (i) { + default: n = 5; break; + } + } + } + """ + ) + ); + } + + @Test + void convertWhenOriginalVariableAssignmentIsComplexExpressionButNoSideEffects() { + rewriteRun( + //language=java + java( + """ + class Test { + String field = "strawberry"; + + void doFormat(int i) { + String orig = "initialValue" + "test" + 45 + true + field + this.field; + switch (i) { + default: orig = "hello"; break; + } + } + } + """, + """ + class Test { + String field = "strawberry"; + + void doFormat(int i) { + String orig = switch (i) { + default: yield "hello"; + }; + } + } + """ + ) + ); + } + + @Test + void notConvertColonSwitchWithEmptyLastCase() { + rewriteRun( + //language=java + java( + """ + class Test { + enum TrafficLight { + RED, GREEN, YELLOW + } + void doFormat(TrafficLight light) { + String status = "initialValue"; + switch (light) { + case RED: status = "stop"; break; + case GREEN: status = "go"; break; + case YELLOW: + } + } + } + """ + ) + ); + } + + @Test + void notConvertSwitchOnUninitializedOriginalVariableAndNonExhaustiveSwitch() { + rewriteRun( + //language=java + java( + """ + class Test { + enum TrafficLight { + RED, GREEN, YELLOW + } + void exhaustiveButCantAddDefaultAndMissingAssignment(TrafficLight light) { + String status; + switch (light) { + case RED: status = "stop"; break; + case GREEN: status = "go"; break; + case YELLOW: + } + } + + void exhaustiveButCantAddDefaultAfterEmptyLabel(TrafficLight light) { + String status = "initialValue"; + switch (light) { + case RED: status = "stop"; break; + case GREEN: status = "go"; break; + case YELLOW: + } + } + + void exhaustiveButMissingAssignment(TrafficLight light) { + String status; + switch (light) { + case RED: status = "stop"; break; + case GREEN: status = "go"; break; + case YELLOW: + default: System.out.println("foo"); + } + } + } + """ + ) + ); + } + + @Test + void inlineWhenVariableOnlyToBeReturned() { + rewriteRun( + //language=java + java( + """ + class Test { + String doFormat() { + String formatted; + switch (1) { + default: formatted = "foo"; break; + } + return formatted; + } + } + """, + """ + class Test { + String doFormat() { + return switch (1) { + default: yield "foo"; + }; + } + } + """ + ) + ); + } + + @Test + void doNotInlineWhenInappropriate() { + rewriteRun( + //language=java + java( + """ + class Test { + String originalVariableNotReturned() { + String formatted; + switch (1) { + default: formatted = "foo"; break; + } + return "string"; + } + + String codeBetweenSwitchAndReturn() { + String formatted; + switch (1) { + default: formatted = "foo"; break; + } + System.out.println("Hey"); + return formatted; + } + + void noReturnedExpression() { + String formatted; + switch (1) { + default: formatted = "foo"; break; + } + return; + } + } + """, + """ + class Test { + String originalVariableNotReturned() { + String formatted = switch (1) { + default: yield "foo"; + }; + return "string"; + } + + String codeBetweenSwitchAndReturn() { + String formatted = switch (1) { + default: yield "foo"; + }; + System.out.println("Hey"); + return formatted; + } + + void noReturnedExpression() { + String formatted = switch (1) { + default: yield "foo"; + }; + return; + } + } + """ + ) + ); + } + + @Test + void defaultAsSecondLabelColonCase() { + rewriteRun( + //language=java + java( + """ + class A { + void doFormat(String str) { + String formatted = "initialValue"; + switch (str) { + case "foo": formatted = "Foo"; break; + case "bar": formatted = "Bar"; break; + case null, default: formatted = "unknown"; + } + } + } + """, + """ + class A { + void doFormat(String str) { + String formatted = switch (str) { + case "foo": yield "Foo"; + case "bar": yield "Bar"; + case null, default: yield "unknown"; + }; + } + } + """ + ) + ); + } + + @Test + void defaultAsSecondLabelArrowCase() { + rewriteRun( + java( + """ + class B { + void doFormat(String str) { + String formatted = "initialValue"; + switch (str) { + case "foo" -> formatted = "Foo"; + case "bar" -> formatted = "Bar"; + case null, default -> formatted = "Other"; + } + } + } + """, + """ + class B { + void doFormat(String str) { + String formatted = switch (str) { + case "foo" -> "Foo"; + case "bar" -> "Bar"; + case null, default -> "Other"; + }; + } + } + """ + ) + ); + } + + @Test + void whitespaceAddedWhenNoOriginalAssignment() { + rewriteRun( + //language=java + java( + """ + class Test { + void doFormat() { + String formatted; + switch (1) { + default: formatted = "foo"; break; + } + } + } + """, + """ + class Test { + void doFormat() { + String formatted = switch (1) { + default: yield "foo"; + }; + } + } + """ + ) + ); + } + + @Test + void commentsArePreserved() { + rewriteRun( + //language=java + java( + """ + class Test { + void doFormat(Object obj) { + // line before the original variable + String formatted = "initialValue"; // original variable after code + // line before the switch + switch (obj) { // first line of the switch + // before the cases + case Integer i -> formatted = String.format("int %d", i); // first case + // between the 1st and 2nd case + /* before the 2nd case */ case Long l -> formatted = String.format("long %d", l); + default -> formatted = "unknown"; + // after the last case + } // last line of the switch + } + } + """, + """ + class Test { + void doFormat(Object obj) { + // line before the original variable + // original variable after code + // line before the switch + String formatted = switch (obj) { // first line of the switch + // before the cases + case Integer i -> String.format("int %d", i); // first case + // between the 1st and 2nd case + /* before the 2nd case */ case Long l -> String.format("long %d", l); + default -> "unknown"; + // after the last case + }; // last line of the switch + } + } + """ + ) + ); + } + + @Test + void commentsArePreservedWhenInlining() { + rewriteRun( + //language=java + java( + """ + class Test { + String doFormat() { + // line before original variable + String formatted; // original variable after code + // between the original variable and the switch + switch (1) { // on the switch after code + default: formatted = "foo"; break; + } // last line of the switch + // between switch and return + return formatted; // after return on the same line + } + } + """, + """ + class Test { + String doFormat() { + // line before original variable + // original variable after code + // between the original variable and the switch + // last line of the switch + // between switch and return + return switch (1) { // on the switch after code + default: yield "foo"; + }; // after return on the same line + } + } + """ + ) + ); + } + + @Test + void notConvertWhenFallThrough() { + rewriteRun( + //language=java + java( + """ + class Test { + void doFormat(String str) { + String formatted = "initialValue"; + switch (str) { + case "A": formatted = "A"; // no break + case "B": formatted = "B"; // no break + case "C": formatted = "C"; // no break + default: formatted = "Z"; + } + } + } + """ + ) + ); + } + + @Test + void notConvertWhenFallThroughAppends() { + rewriteRun( + //language=java + java( + """ + class Test { + void doFormat(String str) { + String formatted = "initialValue"; + switch (str) { + case "A": formatted = "A"; + case "B": formatted = formatted + "B"; + case "C": formatted = formatted + "C"; + default: formatted = "Z"; break; + } + } + } + """ + ) + ); + } +} diff --git a/src/test/java/org/openrewrite/java/migrate/lang/SwitchUtilsTest.java b/src/test/java/org/openrewrite/java/migrate/lang/SwitchUtilsTest.java new file mode 100644 index 0000000000..10904df328 --- /dev/null +++ b/src/test/java/org/openrewrite/java/migrate/lang/SwitchUtilsTest.java @@ -0,0 +1,190 @@ +/* + * Copyright 2025 the original author or authors. + *

+ * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://docs.moderne.io/licensing/moderne-source-available-license + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.java.migrate.lang; + +import org.intellij.lang.annotations.Language; +import org.junit.jupiter.api.Test; +import org.openrewrite.java.JavaIsoVisitor; +import org.openrewrite.java.JavaParser; +import org.openrewrite.java.tree.J; + +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class SwitchUtilsTest { + private static J.Switch extractSwitch(@Language("java") String code) { + J.CompilationUnit cu = (J.CompilationUnit) JavaParser.fromJavaVersion().build().parse(code).findFirst().get(); + return new JavaIsoVisitor>() { + @Override + public J.Switch visitSwitch(J.Switch _switch, AtomicReference switchAtomicReference) { + switchAtomicReference.set(_switch); + return _switch; + } + }.reduce(cu, new AtomicReference<>()).get(); + } + + @Test + void coversAllCasesAllEnums() { + assertTrue( + SwitchUtils.coversAllPossibleValues( + extractSwitch( + """ + class Test { + void method(TrafficLight light) { + switch (light) { + case RED -> System.out.println("stop"); + case YELLOW -> System.out.println("caution"); + case GREEN -> System.out.println("go"); + } + } + enum TrafficLight { RED, YELLOW, GREEN } + } + """ + ) + ) + ); + } + + @Test + void coversAllCasesMissingEnums() { + assertFalse( + SwitchUtils.coversAllPossibleValues( + extractSwitch( + """ + class Test { + void method(TrafficLight light) { + switch (light) { + case RED -> System.out.println("stop"); + case YELLOW -> System.out.println("caution"); + } + } + enum TrafficLight { RED, YELLOW, GREEN } + } + """ + ) + ) + ); + } + + @Test + void coversAllCasesMissingEnumsWithDefault() { + assertTrue( + SwitchUtils.coversAllPossibleValues( + extractSwitch( + """ + class Test { + void method(TrafficLight light) { + switch (light) { + case RED -> System.out.println("stop"); + case YELLOW -> System.out.println("caution"); + default -> System.out.println("unknown"); + } + } + enum TrafficLight { RED, YELLOW, GREEN } + } + """ + ) + ) + ); + } + + @Test + void coversAllCasesEnumOnlyDefault() { + assertTrue( + SwitchUtils.coversAllPossibleValues( + extractSwitch( + """ + class Test { + void method(TrafficLight light) { + switch (light) { + default -> System.out.println("unknown"); + } + } + enum TrafficLight { RED, YELLOW, GREEN } + } + """ + ) + ) + ); + } + + @Test + void coversAllCasesObjectOnlyDefault() { + assertTrue( + SwitchUtils.coversAllPossibleValues( + extractSwitch( + """ + class Test { + void method(Object obj) { + switch (obj) { + default -> System.out.println("default"); + } + } + } + """ + ) + ) + ); + } + + @Test + void coversAllCasesAllSealedClasses() { + assertFalse( + SwitchUtils.coversAllPossibleValues( + extractSwitch( + """ + class Test { + sealed abstract class Shape permits Circle, Square, Rectangle {} + void method(Shape shape) { + switch (shape) { + case Circle c -> System.out.println("circle"); + case Square s -> System.out.println("square"); + case Rectangle r -> System.out.println("rectangle"); + } + } + } + """ + ) + ), "Not implemented yet for sealed classes" + ); + } + + @Test + void coversAllCasesEnumWithExtraMembers() { + assertTrue( + SwitchUtils.coversAllPossibleValues( + extractSwitch( + """ + class Test { + enum EnumWithExtraMembers { + ONE, TWO; + public static final EnumWithExtraMembers THREE = ONE; + } + void method(EnumWithExtraMembers e) { + switch (e) { + case ONE -> System.out.println("one"); + case TWO -> System.out.println("two"); + } + } + } + """ + ) + ) + ); + } +}