Skip to content

Commit b0b496a

Browse files
authored
Null check as switch case (#748)
* Null check as switch case
1 parent 94f75ce commit b0b496a

File tree

4 files changed

+949
-0
lines changed

4 files changed

+949
-0
lines changed
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
/*
2+
* Copyright 2025 the original author or authors.
3+
* <p>
4+
* Licensed under the Moderne Source Available License (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
* <p>
8+
* https://docs.moderne.io/licensing/moderne-source-available-license
9+
* <p>
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.openrewrite.java.migrate.lang;
17+
18+
import lombok.Value;
19+
import org.jspecify.annotations.Nullable;
20+
import org.openrewrite.Cursor;
21+
import org.openrewrite.java.JavaIsoVisitor;
22+
import org.openrewrite.java.search.SemanticallyEqual;
23+
import org.openrewrite.java.tree.Expression;
24+
import org.openrewrite.java.tree.J;
25+
import org.openrewrite.java.tree.JavaType;
26+
import org.openrewrite.java.tree.Statement;
27+
import org.openrewrite.trait.SimpleTraitMatcher;
28+
import org.openrewrite.trait.Trait;
29+
30+
import java.util.concurrent.atomic.AtomicBoolean;
31+
32+
import static org.openrewrite.java.tree.J.Binary.Type.Equal;
33+
34+
@Value
35+
public class NullCheck implements Trait<J.If> {
36+
Cursor cursor;
37+
Expression nullCheckedParameter;
38+
39+
public Statement whenNull() {
40+
return getTree().getThenPart();
41+
}
42+
43+
public @Nullable Statement whenNotNull() {
44+
J.If.Else else_ = getTree().getElsePart();
45+
return else_ == null ? null : else_.getBody();
46+
}
47+
48+
public boolean returns() {
49+
Statement statement = whenNull();
50+
if (statement instanceof J.Block) {
51+
for (Statement s : ((J.Block) statement).getStatements()) {
52+
if (s instanceof J.Return) {
53+
return true;
54+
}
55+
}
56+
return false;
57+
}
58+
return statement instanceof J.Return;
59+
}
60+
61+
/**
62+
* Calculates few potential cases where the null checked variable gets reassigned and only returns false if these cases DO NOT match.
63+
* In any other case this returns true as we do not know that particular situation yet -> no harm -> assume it could be altered in the block.
64+
* @return false only if we are 100% sure the block does not reassigns/changes the null checked variable.
65+
*/
66+
public boolean couldModifyNullCheckedValue() {
67+
Statement statement = whenNull();
68+
if (statement instanceof J.Block || statement instanceof Expression || statement instanceof J.Throw) {
69+
return couldModifyNullCheckedValue(statement, nullCheckedParameter);
70+
}
71+
// Cautious by default
72+
return true;
73+
}
74+
private static boolean couldModifyNullCheckedValue(J expression, Expression nullChecked) {
75+
if (nullChecked instanceof J.FieldAccess && couldModifyNullCheckedValue(expression, ((J.FieldAccess) nullChecked).getTarget())) {
76+
return true;
77+
}
78+
if (nullChecked instanceof J.MethodInvocation &&
79+
((J.MethodInvocation) nullChecked).getSelect() != null &&
80+
couldModifyNullCheckedValue(expression, ((J.MethodInvocation) nullChecked).getSelect())) {
81+
return true;
82+
}
83+
return new JavaIsoVisitor<AtomicBoolean>() {
84+
85+
private final boolean isCertainlyImmutable = nullChecked.getType() != null && JavaType.Primitive.fromClassName(nullChecked.getType().toString()) != null;
86+
87+
@Override
88+
public J.Identifier visitIdentifier(J.Identifier identifier, AtomicBoolean couldModifyValue) {
89+
J.Identifier id = super.visitIdentifier(identifier, couldModifyValue);
90+
if (!isCertainlyImmutable && SemanticallyEqual.areEqual(id, nullChecked)) {
91+
couldModifyValue.set(true);
92+
}
93+
return id;
94+
}
95+
@Override
96+
public J.Assignment visitAssignment(J.Assignment assignment, AtomicBoolean couldModifyValue) {
97+
J.Assignment as = super.visitAssignment(assignment, couldModifyValue);
98+
if (SemanticallyEqual.areEqual(as.getVariable(), nullChecked)) {
99+
couldModifyValue.set(true);
100+
}
101+
return as;
102+
}
103+
@Override
104+
public J.FieldAccess visitFieldAccess(J.FieldAccess fieldAccess, AtomicBoolean couldModifyValue) {
105+
J.FieldAccess fa = super.visitFieldAccess(fieldAccess, couldModifyValue);
106+
if (SemanticallyEqual.areEqual(fa, nullChecked) ||
107+
SemanticallyEqual.areEqual(fa.getTarget(), nullChecked)) {
108+
couldModifyValue.set(true);
109+
}
110+
return fa;
111+
}
112+
@Override
113+
public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, AtomicBoolean couldModifyValue) {
114+
J.MethodInvocation mi = super.visitMethodInvocation(method, couldModifyValue);
115+
if (SemanticallyEqual.areEqual(mi, nullChecked) ||
116+
((mi.getSelect() != null) && SemanticallyEqual.areEqual(mi.getSelect(), nullChecked))) {
117+
couldModifyValue.set(true);
118+
}
119+
return mi;
120+
}
121+
}.reduce(expression, new AtomicBoolean(false)).get();
122+
}
123+
124+
public static class Matcher extends SimpleTraitMatcher<NullCheck> {
125+
126+
public static Matcher nullCheck() {
127+
return new Matcher();
128+
}
129+
130+
@Override
131+
protected @Nullable NullCheck test(Cursor cursor) {
132+
if (cursor.getValue() instanceof J.If) {
133+
J.If if_ = cursor.getValue();
134+
if (if_.getIfCondition().getTree() instanceof J.Binary) {
135+
J.Binary binary = (J.Binary) if_.getIfCondition().getTree();
136+
if (binary.getOperator() == Equal) {
137+
if (J.Literal.isLiteralValue(binary.getLeft(), null)) {
138+
return new NullCheck(cursor, binary.getRight());
139+
} else if (J.Literal.isLiteralValue(binary.getRight(), null)) {
140+
return new NullCheck(cursor, binary.getLeft());
141+
}
142+
}
143+
}
144+
}
145+
return null;
146+
}
147+
}
148+
}
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
/*
2+
* Copyright 2025 the original author or authors.
3+
* <p>
4+
* Licensed under the Moderne Source Available License (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
* <p>
8+
* https://docs.moderne.io/licensing/moderne-source-available-license
9+
* <p>
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.openrewrite.java.migrate.lang;
17+
18+
import lombok.EqualsAndHashCode;
19+
import lombok.Value;
20+
import org.jspecify.annotations.Nullable;
21+
import org.openrewrite.*;
22+
import org.openrewrite.internal.ListUtils;
23+
import org.openrewrite.java.JavaTemplate;
24+
import org.openrewrite.java.JavaVisitor;
25+
import org.openrewrite.java.search.SemanticallyEqual;
26+
import org.openrewrite.java.tree.Expression;
27+
import org.openrewrite.java.tree.J;
28+
import org.openrewrite.java.tree.Space;
29+
import org.openrewrite.java.tree.Statement;
30+
import org.openrewrite.staticanalysis.kotlin.KotlinFileChecker;
31+
32+
import java.time.Duration;
33+
import java.util.Optional;
34+
import java.util.concurrent.atomic.AtomicReference;
35+
36+
import static java.util.Objects.requireNonNull;
37+
import static org.openrewrite.java.migrate.lang.NullCheck.Matcher.nullCheck;
38+
39+
@Value
40+
@EqualsAndHashCode(callSuper = false)
41+
public class NullCheckAsSwitchCase extends Recipe {
42+
@Override
43+
public String getDisplayName() {
44+
return "Add null check to existing switch cases";
45+
}
46+
47+
@Override
48+
public String getDescription() {
49+
return "In later Java 21+, null checks are valid in switch cases. " +
50+
"This recipe will only add null checks to existing switch cases if there are no other statements in between them " +
51+
"or if the block in the if statement is not impacting the flow of the switch.";
52+
}
53+
54+
@Override
55+
public Duration getEstimatedEffortPerOccurrence() {
56+
return Duration.ofMinutes(3);
57+
}
58+
59+
@Override
60+
public TreeVisitor<?, ExecutionContext> getVisitor() {
61+
return Preconditions.check(Preconditions.not(new KotlinFileChecker<>()), new JavaVisitor<ExecutionContext>() {
62+
@Override
63+
public J visitBlock(J.Block block, ExecutionContext ctx) {
64+
AtomicReference<@Nullable NullCheck> nullCheck = new AtomicReference<>();
65+
J.Block b = block.withStatements(ListUtils.map(block.getStatements(), (index, statement) -> {
66+
// Maybe remove a null check preceding a switch statement
67+
Optional<NullCheck> nullCheckOpt = nullCheck().get(statement, getCursor());
68+
if (nullCheckOpt.isPresent()) {
69+
NullCheck check = nullCheckOpt.get();
70+
J nextStatement = index + 1 < block.getStatements().size() ? block.getStatements().get(index + 1) : null;
71+
if (!(nextStatement instanceof J.Switch) ||
72+
hasNullCase((J.Switch) nextStatement) ||
73+
!SemanticallyEqual.areEqual(((J.Switch) nextStatement).getSelector().getTree(), check.getNullCheckedParameter()) ||
74+
check.returns() ||
75+
check.couldModifyNullCheckedValue()) {
76+
return statement;
77+
}
78+
nullCheck.set(check);
79+
return null;
80+
}
81+
82+
// Update the switch following a removed null check
83+
NullCheck check = nullCheck.getAndSet(null);
84+
if (check != null && statement instanceof J.Switch) {
85+
J.Switch aSwitch = (J.Switch) statement;
86+
J.Case nullCase = createNullCase(aSwitch, check.whenNull());
87+
return aSwitch.withCases(aSwitch.getCases().withStatements(
88+
ListUtils.insert(aSwitch.getCases().getStatements(), nullCase, 0)));
89+
}
90+
return statement;
91+
}));
92+
return super.visitBlock(b, ctx);
93+
}
94+
95+
private boolean hasNullCase(J.Switch switch_) {
96+
for (Statement c : switch_.getCases().getStatements()) {
97+
if (c instanceof J.Case) {
98+
for (J j : ((J.Case) c).getCaseLabels()) {
99+
if (j instanceof Expression && J.Literal.isLiteralValue((Expression) j, null)) {
100+
return true;
101+
}
102+
}
103+
}
104+
}
105+
return false;
106+
}
107+
108+
private J.Case createNullCase(J.Switch aSwitch, Statement whenNull) {
109+
if (whenNull instanceof J.Block && ((J.Block) whenNull).getStatements().size() == 1) {
110+
whenNull = ((J.Block) whenNull).getStatements().get(0);
111+
}
112+
String semicolon = whenNull instanceof J.Block ? "" : ";";
113+
J.Switch switchWithNullCase = JavaTemplate.apply(
114+
"switch(#{any()}) { case null -> #{any()}" + semicolon + " }",
115+
new Cursor(getCursor(), aSwitch),
116+
aSwitch.getCoordinates().replace(),
117+
aSwitch.getSelector().getTree(),
118+
whenNull);
119+
J.Case nullCase = (J.Case) switchWithNullCase.getCases().getStatements().get(0);
120+
return nullCase.withBody(requireNonNull(nullCase.getBody()).withPrefix(Space.SINGLE_SPACE));
121+
}
122+
});
123+
}
124+
}

src/main/resources/META-INF/rewrite/java-version-21.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,5 +141,6 @@ description: ->-
141141
tags:
142142
- java21
143143
recipeList:
144+
- org.openrewrite.java.migrate.lang.NullCheckAsSwitchCase
144145
- org.openrewrite.java.migrate.lang.RefineSwitchCases
145146
- org.openrewrite.java.migrate.lang.SwitchCaseEnumGuardToLabel

0 commit comments

Comments
 (0)