Skip to content

Commit 0603f10

Browse files
feat(logging): add verbose mode for timed progress messages
Add `verbose: bool = False` parameter to `log_forward_pass`, `show_model_graph`, and internal pipeline functions. When enabled, prints `[torchlens]`-prefixed progress at each major pipeline stage with timing. Also fixes `_trim_and_reorder_model_history_fields` to preserve all non-ordered attributes (not just private ones). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 087d42d commit 0603f10

File tree

7 files changed

+142
-67
lines changed

7 files changed

+142
-67
lines changed

torchlens/capture/trace.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from ..utils.arg_handling import safe_copy_args, safe_copy_kwargs, normalize_input_args
5050
from .source_tensors import log_source_tensor
5151
from ..data_classes.interface import _give_user_feedback_about_lookup_key
52+
from ..utils.display import _vprint, _vtimed
5253

5354

5455
def save_new_activations(
@@ -133,6 +134,7 @@ def save_new_activations(
133134
self._raw_layer_dict.pop(label, None)
134135

135136
# Now run and log the new inputs.
137+
_vprint(self, "Running fast pass (saving requested activations)")
136138
self._run_and_log_inputs_through_model(
137139
model, input_args, input_kwargs, layers_to_save, random_seed
138140
)
@@ -431,12 +433,23 @@ def run_and_log_inputs_through_model(
431433
# Per-session model preparation
432434
_prepare_model_session(self, model, self._optimizer)
433435
self.elapsed_time_setup = time.time() - self.pass_start_time
436+
_vprint(self, f"Model prepared ({self.elapsed_time_setup:.2f}s)")
437+
438+
# Print input summary
439+
if getattr(self, "verbose", False):
440+
devices = set()
441+
for t in input_tensors:
442+
if hasattr(t, "device"):
443+
devices.add(str(t.device))
444+
device_str = ", ".join(sorted(devices)) if devices else "unknown"
445+
_vprint(self, f"Inputs: {len(input_tensors)} tensor(s) on {device_str}")
434446

435447
# Turn on the logging toggle and run the forward pass.
436448
# Inside this context, every decorated torch function will log its
437449
# inputs/outputs. Source tensors (model inputs) are logged explicitly
438450
# before invoking the model; all subsequent operations are captured
439451
# automatically by the decorated wrappers.
452+
_vprint(self, f"Running {self.logging_mode} forward pass...")
440453
with _state.active_logging(self):
441454
for i, t in enumerate(input_tensors):
442455
log_source_tensor(self, t, "input", input_tensor_addresses[i])
@@ -446,10 +459,16 @@ def run_and_log_inputs_through_model(
446459
self.elapsed_time_forward_pass = (
447460
time.time() - self.pass_start_time - self.elapsed_time_setup
448461
)
462+
_vprint(
463+
self,
464+
f"Forward pass complete ({self.elapsed_time_forward_pass:.2f}s, "
465+
f"{len(self._raw_layer_dict)} raw operations)",
466+
)
449467

450468
output_tensors, output_tensor_addresses = _extract_and_mark_outputs(self, outputs)
451469

452470
_cleanup_model_session(model, input_tensors)
471+
_vprint(self, f"Postprocessing {len(self._raw_layer_dict)} operations...")
453472
self._postprocess(output_tensors, output_tensor_addresses)
454473

455474
except Exception as e:

torchlens/data_classes/model_log.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def __init__(
112112
save_source_context: bool = False,
113113
save_rng_states: bool = False,
114114
detect_loops: bool = True,
115+
verbose: bool = False,
115116
):
116117
"""Initialise a fresh ModelLog for a new logging session.
117118
@@ -130,6 +131,7 @@ def __init__(
130131
around each function call (used by FuncCallLocation).
131132
optimizer: Optional torch optimizer, used to annotate which params
132133
have optimizers attached.
134+
verbose: If True, print timed progress messages at each major pipeline stage.
133135
"""
134136
# Callables are effectively immutable — deepcopy is unnecessary.
135137

@@ -159,6 +161,7 @@ def __init__(
159161
self.save_source_context = save_source_context
160162
self.save_rng_states = save_rng_states
161163
self.detect_loops = detect_loops
164+
self.verbose = verbose
162165
self.has_saved_gradients = False
163166
self.mark_input_output_distances = mark_input_output_distances
164167

torchlens/postprocess/__init__.py

Lines changed: 58 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@
6969
if TYPE_CHECKING:
7070
from ..data_classes.model_log import ModelLog
7171

72+
from ..utils.display import _vprint, _vtimed
73+
7274

7375
def postprocess(
7476
self: "ModelLog", output_tensors: List[torch.Tensor], output_tensor_addresses: List[str]
@@ -97,89 +99,83 @@ def postprocess(
9799
_set_pass_finished(self)
98100
return
99101

100-
# Step 1: Add dedicated output nodes
101-
102-
_add_output_layers(self, output_tensors, output_tensor_addresses)
103-
104-
# Step 2: Trace which nodes are ancestors of output nodes
102+
# Steps 1-3: Graph traversal (output nodes, ancestry, orphan removal)
103+
with _vtimed(self, "Steps 1-3: Graph traversal"):
104+
# Step 1: Add dedicated output nodes
105+
_add_output_layers(self, output_tensors, output_tensor_addresses)
105106

106-
_find_output_ancestors(self)
107+
# Step 2: Trace which nodes are ancestors of output nodes
108+
_find_output_ancestors(self)
107109

108-
# Step 3: Remove orphan nodes, find nodes that don't terminate in output node
109-
110-
_remove_orphan_nodes(self)
110+
# Step 3: Remove orphan nodes, find nodes that don't terminate in output node
111+
_remove_orphan_nodes(self)
111112

112113
# Step 4: Find min/max distance from input and output nodes.
113114
# Conditional: only runs when the user requested distance metadata.
114-
115115
if self.mark_input_output_distances:
116-
_mark_input_output_distances(self)
116+
with _vtimed(self, "Step 4: Input/output distances"):
117+
_mark_input_output_distances(self)
117118

118-
# Step 5: Starting from terminal single boolean tensors, mark the conditional branches.
119+
# Steps 5-7: Control flow (conditional branches, module fixing, buffers)
120+
with _vtimed(self, "Steps 5-7: Control flow"):
121+
# Step 5: Starting from terminal single boolean tensors, mark the conditional branches.
122+
_mark_conditional_branches(self)
119123

120-
_mark_conditional_branches(self)
124+
# Step 6: Annotate the containing modules for all internally-generated tensors.
125+
_fix_modules_for_internal_tensors(self)
121126

122-
# Step 6: Annotate the containing modules for all internally-generated tensors.
123-
# Internally-initialized tensors (e.g., constants, arange results) don't know
124-
# what module they belong to. This traces backward from input-descendant tensors
125-
# to infer module containment. IMPORTANT: also appends module path suffixes to
126-
# operation_equivalence_type, which affects Step 8 loop detection grouping.
127-
128-
_fix_modules_for_internal_tensors(self)
129-
130-
# Step 7: Fix the buffer passes and parent information.
131-
# Connects buffer parents, merges duplicate buffer nodes (same module, same
132-
# value, same parents), and assigns buffer pass numbers.
133-
134-
_fix_buffer_layers(self)
127+
# Step 7: Fix the buffer passes and parent information.
128+
_fix_buffer_layers(self)
135129

136130
# Step 8: Identify all loops, mark repeated layers.
131+
loop_desc = (
132+
"Step 8: Loop detection (full)"
133+
if self.detect_loops
134+
else "Step 8: Loop detection (params only)"
135+
)
136+
with _vtimed(self, loop_desc):
137+
if self.detect_loops:
138+
_detect_and_label_loops(self)
139+
else:
140+
_group_by_shared_params(self)
137141

138-
if self.detect_loops:
139-
_detect_and_label_loops(self)
140-
else:
141-
_group_by_shared_params(self)
142+
# Steps 9-12: Labeling (label mapping, final info, rename, cleanup)
143+
with _vtimed(self, "Steps 9-12: Labeling"):
144+
# Step 9: Go down tensor list, get the mapping from raw tensor names to final tensor names.
145+
_map_raw_labels_to_final_labels(self)
142146

143-
# Step 9: Go down tensor list, get the mapping from raw tensor names to final tensor names.
147+
# Step 10: Log final info for all layers
148+
_log_final_info_for_all_layers(self)
144149

145-
_map_raw_labels_to_final_labels(self)
150+
# Step 11: Rename all raw labels to final labels
151+
_rename_model_history_layer_names(self)
152+
_trim_and_reorder_model_history_fields(self)
146153

147-
# Step 10: Log final info for all layers (operation numbers, module hierarchy,
148-
# param tallies, structural flags). MUST run before Step 12 because lookup key
149-
# generation in Step 12 needs module hierarchy data populated here.
150-
_log_final_info_for_all_layers(self)
154+
# Step 12: Remove unsaved layers, build lookup key mappings
155+
_remove_unwanted_entries_and_log_remaining(self)
151156

152-
# Step 11: Rename all raw labels (e.g., "cos_3_raw") to final labels
153-
# (e.g., "cos_1_3:2") in both ModelLog-level fields and LayerPassLog fields.
154-
# Then reorder ModelLog fields into the canonical display order.
155-
_rename_model_history_layer_names(self)
156-
_trim_and_reorder_model_history_fields(self)
157+
# Steps 13-18: Finalization
158+
with _vtimed(self, "Steps 13-18: Finalization"):
159+
# Step 13: Undecorate all saved tensors and remove saved grad_fns.
160+
_undecorate_all_saved_tensors(self)
157161

158-
# Step 12: Remove unsaved layers (unless keep_unsaved_layers=True), build
159-
# lookup key mappings, and log remaining layer metadata.
160-
_remove_unwanted_entries_and_log_remaining(self)
162+
# Step 14: Clear the cache after any tensor deletions for garbage collection purposes:
163+
torch.cuda.empty_cache()
161164

162-
# Step 13: Undecorate all saved tensors and remove saved grad_fns.
163-
_undecorate_all_saved_tensors(self)
165+
# Step 15: Log time elapsed.
166+
_log_time_elapsed(self)
164167

165-
# Step 14: Clear the cache after any tensor deletions for garbage collection purposes:
166-
torch.cuda.empty_cache()
168+
# Step 16: Populate ParamLog reverse mappings, linked params, num_passes, and gradient metadata.
169+
_finalize_param_logs(self)
167170

168-
# Step 15: Log time elapsed.
169-
_log_time_elapsed(self)
171+
# Step 16.5: Build aggregate LayerLog objects from per-pass LayerPassLog entries.
172+
_build_layer_logs(self)
170173

171-
# Step 16: Populate ParamLog reverse mappings, linked params, num_passes, and gradient metadata.
172-
_finalize_param_logs(self)
173-
174-
# Step 16.5: Build aggregate LayerLog objects from per-pass LayerPassLog entries.
175-
_build_layer_logs(self)
174+
# Step 17: Build structured ModuleLog objects from raw module_* dicts.
175+
_build_module_logs(self)
176176

177-
# Step 17: Build structured ModuleLog objects from raw module_* dicts.
178-
_build_module_logs(self)
179-
180-
# Step 18: log the pass as finished, changing the ModelLog behavior to its user-facing version.
181-
182-
_set_pass_finished(self)
177+
# Step 18: log the pass as finished, changing the ModelLog behavior to its user-facing version.
178+
_set_pass_finished(self)
183179

184180

185181
def postprocess_fast(self: "ModelLog") -> None:
@@ -202,6 +198,7 @@ def postprocess_fast(self: "ModelLog") -> None:
202198
- Step 17: _build_module_logs — module structure doesn't change between
203199
passes and _module_build_data isn't repopulated in fast mode (#108).
204200
"""
201+
_vprint(self, "Fast-pass postprocessing...")
205202
# Use layer_dict_main_keys to get LayerPassLog directly (not LayerLog)
206203
for output_layer_label in self.output_layers:
207204
output_layer = self.layer_dict_main_keys[output_layer_label]

torchlens/postprocess/labeling.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -584,8 +584,9 @@ def _trim_and_reorder_model_history_fields(self) -> None:
584584
new_dir_dict = OrderedDict()
585585
for field in MODEL_LOG_FIELD_ORDER:
586586
new_dir_dict[field] = getattr(self, field)
587-
# Preserve private/internal fields not in the canonical order.
587+
# Preserve all remaining fields not in the canonical order (private/internal
588+
# fields AND runtime-config attributes like ``verbose``).
588589
for field, value in self.__dict__.items():
589-
if field.startswith("_") and field not in new_dir_dict:
590+
if field not in new_dir_dict:
590591
new_dir_dict[field] = value
591592
self.__dict__ = new_dir_dict

torchlens/user_funcs.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
from .utils.introspection import get_vars_of_type_from_obj
3030
from .utils.rng import set_random_seed
31-
from .utils.display import warn_parallel
31+
from .utils.display import warn_parallel, _vprint
3232
from .utils.arg_handling import safe_copy_args, safe_copy_kwargs, normalize_input_args
3333
from .data_classes.model_log import (
3434
ModelLog,
@@ -83,6 +83,7 @@ def _run_model_and_save_specified_activations(
8383
save_source_context: bool = False,
8484
save_rng_states: bool = False,
8585
detect_loops: bool = True,
86+
verbose: bool = False,
8687
) -> ModelLog:
8788
"""Run a forward pass with logging enabled, returning a populated ModelLog.
8889
@@ -112,6 +113,7 @@ def _run_model_and_save_specified_activations(
112113
detect_loops: If True (default), run full isomorphic subgraph expansion to
113114
detect repeated patterns (loops). If False, only group operations that
114115
share the same parameters — much faster for very large graphs.
116+
verbose: If True, print timed progress messages at each major pipeline stage.
115117
116118
Returns:
117119
Fully-populated ModelLog.
@@ -140,6 +142,7 @@ def _run_model_and_save_specified_activations(
140142
save_source_context,
141143
save_rng_states,
142144
detect_loops,
145+
verbose,
143146
)
144147
model_log._run_and_log_inputs_through_model(
145148
model, input_args, input_kwargs, layers_to_save, random_seed
@@ -179,6 +182,7 @@ def log_forward_pass(
179182
num_context_lines: int = 7,
180183
optimizer=None,
181184
detect_loops: bool = True,
185+
verbose: bool = False,
182186
) -> ModelLog:
183187
"""Run a forward pass through *model*, log every operation, and return a ModelLog.
184188
@@ -241,6 +245,7 @@ def log_forward_pass(
241245
random_seed: Fixed RNG seed for reproducibility with stochastic models.
242246
num_context_lines: Lines of source context to capture per function call.
243247
optimizer: Optional optimizer to annotate which params are being optimized.
248+
verbose: If True, print timed progress messages at each major pipeline stage.
244249
245250
Returns:
246251
A ``ModelLog`` containing layer activations (if requested) and full metadata.
@@ -279,12 +284,15 @@ def log_forward_pass(
279284
save_source_context=save_source_context,
280285
save_rng_states=save_rng_states,
281286
detect_loops=detect_loops,
287+
verbose=verbose,
282288
)
283289
else:
284290
# --- TWO-PASS path ---
285291
# Pass 1 (exhaustive): Run with layers_to_save=None and keep_unsaved_layers=True
286292
# so the full graph is discovered and all layer labels are assigned. No
287293
# activations are saved yet — this pass is purely for metadata/structure.
294+
if verbose:
295+
print("[torchlens] Two-pass mode: Pass 1 (exhaustive, metadata only)")
288296
model_log = _run_model_and_save_specified_activations(
289297
model=model,
290298
input_args=input_args, # type: ignore[arg-type]
@@ -303,9 +311,11 @@ def log_forward_pass(
303311
save_source_context=save_source_context,
304312
save_rng_states=save_rng_states,
305313
detect_loops=detect_loops,
314+
verbose=verbose,
306315
)
307316
# Pass 2 (fast): Now that layer labels exist, resolve the user's requested
308317
# layers and replay the model, saving only the matching activations.
318+
_vprint(model_log, "Two-pass mode: Pass 2 (fast, saving requested layers)")
309319
model_log.keep_unsaved_layers = keep_unsaved_layers
310320
model_log.save_new_activations(
311321
model=model,
@@ -315,6 +325,14 @@ def log_forward_pass(
315325
random_seed=random_seed,
316326
)
317327

328+
# Print final summary.
329+
_vprint(
330+
model_log,
331+
f"Done: {len(model_log.layer_logs)} layers, "
332+
f"{model_log.num_tensors_saved} saved, "
333+
f"{model_log.tensor_fsize_total_nice}",
334+
)
335+
318336
# Visualize if desired.
319337
if vis_opt != "none":
320338
model_log.render_graph(
@@ -386,6 +404,7 @@ def show_model_graph(
386404
vis_node_placement: str = "auto",
387405
random_seed: Optional[int] = None,
388406
detect_loops: bool = True,
407+
verbose: bool = False,
389408
) -> None:
390409
"""Convenience wrapper: visualize the computational graph without saving activations.
391410
@@ -428,6 +447,7 @@ def show_model_graph(
428447
save_gradients=False,
429448
random_seed=random_seed,
430449
detect_loops=detect_loops,
450+
verbose=verbose,
431451
)
432452
# Render in a try/finally so temporary tl_ attributes on the model are
433453
# always cleaned up, even if Graphviz rendering raises.

0 commit comments

Comments
 (0)