diff --git a/src/main/java/org/openrewrite/staticanalysis/MinimumSwitchCases.java b/src/main/java/org/openrewrite/staticanalysis/MinimumSwitchCases.java index c2aa5a339c..7c18c0a0fd 100644 --- a/src/main/java/org/openrewrite/staticanalysis/MinimumSwitchCases.java +++ b/src/main/java/org/openrewrite/staticanalysis/MinimumSwitchCases.java @@ -31,10 +31,7 @@ import org.openrewrite.marker.Markers; import org.openrewrite.staticanalysis.csharp.CSharpFileChecker; -import java.util.ArrayList; -import java.util.List; -import java.util.Set; -import java.util.UUID; +import java.util.*; import static java.util.Collections.singleton; import static java.util.Collections.singletonList; @@ -192,6 +189,26 @@ public J visitSwitch(J.Switch switch_, ExecutionContext ctx) { } } + Map firstCaseDeclarations = new JavaIsoVisitor>() { + @Override + public J.VariableDeclarations visitVariableDeclarations(J.VariableDeclarations multiVariable, Map variableDeclarations) { + for (J.VariableDeclarations.NamedVariable var_ : multiVariable.getVariables()) { + variableDeclarations.put(var_.getSimpleName(), multiVariable.withVariables(ListUtils.filter(multiVariable.getVariables(), v -> v == var_))); + } + return multiVariable; + } + + @Override + public J.Block visitBlock(J.Block block, Map variableDeclarations) { + for (Statement statement : block.getStatements()) { + if (statement instanceof J.VariableDeclarations) { + visitVariableDeclarations((J.VariableDeclarations) statement, variableDeclarations); + } + } + return block; + } + }.reduce(cases[0], new HashMap<>()); + // move first case to "if" List thenStatements = getStatements(cases[0]); @@ -201,13 +218,36 @@ public J visitSwitch(J.Switch switch_, ExecutionContext ctx) { // move second case to "else" if (cases[1] != null) { assert generatedIf.getElsePart() != null; + JavaVisitor> firstAssignmentAsDeclarationVisitor = new JavaVisitor>() { + + @Override + public J visitStatement(Statement statement, Map stringVariableDeclarationsMap) { + if (statement instanceof J.Assignment) { + return visitAssignment((J.Assignment) statement, stringVariableDeclarationsMap); + } + return super.visitStatement(statement, stringVariableDeclarationsMap); + } + + @Override + public J visitAssignment(J.Assignment assignment, Map firstCaseDeclarations) { + if (assignment.getVariable() instanceof J.Identifier) { + String varName = ((J.Identifier) assignment.getVariable()).getSimpleName(); + if (firstCaseDeclarations.containsKey(varName)) { + J.VariableDeclarations originalDecl = firstCaseDeclarations.remove(varName); + return originalDecl.withVariables(ListUtils.map(originalDecl.getVariables(), v -> v.withInitializer(assignment.getAssignment()))); + } + } + return super.visitAssignment(assignment, firstCaseDeclarations); + } + }; + if (isDefault(cases[1])) { generatedIf = generatedIf.withElsePart(generatedIf.getElsePart().withBody(((J.Block) generatedIf.getElsePart().getBody()).withStatements(ListUtils.map(getStatements(cases[1]), - s -> s instanceof J.Break ? null : s)))); + s -> s instanceof J.Break || s == null ? null : (Statement) firstAssignmentAsDeclarationVisitor.visitStatement(s, firstCaseDeclarations))))); } else { J.If elseIf = (J.If) generatedIf.getElsePart().getBody(); generatedIf = generatedIf.withElsePart(generatedIf.getElsePart().withBody(elseIf.withThenPart(((J.Block) elseIf.getThenPart()).withStatements(ListUtils.map(getStatements(cases[1]), - s -> s instanceof J.Break ? null : s))))); + s -> s instanceof J.Break || s == null ? null : (Statement) firstAssignmentAsDeclarationVisitor.visitStatement(s, firstCaseDeclarations)))))); } } @@ -250,7 +290,6 @@ private boolean switchesOnEnum(J.Switch switch_) { return selectorType instanceof JavaType.Class && ((JavaType.Class) selectorType).getKind() == JavaType.Class.Kind.Enum; } - }); } diff --git a/src/test/java/org/openrewrite/staticanalysis/MinimumSwitchCasesTest.java b/src/test/java/org/openrewrite/staticanalysis/MinimumSwitchCasesTest.java index 0a37ff4e6c..acc5808466 100644 --- a/src/test/java/org/openrewrite/staticanalysis/MinimumSwitchCasesTest.java +++ b/src/test/java/org/openrewrite/staticanalysis/MinimumSwitchCasesTest.java @@ -887,4 +887,108 @@ void doSomethingElse() {} ) ); } + + @Issue("https://github.com/openrewrite/rewrite-static-analysis/issues/687") + @Test + void variableRedeclarationInNewBlocks() { + rewriteRun( + //language=java + java( + """ + class Test { + int someInt; + void test() { + switch (someInt) { + case 1: + String data = getSomeString(); + doThingOneWith(data); + break; + case 2: + data = getSomeOtherString(); + doThingTwoWith(data); + break; + } + } + String getSomeString() { return "one"; } + String getSomeOtherString() { return "two"; } + void doThingOneWith(String data) {} + void doThingTwoWith(String data) {} + } + """, + """ + class Test { + int someInt; + void test() { + if (someInt == 1) { + String data = getSomeString(); + doThingOneWith(data); + } else if (someInt == 2) { + String data = getSomeOtherString(); + doThingTwoWith(data); + } + } + String getSomeString() { return "one"; } + String getSomeOtherString() { return "two"; } + void doThingOneWith(String data) {} + void doThingTwoWith(String data) {} + } + """ + ) + ); + } + + @Issue("https://github.com/openrewrite/rewrite-static-analysis/issues/687") + @Test + void MultipleVariableRedeclarationInNewBlocks() { + rewriteRun( + //language=java + java( + """ + class Test { + int someInt; + void test() { + switch (someInt) { + case 1: + String data = getSomeString(), otherString = getSomeString(); + doThingOneWith(data); + doThingOneWith(otherString); + break; + case 2: + data = getSomeOtherString(); + otherString = getSomeOtherString(); + doThingTwoWith(data); + doThingTwoWith(otherString); + break; + } + } + String getSomeString() { return "one"; } + String getSomeOtherString() { return "two"; } + void doThingOneWith(String data) {} + void doThingTwoWith(String data) {} + } + """, + """ + class Test { + int someInt; + void test() { + if (someInt == 1) { + String data = getSomeString(), otherString = getSomeString(); + doThingOneWith(data); + doThingOneWith(otherString); + } else if (someInt == 2) { + String data = getSomeOtherString(); + String otherString = getSomeOtherString(); + doThingTwoWith(data); + doThingTwoWith(otherString); + } + } + String getSomeString() { return "one"; } + String getSomeOtherString() { return "two"; } + void doThingOneWith(String data) {} + void doThingTwoWith(String data) {} + } + """ + ) + ); + } }