|
26 | 26 | import os |
27 | 27 | import glob |
28 | 28 | import logging |
29 | | -import re |
30 | 29 | from collections import defaultdict |
31 | | -from typing import List, Dict, Any, Iterator |
| 30 | +from pathlib import Path |
| 31 | +from typing import List, Dict, Any |
32 | 32 |
|
33 | 33 | from . import name_utils |
34 | 34 | from . import utils |
@@ -492,3 +492,129 @@ def analyze_source_files( |
492 | 492 |
|
493 | 493 | return parsed_data, all_imports, all_types, request_arg_schema |
494 | 494 |
|
| 495 | + |
| 496 | +# ============================================================================= |
| 497 | +# Section 3: Code Generation |
| 498 | +# ============================================================================= |
| 499 | + |
| 500 | + |
| 501 | +def _generate_import_statement( |
| 502 | + context: List[Dict[str, Any]], key: str, package: str |
| 503 | +) -> str: |
| 504 | + """Generates a formatted import statement from a list of context dictionaries. |
| 505 | +
|
| 506 | + Args: |
| 507 | + context: A list of dictionaries containing the data. |
| 508 | + key: The key to extract from each dictionary in the context. |
| 509 | + package: The base import package (e.g., "google.cloud.bigquery_v2.services"). |
| 510 | +
|
| 511 | + Returns: |
| 512 | + A formatted, multi-line import statement string. |
| 513 | + """ |
| 514 | + |
| 515 | + names = sorted(list(set([item[key] for item in context]))) |
| 516 | + names_str = ",\n ".join(names) |
| 517 | + return f"from {package} import (\n {names_str}\n)" |
| 518 | + |
| 519 | + |
| 520 | +def _get_request_class_name(method_name: str, config: Dict[str, Any]) -> str: |
| 521 | + """Gets the inferred request class name, applying overrides from config.""" |
| 522 | + inferred_request_name = name_utils.method_to_request_class_name(method_name) |
| 523 | + method_overrides = config.get("filter", {}).get("methods", {}).get("overrides", {}) |
| 524 | + if method_name in method_overrides: |
| 525 | + return method_overrides[method_name].get( |
| 526 | + "request_class_name", inferred_request_name |
| 527 | + ) |
| 528 | + return inferred_request_name |
| 529 | + |
| 530 | + |
| 531 | +def _find_fq_request_name( |
| 532 | + request_name: str, request_arg_schema: Dict[str, List[str]] |
| 533 | +) -> str: |
| 534 | + """Finds the fully qualified request name in the schema.""" |
| 535 | + for key in request_arg_schema.keys(): |
| 536 | + if key.endswith(f".{request_name}"): |
| 537 | + return key |
| 538 | + return "" |
| 539 | + |
| 540 | + |
| 541 | +def generate_code(config: Dict[str, Any], analysis_results: tuple) -> None: |
| 542 | + """ |
| 543 | + Generates source code files using Jinja2 templates. |
| 544 | + """ |
| 545 | + |
| 546 | + data, all_imports, all_types, request_arg_schema = analysis_results |
| 547 | + project_root = config["project_root"] |
| 548 | + config_dir = config["config_dir"] |
| 549 | + |
| 550 | + templates_config = config.get("templates", []) |
| 551 | + for item in templates_config: |
| 552 | + template_path = str(Path(config_dir) / item["template"]) |
| 553 | + output_path = str(Path(project_root) / item["output"]) |
| 554 | + |
| 555 | + template = utils.load_template(template_path) |
| 556 | + methods_context = [] |
| 557 | + for class_name, methods in data.items(): |
| 558 | + for method_name, method_info in methods.items(): |
| 559 | + context = { |
| 560 | + "name": method_name, |
| 561 | + "class_name": class_name, |
| 562 | + "return_type": method_info["return_type"], |
| 563 | + } |
| 564 | + |
| 565 | + request_name = _get_request_class_name(method_name, config) |
| 566 | + fq_request_name = _find_fq_request_name( |
| 567 | + request_name, request_arg_schema |
| 568 | + ) |
| 569 | + |
| 570 | + if fq_request_name: |
| 571 | + context["request_class_full_name"] = fq_request_name |
| 572 | + context["request_id_args"] = request_arg_schema[fq_request_name] |
| 573 | + |
| 574 | + methods_context.append(context) |
| 575 | + |
| 576 | + # Prepare imports for the template |
| 577 | + services_context = [] |
| 578 | + client_class_names = sorted( |
| 579 | + list(set([m["class_name"] for m in methods_context])) |
| 580 | + ) |
| 581 | + |
| 582 | + for class_name in client_class_names: |
| 583 | + service_name_cluster = name_utils.generate_service_names(class_name) |
| 584 | + services_context.append(service_name_cluster) |
| 585 | + |
| 586 | + # Also need to update methods_context to include the service_name and module_name |
| 587 | + # so the template knows which client to use for each method. |
| 588 | + class_to_service_map = {s["service_client_class"]: s for s in services_context} |
| 589 | + for method in methods_context: |
| 590 | + service_info = class_to_service_map.get(method["class_name"]) |
| 591 | + if service_info: |
| 592 | + method["service_name"] = service_info["service_name"] |
| 593 | + method["service_module_name"] = service_info["service_module_name"] |
| 594 | + |
| 595 | + # Prepare new imports |
| 596 | + service_imports = [ |
| 597 | + _generate_import_statement( |
| 598 | + services_context, |
| 599 | + "service_module_name", |
| 600 | + "google.cloud.bigquery_v2.services", |
| 601 | + ) |
| 602 | + ] |
| 603 | + |
| 604 | + # Prepare type imports |
| 605 | + type_imports = [ |
| 606 | + _generate_import_statement( |
| 607 | + services_context, "service_name", "google.cloud.bigquery_v2.types" |
| 608 | + ) |
| 609 | + ] |
| 610 | + |
| 611 | + final_code = template.render( |
| 612 | + service_name=config.get("service_name"), |
| 613 | + methods=methods_context, |
| 614 | + services=services_context, |
| 615 | + service_imports=service_imports, |
| 616 | + type_imports=type_imports, |
| 617 | + request_arg_schema=request_arg_schema, |
| 618 | + ) |
| 619 | + |
| 620 | + utils.write_code_to_file(output_path, final_code) |
0 commit comments