|
24 | 24 |
|
25 | 25 | import ast |
26 | 26 | import os |
| 27 | +import argparse |
27 | 28 | import glob |
28 | 29 | import logging |
| 30 | +import re |
29 | 31 | from collections import defaultdict |
30 | | -from pathlib import Path |
31 | | -from typing import List, Dict, Any |
| 32 | +from typing import List, Dict, Any, Iterator |
32 | 33 |
|
33 | 34 | from . import name_utils |
34 | 35 | from . import utils |
@@ -511,7 +512,6 @@ def _generate_import_statement( |
511 | 512 | Returns: |
512 | 513 | A formatted, multi-line import statement string. |
513 | 514 | """ |
514 | | - |
515 | 515 | names = sorted(list(set([item[key] for item in context]))) |
516 | 516 | names_str = ",\n ".join(names) |
517 | 517 | return f"from {package} import (\n {names_str}\n)" |
@@ -542,7 +542,6 @@ def generate_code(config: Dict[str, Any], analysis_results: tuple) -> None: |
542 | 542 | """ |
543 | 543 | Generates source code files using Jinja2 templates. |
544 | 544 | """ |
545 | | - |
546 | 545 | data, all_imports, all_types, request_arg_schema = analysis_results |
547 | 546 | project_root = config["project_root"] |
548 | 547 | config_dir = config["config_dir"] |
@@ -618,3 +617,135 @@ def generate_code(config: Dict[str, Any], analysis_results: tuple) -> None: |
618 | 617 | ) |
619 | 618 |
|
620 | 619 | utils.write_code_to_file(output_path, final_code) |
| 620 | + |
| 621 | + |
| 622 | +# ============================================================================= |
| 623 | +# Section 4: Main Execution |
| 624 | +# ============================================================================= |
| 625 | + |
| 626 | + |
| 627 | +def setup_config_and_paths(config_path: str) -> Dict[str, Any]: |
| 628 | + """Loads the configuration and sets up necessary paths. |
| 629 | +
|
| 630 | + Args: |
| 631 | + config_path: The path to the YAML configuration file. |
| 632 | +
|
| 633 | + Returns: |
| 634 | + A dictionary containing the loaded configuration and paths. |
| 635 | + """ |
| 636 | + |
| 637 | + def find_project_root(start_path: str, markers: list[str]) -> str | None: |
| 638 | + """Finds the project root by searching upwards for a marker.""" |
| 639 | + current_path = os.path.abspath(start_path) |
| 640 | + while True: |
| 641 | + for marker in markers: |
| 642 | + if os.path.exists(os.path.join(current_path, marker)): |
| 643 | + return current_path |
| 644 | + parent_path = os.path.dirname(current_path) |
| 645 | + if parent_path == current_path: # Filesystem root |
| 646 | + return None |
| 647 | + current_path = parent_path |
| 648 | + |
| 649 | + # Load configuration from the YAML file. |
| 650 | + config = utils.load_config(config_path) |
| 651 | + |
| 652 | + # Determine the project root. |
| 653 | + script_dir = os.path.dirname(os.path.abspath(__file__)) |
| 654 | + project_root = find_project_root(script_dir, ["setup.py", ".git"]) |
| 655 | + if not project_root: |
| 656 | + project_root = os.getcwd() # Fallback to current directory |
| 657 | + |
| 658 | + # Set paths in the config dictionary. |
| 659 | + config["project_root"] = project_root |
| 660 | + config["config_dir"] = os.path.dirname(os.path.abspath(config_path)) |
| 661 | + |
| 662 | + return config |
| 663 | + |
| 664 | + |
| 665 | +def _execute_post_processing(config: Dict[str, Any]): |
| 666 | + """ |
| 667 | + Executes post-processing steps, such as patching existing files. |
| 668 | + """ |
| 669 | + project_root = config["project_root"] |
| 670 | + post_processing_jobs = config.get("post_processing_templates", []) |
| 671 | + |
| 672 | + for job in post_processing_jobs: |
| 673 | + template_path = os.path.join(config["config_dir"], job["template"]) |
| 674 | + target_file_path = os.path.join(project_root, job["target_file"]) |
| 675 | + |
| 676 | + if not os.path.exists(target_file_path): |
| 677 | + logging.warning( |
| 678 | + f"Target file {target_file_path} not found, skipping post-processing job." |
| 679 | + ) |
| 680 | + continue |
| 681 | + |
| 682 | + # Read the target file |
| 683 | + with open(target_file_path, "r") as f: |
| 684 | + lines = f.readlines() |
| 685 | + |
| 686 | + # --- Extract existing imports and __all__ members --- |
| 687 | + imports = [] |
| 688 | + all_list = [] |
| 689 | + all_start_index = -1 |
| 690 | + all_end_index = -1 |
| 691 | + |
| 692 | + for i, line in enumerate(lines): |
| 693 | + if line.strip().startswith("from ."): |
| 694 | + imports.append(line.strip()) |
| 695 | + if line.strip() == "__all__ = (": |
| 696 | + all_start_index = i |
| 697 | + if all_start_index != -1 and line.strip() == ")": |
| 698 | + all_end_index = i |
| 699 | + |
| 700 | + if all_start_index != -1 and all_end_index != -1: |
| 701 | + for i in range(all_start_index + 1, all_end_index): |
| 702 | + member = lines[i].strip().replace('"', "").replace(",", "") |
| 703 | + if member: |
| 704 | + all_list.append(member) |
| 705 | + |
| 706 | + # --- Add new items and sort --- |
| 707 | + for new_import in job.get("add_imports", []): |
| 708 | + if new_import not in imports: |
| 709 | + imports.append(new_import) |
| 710 | + imports.sort() |
| 711 | + imports = [f"{imp}\n" for imp in imports] # re-add newlines |
| 712 | + |
| 713 | + for new_member in job.get("add_to_all", []): |
| 714 | + if new_member not in all_list: |
| 715 | + all_list.append(new_member) |
| 716 | + all_list.sort() |
| 717 | + |
| 718 | + # --- Render the new file content --- |
| 719 | + template = utils.load_template(template_path) |
| 720 | + new_content = template.render( |
| 721 | + imports=imports, |
| 722 | + all_list=all_list, |
| 723 | + ) |
| 724 | + |
| 725 | + # --- Overwrite the target file --- |
| 726 | + with open(target_file_path, "w") as f: |
| 727 | + f.write(new_content) |
| 728 | + |
| 729 | + logging.info(f"Successfully post-processed and overwrote {target_file_path}") |
| 730 | + |
| 731 | + |
| 732 | +if __name__ == "__main__": |
| 733 | + parser = argparse.ArgumentParser( |
| 734 | + description="A generic Python code generator for clients." |
| 735 | + ) |
| 736 | + parser.add_argument("config", help="Path to the YAML configuration file.") |
| 737 | + args = parser.parse_args() |
| 738 | + |
| 739 | + # Load config and set up paths. |
| 740 | + config = setup_config_and_paths(args.config) |
| 741 | + |
| 742 | + # Analyze the source code. |
| 743 | + analysis_results = analyze_source_files(config) |
| 744 | + |
| 745 | + # Generate the new client code. |
| 746 | + generate_code(config, analysis_results) |
| 747 | + |
| 748 | + # Run post-processing steps. |
| 749 | + _execute_post_processing(config) |
| 750 | + |
| 751 | + # TODO: Ensure blacken gets called on the generated source files as a final step. |
0 commit comments