2727import argparse
2828import glob
2929import logging
30+ import re
3031from collections import defaultdict
3132from pathlib import Path
3233from typing import List , Dict , Any
@@ -46,7 +47,7 @@ class CodeAnalyzer(ast.NodeVisitor):
4647 """
4748
4849 def __init__ (self ):
49- self .structure : List [Dict [str , Any ]] = []
50+ self .analyzed_classes : List [Dict [str , Any ]] = []
5051 self .imports : set [str ] = set ()
5152 self .types : set [str ] = set ()
5253 self ._current_class_info : Dict [str , Any ] | None = None
@@ -105,13 +106,19 @@ def _collect_types_from_node(self, node: ast.AST | None) -> None:
105106 if type_str :
106107 self .types .add (type_str )
107108 elif isinstance (node , ast .Subscript ):
108- self ._collect_types_from_node (node .value )
109+ # Add the base type of the subscript (e.g., "List", "Dict")
110+ if isinstance (node .value , ast .Name ):
111+ self .types .add (node .value .id )
112+ self ._collect_types_from_node (node .value ) # Recurse on value just in case
109113 self ._collect_types_from_node (node .slice )
110114 elif isinstance (node , (ast .Tuple , ast .List )):
111115 for elt in node .elts :
112116 self ._collect_types_from_node (elt )
113- elif isinstance (node , ast .Constant ) and isinstance (node .value , str ):
114- self .types .add (node .value )
117+ elif isinstance (node , ast .Constant ):
118+ if isinstance (node .value , str ): # Forward references
119+ self .types .add (node .value )
120+ elif node .value is None : # None type
121+ self .types .add ("None" )
115122 elif isinstance (node , ast .BinOp ) and isinstance (
116123 node .op , ast .BitOr
117124 ): # For | union type
@@ -163,7 +170,7 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None:
163170 type_str = self ._get_type_str (item .annotation )
164171 class_info ["attributes" ].append ({"name" : attr_name , "type" : type_str })
165172
166- self .structure .append (class_info )
173+ self .analyzed_classes .append (class_info )
167174 self ._current_class_info = class_info
168175 self ._depth += 1
169176 self .generic_visit (node )
@@ -259,6 +266,7 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
259266 # directly within the class body, not inside a method.
260267 elif isinstance (target , ast .Name ) and not self ._is_in_method :
261268 self ._add_attribute (target .id , self ._get_type_str (node .annotation ))
269+ self ._collect_types_from_node (node .annotation )
262270 self .generic_visit (node )
263271
264272
@@ -279,7 +287,7 @@ def parse_code(code: str) -> tuple[List[Dict[str, Any]], set[str], set[str]]:
279287 tree = ast .parse (code )
280288 analyzer = CodeAnalyzer ()
281289 analyzer .visit (tree )
282- return analyzer .structure , analyzer .imports , analyzer .types
290+ return analyzer .analyzed_classes , analyzer .imports , analyzer .types
283291
284292
285293def parse_file (file_path : str ) -> tuple [List [Dict [str , Any ]], set [str ], set [str ]]:
@@ -331,10 +339,10 @@ def list_code_objects(
331339 all_class_keys = []
332340
333341 def process_structure (
334- structure : List [Dict [str , Any ]], file_name : str | None = None
342+ analyzed_classes : List [Dict [str , Any ]], file_name : str | None = None
335343 ):
336344 """Populates the results dictionary from the parsed AST structure."""
337- for class_info in structure :
345+ for class_info in analyzed_classes :
338346 key = class_info ["class_name" ]
339347 if file_name :
340348 key = f"{ key } (in { file_name } )"
@@ -360,13 +368,13 @@ def process_structure(
360368
361369 # Determine if the path is a file or directory and process accordingly
362370 if os .path .isfile (path ) and path .endswith (".py" ):
363- structure , _ , _ = parse_file (path )
364- process_structure (structure )
371+ analyzed_classes , _ , _ = parse_file (path )
372+ process_structure (analyzed_classes )
365373 elif os .path .isdir (path ):
366374 # This assumes `utils.walk_codebase` is defined elsewhere.
367375 for file_path in utils .walk_codebase (path ):
368- structure , _ , _ = parse_file (file_path )
369- process_structure (structure , file_name = os .path .basename (file_path ))
376+ analyzed_classes , _ , _ = parse_file (file_path )
377+ process_structure (analyzed_classes , file_name = os .path .basename (file_path ))
370378
371379 # Return the data in the desired format based on the flags
372380 if not show_methods and not show_attributes :
@@ -466,11 +474,11 @@ def _build_request_arg_schema(
466474 module_name = os .path .splitext (relative_path )[0 ].replace (os .path .sep , "." )
467475
468476 try :
469- structure , _ , _ = parse_file (file_path )
470- if not structure :
477+ analyzed_classes , _ , _ = parse_file (file_path )
478+ if not analyzed_classes :
471479 continue
472480
473- for class_info in structure :
481+ for class_info in analyzed_classes :
474482 class_name = class_info .get ("class_name" , "Unknown" )
475483 if class_name .endswith ("Request" ):
476484 full_class_name = f"{ module_name } .{ class_name } "
@@ -498,11 +506,11 @@ def _process_service_clients(
498506 if "/services/" not in file_path :
499507 continue
500508
501- structure , imports , types = parse_file (file_path )
509+ analyzed_classes , imports , types = parse_file (file_path )
502510 all_imports .update (imports )
503511 all_types .update (types )
504512
505- for class_info in structure :
513+ for class_info in analyzed_classes :
506514 class_name = class_info ["class_name" ]
507515 if not _should_include_class (class_name , class_filters ):
508516 continue
@@ -546,7 +554,6 @@ def analyze_source_files(
546554 # Make the pattern absolute
547555 absolute_pattern = os .path .join (project_root , pattern )
548556 source_files .extend (glob .glob (absolute_pattern , recursive = True ))
549-
550557 # PASS 1: Build the request argument schema from the types files.
551558 request_arg_schema = _build_request_arg_schema (source_files , project_root )
552559
@@ -607,14 +614,14 @@ def generate_code(config: Dict[str, Any], analysis_results: tuple) -> None:
607614 Generates source code files using Jinja2 templates.
608615 """
609616 data , all_imports , all_types , request_arg_schema = analysis_results
610- project_root = config ["project_root" ]
611- config_dir = config ["config_dir" ]
612617
613618 templates_config = config .get ("templates" , [])
614619 for item in templates_config :
615- template_path = str ( Path ( config_dir ) / item ["template" ])
616- output_path = str ( Path ( project_root ) / item ["output" ])
620+ template_name = item ["template" ]
621+ output_name = item ["output" ]
617622
623+ template_path = str (Path (config ["config_dir" ]) / template_name )
624+ output_path = str (Path (config ["project_root" ]) / output_name )
618625 template = utils .load_template (template_path )
619626 methods_context = []
620627 for class_name , methods in data .items ():
@@ -710,18 +717,20 @@ def find_project_root(start_path: str, markers: list[str]) -> str | None:
710717 return None
711718 current_path = parent_path
712719
713- # Load configuration from the YAML file.
714- config = utils .load_config (config_path )
720+ # Get the absolute path of the config file
721+ abs_config_path = os .path .abspath (config_path )
722+ config = utils .load_config (abs_config_path )
715723
716- # Determine the project root.
717- script_dir = os .path .dirname (os .path .abspath (__file__ ))
718- project_root = find_project_root (script_dir , ["setup.py" , ".git" ])
724+ # Determine the project root
725+ # Start searching from the directory of this script file
726+ script_file_dir = os .path .dirname (os .path .abspath (__file__ ))
727+ project_root = find_project_root (script_file_dir , [".git" ])
719728 if not project_root :
720- project_root = os .getcwd () # Fallback to current directory
729+ # Fallback to the directory from which the script was invoked
730+ project_root = os .getcwd ()
721731
722- # Set paths in the config dictionary.
723732 config ["project_root" ] = project_root
724- config ["config_dir" ] = os .path .dirname (os . path . abspath ( config_path ) )
733+ config ["config_dir" ] = os .path .dirname (abs_config_path )
725734
726735 return config
727736
@@ -762,9 +771,12 @@ def _execute_post_processing(config: Dict[str, Any]):
762771 all_end_index = i
763772
764773 if all_start_index != - 1 and all_end_index != - 1 :
765- for i in range (all_start_index + 1 , all_end_index ):
766- member = lines [i ].strip ().replace ('"' , "" ).replace ("," , "" )
767- if member :
774+ all_content = "" .join (lines [all_start_index + 1 : all_end_index ])
775+
776+ # Extract quoted strings
777+ found_members = re .findall (r'"([^"]+)"' , all_content )
778+ for member in found_members :
779+ if member not in all_list :
768780 all_list .append (member )
769781
770782 # --- Add new items and sort ---
@@ -777,7 +789,9 @@ def _execute_post_processing(config: Dict[str, Any]):
777789 for new_member in job .get ("add_to_all" , []):
778790 if new_member not in all_list :
779791 all_list .append (new_member )
780- all_list .sort ()
792+ all_list = sorted (list (set (all_list ))) # Ensure unique and sorted
793+ # Format for the template
794+ all_list = [f' "{ item } ",\n ' for item in all_list ]
781795
782796 # --- Render the new file content ---
783797 template = utils .load_template (template_path )
0 commit comments