Skip to content

Commit ee36bb8

Browse files
authored
Handle non-lambda arguments in SuccessFailureCallbackToBiConsumerVisitor (#929)
When addCallback() receives ternary expressions or other non-lambda arguments, the visitor would throw a ClassCastException attempting to cast J.Ternary to J.Lambda. Add instanceof checks before casting both the success and failure callback arguments, returning early when arguments are not lambdas.
1 parent 39acb32 commit ee36bb8

File tree

2 files changed

+46
-2
lines changed

2 files changed

+46
-2
lines changed

src/main/java/org/openrewrite/java/spring/util/concurrent/SuccessFailureCallbackToBiConsumerVisitor.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,25 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
5555
return mi;
5656
}
5757

58+
if (!(mi.getArguments().get(0) instanceof J.Lambda)) {
59+
return mi;
60+
}
5861
J.Lambda successCallback = (J.Lambda) mi.getArguments().get(0);
5962

6063
boolean isKafkaFailureCallback = false;
6164
J.Lambda failureCallback;
6265
if (mi.getArguments().get(1) instanceof J.TypeCast) {
66+
J.TypeCast typeCast = (J.TypeCast) mi.getArguments().get(1);
67+
if (!(typeCast.getExpression() instanceof J.Lambda)) {
68+
return mi;
69+
}
6370
// In this case, assume it's casted to `org.springframework.kafka.core.KafkaFailureCallback` only
64-
failureCallback = (J.Lambda) ((J.TypeCast) mi.getArguments().get(1)).getExpression();
71+
failureCallback = (J.Lambda) typeCast.getExpression();
6572
isKafkaFailureCallback = true;
66-
} else {
73+
} else if (mi.getArguments().get(1) instanceof J.Lambda) {
6774
failureCallback = (J.Lambda) mi.getArguments().get(1);
75+
} else {
76+
return mi;
6877
}
6978

7079
J.Identifier successParam = ((J.VariableDeclarations) successCallback.getParameters().getParameters().get(0)).getVariables().get(0).getName();

src/test/java/org/openrewrite/java/spring/util/concurrent/ListenableToCompletableFutureTest.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,41 @@ void test(CompletableFuture<String> future) {
263263
);
264264
}
265265

266+
@Test
267+
void addSuccessFailureCallbackWithTernaryArguments() {
268+
//language=java
269+
rewriteRun(
270+
java(
271+
"""
272+
import org.springframework.util.concurrent.ListenableFuture;
273+
import org.springframework.util.concurrent.SuccessCallback;
274+
import org.springframework.util.concurrent.FailureCallback;
275+
class A {
276+
void test(ListenableFuture<String> future, SuccessCallback<String> successCallback, FailureCallback failureCallback) {
277+
future.addCallback(
278+
successCallback != null ? result -> successCallback.onSuccess(result) : null,
279+
failureCallback != null ? ex -> failureCallback.onFailure(ex) : null);
280+
}
281+
}
282+
""",
283+
"""
284+
import org.springframework.util.concurrent.SuccessCallback;
285+
import org.springframework.util.concurrent.FailureCallback;
286+
287+
import java.util.concurrent.CompletableFuture;
288+
289+
class A {
290+
void test(CompletableFuture<String> future, SuccessCallback<String> successCallback, FailureCallback failureCallback) {
291+
future.whenComplete(
292+
successCallback != null ? result -> successCallback.onSuccess(result) : null,
293+
failureCallback != null ? ex -> failureCallback.onFailure(ex) : null);
294+
}
295+
}
296+
"""
297+
)
298+
);
299+
}
300+
266301
@Test
267302
void addSuccessFailureCallbackWithTypeCast() {
268303
//language=java

0 commit comments

Comments
 (0)