@@ -26,13 +26,14 @@ def visit_statement(self, statement: Statement, p: P) -> J:
2626
2727 parent_cursor = self .cursor .parent_tree_cursor ()
2828 top_level = isinstance (parent_cursor .value , CompilationUnit )
29- if top_level and isinstance (statement , (Import , MultiImport )):
30- parent_cursor .put_message ('previous_import' , True )
29+
30+ if isinstance (statement , (Import , MultiImport )):
31+ parent_cursor .put_message ('prev_import' , True )
3132 prev_import = False
3233 else :
33- prev_import = top_level and parent_cursor .get_message ('previous_import ' , False )
34+ prev_import = parent_cursor .get_message ('prev_import ' , False )
3435 if prev_import :
35- parent_cursor .put_message ('previous_import ' , False )
36+ parent_cursor .put_message ('prev_import ' , False )
3637
3738 if top_level :
3839 if statement == cast (CompilationUnit , parent_cursor .value ).statements [0 ]:
@@ -44,12 +45,16 @@ def visit_statement(self, statement: Statement, p: P) -> J:
4445 else :
4546 in_block = isinstance (parent_cursor .value , Block )
4647 in_class = in_block and isinstance (parent_cursor .parent_tree_cursor ().value , ClassDeclaration )
48+ min_lines = 0
4749 if in_class :
4850 is_first = cast (Block , parent_cursor .value ).statements [0 ] is statement
4951 if not is_first and isinstance (statement , MethodDeclaration ):
50- statement = minimum_lines_for_tree ( statement , self ._style .minimum ._around_method )
52+ min_lines = max ( min_lines , self ._style .minimum .around_method )
5153 elif not is_first and isinstance (statement , ClassDeclaration ):
52- statement = minimum_lines_for_tree (statement , self ._style .minimum ._around_class )
54+ min_lines = max (min_lines , self ._style .minimum .around_class )
55+ if prev_import :
56+ min_lines = max (min_lines , self ._style .minimum .after_local_imports )
57+ statement = minimum_lines_for_tree (statement , min_lines )
5358 return statement
5459
5560 def post_visit (self , tree : T , p : P ) -> Optional [T ]:
@@ -66,6 +71,8 @@ def minimum_lines_for_right_padded(tree: JRightPadded[J2], min_lines) -> JRightP
6671
6772
6873def minimum_lines_for_tree (tree : J , min_lines ) -> J :
74+ if min_lines == 0 :
75+ return tree
6976 return tree .with_prefix (minimum_lines_for_space (tree .prefix , min_lines ))
7077
7178
0 commit comments