Skip to content

Commit 69099e8

Browse files
Merge pull request #131 from johnmarktaylor91/perf/postprocess-pipeline-optimizations
2 parents bd8b348 + 11ea006 commit 69099e8

File tree

5 files changed

+133
-81
lines changed

5 files changed

+133
-81
lines changed

torchlens/postprocess/__init__.py

Lines changed: 48 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
from typing import TYPE_CHECKING, List
3636

37+
import time
3738
import torch
3839

3940
from ..utils.tensor_utils import safe_copy
@@ -99,84 +100,102 @@ def postprocess(
99100
_set_pass_finished(self)
100101
return
101102

102-
# Steps 1-3: Graph traversal (output nodes, ancestry, orphan removal)
103-
with _vtimed(self, "Steps 1-3: Graph traversal"):
104-
# Step 1: Add dedicated output nodes
103+
_vprint(
104+
self,
105+
f"Postprocessing {len(self._raw_layer_labels_list):,} layers "
106+
f"({len(self.buffer_layers):,} buffers)...",
107+
)
108+
_post_t0 = time.time() if getattr(self, "verbose", False) else 0
109+
110+
# Step 1: Add dedicated output nodes
111+
with _vtimed(self, " Step 1: Add output layers"):
105112
_add_output_layers(self, output_tensors, output_tensor_addresses)
106113

107-
# Step 2: Trace which nodes are ancestors of output nodes
114+
# Step 2: Trace which nodes are ancestors of output nodes
115+
with _vtimed(self, " Step 2: Trace output ancestors"):
108116
_find_output_ancestors(self)
109117

110-
# Step 3: Remove orphan nodes, find nodes that don't terminate in output node
118+
# Step 3: Remove orphan nodes, find nodes that don't terminate in output node
119+
with _vtimed(self, " Step 3: Remove orphan nodes"):
111120
_remove_orphan_nodes(self)
112121

113122
# Step 4: Find min/max distance from input and output nodes.
114123
# Conditional: only runs when the user requested distance metadata.
115124
if self.mark_input_output_distances:
116-
with _vtimed(self, "Step 4: Input/output distances"):
125+
with _vtimed(self, " Step 4: Input/output distances"):
117126
_mark_input_output_distances(self)
118127

119-
# Steps 5-7: Control flow (conditional branches, module fixing, buffers)
120-
with _vtimed(self, "Steps 5-7: Control flow"):
121-
# Step 5: Starting from terminal single boolean tensors, mark the conditional branches.
128+
# Step 5: Starting from terminal single boolean tensors, mark the conditional branches.
129+
with _vtimed(self, " Step 5: Mark conditional branches"):
122130
_mark_conditional_branches(self)
123131

124-
# Step 6: Annotate the containing modules for all internally-generated tensors.
132+
# Step 6: Annotate the containing modules for all internally-generated tensors.
133+
with _vtimed(self, " Step 6: Fix module containment"):
125134
_fix_modules_for_internal_tensors(self)
126135

127-
# Step 7: Fix the buffer passes and parent information.
136+
# Step 7: Fix the buffer passes and parent information.
137+
with _vtimed(self, " Step 7: Fix buffer layers"):
128138
_fix_buffer_layers(self)
129139

130140
# Step 8: Identify all loops, mark repeated layers.
131141
loop_desc = (
132-
"Step 8: Loop detection (full)"
142+
" Step 8: Loop detection (full)"
133143
if self.detect_loops
134-
else "Step 8: Loop detection (params only)"
144+
else " Step 8: Loop detection (params only)"
135145
)
136146
with _vtimed(self, loop_desc):
137147
if self.detect_loops:
138148
_detect_and_label_loops(self)
139149
else:
140150
_group_by_shared_params(self)
141151

142-
# Steps 9-12: Labeling (label mapping, final info, rename, cleanup)
143-
with _vtimed(self, "Steps 9-12: Labeling"):
144-
# Step 9: Go down tensor list, get the mapping from raw tensor names to final tensor names.
152+
# Step 9: Go down tensor list, get the mapping from raw tensor names to final tensor names.
153+
with _vtimed(self, " Step 9: Map labels"):
145154
_map_raw_labels_to_final_labels(self)
146155

147-
# Step 10: Log final info for all layers
156+
# Step 10: Log final info for all layers
157+
with _vtimed(self, " Step 10: Log final info"):
148158
_log_final_info_for_all_layers(self)
149159

150-
# Step 11: Rename all raw labels to final labels
160+
# Step 11: Rename all raw labels to final labels
161+
with _vtimed(self, " Step 11: Rename labels"):
151162
_rename_model_history_layer_names(self)
152163
_trim_and_reorder_model_history_fields(self)
153164

154-
# Step 12: Remove unsaved layers, build lookup key mappings
165+
# Step 12: Remove unsaved layers, build lookup key mappings
166+
with _vtimed(self, " Step 12: Build lookup keys"):
155167
_remove_unwanted_entries_and_log_remaining(self)
156168

157-
# Steps 13-18: Finalization
158-
with _vtimed(self, "Steps 13-18: Finalization"):
159-
# Step 13: Undecorate all saved tensors and remove saved grad_fns.
169+
# Step 13: Undecorate all saved tensors and remove saved grad_fns.
170+
with _vtimed(self, " Step 13: Undecorate tensors"):
160171
_undecorate_all_saved_tensors(self)
161172

162-
# Step 14: Clear the cache after any tensor deletions for garbage collection purposes:
163-
torch.cuda.empty_cache()
173+
# Step 14: Clear the cache after any tensor deletions for garbage collection purposes.
174+
torch.cuda.empty_cache()
164175

165-
# Step 15: Log time elapsed.
176+
# Step 15: Log time elapsed.
177+
with _vtimed(self, " Step 15: Log timing"):
166178
_log_time_elapsed(self)
167179

168-
# Step 16: Populate ParamLog reverse mappings, linked params, num_passes, and gradient metadata.
180+
# Step 16: Populate ParamLog reverse mappings, linked params, num_passes, and gradient metadata.
181+
with _vtimed(self, " Step 16: Finalize params"):
169182
_finalize_param_logs(self)
170183

171-
# Step 16.5: Build aggregate LayerLog objects from per-pass LayerPassLog entries.
184+
# Step 16.5: Build aggregate LayerLog objects from per-pass LayerPassLog entries.
185+
with _vtimed(self, " Step 16.5: Build layer logs"):
172186
_build_layer_logs(self)
173187

174-
# Step 17: Build structured ModuleLog objects from raw module_* dicts.
188+
# Step 17: Build structured ModuleLog objects from raw module_* dicts.
189+
with _vtimed(self, " Step 17: Build module logs"):
175190
_build_module_logs(self)
176191

177-
# Step 18: log the pass as finished, changing the ModelLog behavior to its user-facing version.
192+
# Step 18: log the pass as finished, changing the ModelLog behavior to its user-facing version.
193+
with _vtimed(self, " Step 18: Mark pass finished"):
178194
_set_pass_finished(self)
179195

196+
if getattr(self, "verbose", False):
197+
print(f"[torchlens] Postprocessing complete ({time.time() - _post_t0:.2f}s)")
198+
180199

181200
def postprocess_fast(self: "ModelLog") -> None:
182201
"""Lightweight postprocessing for fast (second-pass) logging mode.

torchlens/postprocess/control_flow.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,9 +310,14 @@ def _fix_modules_for_internal_tensors(self) -> None:
310310
# Append module path suffix to operation_equivalence_type for ALL tensors.
311311
# This ensures loop detection (Step 8) treats same-function operations in
312312
# different modules as distinct equivalence types.
313+
_module_str_cache = {}
313314
for layer in self:
314-
module_str = "_".join([module_pass[0] for module_pass in layer.containing_modules])
315-
layer.operation_equivalence_type += module_str
315+
cm_key = tuple(layer.containing_modules)
316+
if cm_key not in _module_str_cache:
317+
_module_str_cache[cm_key] = "_".join(
318+
[module_pass[0] for module_pass in layer.containing_modules]
319+
)
320+
layer.operation_equivalence_type += _module_str_cache[cm_key]
316321

317322

318323
def _fix_modules_for_single_internal_tensor(

torchlens/postprocess/finalization.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,14 @@ def _undecorate_all_saved_tensors(self) -> None:
4646
if layer_entry.activation is not None:
4747
tensors_to_undecorate.append(layer_entry.activation)
4848

49-
tensors_to_undecorate.extend(
50-
get_vars_of_type_from_obj(layer_entry.captured_args, torch.Tensor, search_depth=2)
51-
)
52-
tensors_to_undecorate.extend(
53-
get_vars_of_type_from_obj(layer_entry.captured_kwargs, torch.Tensor, search_depth=2)
54-
)
49+
if layer_entry.captured_args:
50+
tensors_to_undecorate.extend(
51+
get_vars_of_type_from_obj(layer_entry.captured_args, torch.Tensor, search_depth=2)
52+
)
53+
if layer_entry.captured_kwargs:
54+
tensors_to_undecorate.extend(
55+
get_vars_of_type_from_obj(layer_entry.captured_kwargs, torch.Tensor, search_depth=2)
56+
)
5557

5658
for t in tensors_to_undecorate:
5759
if hasattr(t, "tl_tensor_label_raw"):
@@ -321,7 +323,9 @@ class ModuleParamInfo(NamedTuple):
321323
buffer_layers: list
322324

323325

324-
def _build_module_param_info(self: "ModelLog", address: str, mbd: dict) -> ModuleParamInfo:
326+
def _build_module_param_info(
327+
self: "ModelLog", address: str, mbd: dict, _buffer_layers_by_module: Optional[dict] = None
328+
) -> ModuleParamInfo:
325329
"""Gather parameter counts, sizes, and buffer layers for a single module."""
326330
from ..data_classes.param_log import ParamAccessor
327331

@@ -332,14 +336,17 @@ def _build_module_param_info(self: "ModelLog", address: str, mbd: dict) -> Modul
332336
m_num_frozen = mbd["module_nparams_frozen"].get(address, 0)
333337
m_fsize = sum(pl.memory for pl in module_param_dict.values())
334338

335-
module_buffer_layers = [
336-
bl
337-
for bl in self.buffer_layers
338-
if bl in self.layer_dict_all_keys
339-
and hasattr(self.layer_dict_all_keys[bl], "buffer_address")
340-
and self.layer_dict_all_keys[bl].buffer_address is not None
341-
and self.layer_dict_all_keys[bl].buffer_address.rsplit(".", 1)[0] == address
342-
]
339+
if _buffer_layers_by_module is not None:
340+
module_buffer_layers = list(_buffer_layers_by_module.get(address, []))
341+
else:
342+
module_buffer_layers = [
343+
bl
344+
for bl in self.buffer_layers
345+
if bl in self.layer_dict_all_keys
346+
and hasattr(self.layer_dict_all_keys[bl], "buffer_address")
347+
and self.layer_dict_all_keys[bl].buffer_address is not None
348+
and self.layer_dict_all_keys[bl].buffer_address.rsplit(".", 1)[0] == address
349+
]
343350

344351
return ModuleParamInfo(
345352
module_params, m_num_params, m_num_trainable, m_num_frozen, m_fsize, module_buffer_layers
@@ -395,6 +402,15 @@ def _build_module_logs(self: "ModelLog") -> None:
395402
for _alias in _meta.get("all_addresses", [_primary_addr]):
396403
_metadata_by_alias[_alias] = _meta
397404

405+
# Pre-compute buffer layers grouped by parent module address (O6).
406+
_buffer_layers_by_module = defaultdict(list)
407+
for bl in self.buffer_layers:
408+
if bl in self.layer_dict_all_keys:
409+
bl_entry = self.layer_dict_all_keys[bl]
410+
if hasattr(bl_entry, "buffer_address") and bl_entry.buffer_address is not None:
411+
module_addr = bl_entry.buffer_address.rsplit(".", 1)[0]
412+
_buffer_layers_by_module[module_addr].append(bl)
413+
398414
# --- Build ModuleLogs for each submodule ---
399415
for address in mbd["module_addresses"]:
400416
meta = _metadata_by_alias.get(address, {})
@@ -425,7 +441,7 @@ def _build_module_logs(self: "ModelLog") -> None:
425441
all_module_addresses=all_addresses,
426442
)
427443
call_children_all, call_parent_addr = _resolve_call_hierarchy(passes)
428-
param_info = _build_module_param_info(self, address, mbd)
444+
param_info = _build_module_param_info(self, address, mbd, _buffer_layers_by_module)
429445

430446
# address_children from metadata may have a different address prefix
431447
# when the metadata was captured for a shared module under a different

torchlens/postprocess/labeling.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
"""
2323

2424
import weakref
25-
from collections import OrderedDict, defaultdict
25+
from collections import defaultdict
2626
from typing import TYPE_CHECKING
2727

2828
from ..constants import MODEL_LOG_FIELD_ORDER, LAYER_PASS_LOG_FIELD_ORDER
@@ -491,7 +491,7 @@ def _trim_and_reorder_layer_entry_fields(layer_entry: LayerPassLog) -> None:
491491
Callable attributes (methods) are excluded from the reordered dict.
492492
"""
493493
old_dict = layer_entry.__dict__
494-
new_dir_dict = OrderedDict()
494+
new_dir_dict = {}
495495
# First: fields in canonical order.
496496
for field in LAYER_PASS_LOG_FIELD_ORDER:
497497
if field in old_dict:
@@ -563,18 +563,11 @@ def _rename_model_history_layer_names(self) -> None:
563563

564564
mla = self._module_build_data["module_layer_argnames"]
565565
for module_pass, arglist in mla.items():
566-
inds_to_remove = set()
567-
for a, arg in enumerate(arglist):
568-
raw_name = mla[module_pass][a][0]
569-
if raw_name not in self._raw_to_final_layer_labels:
570-
inds_to_remove.add(a)
571-
continue
572-
new_name = self._raw_to_final_layer_labels[raw_name]
573-
argname = mla[module_pass][a][1]
574-
mla[module_pass][a] = (new_name, argname)
575-
mla[module_pass] = [
576-
mla[module_pass][i] for i in range(len(arglist)) if i not in inds_to_remove
577-
]
566+
new_arglist = []
567+
for raw_name, argname in arglist:
568+
if raw_name in self._raw_to_final_layer_labels:
569+
new_arglist.append((self._raw_to_final_layer_labels[raw_name], argname))
570+
mla[module_pass] = new_arglist
578571

579572

580573
def _trim_and_reorder_model_history_fields(self) -> None:
@@ -584,7 +577,7 @@ def _trim_and_reorder_model_history_fields(self) -> None:
584577
Public fields listed in MODEL_LOG_FIELD_ORDER come first, followed by any
585578
private fields (starting with ``_``) not already in the order list.
586579
"""
587-
new_dir_dict = OrderedDict()
580+
new_dir_dict = {}
588581
for field in MODEL_LOG_FIELD_ORDER:
589582
new_dir_dict[field] = getattr(self, field)
590583
# Preserve all remaining fields not in the canonical order (private/internal

torchlens/postprocess/loop_detection.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -382,9 +382,17 @@ def union(x, y, uf=uf_parent):
382382
if rx != ry:
383383
uf[rx] = ry
384384

385-
for member1, member2 in it.combinations(members, 2):
386-
if member_neighbor_isos[member1] & member_neighbor_isos[member2]:
387-
union(member1, member2)
385+
# Reverse-index approach: union members sharing a neighbor key.
386+
# O(members × avg_neighbors) instead of O(members²).
387+
_reverse_index = defaultdict(list)
388+
for member_label in members:
389+
for neighbor_key in member_neighbor_isos[member_label]:
390+
_reverse_index[neighbor_key].append(member_label)
391+
for members_with_key in _reverse_index.values():
392+
if len(members_with_key) > 1:
393+
first = members_with_key[0]
394+
for other in members_with_key[1:]:
395+
union(first, other)
388396

389397
components = defaultdict(list)
390398
for member in members:
@@ -628,10 +636,16 @@ def _find_isomorphic_matches(
628636

629637
# Remove collisions: if the same node appears in multiple (node, subgraph) tuples,
630638
# discard all occurrences to avoid assigning one node to multiple subgraphs.
631-
node_labels = [node[0] for node in new_equivalent_nodes]
632-
new_equivalent_nodes = [
633-
node for node in new_equivalent_nodes if node_labels.count(node[0]) == 1
634-
]
639+
_seen_labels = set()
640+
_dupe_labels = set()
641+
for node in new_equivalent_nodes:
642+
if node[0] in _seen_labels:
643+
_dupe_labels.add(node[0])
644+
_seen_labels.add(node[0])
645+
if _dupe_labels:
646+
new_equivalent_nodes = [
647+
node for node in new_equivalent_nodes if node[0] not in _dupe_labels
648+
]
635649
return new_equivalent_nodes
636650

637651

@@ -776,26 +790,31 @@ def _union(x: str, y: str) -> None:
776790
for iso_nodes_orig in iso_node_groups.values():
777791
all_iso_nodes.update(iso_nodes_orig)
778792

793+
# Pre-compute param types per subgraph for O(1) lookup in the pair loop (O10).
794+
_sg_param_types: Dict[str, frozenset] = {}
795+
for iso_nodes_orig in iso_node_groups.values():
796+
for node_label in iso_nodes_orig:
797+
sg = node_to_subgraph[node_label]
798+
sg_label = sg.starting_node
799+
if sg_label not in _sg_param_types:
800+
_sg_param_types[sg_label] = frozenset(
801+
self[pnode].operation_equivalence_type for pnode in sg.param_nodes
802+
)
803+
779804
# PASS 1: Within iso-groups — merge nodes whose subgraphs share param types or are adjacent.
780805
for iso_group_label, iso_nodes_orig in iso_node_groups.items():
781806
iso_nodes = sorted(iso_nodes_orig)
782807
for node1_label, node2_label in it.combinations(iso_nodes, 2):
783-
node1_subgraph = node_to_subgraph[node1_label]
784-
node2_subgraph = node_to_subgraph[node2_label]
785-
node1_subgraph_label = node1_subgraph.starting_node
786-
node2_subgraph_label = node2_subgraph.starting_node
787-
node1_param_types = [
788-
self[pnode].operation_equivalence_type for pnode in node1_subgraph.param_nodes
789-
]
790-
node2_param_types = [
791-
self[pnode].operation_equivalence_type for pnode in node2_subgraph.param_nodes
792-
]
793-
overlapping_param_types = set(node1_param_types).intersection(set(node2_param_types))
808+
node1_subgraph_label = node_to_subgraph[node1_label].starting_node
809+
node2_subgraph_label = node_to_subgraph[node2_label].starting_node
810+
overlapping_param_types = (
811+
_sg_param_types[node1_subgraph_label] & _sg_param_types[node2_subgraph_label]
812+
)
794813
subgraphs_are_adjacent = (
795814
node1_subgraph_label in adjacent_subgraphs
796815
and node2_subgraph_label in adjacent_subgraphs[node1_subgraph_label]
797816
)
798-
if (len(overlapping_param_types) > 0) or subgraphs_are_adjacent:
817+
if overlapping_param_types or subgraphs_are_adjacent:
799818
_union(node1_label, node2_label)
800819

801820
# PASS 2: Cross iso-groups — unconditionally merge by (func, params) identity.

0 commit comments

Comments
 (0)