1313from collections import defaultdict
1414from typing import Dict , List
1515
16+ import black
17+
1618
1719class Refactorer :
1820 """
@@ -41,17 +43,30 @@ def suggest_dimensional_split(self) -> str:
4143 new_func_name = f"_{ self .function_node .name } _{ dimension } "
4244 new_func = self ._create_new_function (new_func_name , nodes )
4345 new_functions .append (new_func )
44- new_body_calls . append (
45- ast . Expr (
46- value = ast . Call (
47- func = ast . Name ( id = new_func_name , ctx = ast . Load ()),
48- args = [
49- ast . Name ( id = arg . arg , ctx = ast . Load ())
50- for arg in self . function_node . args . args
51- ] ,
52- keywords = [] ,
53- )
46+ # Handle 'self' for method calls
47+ is_method = (
48+ self . function_node . args . args
49+ and self . function_node . args . args [ 0 ]. arg == "self"
50+ )
51+ if is_method :
52+ call_func = ast . Attribute (
53+ value = ast . Name ( id = "self" , ctx = ast . Load ()) ,
54+ attr = new_func_name ,
55+ ctx = ast . Load (),
5456 )
57+ call_args = [
58+ ast .Name (id = arg .arg , ctx = ast .Load ())
59+ for arg in self .function_node .args .args [1 :]
60+ ]
61+ else :
62+ call_func = ast .Name (id = new_func_name , ctx = ast .Load ())
63+ call_args = [
64+ ast .Name (id = arg .arg , ctx = ast .Load ())
65+ for arg in self .function_node .args .args
66+ ]
67+
68+ new_body_calls .append (
69+ ast .Expr (value = ast .Call (func = call_func , args = call_args , keywords = []))
5570 )
5671
5772 original_func_rewritten = ast .FunctionDef (
@@ -70,13 +85,27 @@ def suggest_dimensional_split(self) -> str:
7085
7186 # Fix missing location info and unparse the entire module
7287 ast .fix_missing_locations (new_module )
73- final_code = ast .unparse (new_module )
88+ unformatted_code = ast .unparse (new_module )
89+
90+ # Format the generated code using black
91+ try :
92+ final_code = black .format_str (
93+ unformatted_code , mode = black .FileMode ()
94+ ).strip ()
95+ except black .NothingChanged :
96+ final_code = unformatted_code .strip ()
7497
7598 return "# --- Suggested Refactoring: Dimensional Split ---\n \n " + final_code
7699
77100 def _group_nodes_by_dimension (self ) -> Dict [str , List [ast .AST ]]:
78- """Groups the function's body nodes by their semantic dimension."""
101+ """
102+ Groups the function's body nodes by their semantic dimension,
103+ keeping control flow blocks together.
104+ """
79105 groups = defaultdict (list )
106+
107+ # This is a simplified approach. A more robust solution would
108+ # build a dependency graph.
80109 for node , dimension in self .execution_map .items ():
81110 groups [dimension ].append (node )
82111 return groups
0 commit comments