2323import org .openrewrite .java .JavaTemplate ;
2424import org .openrewrite .java .JavaVisitor ;
2525import org .openrewrite .java .search .SemanticallyEqual ;
26- import org .openrewrite .java .tree .Expression ;
27- import org .openrewrite .java .tree .J ;
28- import org .openrewrite .java .tree .Space ;
29- import org .openrewrite .java .tree .Statement ;
26+ import org .openrewrite .java .search .UsesJavaVersion ;
27+ import org .openrewrite .java .tree .*;
28+ import org .openrewrite .staticanalysis .groovy .GroovyFileChecker ;
3029import org .openrewrite .staticanalysis .kotlin .KotlinFileChecker ;
3130
3231import java .time .Duration ;
32+ import java .util .ArrayList ;
33+ import java .util .List ;
34+ import java .util .Objects ;
3335import java .util .Optional ;
3436import java .util .concurrent .atomic .AtomicReference ;
3537
@@ -58,7 +60,13 @@ public Duration getEstimatedEffortPerOccurrence() {
5860
5961 @ Override
6062 public TreeVisitor <?, ExecutionContext > getVisitor () {
61- return Preconditions .check (Preconditions .not (new KotlinFileChecker <>()), new JavaVisitor <ExecutionContext >() {
63+ TreeVisitor <?, ExecutionContext > preconditions = Preconditions .and (
64+ new UsesJavaVersion <>(21 ),
65+ Preconditions .not (new KotlinFileChecker <>()),
66+ Preconditions .not (new GroovyFileChecker <>())
67+ );
68+
69+ return Preconditions .check (preconditions , new JavaVisitor <ExecutionContext >() {
6270 @ Override
6371 public J visitBlock (J .Block block , ExecutionContext ctx ) {
6472 AtomicReference <@ Nullable NullCheck > nullCheck = new AtomicReference <>();
@@ -68,13 +76,18 @@ public J visitBlock(J.Block block, ExecutionContext ctx) {
6876 if (nullCheckOpt .isPresent ()) {
6977 NullCheck check = nullCheckOpt .get ();
7078 J nextStatement = index + 1 < block .getStatements ().size () ? block .getStatements ().get (index + 1 ) : null ;
71- if (!(nextStatement instanceof J .Switch ) ||
72- hasNullCase ((J .Switch ) nextStatement ) ||
73- !SemanticallyEqual .areEqual (((J .Switch ) nextStatement ).getSelector ().getTree (), check .getNullCheckedParameter ()) ||
74- check .returns () ||
75- check .couldModifyNullCheckedValue ()) {
79+ if (!(nextStatement instanceof J .Switch ) || check .returns () || check .couldModifyNullCheckedValue ()) {
80+ return statement ;
81+ }
82+ J .Switch nextSwitch = (J .Switch ) nextStatement ;
83+ // Only if the switch does not have a null case and switches on the same value as the null check, we can remove the null check
84+ // It must have all possible input values covered
85+ if (hasNullCase (nextSwitch ) ||
86+ !SemanticallyEqual .areEqual (nextSwitch .getSelector ().getTree (), check .getNullCheckedParameter ()) ||
87+ !coversAllPossibleValues (nextSwitch )) {
7688 return statement ;
7789 }
90+
7891 nullCheck .set (check );
7992 return null ;
8093 }
@@ -106,6 +119,16 @@ private boolean hasNullCase(J.Switch switch_) {
106119 }
107120
108121 private J .Case createNullCase (J .Switch aSwitch , Statement whenNull ) {
122+ J .Case currentFirstCase = aSwitch .getCases ().getStatements ().isEmpty () ||
123+ !(aSwitch .getCases ().getStatements ().get (0 ) instanceof J .Case ) ?
124+ null : (J .Case ) aSwitch .getCases ().getStatements ().get (0 );
125+ if (currentFirstCase == null || J .Case .Type .Rule == currentFirstCase .getType ()) {
126+ return createCaseRule (aSwitch , whenNull );
127+ }
128+ return createCaseStatement (aSwitch , whenNull , currentFirstCase );
129+ }
130+
131+ private J .Case createCaseRule (J .Switch aSwitch , Statement whenNull ) {
109132 if (whenNull instanceof J .Block && ((J .Block ) whenNull ).getStatements ().size () == 1 ) {
110133 Statement firstStatement = ((J .Block ) whenNull ).getStatements ().get (0 );
111134 if (firstStatement instanceof Expression || firstStatement instanceof J .Throw ) {
@@ -122,6 +145,75 @@ private J.Case createNullCase(J.Switch aSwitch, Statement whenNull) {
122145 J .Case nullCase = (J .Case ) switchWithNullCase .getCases ().getStatements ().get (0 );
123146 return nullCase .withBody (requireNonNull (nullCase .getBody ()).withPrefix (Space .SINGLE_SPACE ));
124147 }
148+
149+ private J .Case createCaseStatement (J .Switch aSwitch , Statement whenNull , J .Case currentFirstCase ) {
150+ List <J > statements = new ArrayList <>();
151+ statements .add (aSwitch .getSelector ().getTree ());
152+ if (whenNull instanceof J .Block ) {
153+ statements .addAll (((J .Block ) whenNull ).getStatements ());
154+ } else {
155+ statements .add (whenNull );
156+ }
157+
158+ // Check if the last statement is a throw statement
159+ Statement lastStatement = null ;
160+ if (whenNull instanceof J .Block ) {
161+ List <Statement > blockStatements = ((J .Block ) whenNull ).getStatements ();
162+ if (!blockStatements .isEmpty ()) {
163+ lastStatement = blockStatements .get (blockStatements .size () - 1 );
164+ }
165+ } else {
166+ lastStatement = whenNull ;
167+ }
168+
169+ StringBuilder template = new StringBuilder ("switch(#{any()}) {\n case null:" );
170+ for (int i = 1 ; i < statements .size (); i ++) {
171+ template .append ("\n #{any()};" );
172+ }
173+ if (!(lastStatement instanceof J .Throw )) {
174+ template .append ("\n break;" );
175+ }
176+ template .append ("\n }" );
177+ J .Switch switchWithNullCase = JavaTemplate .apply (
178+ template .toString (),
179+ new Cursor (getCursor (), aSwitch ),
180+ aSwitch .getCoordinates ().replace (),
181+ statements .toArray ());
182+ J .Case nullCase = (J .Case ) switchWithNullCase .getCases ().getStatements ().get (0 );
183+ Space currentFirstCaseIndentation = currentFirstCase .getStatements ().stream ().map (J ::getPrefix ).findFirst ().orElse (Space .SINGLE_SPACE );
184+
185+ return nullCase .withStatements (ListUtils .mapFirst (nullCase .getStatements (), s -> s == null ? null : s .withPrefix (currentFirstCaseIndentation )));
186+ }
187+
188+ private boolean coversAllPossibleValues (J .Switch switch_ ) {
189+ List <J > labels = new ArrayList <>();
190+ for (Statement statement : switch_ .getCases ().getStatements ()) {
191+ for (J j : ((J .Case ) statement ).getCaseLabels ()) {
192+ if (j instanceof J .Identifier && "default" .equals (((J .Identifier ) j ).getSimpleName ())) {
193+ return true ;
194+ }
195+ labels .add (j );
196+ }
197+ }
198+ JavaType javaType = switch_ .getSelector ().getTree ().getType ();
199+ if (javaType instanceof JavaType .Class && ((JavaType .Class ) javaType ).getKind () == JavaType .FullyQualified .Kind .Enum ) {
200+ // Every enum value must be present in the switch
201+ return ((JavaType .Class ) javaType ).getMembers ().stream ().allMatch (variable ->
202+ labels .stream ().anyMatch (label -> {
203+ if (!(label instanceof TypeTree && TypeUtils .isOfType (((TypeTree ) label ).getType (), javaType ))) {
204+ return false ;
205+ }
206+ J .Identifier enumName = null ;
207+ if (label instanceof J .Identifier ) {
208+ enumName = (J .Identifier ) label ;
209+ } else if (label instanceof J .FieldAccess ) {
210+ enumName = ((J .FieldAccess ) label ).getName ();
211+ }
212+ return enumName != null && Objects .equals (variable .getName (), enumName .getSimpleName ());
213+ }));
214+ }
215+ return false ;
216+ }
125217 });
126218 }
127219}
0 commit comments