22from __future__ import annotations
33
44import ast
5+ from itertools import chain
56from typing import TYPE_CHECKING , Optional
67
78import libcst as cst
@@ -119,6 +120,32 @@ def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> c
119120
120121 return updated_node
121122
123+ def _find_insertion_index (self , updated_node : cst .Module ) -> int :
124+ """Find the position of the last import statement in the top-level of the module."""
125+ insert_index = 0
126+ for i , stmt in enumerate (updated_node .body ):
127+ is_top_level_import = isinstance (stmt , cst .SimpleStatementLine ) and any (
128+ isinstance (child , (cst .Import , cst .ImportFrom )) for child in stmt .body
129+ )
130+
131+ is_conditional_import = isinstance (stmt , cst .If ) and all (
132+ isinstance (inner , cst .SimpleStatementLine )
133+ and all (isinstance (child , (cst .Import , cst .ImportFrom )) for child in inner .body )
134+ for inner in stmt .body .body
135+ )
136+
137+ if is_top_level_import or is_conditional_import :
138+ insert_index = i + 1
139+
140+ # Stop scanning once we reach a class or function definition.
141+ # Imports are supposed to be at the top of the file, but they can technically appear anywhere, even at the bottom of the file.
142+ # Without this check, a stray import later in the file
143+ # would incorrectly shift our insertion index below actual code definitions.
144+ if isinstance (stmt , (cst .ClassDef , cst .FunctionDef )):
145+ break
146+
147+ return insert_index
148+
122149 def leave_Module (self , original_node : cst .Module , updated_node : cst .Module ) -> cst .Module :
123150 # Add any new assignments that weren't in the original file
124151 new_statements = list (updated_node .body )
@@ -131,18 +158,26 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
131158 ]
132159
133160 if assignments_to_append :
134- # Add a blank line before appending new assignments if needed
135- if new_statements and not isinstance (new_statements [- 1 ], cst .EmptyLine ):
136- new_statements .append (cst .SimpleStatementLine ([cst .Pass ()], leading_lines = [cst .EmptyLine ()]))
137- new_statements .pop () # Remove the Pass statement but keep the empty line
138-
139- # Add the new assignments
140- new_statements .extend (
141- [
142- cst .SimpleStatementLine ([assignment ], leading_lines = [cst .EmptyLine ()])
143- for assignment in assignments_to_append
144- ]
145- )
161+ # after last top-level imports
162+ insert_index = self ._find_insertion_index (updated_node )
163+
164+ assignment_lines = [
165+ cst .SimpleStatementLine ([assignment ], leading_lines = [cst .EmptyLine ()])
166+ for assignment in assignments_to_append
167+ ]
168+
169+ new_statements = list (chain (new_statements [:insert_index ], assignment_lines , new_statements [insert_index :]))
170+
171+ # Add a blank line after the last assignment if needed
172+ after_index = insert_index + len (assignment_lines )
173+ if after_index < len (new_statements ):
174+ next_stmt = new_statements [after_index ]
175+ # If there's no empty line, add one
176+ has_empty = any (isinstance (line , cst .EmptyLine ) for line in next_stmt .leading_lines )
177+ if not has_empty :
178+ new_statements [after_index ] = next_stmt .with_changes (
179+ leading_lines = [cst .EmptyLine (), * next_stmt .leading_lines ]
180+ )
146181
147182 return updated_node .with_changes (body = new_statements )
148183
0 commit comments