2323import org .openrewrite .java .JavaTemplate ;
2424import org .openrewrite .java .JavaVisitor ;
2525import org .openrewrite .java .search .SemanticallyEqual ;
26- import org .openrewrite .java .tree .Expression ;
27- import org .openrewrite .java .tree .J ;
28- import org .openrewrite .java .tree .Space ;
29- import org .openrewrite .java .tree .Statement ;
26+ import org .openrewrite .java .search .UsesJavaVersion ;
27+ import org .openrewrite .java .tree .*;
28+ import org .openrewrite .staticanalysis .groovy .GroovyFileChecker ;
3029import org .openrewrite .staticanalysis .kotlin .KotlinFileChecker ;
3130
3231import java .time .Duration ;
32+ import java .util .ArrayList ;
33+ import java .util .List ;
34+ import java .util .Objects ;
3335import java .util .Optional ;
3436import java .util .concurrent .atomic .AtomicReference ;
3537
@@ -58,7 +60,13 @@ public Duration getEstimatedEffortPerOccurrence() {
5860
5961 @ Override
6062 public TreeVisitor <?, ExecutionContext > getVisitor () {
61- return Preconditions .check (Preconditions .not (new KotlinFileChecker <>()), new JavaVisitor <ExecutionContext >() {
63+ TreeVisitor <?, ExecutionContext > preconditions = Preconditions .and (
64+ new UsesJavaVersion <>(21 ),
65+ Preconditions .not (new KotlinFileChecker <>()),
66+ Preconditions .not (new GroovyFileChecker <>())
67+ );
68+
69+ return Preconditions .check (preconditions , new JavaVisitor <ExecutionContext >() {
6270 @ Override
6371 public J visitBlock (J .Block block , ExecutionContext ctx ) {
6472 AtomicReference <@ Nullable NullCheck > nullCheck = new AtomicReference <>();
@@ -68,13 +76,18 @@ public J visitBlock(J.Block block, ExecutionContext ctx) {
6876 if (nullCheckOpt .isPresent ()) {
6977 NullCheck check = nullCheckOpt .get ();
7078 J nextStatement = index + 1 < block .getStatements ().size () ? block .getStatements ().get (index + 1 ) : null ;
71- if (!(nextStatement instanceof J .Switch ) ||
72- hasNullCase ((J .Switch ) nextStatement ) ||
73- !SemanticallyEqual .areEqual (((J .Switch ) nextStatement ).getSelector ().getTree (), check .getNullCheckedParameter ()) ||
74- check .returns () ||
75- check .couldModifyNullCheckedValue ()) {
79+ if (!(nextStatement instanceof J .Switch ) || check .returns () || check .couldModifyNullCheckedValue ()) {
80+ return statement ;
81+ }
82+ J .Switch nextSwitch = (J .Switch ) nextStatement ;
83+ // Only if the switch does not have a null case and switches on the same value as the null check, we can remove the null check
84+ // It must have all possible input values covered
85+ if (hasNullCase (nextSwitch ) ||
86+ !SemanticallyEqual .areEqual (nextSwitch .getSelector ().getTree (), check .getNullCheckedParameter ()) ||
87+ !coversAllPossibleValues (nextSwitch )) {
7688 return statement ;
7789 }
90+
7891 nullCheck .set (check );
7992 return null ;
8093 }
@@ -106,6 +119,16 @@ private boolean hasNullCase(J.Switch switch_) {
106119 }
107120
108121 private J .Case createNullCase (J .Switch aSwitch , Statement whenNull ) {
122+ J .Case currentFirstCase = aSwitch .getCases ().getStatements ().isEmpty () ||
123+ !(aSwitch .getCases ().getStatements ().get (0 ) instanceof J .Case ) ?
124+ null : (J .Case ) aSwitch .getCases ().getStatements ().get (0 );
125+ if (currentFirstCase == null || J .Case .Type .Rule == currentFirstCase .getType ()) {
126+ return createCaseRule (aSwitch , whenNull );
127+ }
128+ return createCaseStatement (aSwitch , whenNull , currentFirstCase );
129+ }
130+
131+ private J .Case createCaseRule (J .Switch aSwitch , Statement whenNull ) {
109132 if (whenNull instanceof J .Block && ((J .Block ) whenNull ).getStatements ().size () == 1 ) {
110133 Statement firstStatement = ((J .Block ) whenNull ).getStatements ().get (0 );
111134 if (firstStatement instanceof Expression || firstStatement instanceof J .Throw ) {
@@ -122,6 +145,60 @@ private J.Case createNullCase(J.Switch aSwitch, Statement whenNull) {
122145 J .Case nullCase = (J .Case ) switchWithNullCase .getCases ().getStatements ().get (0 );
123146 return nullCase .withBody (requireNonNull (nullCase .getBody ()).withPrefix (Space .SINGLE_SPACE ));
124147 }
148+
149+ private J .Case createCaseStatement (J .Switch aSwitch , Statement whenNull , J .Case currentFirstCase ) {
150+ List <J > statements = new ArrayList <>();
151+ statements .add (aSwitch .getSelector ().getTree ());
152+ if (whenNull instanceof J .Block ) {
153+ statements .addAll (((J .Block ) whenNull ).getStatements ());
154+ } else {
155+ statements .add (whenNull );
156+ }
157+ StringBuilder template = new StringBuilder ("switch(#{any()}) {\n case null:" );
158+ for (int i = 1 ; i < statements .size (); i ++) {
159+ template .append ("\n #{any()};" );
160+ }
161+ template .append ("\n break;\n }" );
162+ J .Switch switchWithNullCase = JavaTemplate .apply (
163+ template .toString (),
164+ new Cursor (getCursor (), aSwitch ),
165+ aSwitch .getCoordinates ().replace (),
166+ statements .toArray ());
167+ J .Case nullCase = (J .Case ) switchWithNullCase .getCases ().getStatements ().get (0 );
168+ Space currentFirstCaseIndentation = currentFirstCase .getStatements ().stream ().map (J ::getPrefix ).findFirst ().orElse (Space .SINGLE_SPACE );
169+
170+ return nullCase .withStatements (ListUtils .mapFirst (nullCase .getStatements (), s -> s == null ? null : s .withPrefix (currentFirstCaseIndentation )));
171+ }
172+
173+ private boolean coversAllPossibleValues (J .Switch switch_ ) {
174+ List <J > labels = new ArrayList <>();
175+ for (Statement statement : switch_ .getCases ().getStatements ()) {
176+ for (J j : ((J .Case ) statement ).getCaseLabels ()) {
177+ if (j instanceof J .Identifier && "default" .equals (((J .Identifier ) j ).getSimpleName ())) {
178+ return true ;
179+ }
180+ labels .add (j );
181+ }
182+ }
183+ JavaType javaType = switch_ .getSelector ().getTree ().getType ();
184+ if (javaType instanceof JavaType .Class && ((JavaType .Class ) javaType ).getKind () == JavaType .FullyQualified .Kind .Enum ) {
185+ // Every enum value must be present in the switch
186+ return ((JavaType .Class ) javaType ).getMembers ().stream ().allMatch (variable ->
187+ labels .stream ().anyMatch (label -> {
188+ if (!(label instanceof TypeTree && TypeUtils .isOfType (((TypeTree ) label ).getType (), javaType ))) {
189+ return false ;
190+ }
191+ J .Identifier enumName = null ;
192+ if (label instanceof J .Identifier ) {
193+ enumName = (J .Identifier ) label ;
194+ } else if (label instanceof J .FieldAccess ) {
195+ enumName = ((J .FieldAccess ) label ).getName ();
196+ }
197+ return enumName != null && Objects .equals (variable .getName (), enumName .getSimpleName ());
198+ }));
199+ }
200+ return false ;
201+ }
125202 });
126203 }
127204}
0 commit comments