-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathrun_environment_generation.py
More file actions
151 lines (119 loc) Β· 4.81 KB
/
run_environment_generation.py
File metadata and controls
151 lines (119 loc) Β· 4.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""
Environment Generation Entry Point
Usage:
python run_environment_generation.py --config config/env_gen.yaml
python run_environment_generation.py --theme "A puzzle game"
"""
import argparse
import asyncio
from pathlib import Path
import yaml
from autoenv.pipeline import VisualPipeline, GeneratorPipeline
from base.engine.cost_monitor import CostMonitor
DEFAULT_CONFIG = "config/env_gen.yaml"
def load_config(path: str) -> dict:
p = Path(path)
if not p.exists():
return {}
return yaml.safe_load(p.read_text(encoding="utf-8")) or {}
def get_themes(themes_folder: str) -> list[str]:
folder = Path(themes_folder)
if not folder.is_dir():
return []
return sorted(str(f) for f in folder.glob("*.txt"))
async def run_generation(
theme: str,
model: str,
output: str,
mode: str = "textual",
image_model: str | None = None,
):
"""Run a single generation task."""
label = theme
if theme.endswith(".txt") and Path(theme).exists():
label = Path(theme).stem
theme = Path(theme).read_text(encoding="utf-8")
# Step 1: Run generator pipeline
print(f"π [{label}] Generating environment...")
gen_pipeline = GeneratorPipeline.create_default(llm_name=model)
gen_ctx = await gen_pipeline.run(requirements=theme, output_dir=output)
if not gen_ctx.success:
print(f"β [{label}] Generation failed: {gen_ctx.error}")
return
env_path = gen_ctx.env_folder_path
print(f"β
[{label}] Environment generated β {env_path}")
# Step 2: Run visual pipeline if multimodal mode
if mode == "multimodal":
if not image_model:
print(f"β οΈ [{label}] Skipping visual: no image_model configured")
return
print(f"π¨ [{label}] Generating visuals...")
visual_output = env_path / "visual"
visual_pipeline = VisualPipeline.create_default(
llm_name=model,
image_model=image_model,
)
visual_ctx = await visual_pipeline.run(
benchmark_path=env_path,
output_dir=visual_output,
)
if visual_ctx.success:
print(f"β
[{label}] Visuals generated β {visual_output}")
else:
print(f"β [{label}] Visual generation failed: {visual_ctx.error}")
async def main():
parser = argparse.ArgumentParser(description="Generate environments")
parser.add_argument("--config", default=DEFAULT_CONFIG, help="Config YAML path")
parser.add_argument("--theme", help="Override: single theme text or .txt file")
parser.add_argument("--model", help="Override: LLM model name")
parser.add_argument("--output", help="Override: output directory")
parser.add_argument("--mode", choices=["textual", "multimodal"], help="Override mode")
args = parser.parse_args()
cfg = load_config(args.config)
# CLI args override config
model = args.model or cfg.get("model") or "claude-sonnet-4-5"
output = args.output or cfg.get("envs_root_path") or "workspace/envs"
mode = args.mode or cfg.get("mode") or "textual"
image_model = cfg.get("image_model")
concurrency = cfg.get("concurrency", 1)
Path(output).mkdir(parents=True, exist_ok=True)
print(f"π§ Config: {args.config}")
print(f"π€ Model: {model}")
print(f"π¨ Image Model: {image_model}")
print(f"π Output: {output}")
print(f"π¦ Mode: {mode}")
# Determine themes (priority: CLI --theme > themes_folder > theme)
themes: list[str] = []
if args.theme:
themes = [args.theme]
elif cfg.get("themes_folder"):
themes = get_themes(cfg["themes_folder"])
elif cfg.get("theme"):
themes = [cfg["theme"]]
if not themes:
print("β No theme provided. Set 'theme' or 'themes_folder' in config.")
return
# Concurrent execution with cost tracking
sem = asyncio.Semaphore(concurrency)
async def task(t: str):
async with sem:
await run_generation(t, model, output, mode, image_model)
with CostMonitor() as monitor:
await asyncio.gather(*[task(t) for t in themes])
# Print and save cost summary
summary = monitor.summary()
print("\n" + "=" * 50)
print("π° Cost Summary")
print("=" * 50)
print(f"Total Cost: ${summary['total_cost']:.4f}")
print(f"Total Calls: {summary['call_count']}")
print(f"Input Tokens: {summary['total_input_tokens']:,}")
print(f"Output Tokens: {summary['total_output_tokens']:,}")
if summary["by_model"]:
print("\nBy Model:")
for model_name, stats in summary["by_model"].items():
print(f" {model_name}: ${stats['cost']:.4f} ({stats['calls']} calls)")
cost_file = monitor.save()
print(f"\nπ Cost saved: {cost_file}")
if __name__ == "__main__":
asyncio.run(main())