|
30 | 30 | from collections import defaultdict |
31 | 31 | from typing import List, Dict, Any, Iterator |
32 | 32 |
|
| 33 | +from . import name_utils |
33 | 34 | from . import utils |
34 | 35 |
|
35 | 36 | # ============================================================================= |
@@ -490,3 +491,123 @@ def analyze_source_files( |
490 | 491 | ) |
491 | 492 |
|
492 | 493 | return parsed_data, all_imports, all_types, request_arg_schema |
| 494 | + |
| 495 | + |
| 496 | +# ============================================================================= |
| 497 | +# Section 3: Code Generation |
| 498 | +# ============================================================================= |
| 499 | + |
| 500 | + |
| 501 | +def _generate_import_statement( |
| 502 | + context: List[Dict[str, Any]], key: str, path: str |
| 503 | +) -> str: |
| 504 | + """Generates a formatted import statement from a list of context dictionaries. |
| 505 | +
|
| 506 | + Args: |
| 507 | + context: A list of dictionaries containing the data. |
| 508 | + key: The key to extract from each dictionary in the context. |
| 509 | + path: The base import path (e.g., "google.cloud.bigquery_v2.services"). |
| 510 | +
|
| 511 | + Returns: |
| 512 | + A formatted, multi-line import statement string. |
| 513 | + """ |
| 514 | + names = sorted(list(set([item[key] for item in context]))) |
| 515 | + names_str = ",\n ".join(names) |
| 516 | + return f"from {path} import (\n {names_str}\n)" |
| 517 | + |
| 518 | + |
| 519 | +def generate_code(config: Dict[str, Any], analysis_results: tuple) -> None: |
| 520 | + """ |
| 521 | + Generates source code files using Jinja2 templates. |
| 522 | + """ |
| 523 | + data, all_imports, all_types, request_arg_schema = analysis_results |
| 524 | + project_root = config["project_root"] |
| 525 | + config_dir = config["config_dir"] |
| 526 | + |
| 527 | + templates_config = config.get("templates", []) |
| 528 | + for item in templates_config: |
| 529 | + template_path = os.path.join(config_dir, item["template"]) |
| 530 | + output_path = os.path.join(project_root, item["output"]) |
| 531 | + |
| 532 | + template = utils.load_template(template_path) |
| 533 | + methods_context = [] |
| 534 | + for class_name, methods in data.items(): |
| 535 | + for method_name, method_info in methods.items(): |
| 536 | + context = { |
| 537 | + "name": method_name, |
| 538 | + "class_name": class_name, |
| 539 | + "return_type": method_info["return_type"], |
| 540 | + } |
| 541 | + |
| 542 | + # Infer the request class and find its schema. |
| 543 | + inferred_request_name = name_utils.method_to_request_class_name( |
| 544 | + method_name |
| 545 | + ) |
| 546 | + |
| 547 | + # Check for a request class name override in the config. |
| 548 | + method_overrides = ( |
| 549 | + config.get("filter", {}).get("methods", {}).get("overrides", {}) |
| 550 | + ) |
| 551 | + if method_name in method_overrides: |
| 552 | + inferred_request_name = method_overrides[method_name].get( |
| 553 | + "request_class_name", inferred_request_name |
| 554 | + ) |
| 555 | + |
| 556 | + fq_request_name = "" |
| 557 | + for key in request_arg_schema.keys(): |
| 558 | + if key.endswith(f".{inferred_request_name}"): |
| 559 | + fq_request_name = key |
| 560 | + break |
| 561 | + |
| 562 | + # If found, augment the method context. |
| 563 | + if fq_request_name: |
| 564 | + context["request_class_full_name"] = fq_request_name |
| 565 | + context["request_id_args"] = request_arg_schema[fq_request_name] |
| 566 | + |
| 567 | + methods_context.append(context) |
| 568 | + |
| 569 | + # Prepare imports for the template |
| 570 | + services_context = [] |
| 571 | + client_class_names = sorted( |
| 572 | + list(set([m["class_name"] for m in methods_context])) |
| 573 | + ) |
| 574 | + |
| 575 | + for class_name in client_class_names: |
| 576 | + service_name_cluster = name_utils.generate_service_names(class_name) |
| 577 | + services_context.append(service_name_cluster) |
| 578 | + |
| 579 | + # Also need to update methods_context to include the service_name and module_name |
| 580 | + # so the template knows which client to use for each method. |
| 581 | + class_to_service_map = {s["service_client_class"]: s for s in services_context} |
| 582 | + for method in methods_context: |
| 583 | + service_info = class_to_service_map.get(method["class_name"]) |
| 584 | + if service_info: |
| 585 | + method["service_name"] = service_info["service_name"] |
| 586 | + method["service_module_name"] = service_info["service_module_name"] |
| 587 | + |
| 588 | + # Prepare new imports |
| 589 | + service_imports = [ |
| 590 | + _generate_import_statement( |
| 591 | + services_context, |
| 592 | + "service_module_name", |
| 593 | + "google.cloud.bigquery_v2.services", |
| 594 | + ) |
| 595 | + ] |
| 596 | + |
| 597 | + # Prepare type imports |
| 598 | + type_imports = [ |
| 599 | + _generate_import_statement( |
| 600 | + services_context, "service_name", "google.cloud.bigquery_v2.types" |
| 601 | + ) |
| 602 | + ] |
| 603 | + |
| 604 | + final_code = template.render( |
| 605 | + service_name=config.get("service_name"), |
| 606 | + methods=methods_context, |
| 607 | + services=services_context, |
| 608 | + service_imports=service_imports, |
| 609 | + type_imports=type_imports, |
| 610 | + request_arg_schema=request_arg_schema, |
| 611 | + ) |
| 612 | + |
| 613 | + utils.write_code_to_file(output_path, final_code) |
0 commit comments