@@ -63,6 +63,7 @@ public TreeVisitor<?, ExecutionContext> getVisitor() {
6363 private static class ExpectedExceptionToAssertThrowsVisitor extends JavaIsoVisitor <ExecutionContext > {
6464
6565 private static final String FIRST_EXPECTED_EXCEPTION_METHOD_INVOCATION = "firstExpectedExceptionMethodInvocation" ;
66+ private static final String STATEMENTS_BEFORE_EXPECT_EXCEPTION = "statementsBeforeExpectException" ;
6667 private static final String STATEMENTS_AFTER_EXPECT_EXCEPTION = "statementsAfterExpectException" ;
6768 private static final String HAS_MATCHER = "hasMatcher" ;
6869 private static final String EXCEPTION_CLASS = "exceptionClass" ;
@@ -100,13 +101,73 @@ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, Ex
100101 if (getCursor ().pollMessage ("hasExpectException" ) != null ) {
101102 List <NameTree > thrown = m .getThrows ();
102103 if (thrown != null && !thrown .isEmpty ()) {
104+ List <Statement > statementsBeforeExpect = getCursor ().pollMessage (STATEMENTS_BEFORE_EXPECT_EXCEPTION );
105+ if (statementsBeforeExpectThrowCheckedException (statementsBeforeExpect )) {
106+ return m ;
107+ }
103108 assert m .getBody () != null ;
104109 return m .withBody (m .getBody ().withPrefix (thrown .get (0 ).getPrefix ())).withThrows (emptyList ());
105110 }
106111 }
107112 return m ;
108113 }
109114
115+ private boolean statementsBeforeExpectThrowCheckedException (List <Statement > statements ) {
116+ return statements .stream ().anyMatch (this ::statementThrowsCheckedException );
117+ }
118+
119+ private boolean statementThrowsCheckedException (Statement statement ) {
120+ AtomicBoolean throwsChecked = new AtomicBoolean (false );
121+ new JavaIsoVisitor <AtomicBoolean >() {
122+ @ Override
123+ public J .MethodInvocation visitMethodInvocation (J .MethodInvocation method , AtomicBoolean found ) {
124+ if (found .get ()) {
125+ return method ;
126+ }
127+ JavaType .Method methodType = method .getMethodType ();
128+ if (methodType == null ) {
129+ return super .visitMethodInvocation (method , found );
130+ }
131+ List <JavaType > thrownExceptions = methodType .getThrownExceptions ();
132+ for (JavaType thrownException : thrownExceptions ) {
133+ if (isCheckedException (thrownException )) {
134+ found .set (true );
135+ return method ;
136+ }
137+ }
138+ return super .visitMethodInvocation (method , found );
139+ }
140+
141+ @ Override
142+ public J .NewClass visitNewClass (J .NewClass newClass , AtomicBoolean found ) {
143+ if (found .get ()) {
144+ return newClass ;
145+ }
146+ JavaType .Method constructorType = newClass .getConstructorType ();
147+ if (constructorType == null ) {
148+ return super .visitNewClass (newClass , found );
149+ }
150+ List <JavaType > thrownExceptions = constructorType .getThrownExceptions ();
151+ for (JavaType thrownException : thrownExceptions ) {
152+ if (isCheckedException (thrownException )) {
153+ found .set (true );
154+ return newClass ;
155+ }
156+ }
157+ return super .visitNewClass (newClass , found );
158+ }
159+ }.visit (statement , throwsChecked );
160+ return throwsChecked .get ();
161+ }
162+
163+ private boolean isCheckedException (JavaType exceptionType ) {
164+ if (exceptionType == null ) {
165+ return false ;
166+ }
167+ return !TypeUtils .isAssignableTo ("java.lang.RuntimeException" , exceptionType ) &&
168+ !TypeUtils .isAssignableTo ("java.lang.Error" , exceptionType );
169+ }
170+
110171 @ Override
111172 public J .Block visitBlock (J .Block block , ExecutionContext ctx ) {
112173 J .Block b = super .visitBlock (block , ctx );
@@ -175,7 +236,13 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
175236 return method ;
176237 }
177238 getCursor ().dropParentUntil (J .MethodDeclaration .class ::isInstance ).putMessage ("hasExpectException" , true );
178- getCursor ().dropParentUntil (J .Block .class ::isInstance ).computeMessageIfAbsent (FIRST_EXPECTED_EXCEPTION_METHOD_INVOCATION , k -> method );
239+ Cursor blockCursor = getCursor ().dropParentUntil (J .Block .class ::isInstance );
240+ blockCursor .computeMessageIfAbsent (FIRST_EXPECTED_EXCEPTION_METHOD_INVOCATION , k -> method );
241+
242+ List <Statement > predecessorStatements = findPredecessorStatements (getCursor ());
243+ getCursor ().dropParentUntil (J .MethodDeclaration .class ::isInstance )
244+ .computeMessageIfAbsent (STATEMENTS_BEFORE_EXPECT_EXCEPTION , k -> predecessorStatements );
245+
179246 List <Statement > successorStatements = findSuccessorStatements (getCursor ());
180247 getCursor ().putMessageOnFirstEnclosing (J .Block .class , STATEMENTS_AFTER_EXPECT_EXCEPTION , successorStatements );
181248 if (EXPECTED_EXCEPTION_CLASS_MATCHER .matches (method )) {
@@ -186,6 +253,25 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
186253 return method ;
187254 }
188255
256+ /**
257+ * From the current cursor point find all preceding statements in the method body.
258+ */
259+ private List <Statement > findPredecessorStatements (Cursor cursor ) {
260+ J .MethodDeclaration methodDecl = cursor .firstEnclosing (J .MethodDeclaration .class );
261+ if (methodDecl == null || methodDecl .getBody () == null ) {
262+ return emptyList ();
263+ }
264+ List <Statement > predecessorStatements = new ArrayList <>();
265+ Statement currentStatement = cursor .firstEnclosing (Statement .class );
266+ for (Statement statement : methodDecl .getBody ().getStatements ()) {
267+ if (statement == currentStatement ) {
268+ break ;
269+ }
270+ predecessorStatements .add (statement );
271+ }
272+ return predecessorStatements ;
273+ }
274+
189275 /**
190276 * From the current cursor point find all the next statements that can be executed in the current path.
191277 */
0 commit comments