6969if TYPE_CHECKING :
7070 from ..data_classes .model_log import ModelLog
7171
72+ from ..utils .display import _vprint , _vtimed
73+
7274
7375def postprocess (
7476 self : "ModelLog" , output_tensors : List [torch .Tensor ], output_tensor_addresses : List [str ]
@@ -97,89 +99,83 @@ def postprocess(
9799 _set_pass_finished (self )
98100 return
99101
100- # Step 1: Add dedicated output nodes
101-
102- _add_output_layers (self , output_tensors , output_tensor_addresses )
103-
104- # Step 2: Trace which nodes are ancestors of output nodes
102+ # Steps 1-3: Graph traversal (output nodes, ancestry, orphan removal)
103+ with _vtimed (self , "Steps 1-3: Graph traversal" ):
104+ # Step 1: Add dedicated output nodes
105+ _add_output_layers (self , output_tensors , output_tensor_addresses )
105106
106- _find_output_ancestors (self )
107+ # Step 2: Trace which nodes are ancestors of output nodes
108+ _find_output_ancestors (self )
107109
108- # Step 3: Remove orphan nodes, find nodes that don't terminate in output node
109-
110- _remove_orphan_nodes (self )
110+ # Step 3: Remove orphan nodes, find nodes that don't terminate in output node
111+ _remove_orphan_nodes (self )
111112
112113 # Step 4: Find min/max distance from input and output nodes.
113114 # Conditional: only runs when the user requested distance metadata.
114-
115115 if self .mark_input_output_distances :
116- _mark_input_output_distances (self )
116+ with _vtimed (self , "Step 4: Input/output distances" ):
117+ _mark_input_output_distances (self )
117118
118- # Step 5: Starting from terminal single boolean tensors, mark the conditional branches.
119+ # Steps 5-7: Control flow (conditional branches, module fixing, buffers)
120+ with _vtimed (self , "Steps 5-7: Control flow" ):
121+ # Step 5: Starting from terminal single boolean tensors, mark the conditional branches.
122+ _mark_conditional_branches (self )
119123
120- _mark_conditional_branches (self )
124+ # Step 6: Annotate the containing modules for all internally-generated tensors.
125+ _fix_modules_for_internal_tensors (self )
121126
122- # Step 6: Annotate the containing modules for all internally-generated tensors.
123- # Internally-initialized tensors (e.g., constants, arange results) don't know
124- # what module they belong to. This traces backward from input-descendant tensors
125- # to infer module containment. IMPORTANT: also appends module path suffixes to
126- # operation_equivalence_type, which affects Step 8 loop detection grouping.
127-
128- _fix_modules_for_internal_tensors (self )
129-
130- # Step 7: Fix the buffer passes and parent information.
131- # Connects buffer parents, merges duplicate buffer nodes (same module, same
132- # value, same parents), and assigns buffer pass numbers.
133-
134- _fix_buffer_layers (self )
127+ # Step 7: Fix the buffer passes and parent information.
128+ _fix_buffer_layers (self )
135129
136130 # Step 8: Identify all loops, mark repeated layers.
131+ loop_desc = (
132+ "Step 8: Loop detection (full)"
133+ if self .detect_loops
134+ else "Step 8: Loop detection (params only)"
135+ )
136+ with _vtimed (self , loop_desc ):
137+ if self .detect_loops :
138+ _detect_and_label_loops (self )
139+ else :
140+ _group_by_shared_params (self )
137141
138- if self . detect_loops :
139- _detect_and_label_loops ( self )
140- else :
141- _group_by_shared_params (self )
142+ # Steps 9-12: Labeling (label mapping, final info, rename, cleanup)
143+ with _vtimed ( self , "Steps 9-12: Labeling" ):
144+ # Step 9: Go down tensor list, get the mapping from raw tensor names to final tensor names.
145+ _map_raw_labels_to_final_labels (self )
142146
143- # Step 9: Go down tensor list, get the mapping from raw tensor names to final tensor names.
147+ # Step 10: Log final info for all layers
148+ _log_final_info_for_all_layers (self )
144149
145- _map_raw_labels_to_final_labels (self )
150+ # Step 11: Rename all raw labels to final labels
151+ _rename_model_history_layer_names (self )
152+ _trim_and_reorder_model_history_fields (self )
146153
147- # Step 10: Log final info for all layers (operation numbers, module hierarchy,
148- # param tallies, structural flags). MUST run before Step 12 because lookup key
149- # generation in Step 12 needs module hierarchy data populated here.
150- _log_final_info_for_all_layers (self )
154+ # Step 12: Remove unsaved layers, build lookup key mappings
155+ _remove_unwanted_entries_and_log_remaining (self )
151156
152- # Step 11: Rename all raw labels (e.g., "cos_3_raw") to final labels
153- # (e.g., "cos_1_3:2") in both ModelLog-level fields and LayerPassLog fields.
154- # Then reorder ModelLog fields into the canonical display order.
155- _rename_model_history_layer_names (self )
156- _trim_and_reorder_model_history_fields (self )
157+ # Steps 13-18: Finalization
158+ with _vtimed (self , "Steps 13-18: Finalization" ):
159+ # Step 13: Undecorate all saved tensors and remove saved grad_fns.
160+ _undecorate_all_saved_tensors (self )
157161
158- # Step 12: Remove unsaved layers (unless keep_unsaved_layers=True), build
159- # lookup key mappings, and log remaining layer metadata.
160- _remove_unwanted_entries_and_log_remaining (self )
162+ # Step 14: Clear the cache after any tensor deletions for garbage collection purposes:
163+ torch .cuda .empty_cache ()
161164
162- # Step 13: Undecorate all saved tensors and remove saved grad_fns .
163- _undecorate_all_saved_tensors (self )
165+ # Step 15: Log time elapsed .
166+ _log_time_elapsed (self )
164167
165- # Step 14: Clear the cache after any tensor deletions for garbage collection purposes:
166- torch . cuda . empty_cache ( )
168+ # Step 16: Populate ParamLog reverse mappings, linked params, num_passes, and gradient metadata.
169+ _finalize_param_logs ( self )
167170
168- # Step 15: Log time elapsed .
169- _log_time_elapsed (self )
171+ # Step 16.5: Build aggregate LayerLog objects from per-pass LayerPassLog entries .
172+ _build_layer_logs (self )
170173
171- # Step 16: Populate ParamLog reverse mappings, linked params, num_passes, and gradient metadata.
172- _finalize_param_logs (self )
173-
174- # Step 16.5: Build aggregate LayerLog objects from per-pass LayerPassLog entries.
175- _build_layer_logs (self )
174+ # Step 17: Build structured ModuleLog objects from raw module_* dicts.
175+ _build_module_logs (self )
176176
177- # Step 17: Build structured ModuleLog objects from raw module_* dicts.
178- _build_module_logs (self )
179-
180- # Step 18: log the pass as finished, changing the ModelLog behavior to its user-facing version.
181-
182- _set_pass_finished (self )
177+ # Step 18: log the pass as finished, changing the ModelLog behavior to its user-facing version.
178+ _set_pass_finished (self )
183179
184180
185181def postprocess_fast (self : "ModelLog" ) -> None :
@@ -202,6 +198,7 @@ def postprocess_fast(self: "ModelLog") -> None:
202198 - Step 17: _build_module_logs — module structure doesn't change between
203199 passes and _module_build_data isn't repopulated in fast mode (#108).
204200 """
201+ _vprint (self , "Fast-pass postprocessing..." )
205202 # Use layer_dict_main_keys to get LayerPassLog directly (not LayerLog)
206203 for output_layer_label in self .output_layers :
207204 output_layer = self .layer_dict_main_keys [output_layer_label ]
0 commit comments