Skip to content

Commit fb47cff

Browse files
Merge pull request #123 from johnmarktaylor91/feat/verbose-mode
2 parents 087d42d + 07a8186 commit fb47cff

File tree

11 files changed

+227
-147
lines changed

11 files changed

+227
-147
lines changed

scripts/CLAUDE.md

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
| File | Purpose |
66
|------|---------|
77
| `check_flops_coverage.py` | Reports FLOPs module coverage against all decorated torch functions |
8-
| `run_250k.py` | Render a 250k-node graph with ELK layout |
9-
| `run_1M.py` | Render a 1M-node graph with ELK (phased: construct → log → render) |
8+
| `render_large_graph.py` | Render a large random graph with ELK layout (any node count) |
109

1110
## check_flops_coverage.py
1211

@@ -20,3 +19,12 @@ Run with:
2019
```bash
2120
python scripts/check_flops_coverage.py
2221
```
22+
23+
## render_large_graph.py
24+
25+
Renders a large random graph using the ELK layout engine. Accepts any target node count.
26+
27+
```bash
28+
python scripts/render_large_graph.py 250000
29+
python scripts/render_large_graph.py 1000000 --format png --seed 123
30+
```

scripts/render_large_graph.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""Render a large random graph with ELK layout.
2+
3+
Usage:
4+
python scripts/render_large_graph.py NUM_NODES [OPTIONS]
5+
6+
Examples:
7+
python scripts/render_large_graph.py 250000
8+
python scripts/render_large_graph.py 1000000 --format png --seed 123
9+
python scripts/render_large_graph.py 50000 --outdir /tmp/graphs
10+
"""
11+
12+
import argparse
13+
import os
14+
import sys
15+
import time
16+
17+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "tests"))
18+
19+
import torch
20+
from example_models import RandomGraphModel
21+
from torchlens import log_forward_pass
22+
23+
24+
def main():
25+
parser = argparse.ArgumentParser(description="Render a large random graph with ELK layout.")
26+
parser.add_argument("num_nodes", type=int, help="Target number of nodes in the graph")
27+
parser.add_argument("--format", default="svg", help="Output format (default: svg)")
28+
parser.add_argument("--seed", type=int, default=42, help="Random seed (default: 42)")
29+
parser.add_argument(
30+
"--outdir",
31+
default=os.path.join("tests", "test_outputs", "visualizations", "large"),
32+
help="Output directory",
33+
)
34+
args = parser.parse_args()
35+
36+
os.makedirs(args.outdir, exist_ok=True)
37+
label = f"elk_{args.num_nodes // 1000}k"
38+
39+
# Phase 1: Model construction
40+
t0 = time.time()
41+
model = RandomGraphModel(target_nodes=args.num_nodes, seed=args.seed)
42+
x = torch.randn(2, 64)
43+
t1 = time.time()
44+
print(f"Phase 1 — Model construction: {t1 - t0:.1f}s", flush=True)
45+
46+
# Phase 2: log_forward_pass (logging + postprocessing)
47+
ml = log_forward_pass(model, x, layers_to_save=None, detect_loops=False, verbose=True)
48+
t2 = time.time()
49+
print(f"Phase 2 — log_forward_pass: {t2 - t1:.1f}s ({len(ml)} layers)", flush=True)
50+
51+
# Phase 3: Render
52+
ml.render_graph(
53+
vis_opt="unrolled",
54+
vis_nesting_depth=1000,
55+
vis_outpath=os.path.join(args.outdir, label),
56+
save_only=True,
57+
vis_fileformat=args.format,
58+
vis_node_placement="elk",
59+
)
60+
t3 = time.time()
61+
print(f"Phase 3 — ELK render: {t3 - t2:.1f}s", flush=True)
62+
63+
total = t3 - t0
64+
print(f"Total: {total:.1f}s ({total / 60:.1f} min)", flush=True)
65+
66+
out_path = os.path.join(args.outdir, f"{label}.{args.format}")
67+
if os.path.exists(out_path):
68+
size_mb = os.path.getsize(out_path) / 1024 / 1024
69+
print(f"Output: {out_path} ({size_mb:.1f} MB)")
70+
71+
ml.cleanup()
72+
73+
74+
if __name__ == "__main__":
75+
main()

scripts/run_1M.py

Lines changed: 0 additions & 48 deletions
This file was deleted.

scripts/run_250k.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

torchlens/capture/trace.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from ..utils.arg_handling import safe_copy_args, safe_copy_kwargs, normalize_input_args
5050
from .source_tensors import log_source_tensor
5151
from ..data_classes.interface import _give_user_feedback_about_lookup_key
52+
from ..utils.display import _vprint, _vtimed
5253

5354

5455
def save_new_activations(
@@ -133,6 +134,7 @@ def save_new_activations(
133134
self._raw_layer_dict.pop(label, None)
134135

135136
# Now run and log the new inputs.
137+
_vprint(self, "Running fast pass (saving requested activations)")
136138
self._run_and_log_inputs_through_model(
137139
model, input_args, input_kwargs, layers_to_save, random_seed
138140
)
@@ -431,12 +433,23 @@ def run_and_log_inputs_through_model(
431433
# Per-session model preparation
432434
_prepare_model_session(self, model, self._optimizer)
433435
self.elapsed_time_setup = time.time() - self.pass_start_time
436+
_vprint(self, f"Model prepared ({self.elapsed_time_setup:.2f}s)")
437+
438+
# Print input summary
439+
if getattr(self, "verbose", False):
440+
devices = set()
441+
for t in input_tensors:
442+
if hasattr(t, "device"):
443+
devices.add(str(t.device))
444+
device_str = ", ".join(sorted(devices)) if devices else "unknown"
445+
_vprint(self, f"Inputs: {len(input_tensors)} tensor(s) on {device_str}")
434446

435447
# Turn on the logging toggle and run the forward pass.
436448
# Inside this context, every decorated torch function will log its
437449
# inputs/outputs. Source tensors (model inputs) are logged explicitly
438450
# before invoking the model; all subsequent operations are captured
439451
# automatically by the decorated wrappers.
452+
_vprint(self, f"Running {self.logging_mode} forward pass...")
440453
with _state.active_logging(self):
441454
for i, t in enumerate(input_tensors):
442455
log_source_tensor(self, t, "input", input_tensor_addresses[i])
@@ -446,10 +459,16 @@ def run_and_log_inputs_through_model(
446459
self.elapsed_time_forward_pass = (
447460
time.time() - self.pass_start_time - self.elapsed_time_setup
448461
)
462+
_vprint(
463+
self,
464+
f"Forward pass complete ({self.elapsed_time_forward_pass:.2f}s, "
465+
f"{len(self._raw_layer_dict)} raw operations)",
466+
)
449467

450468
output_tensors, output_tensor_addresses = _extract_and_mark_outputs(self, outputs)
451469

452470
_cleanup_model_session(model, input_tensors)
471+
_vprint(self, f"Postprocessing {len(self._raw_layer_dict)} operations...")
453472
self._postprocess(output_tensors, output_tensor_addresses)
454473

455474
except Exception as e:

torchlens/data_classes/model_log.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def __init__(
112112
save_source_context: bool = False,
113113
save_rng_states: bool = False,
114114
detect_loops: bool = True,
115+
verbose: bool = False,
115116
):
116117
"""Initialise a fresh ModelLog for a new logging session.
117118
@@ -130,6 +131,7 @@ def __init__(
130131
around each function call (used by FuncCallLocation).
131132
optimizer: Optional torch optimizer, used to annotate which params
132133
have optimizers attached.
134+
verbose: If True, print timed progress messages at each major pipeline stage.
133135
"""
134136
# Callables are effectively immutable — deepcopy is unnecessary.
135137

@@ -159,6 +161,7 @@ def __init__(
159161
self.save_source_context = save_source_context
160162
self.save_rng_states = save_rng_states
161163
self.detect_loops = detect_loops
164+
self.verbose = verbose
162165
self.has_saved_gradients = False
163166
self.mark_input_output_distances = mark_input_output_distances
164167

0 commit comments

Comments
 (0)