Skip to content

Commit 3e9ade6

Browse files
committed
feat: adds code generation logic
1 parent 44a0777 commit 3e9ade6

File tree

1 file changed

+121
-0
lines changed

1 file changed

+121
-0
lines changed

scripts/microgenerator/generate.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from collections import defaultdict
3131
from typing import List, Dict, Any, Iterator
3232

33+
from . import name_utils
3334
from . import utils
3435

3536
# =============================================================================
@@ -490,3 +491,123 @@ def analyze_source_files(
490491
)
491492

492493
return parsed_data, all_imports, all_types, request_arg_schema
494+
495+
496+
# =============================================================================
497+
# Section 3: Code Generation
498+
# =============================================================================
499+
500+
501+
def _generate_import_statement(
502+
context: List[Dict[str, Any]], key: str, path: str
503+
) -> str:
504+
"""Generates a formatted import statement from a list of context dictionaries.
505+
506+
Args:
507+
context: A list of dictionaries containing the data.
508+
key: The key to extract from each dictionary in the context.
509+
path: The base import path (e.g., "google.cloud.bigquery_v2.services").
510+
511+
Returns:
512+
A formatted, multi-line import statement string.
513+
"""
514+
names = sorted(list(set([item[key] for item in context])))
515+
names_str = ",\n ".join(names)
516+
return f"from {path} import (\n {names_str}\n)"
517+
518+
519+
def generate_code(config: Dict[str, Any], analysis_results: tuple) -> None:
520+
"""
521+
Generates source code files using Jinja2 templates.
522+
"""
523+
data, all_imports, all_types, request_arg_schema = analysis_results
524+
project_root = config["project_root"]
525+
config_dir = config["config_dir"]
526+
527+
templates_config = config.get("templates", [])
528+
for item in templates_config:
529+
template_path = os.path.join(config_dir, item["template"])
530+
output_path = os.path.join(project_root, item["output"])
531+
532+
template = utils.load_template(template_path)
533+
methods_context = []
534+
for class_name, methods in data.items():
535+
for method_name, method_info in methods.items():
536+
context = {
537+
"name": method_name,
538+
"class_name": class_name,
539+
"return_type": method_info["return_type"],
540+
}
541+
542+
# Infer the request class and find its schema.
543+
inferred_request_name = name_utils.method_to_request_class_name(
544+
method_name
545+
)
546+
547+
# Check for a request class name override in the config.
548+
method_overrides = (
549+
config.get("filter", {}).get("methods", {}).get("overrides", {})
550+
)
551+
if method_name in method_overrides:
552+
inferred_request_name = method_overrides[method_name].get(
553+
"request_class_name", inferred_request_name
554+
)
555+
556+
fq_request_name = ""
557+
for key in request_arg_schema.keys():
558+
if key.endswith(f".{inferred_request_name}"):
559+
fq_request_name = key
560+
break
561+
562+
# If found, augment the method context.
563+
if fq_request_name:
564+
context["request_class_full_name"] = fq_request_name
565+
context["request_id_args"] = request_arg_schema[fq_request_name]
566+
567+
methods_context.append(context)
568+
569+
# Prepare imports for the template
570+
services_context = []
571+
client_class_names = sorted(
572+
list(set([m["class_name"] for m in methods_context]))
573+
)
574+
575+
for class_name in client_class_names:
576+
service_name_cluster = name_utils.generate_service_names(class_name)
577+
services_context.append(service_name_cluster)
578+
579+
# Also need to update methods_context to include the service_name and module_name
580+
# so the template knows which client to use for each method.
581+
class_to_service_map = {s["service_client_class"]: s for s in services_context}
582+
for method in methods_context:
583+
service_info = class_to_service_map.get(method["class_name"])
584+
if service_info:
585+
method["service_name"] = service_info["service_name"]
586+
method["service_module_name"] = service_info["service_module_name"]
587+
588+
# Prepare new imports
589+
service_imports = [
590+
_generate_import_statement(
591+
services_context,
592+
"service_module_name",
593+
"google.cloud.bigquery_v2.services",
594+
)
595+
]
596+
597+
# Prepare type imports
598+
type_imports = [
599+
_generate_import_statement(
600+
services_context, "service_name", "google.cloud.bigquery_v2.types"
601+
)
602+
]
603+
604+
final_code = template.render(
605+
service_name=config.get("service_name"),
606+
methods=methods_context,
607+
services=services_context,
608+
service_imports=service_imports,
609+
type_imports=type_imports,
610+
request_arg_schema=request_arg_schema,
611+
)
612+
613+
utils.write_code_to_file(output_path, final_code)

0 commit comments

Comments
 (0)