|
24 | 24 |
|
25 | 25 | import ast |
26 | 26 | import os |
| 27 | +import argparse |
27 | 28 | import glob |
28 | 29 | import logging |
29 | 30 | import re |
@@ -492,3 +493,259 @@ def analyze_source_files( |
492 | 493 |
|
493 | 494 | return parsed_data, all_imports, all_types, request_arg_schema |
494 | 495 |
|
| 496 | + |
| 497 | +# ============================================================================= |
| 498 | +# Section 3: Code Generation |
| 499 | +# ============================================================================= |
| 500 | + |
| 501 | + |
| 502 | +def _generate_import_statement( |
| 503 | + context: List[Dict[str, Any]], key: str, package: str |
| 504 | +) -> str: |
| 505 | + """Generates a formatted import statement from a list of context dictionaries. |
| 506 | +
|
| 507 | + Args: |
| 508 | + context: A list of dictionaries containing the data. |
| 509 | + key: The key to extract from each dictionary in the context. |
| 510 | + package: The base import package (e.g., "google.cloud.bigquery_v2.services"). |
| 511 | +
|
| 512 | + Returns: |
| 513 | + A formatted, multi-line import statement string. |
| 514 | + """ |
| 515 | + names = sorted(list(set([item[key] for item in context]))) |
| 516 | + names_str = ",\n ".join(names) |
| 517 | + return f"from {package} import (\n {names_str}\n)" |
| 518 | + |
| 519 | + |
| 520 | +def _get_request_class_name(method_name: str, config: Dict[str, Any]) -> str: |
| 521 | + """Gets the inferred request class name, applying overrides from config.""" |
| 522 | + inferred_request_name = name_utils.method_to_request_class_name(method_name) |
| 523 | + method_overrides = config.get("filter", {}).get("methods", {}).get("overrides", {}) |
| 524 | + if method_name in method_overrides: |
| 525 | + return method_overrides[method_name].get( |
| 526 | + "request_class_name", inferred_request_name |
| 527 | + ) |
| 528 | + return inferred_request_name |
| 529 | + |
| 530 | + |
| 531 | +def _find_fq_request_name( |
| 532 | + request_name: str, request_arg_schema: Dict[str, List[str]] |
| 533 | +) -> str: |
| 534 | + """Finds the fully qualified request name in the schema.""" |
| 535 | + for key in request_arg_schema.keys(): |
| 536 | + if key.endswith(f".{request_name}"): |
| 537 | + return key |
| 538 | + return "" |
| 539 | + |
| 540 | + |
| 541 | +def generate_code(config: Dict[str, Any], analysis_results: tuple) -> None: |
| 542 | + """ |
| 543 | + Generates source code files using Jinja2 templates. |
| 544 | + """ |
| 545 | + data, all_imports, all_types, request_arg_schema = analysis_results |
| 546 | + project_root = config["project_root"] |
| 547 | + config_dir = config["config_dir"] |
| 548 | + |
| 549 | + templates_config = config.get("templates", []) |
| 550 | + for item in templates_config: |
| 551 | + template_path = str(Path(config_dir) / item["template"]) |
| 552 | + output_path = str(Path(project_root) / item["output"]) |
| 553 | + |
| 554 | + template = utils.load_template(template_path) |
| 555 | + methods_context = [] |
| 556 | + for class_name, methods in data.items(): |
| 557 | + for method_name, method_info in methods.items(): |
| 558 | + context = { |
| 559 | + "name": method_name, |
| 560 | + "class_name": class_name, |
| 561 | + "return_type": method_info["return_type"], |
| 562 | + } |
| 563 | + |
| 564 | + request_name = _get_request_class_name(method_name, config) |
| 565 | + fq_request_name = _find_fq_request_name( |
| 566 | + request_name, request_arg_schema |
| 567 | + ) |
| 568 | + |
| 569 | + if fq_request_name: |
| 570 | + context["request_class_full_name"] = fq_request_name |
| 571 | + context["request_id_args"] = request_arg_schema[fq_request_name] |
| 572 | + |
| 573 | + methods_context.append(context) |
| 574 | + |
| 575 | + # Prepare imports for the template |
| 576 | + services_context = [] |
| 577 | + client_class_names = sorted( |
| 578 | + list(set([m["class_name"] for m in methods_context])) |
| 579 | + ) |
| 580 | + |
| 581 | + for class_name in client_class_names: |
| 582 | + service_name_cluster = name_utils.generate_service_names(class_name) |
| 583 | + services_context.append(service_name_cluster) |
| 584 | + |
| 585 | + # Also need to update methods_context to include the service_name and module_name |
| 586 | + # so the template knows which client to use for each method. |
| 587 | + class_to_service_map = {s["service_client_class"]: s for s in services_context} |
| 588 | + for method in methods_context: |
| 589 | + service_info = class_to_service_map.get(method["class_name"]) |
| 590 | + if service_info: |
| 591 | + method["service_name"] = service_info["service_name"] |
| 592 | + method["service_module_name"] = service_info["service_module_name"] |
| 593 | + |
| 594 | + # Prepare new imports |
| 595 | + service_imports = [ |
| 596 | + _generate_import_statement( |
| 597 | + services_context, |
| 598 | + "service_module_name", |
| 599 | + "google.cloud.bigquery_v2.services", |
| 600 | + ) |
| 601 | + ] |
| 602 | + |
| 603 | + # Prepare type imports |
| 604 | + type_imports = [ |
| 605 | + _generate_import_statement( |
| 606 | + services_context, "service_name", "google.cloud.bigquery_v2.types" |
| 607 | + ) |
| 608 | + ] |
| 609 | + |
| 610 | + final_code = template.render( |
| 611 | + service_name=config.get("service_name"), |
| 612 | + methods=methods_context, |
| 613 | + services=services_context, |
| 614 | + service_imports=service_imports, |
| 615 | + type_imports=type_imports, |
| 616 | + request_arg_schema=request_arg_schema, |
| 617 | + ) |
| 618 | + |
| 619 | + utils.write_code_to_file(output_path, final_code) |
| 620 | + |
| 621 | + |
| 622 | +# ============================================================================= |
| 623 | +# Section 4: Main Execution |
| 624 | +# ============================================================================= |
| 625 | + |
| 626 | + |
| 627 | +def setup_config_and_paths(config_path: str) -> Dict[str, Any]: |
| 628 | + """Loads the configuration and sets up necessary paths. |
| 629 | +
|
| 630 | + Args: |
| 631 | + config_path: The path to the YAML configuration file. |
| 632 | +
|
| 633 | + Returns: |
| 634 | + A dictionary containing the loaded configuration and paths. |
| 635 | + """ |
| 636 | + |
| 637 | + def find_project_root(start_path: str, markers: list[str]) -> str | None: |
| 638 | + """Finds the project root by searching upwards for a marker.""" |
| 639 | + current_path = os.path.abspath(start_path) |
| 640 | + while True: |
| 641 | + for marker in markers: |
| 642 | + if os.path.exists(os.path.join(current_path, marker)): |
| 643 | + return current_path |
| 644 | + parent_path = os.path.dirname(current_path) |
| 645 | + if parent_path == current_path: # Filesystem root |
| 646 | + return None |
| 647 | + current_path = parent_path |
| 648 | + |
| 649 | + # Load configuration from the YAML file. |
| 650 | + config = utils.load_config(config_path) |
| 651 | + |
| 652 | + # Determine the project root. |
| 653 | + script_dir = os.path.dirname(os.path.abspath(__file__)) |
| 654 | + project_root = find_project_root(script_dir, ["setup.py", ".git"]) |
| 655 | + if not project_root: |
| 656 | + project_root = os.getcwd() # Fallback to current directory |
| 657 | + |
| 658 | + # Set paths in the config dictionary. |
| 659 | + config["project_root"] = project_root |
| 660 | + config["config_dir"] = os.path.dirname(os.path.abspath(config_path)) |
| 661 | + |
| 662 | + return config |
| 663 | + |
| 664 | + |
| 665 | +def _execute_post_processing(config: Dict[str, Any]): |
| 666 | + """ |
| 667 | + Executes post-processing steps, such as patching existing files. |
| 668 | + """ |
| 669 | + project_root = config["project_root"] |
| 670 | + post_processing_jobs = config.get("post_processing_templates", []) |
| 671 | + |
| 672 | + for job in post_processing_jobs: |
| 673 | + template_path = os.path.join(config["config_dir"], job["template"]) |
| 674 | + target_file_path = os.path.join(project_root, job["target_file"]) |
| 675 | + |
| 676 | + if not os.path.exists(target_file_path): |
| 677 | + logging.warning( |
| 678 | + f"Target file {target_file_path} not found, skipping post-processing job." |
| 679 | + ) |
| 680 | + continue |
| 681 | + |
| 682 | + # Read the target file |
| 683 | + with open(target_file_path, "r") as f: |
| 684 | + lines = f.readlines() |
| 685 | + |
| 686 | + # --- Extract existing imports and __all__ members --- |
| 687 | + imports = [] |
| 688 | + all_list = [] |
| 689 | + all_start_index = -1 |
| 690 | + all_end_index = -1 |
| 691 | + |
| 692 | + for i, line in enumerate(lines): |
| 693 | + if line.strip().startswith("from ."): |
| 694 | + imports.append(line.strip()) |
| 695 | + if line.strip() == "__all__ = (": |
| 696 | + all_start_index = i |
| 697 | + if all_start_index != -1 and line.strip() == ")": |
| 698 | + all_end_index = i |
| 699 | + |
| 700 | + if all_start_index != -1 and all_end_index != -1: |
| 701 | + for i in range(all_start_index + 1, all_end_index): |
| 702 | + member = lines[i].strip().replace('"', "").replace(",", "") |
| 703 | + if member: |
| 704 | + all_list.append(member) |
| 705 | + |
| 706 | + # --- Add new items and sort --- |
| 707 | + for new_import in job.get("add_imports", []): |
| 708 | + if new_import not in imports: |
| 709 | + imports.append(new_import) |
| 710 | + imports.sort() |
| 711 | + imports = [f"{imp}\n" for imp in imports] # re-add newlines |
| 712 | + |
| 713 | + for new_member in job.get("add_to_all", []): |
| 714 | + if new_member not in all_list: |
| 715 | + all_list.append(new_member) |
| 716 | + all_list.sort() |
| 717 | + |
| 718 | + # --- Render the new file content --- |
| 719 | + template = utils.load_template(template_path) |
| 720 | + new_content = template.render( |
| 721 | + imports=imports, |
| 722 | + all_list=all_list, |
| 723 | + ) |
| 724 | + |
| 725 | + # --- Overwrite the target file --- |
| 726 | + with open(target_file_path, "w") as f: |
| 727 | + f.write(new_content) |
| 728 | + |
| 729 | + logging.info(f"Successfully post-processed and overwrote {target_file_path}") |
| 730 | + |
| 731 | + |
| 732 | +if __name__ == "__main__": |
| 733 | + parser = argparse.ArgumentParser( |
| 734 | + description="A generic Python code generator for clients." |
| 735 | + ) |
| 736 | + parser.add_argument("config", help="Path to the YAML configuration file.") |
| 737 | + args = parser.parse_args() |
| 738 | + |
| 739 | + # Load config and set up paths. |
| 740 | + config = setup_config_and_paths(args.config) |
| 741 | + |
| 742 | + # Analyze the source code. |
| 743 | + analysis_results = analyze_source_files(config) |
| 744 | + |
| 745 | + # Generate the new client code. |
| 746 | + generate_code(config, analysis_results) |
| 747 | + |
| 748 | + # Run post-processing steps. |
| 749 | + _execute_post_processing(config) |
| 750 | + |
| 751 | + # TODO: Ensure blacken gets called on the generated source files as a final step. |
0 commit comments