|
| 1 | +"""Render a large random graph with ELK layout. |
| 2 | +
|
| 3 | +Usage: |
| 4 | + python scripts/render_large_graph.py NUM_NODES [OPTIONS] |
| 5 | +
|
| 6 | +Examples: |
| 7 | + python scripts/render_large_graph.py 250000 |
| 8 | + python scripts/render_large_graph.py 1000000 --format png --seed 123 |
| 9 | + python scripts/render_large_graph.py 50000 --outdir /tmp/graphs |
| 10 | +""" |
| 11 | + |
| 12 | +import argparse |
| 13 | +import os |
| 14 | +import sys |
| 15 | +import time |
| 16 | + |
| 17 | +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "tests")) |
| 18 | + |
| 19 | +import torch |
| 20 | +from example_models import RandomGraphModel |
| 21 | +from torchlens import log_forward_pass |
| 22 | + |
| 23 | + |
| 24 | +def main(): |
| 25 | + parser = argparse.ArgumentParser(description="Render a large random graph with ELK layout.") |
| 26 | + parser.add_argument("num_nodes", type=int, help="Target number of nodes in the graph") |
| 27 | + parser.add_argument("--format", default="svg", help="Output format (default: svg)") |
| 28 | + parser.add_argument("--seed", type=int, default=42, help="Random seed (default: 42)") |
| 29 | + parser.add_argument( |
| 30 | + "--outdir", |
| 31 | + default=os.path.join("tests", "test_outputs", "visualizations", "large"), |
| 32 | + help="Output directory", |
| 33 | + ) |
| 34 | + args = parser.parse_args() |
| 35 | + |
| 36 | + os.makedirs(args.outdir, exist_ok=True) |
| 37 | + label = f"elk_{args.num_nodes // 1000}k" |
| 38 | + |
| 39 | + # Phase 1: Model construction |
| 40 | + t0 = time.time() |
| 41 | + model = RandomGraphModel(target_nodes=args.num_nodes, seed=args.seed) |
| 42 | + x = torch.randn(2, 64) |
| 43 | + t1 = time.time() |
| 44 | + print(f"Phase 1 — Model construction: {t1 - t0:.1f}s", flush=True) |
| 45 | + |
| 46 | + # Phase 2: log_forward_pass (logging + postprocessing) |
| 47 | + ml = log_forward_pass(model, x, layers_to_save=None, detect_loops=False, verbose=True) |
| 48 | + t2 = time.time() |
| 49 | + print(f"Phase 2 — log_forward_pass: {t2 - t1:.1f}s ({len(ml)} layers)", flush=True) |
| 50 | + |
| 51 | + # Phase 3: Render |
| 52 | + ml.render_graph( |
| 53 | + vis_opt="unrolled", |
| 54 | + vis_nesting_depth=1000, |
| 55 | + vis_outpath=os.path.join(args.outdir, label), |
| 56 | + save_only=True, |
| 57 | + vis_fileformat=args.format, |
| 58 | + vis_node_placement="elk", |
| 59 | + ) |
| 60 | + t3 = time.time() |
| 61 | + print(f"Phase 3 — ELK render: {t3 - t2:.1f}s", flush=True) |
| 62 | + |
| 63 | + total = t3 - t0 |
| 64 | + print(f"Total: {total:.1f}s ({total / 60:.1f} min)", flush=True) |
| 65 | + |
| 66 | + out_path = os.path.join(args.outdir, f"{label}.{args.format}") |
| 67 | + if os.path.exists(out_path): |
| 68 | + size_mb = os.path.getsize(out_path) / 1024 / 1024 |
| 69 | + print(f"Output: {out_path} ({size_mb:.1f} MB)") |
| 70 | + |
| 71 | + ml.cleanup() |
| 72 | + |
| 73 | + |
| 74 | +if __name__ == "__main__": |
| 75 | + main() |
0 commit comments