|
34 | 34 |
|
35 | 35 | from typing import TYPE_CHECKING, List |
36 | 36 |
|
| 37 | +import time |
37 | 38 | import torch |
38 | 39 |
|
39 | 40 | from ..utils.tensor_utils import safe_copy |
@@ -99,84 +100,102 @@ def postprocess( |
99 | 100 | _set_pass_finished(self) |
100 | 101 | return |
101 | 102 |
|
102 | | - # Steps 1-3: Graph traversal (output nodes, ancestry, orphan removal) |
103 | | - with _vtimed(self, "Steps 1-3: Graph traversal"): |
104 | | - # Step 1: Add dedicated output nodes |
| 103 | + _vprint( |
| 104 | + self, |
| 105 | + f"Postprocessing {len(self._raw_layer_labels_list):,} layers " |
| 106 | + f"({len(self.buffer_layers):,} buffers)...", |
| 107 | + ) |
| 108 | + _post_t0 = time.time() if getattr(self, "verbose", False) else 0 |
| 109 | + |
| 110 | + # Step 1: Add dedicated output nodes |
| 111 | + with _vtimed(self, " Step 1: Add output layers"): |
105 | 112 | _add_output_layers(self, output_tensors, output_tensor_addresses) |
106 | 113 |
|
107 | | - # Step 2: Trace which nodes are ancestors of output nodes |
| 114 | + # Step 2: Trace which nodes are ancestors of output nodes |
| 115 | + with _vtimed(self, " Step 2: Trace output ancestors"): |
108 | 116 | _find_output_ancestors(self) |
109 | 117 |
|
110 | | - # Step 3: Remove orphan nodes, find nodes that don't terminate in output node |
| 118 | + # Step 3: Remove orphan nodes, find nodes that don't terminate in output node |
| 119 | + with _vtimed(self, " Step 3: Remove orphan nodes"): |
111 | 120 | _remove_orphan_nodes(self) |
112 | 121 |
|
113 | 122 | # Step 4: Find min/max distance from input and output nodes. |
114 | 123 | # Conditional: only runs when the user requested distance metadata. |
115 | 124 | if self.mark_input_output_distances: |
116 | | - with _vtimed(self, "Step 4: Input/output distances"): |
| 125 | + with _vtimed(self, " Step 4: Input/output distances"): |
117 | 126 | _mark_input_output_distances(self) |
118 | 127 |
|
119 | | - # Steps 5-7: Control flow (conditional branches, module fixing, buffers) |
120 | | - with _vtimed(self, "Steps 5-7: Control flow"): |
121 | | - # Step 5: Starting from terminal single boolean tensors, mark the conditional branches. |
| 128 | + # Step 5: Starting from terminal single boolean tensors, mark the conditional branches. |
| 129 | + with _vtimed(self, " Step 5: Mark conditional branches"): |
122 | 130 | _mark_conditional_branches(self) |
123 | 131 |
|
124 | | - # Step 6: Annotate the containing modules for all internally-generated tensors. |
| 132 | + # Step 6: Annotate the containing modules for all internally-generated tensors. |
| 133 | + with _vtimed(self, " Step 6: Fix module containment"): |
125 | 134 | _fix_modules_for_internal_tensors(self) |
126 | 135 |
|
127 | | - # Step 7: Fix the buffer passes and parent information. |
| 136 | + # Step 7: Fix the buffer passes and parent information. |
| 137 | + with _vtimed(self, " Step 7: Fix buffer layers"): |
128 | 138 | _fix_buffer_layers(self) |
129 | 139 |
|
130 | 140 | # Step 8: Identify all loops, mark repeated layers. |
131 | 141 | loop_desc = ( |
132 | | - "Step 8: Loop detection (full)" |
| 142 | + " Step 8: Loop detection (full)" |
133 | 143 | if self.detect_loops |
134 | | - else "Step 8: Loop detection (params only)" |
| 144 | + else " Step 8: Loop detection (params only)" |
135 | 145 | ) |
136 | 146 | with _vtimed(self, loop_desc): |
137 | 147 | if self.detect_loops: |
138 | 148 | _detect_and_label_loops(self) |
139 | 149 | else: |
140 | 150 | _group_by_shared_params(self) |
141 | 151 |
|
142 | | - # Steps 9-12: Labeling (label mapping, final info, rename, cleanup) |
143 | | - with _vtimed(self, "Steps 9-12: Labeling"): |
144 | | - # Step 9: Go down tensor list, get the mapping from raw tensor names to final tensor names. |
| 152 | + # Step 9: Go down tensor list, get the mapping from raw tensor names to final tensor names. |
| 153 | + with _vtimed(self, " Step 9: Map labels"): |
145 | 154 | _map_raw_labels_to_final_labels(self) |
146 | 155 |
|
147 | | - # Step 10: Log final info for all layers |
| 156 | + # Step 10: Log final info for all layers |
| 157 | + with _vtimed(self, " Step 10: Log final info"): |
148 | 158 | _log_final_info_for_all_layers(self) |
149 | 159 |
|
150 | | - # Step 11: Rename all raw labels to final labels |
| 160 | + # Step 11: Rename all raw labels to final labels |
| 161 | + with _vtimed(self, " Step 11: Rename labels"): |
151 | 162 | _rename_model_history_layer_names(self) |
152 | 163 | _trim_and_reorder_model_history_fields(self) |
153 | 164 |
|
154 | | - # Step 12: Remove unsaved layers, build lookup key mappings |
| 165 | + # Step 12: Remove unsaved layers, build lookup key mappings |
| 166 | + with _vtimed(self, " Step 12: Build lookup keys"): |
155 | 167 | _remove_unwanted_entries_and_log_remaining(self) |
156 | 168 |
|
157 | | - # Steps 13-18: Finalization |
158 | | - with _vtimed(self, "Steps 13-18: Finalization"): |
159 | | - # Step 13: Undecorate all saved tensors and remove saved grad_fns. |
| 169 | + # Step 13: Undecorate all saved tensors and remove saved grad_fns. |
| 170 | + with _vtimed(self, " Step 13: Undecorate tensors"): |
160 | 171 | _undecorate_all_saved_tensors(self) |
161 | 172 |
|
162 | | - # Step 14: Clear the cache after any tensor deletions for garbage collection purposes: |
163 | | - torch.cuda.empty_cache() |
| 173 | + # Step 14: Clear the cache after any tensor deletions for garbage collection purposes. |
| 174 | + torch.cuda.empty_cache() |
164 | 175 |
|
165 | | - # Step 15: Log time elapsed. |
| 176 | + # Step 15: Log time elapsed. |
| 177 | + with _vtimed(self, " Step 15: Log timing"): |
166 | 178 | _log_time_elapsed(self) |
167 | 179 |
|
168 | | - # Step 16: Populate ParamLog reverse mappings, linked params, num_passes, and gradient metadata. |
| 180 | + # Step 16: Populate ParamLog reverse mappings, linked params, num_passes, and gradient metadata. |
| 181 | + with _vtimed(self, " Step 16: Finalize params"): |
169 | 182 | _finalize_param_logs(self) |
170 | 183 |
|
171 | | - # Step 16.5: Build aggregate LayerLog objects from per-pass LayerPassLog entries. |
| 184 | + # Step 16.5: Build aggregate LayerLog objects from per-pass LayerPassLog entries. |
| 185 | + with _vtimed(self, " Step 16.5: Build layer logs"): |
172 | 186 | _build_layer_logs(self) |
173 | 187 |
|
174 | | - # Step 17: Build structured ModuleLog objects from raw module_* dicts. |
| 188 | + # Step 17: Build structured ModuleLog objects from raw module_* dicts. |
| 189 | + with _vtimed(self, " Step 17: Build module logs"): |
175 | 190 | _build_module_logs(self) |
176 | 191 |
|
177 | | - # Step 18: log the pass as finished, changing the ModelLog behavior to its user-facing version. |
| 192 | + # Step 18: log the pass as finished, changing the ModelLog behavior to its user-facing version. |
| 193 | + with _vtimed(self, " Step 18: Mark pass finished"): |
178 | 194 | _set_pass_finished(self) |
179 | 195 |
|
| 196 | + if getattr(self, "verbose", False): |
| 197 | + print(f"[torchlens] Postprocessing complete ({time.time() - _post_t0:.2f}s)") |
| 198 | + |
180 | 199 |
|
181 | 200 | def postprocess_fast(self: "ModelLog") -> None: |
182 | 201 | """Lightweight postprocessing for fast (second-pass) logging mode. |
|
0 commit comments