Skip to content

Commit c4bfb09

Browse files
authored
Merge branch 'autogen' into test/test-utils-load_resource
2 parents 22017bd + dae58b7 commit c4bfb09

File tree

1 file changed

+257
-0
lines changed

1 file changed

+257
-0
lines changed

scripts/microgenerator/generate.py

Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import ast
2626
import os
27+
import argparse
2728
import glob
2829
import logging
2930
import re
@@ -492,3 +493,259 @@ def analyze_source_files(
492493

493494
return parsed_data, all_imports, all_types, request_arg_schema
494495

496+
497+
# =============================================================================
498+
# Section 3: Code Generation
499+
# =============================================================================
500+
501+
502+
def _generate_import_statement(
503+
context: List[Dict[str, Any]], key: str, package: str
504+
) -> str:
505+
"""Generates a formatted import statement from a list of context dictionaries.
506+
507+
Args:
508+
context: A list of dictionaries containing the data.
509+
key: The key to extract from each dictionary in the context.
510+
package: The base import package (e.g., "google.cloud.bigquery_v2.services").
511+
512+
Returns:
513+
A formatted, multi-line import statement string.
514+
"""
515+
names = sorted(list(set([item[key] for item in context])))
516+
names_str = ",\n ".join(names)
517+
return f"from {package} import (\n {names_str}\n)"
518+
519+
520+
def _get_request_class_name(method_name: str, config: Dict[str, Any]) -> str:
521+
"""Gets the inferred request class name, applying overrides from config."""
522+
inferred_request_name = name_utils.method_to_request_class_name(method_name)
523+
method_overrides = config.get("filter", {}).get("methods", {}).get("overrides", {})
524+
if method_name in method_overrides:
525+
return method_overrides[method_name].get(
526+
"request_class_name", inferred_request_name
527+
)
528+
return inferred_request_name
529+
530+
531+
def _find_fq_request_name(
532+
request_name: str, request_arg_schema: Dict[str, List[str]]
533+
) -> str:
534+
"""Finds the fully qualified request name in the schema."""
535+
for key in request_arg_schema.keys():
536+
if key.endswith(f".{request_name}"):
537+
return key
538+
return ""
539+
540+
541+
def generate_code(config: Dict[str, Any], analysis_results: tuple) -> None:
542+
"""
543+
Generates source code files using Jinja2 templates.
544+
"""
545+
data, all_imports, all_types, request_arg_schema = analysis_results
546+
project_root = config["project_root"]
547+
config_dir = config["config_dir"]
548+
549+
templates_config = config.get("templates", [])
550+
for item in templates_config:
551+
template_path = str(Path(config_dir) / item["template"])
552+
output_path = str(Path(project_root) / item["output"])
553+
554+
template = utils.load_template(template_path)
555+
methods_context = []
556+
for class_name, methods in data.items():
557+
for method_name, method_info in methods.items():
558+
context = {
559+
"name": method_name,
560+
"class_name": class_name,
561+
"return_type": method_info["return_type"],
562+
}
563+
564+
request_name = _get_request_class_name(method_name, config)
565+
fq_request_name = _find_fq_request_name(
566+
request_name, request_arg_schema
567+
)
568+
569+
if fq_request_name:
570+
context["request_class_full_name"] = fq_request_name
571+
context["request_id_args"] = request_arg_schema[fq_request_name]
572+
573+
methods_context.append(context)
574+
575+
# Prepare imports for the template
576+
services_context = []
577+
client_class_names = sorted(
578+
list(set([m["class_name"] for m in methods_context]))
579+
)
580+
581+
for class_name in client_class_names:
582+
service_name_cluster = name_utils.generate_service_names(class_name)
583+
services_context.append(service_name_cluster)
584+
585+
# Also need to update methods_context to include the service_name and module_name
586+
# so the template knows which client to use for each method.
587+
class_to_service_map = {s["service_client_class"]: s for s in services_context}
588+
for method in methods_context:
589+
service_info = class_to_service_map.get(method["class_name"])
590+
if service_info:
591+
method["service_name"] = service_info["service_name"]
592+
method["service_module_name"] = service_info["service_module_name"]
593+
594+
# Prepare new imports
595+
service_imports = [
596+
_generate_import_statement(
597+
services_context,
598+
"service_module_name",
599+
"google.cloud.bigquery_v2.services",
600+
)
601+
]
602+
603+
# Prepare type imports
604+
type_imports = [
605+
_generate_import_statement(
606+
services_context, "service_name", "google.cloud.bigquery_v2.types"
607+
)
608+
]
609+
610+
final_code = template.render(
611+
service_name=config.get("service_name"),
612+
methods=methods_context,
613+
services=services_context,
614+
service_imports=service_imports,
615+
type_imports=type_imports,
616+
request_arg_schema=request_arg_schema,
617+
)
618+
619+
utils.write_code_to_file(output_path, final_code)
620+
621+
622+
# =============================================================================
623+
# Section 4: Main Execution
624+
# =============================================================================
625+
626+
627+
def setup_config_and_paths(config_path: str) -> Dict[str, Any]:
628+
"""Loads the configuration and sets up necessary paths.
629+
630+
Args:
631+
config_path: The path to the YAML configuration file.
632+
633+
Returns:
634+
A dictionary containing the loaded configuration and paths.
635+
"""
636+
637+
def find_project_root(start_path: str, markers: list[str]) -> str | None:
638+
"""Finds the project root by searching upwards for a marker."""
639+
current_path = os.path.abspath(start_path)
640+
while True:
641+
for marker in markers:
642+
if os.path.exists(os.path.join(current_path, marker)):
643+
return current_path
644+
parent_path = os.path.dirname(current_path)
645+
if parent_path == current_path: # Filesystem root
646+
return None
647+
current_path = parent_path
648+
649+
# Load configuration from the YAML file.
650+
config = utils.load_config(config_path)
651+
652+
# Determine the project root.
653+
script_dir = os.path.dirname(os.path.abspath(__file__))
654+
project_root = find_project_root(script_dir, ["setup.py", ".git"])
655+
if not project_root:
656+
project_root = os.getcwd() # Fallback to current directory
657+
658+
# Set paths in the config dictionary.
659+
config["project_root"] = project_root
660+
config["config_dir"] = os.path.dirname(os.path.abspath(config_path))
661+
662+
return config
663+
664+
665+
def _execute_post_processing(config: Dict[str, Any]):
666+
"""
667+
Executes post-processing steps, such as patching existing files.
668+
"""
669+
project_root = config["project_root"]
670+
post_processing_jobs = config.get("post_processing_templates", [])
671+
672+
for job in post_processing_jobs:
673+
template_path = os.path.join(config["config_dir"], job["template"])
674+
target_file_path = os.path.join(project_root, job["target_file"])
675+
676+
if not os.path.exists(target_file_path):
677+
logging.warning(
678+
f"Target file {target_file_path} not found, skipping post-processing job."
679+
)
680+
continue
681+
682+
# Read the target file
683+
with open(target_file_path, "r") as f:
684+
lines = f.readlines()
685+
686+
# --- Extract existing imports and __all__ members ---
687+
imports = []
688+
all_list = []
689+
all_start_index = -1
690+
all_end_index = -1
691+
692+
for i, line in enumerate(lines):
693+
if line.strip().startswith("from ."):
694+
imports.append(line.strip())
695+
if line.strip() == "__all__ = (":
696+
all_start_index = i
697+
if all_start_index != -1 and line.strip() == ")":
698+
all_end_index = i
699+
700+
if all_start_index != -1 and all_end_index != -1:
701+
for i in range(all_start_index + 1, all_end_index):
702+
member = lines[i].strip().replace('"', "").replace(",", "")
703+
if member:
704+
all_list.append(member)
705+
706+
# --- Add new items and sort ---
707+
for new_import in job.get("add_imports", []):
708+
if new_import not in imports:
709+
imports.append(new_import)
710+
imports.sort()
711+
imports = [f"{imp}\n" for imp in imports] # re-add newlines
712+
713+
for new_member in job.get("add_to_all", []):
714+
if new_member not in all_list:
715+
all_list.append(new_member)
716+
all_list.sort()
717+
718+
# --- Render the new file content ---
719+
template = utils.load_template(template_path)
720+
new_content = template.render(
721+
imports=imports,
722+
all_list=all_list,
723+
)
724+
725+
# --- Overwrite the target file ---
726+
with open(target_file_path, "w") as f:
727+
f.write(new_content)
728+
729+
logging.info(f"Successfully post-processed and overwrote {target_file_path}")
730+
731+
732+
if __name__ == "__main__":
733+
parser = argparse.ArgumentParser(
734+
description="A generic Python code generator for clients."
735+
)
736+
parser.add_argument("config", help="Path to the YAML configuration file.")
737+
args = parser.parse_args()
738+
739+
# Load config and set up paths.
740+
config = setup_config_and_paths(args.config)
741+
742+
# Analyze the source code.
743+
analysis_results = analyze_source_files(config)
744+
745+
# Generate the new client code.
746+
generate_code(config, analysis_results)
747+
748+
# Run post-processing steps.
749+
_execute_post_processing(config)
750+
751+
# TODO: Ensure blacken gets called on the generated source files as a final step.

0 commit comments

Comments
 (0)