@@ -116,76 +116,58 @@ def quantize(
116116
117117 return self .model
118118
119- def _collect_activation_stats (
120- self ,
121- data : torch .Tensor # Removed num_steps parameter
122- ):
123- """Collect activation statistics for each layer."""
124-
125- # Register hooks for all linear layers
126- handles = []
127- for name , module in self .model .named_modules ():
128- if isinstance (module , nn .Linear ):
129- def hook_fn (name ):
130- def fn (module , input , output ):
131- if name not in self .act_scales :
132- self .act_scales [name ] = []
133- x = input [0 ].detach ()
134- # Handle both 2D and 3D inputs
135- if len (x .shape ) == 3 :
136- # For 3D input (batch_size, seq_len, hidden_size)
137- scale = torch .max (torch .abs (x .view (- 1 , x .size (- 1 ))))
138- else :
139- scale = torch .max (torch .abs (x ))
140- self .act_scales [name ].append (scale .cpu ()) # Move to CPU to save memory
141- return fn
142-
143- handles .append (
144- module .register_forward_hook (hook_fn (name ))
145- )
146-
147- # Run calibration (forward pass on the provided data batch)
148- with torch .no_grad ():
149- # Ensure data is on the primary device for model processing
150- data_on_device = move_to_device (data , self .device_manager .primary_device )
151- self .model (data_on_device )
152- # Data can be moved back to CPU if it's large and memory is a concern,
153- # but hooks should have already captured necessary info to CPU.
154- # For simplicity here, we assume hooks manage CPU transfer if needed.
155- # del data_on_device # Optionally delete if memory is very tight
156-
157- # Remove hooks
158- for handle in handles :
159- handle .remove ()
160-
161- # model is already on self.device_manager.primary_device from the quantize method's perspective
162- # or moved by prepare_calibration_data.
163- # The processing of act_scales should happen after all batches are processed.
164- # However, the current structure calls this per batch.
165- # For now, let's keep the quantile calculation here, but ideally, it would be after the main loop in `quantize`.
166- # To avoid issues with model device, let's ensure model is on CPU for this CPU-bound operation,
167- # then move it back if it was on GPU.
168-
169- original_model_device = self .model .device # Store original device
170- self .model = move_to_device (self .model , torch .device ('cpu' ))
171- self ._clear_memory ()
172-
173- # Process collected statistics on CPU
174- for name in self .act_scales :
175- if self .act_scales [name ]: # Ensure list is not empty
176- scales_list = self .act_scales [name ]
177- # If scales_list contains tensors that are not on CPU, move them.
178- # Assuming they are already on CPU due to `scale.cpu()` in hook.
179- scales_tensor = torch .stack (scales_list )
180- self .act_scales [name ] = torch .quantile (scales_tensor , 0.999 )
181- else :
182- # Handle cases where a layer might not have collected scales (e.g. not used in forward pass)
183- self .logger .log_warning (f"No activation scales collected for layer { name } . Using default scale of 1.0." )
184- self .act_scales [name ] = torch .tensor (1.0 , device = 'cpu' ) # Default to a CPU tensor
185-
186- # Restore model to its original device
187- self .model = move_to_device (self .model , original_model_device )
188- # The duplicated block of "Process collected statistics" is now removed.
119+ def _collect_activation_stats (self , data : torch .Tensor ):
120+ """Collect activation statistics for each layer."""
121+ # Store temporary scales for this batch
122+ batch_scales = {}
123+
124+ # Register hooks for all linear layers
125+ handles = []
126+ for name , module in self .model .named_modules ():
127+ if isinstance (module , nn .Linear ):
128+ def hook_fn (name ):
129+ def fn (module , input , output ):
130+ # Initialize the list for this layer if not exists
131+ if name not in batch_scales :
132+ batch_scales [name ] = []
133+
134+ x = input [0 ].detach ()
135+ # Handle both 2D and 3D inputs
136+ if len (x .shape ) == 3 :
137+ # For 3D input (batch_size, seq_len, hidden_size)
138+ scale = torch .max (torch .abs (x .view (- 1 , x .size (- 1 ))))
139+ else :
140+ scale = torch .max (torch .abs (x ))
141+ # Store scale in our temporary dictionary
142+ batch_scales [name ].append (scale .cpu ())
143+ return fn
144+
145+ handles .append (
146+ module .register_forward_hook (hook_fn (name ))
147+ )
148+
149+ # Run calibration (forward pass on the provided data batch)
150+ with torch .no_grad ():
151+ data_on_device = move_to_device (data , self .device_manager .primary_device )
152+ self .model (data_on_device )
153+
154+ # Remove hooks
155+ for handle in handles :
156+ handle .remove ()
157+
158+ # Process the collected scales
159+ for name in batch_scales :
160+ if batch_scales [name ]: # If we collected any scales for this layer
161+ scales_tensor = torch .stack (batch_scales [name ])
162+ # If this is the first batch
163+ if name not in self .act_scales :
164+ self .act_scales [name ] = []
165+ # Add the processed scales to our main storage
166+ self .act_scales [name ].extend ([s .item () for s in scales_tensor ])
167+
168+ # Clean up
169+ del batch_scales
170+ self ._clear_memory ()
189171
190172 def _quantize_layer (
191173 self ,
0 commit comments