@@ -67,23 +67,34 @@ private Supplier<JavaParser> assertionsParser(ExecutionContext ctx) {
6767 }
6868 private static final MethodMatcher ASSERT_THROWS_MATCHER = new MethodMatcher ("org.junit.jupiter.api.Assertions assertThrows(..)" );
6969
70+ private static final JavaType THROWING_CALLABLE_TYPE = JavaType .buildType ("org.assertj.core.api.ThrowableAssert.ThrowingCallable" );
71+
7072 @ Override
7173 public J .MethodInvocation visitMethodInvocation (J .MethodInvocation method , ExecutionContext ctx ) {
7274 J .MethodInvocation mi = super .visitMethodInvocation (method , ctx );
7375 if (ASSERT_THROWS_MATCHER .matches (mi ) && mi .getArguments ().size () == 2 ) {
74- J .Lambda lambdaArg = (J .Lambda ) mi .getArguments ().get (1 );
75- lambdaArg = lambdaArg .withType (JavaType .buildType ("org.assertj.core.api.ThrowableAssert.ThrowingCallable" ));
76- mi = mi .withTemplate (
77- JavaTemplate
78- .builder (this ::getCursor ,
79- "assertThatExceptionOfType(#{any(java.lang.Class)}).isThrownBy(#{any(org.assertj.core.api.ThrowableAssert.ThrowingCallable)})" )
80- .javaParser (assertionsParser (ctx ))
81- .staticImports ("org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType" )
82- .build (),
83- mi .getCoordinates ().replace (),
84- mi .getArguments ().get (0 ), lambdaArg );
85- maybeAddImport ("org.assertj.core.api.AssertionsForClassTypes" , "assertThatExceptionOfType" );
86- maybeRemoveImport ("org.junit.jupiter.api.Assertions.assertThrows" );
76+ J executable = mi .getArguments ().get (1 );
77+ if (executable instanceof J .Lambda ) {
78+ executable = ((J .Lambda ) executable ).withType (THROWING_CALLABLE_TYPE );
79+ } else if (executable instanceof J .MemberReference ) {
80+ executable = ((J .MemberReference ) executable ).withType (THROWING_CALLABLE_TYPE );
81+ } else {
82+ executable = null ;
83+ }
84+
85+ if (executable != null ) {
86+ mi = mi .withTemplate (
87+ JavaTemplate
88+ .builder (this ::getCursor ,
89+ "assertThatExceptionOfType(#{any(java.lang.Class)}).isThrownBy(#{any(org.assertj.core.api.ThrowableAssert.ThrowingCallable)})" )
90+ .javaParser (assertionsParser (ctx ))
91+ .staticImports ("org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType" )
92+ .build (),
93+ mi .getCoordinates ().replace (),
94+ mi .getArguments ().get (0 ), executable );
95+ maybeAddImport ("org.assertj.core.api.AssertionsForClassTypes" , "assertThatExceptionOfType" );
96+ maybeRemoveImport ("org.junit.jupiter.api.Assertions.assertThrows" );
97+ }
8798 }
8899 return mi ;
89100 }
0 commit comments