@@ -90,7 +90,30 @@ def quantize(
9090 self .logger .log_info (f"Processing layer: { name } " )
9191
9292 # Get activation scale for this layer
93- act_scale = self .act_scales .get (name )
93+ act_scale_list_or_tensor = self .act_scales .get (name )
94+
95+ if act_scale_list_or_tensor is not None :
96+ if isinstance (act_scale_list_or_tensor , list ):
97+ if all (isinstance (t , torch .Tensor ) for t in act_scale_list_or_tensor ):
98+ # Average the list of tensors
99+ act_scale = torch .stack (act_scale_list_or_tensor ).mean (dim = 0 )
100+ else :
101+ # Handle unexpected content in the list
102+ self .logger .log_error (f"Activation scales for { name } contain non-tensor elements. Quantization may be incorrect." )
103+ # Fallback: attempt to use the list directly if _quantize_layer can handle it, or create a default
104+ # For safety, creating a default scale here.
105+ act_scale = torch .ones (module .in_features , device = self .device_manager .primary_device )
106+ elif isinstance (act_scale_list_or_tensor , torch .Tensor ):
107+ # If it's already a tensor (e.g., if averaging was done elsewhere or only one batch)
108+ act_scale = act_scale_list_or_tensor
109+ else :
110+ self .logger .log_error (f"Unexpected type for activation scales of { name } : { type (act_scale_list_or_tensor )} . Using default." )
111+ act_scale = torch .ones (module .in_features , device = self .device_manager .primary_device )
112+ else :
113+ self .logger .log_warning (f"No activation scales found for { name } . Using default scale of 1.0." )
114+ # module.in_features should correspond to the expected dimension of the scale
115+ act_scale = torch .ones (module .in_features , device = self .device_manager .primary_device )
116+
94117 quantized = self ._quantize_layer (module , act_scale )
95118
96119 # Replace layer in model
@@ -135,10 +158,13 @@ def fn(module, input, output):
135158 # Handle both 2D and 3D inputs
136159 if len (x .shape ) == 3 :
137160 # For 3D input (batch_size, seq_len, hidden_size)
138- scale = torch .max (torch .abs (x .view (- 1 , x .size (- 1 ))))
161+ # Compute scales per hidden channel: (hidden_size,)
162+ scale = torch .amax (torch .abs (x ), dim = [0 , 1 ])
139163 else :
140- scale = torch .max (torch .abs (x ))
141- # Store scale in our temporary dictionary
164+ # For 2D input (batch_size, hidden_size)
165+ # Compute scales per hidden channel: (hidden_size,)
166+ scale = torch .amax (torch .abs (x ), dim = 0 )
167+ # Store scale tensor (moved to CPU) in our temporary dictionary
142168 batch_scales [name ].append (scale .cpu ())
143169 return fn
144170
@@ -150,6 +176,7 @@ def fn(module, input, output):
150176 with torch .no_grad ():
151177 data_on_device = move_to_device (data , self .device_manager .primary_device )
152178 self .model (data_on_device )
179+ del data_on_device # Free memory after forward pass
153180
154181 # Remove hooks
155182 for handle in handles :
@@ -158,12 +185,12 @@ def fn(module, input, output):
158185 # Process the collected scales
159186 for name in batch_scales :
160187 if batch_scales [name ]: # If we collected any scales for this layer
161- scales_tensor = torch .stack (batch_scales [name ])
162- # If this is the first batch
188+ # If this is the first batch for this layer
163189 if name not in self .act_scales :
164190 self .act_scales [name ] = []
165- # Add the processed scales to our main storage
166- self .act_scales [name ].extend ([s .item () for s in scales_tensor ])
191+ # Extend the list of scale tensors for this layer
192+ # batch_scales[name] already contains CPU tensors
193+ self .act_scales [name ].extend (batch_scales [name ])
167194
168195 # Clean up
169196 del batch_scales
@@ -206,13 +233,43 @@ def _quantize_layer(
206233
207234 # Ensure act_scale is on the same device as W before division
208235 act_scale_on_device = move_to_device (act_scale , W .device )
209- W = W / act_scale_on_device .view (1 , - 1 )
236+
237+ try :
238+ W = W / act_scale_on_device .view (1 , - 1 )
239+ except RuntimeError as e :
240+ error_message = (
241+ f"Failed to scale weights with activation scales in _quantize_layer.\n "
242+ f" Weight (W) shape: { W .shape } \n "
243+ f" Activation scale (act_scale_on_device) shape: { act_scale_on_device .shape } \n "
244+ f" Original error: { str (e )} "
245+ )
246+ self .logger .log_error (error_message )
247+ raise RuntimeError (error_message ) from e
210248
211249 # Compute quantization scales per group
212250 # All computations for scales and zero_points should happen on target_device
213251 if self .group_size > 0 :
252+ if W .shape [0 ] % self .group_size != 0 :
253+ error_message = (
254+ f"Weight dimension { W .shape [0 ]} is not divisible by group_size { self .group_size } "
255+ f"in _quantize_layer for layer being processed."
256+ )
257+ self .logger .log_error (error_message )
258+ raise ValueError (error_message ) # ValueError is more appropriate here
259+
214260 n_groups = W .shape [0 ] // self .group_size
215- W_groups = W .view (n_groups , self .group_size , - 1 )
261+ try :
262+ W_groups = W .view (n_groups , self .group_size , - 1 )
263+ except RuntimeError as e :
264+ error_message = (
265+ f"Failed to create view for grouped weights in _quantize_layer.\n "
266+ f" Weight (W) shape: { W .shape } \n "
267+ f" Calculated n_groups: { n_groups } \n "
268+ f" Group size: { self .group_size } \n "
269+ f" Original error: { str (e )} "
270+ )
271+ self .logger .log_error (error_message )
272+ raise RuntimeError (error_message ) from e
216273
217274 scales_list = [] # Renamed from scales to scales_list
218275 zero_points_list = [] if self .zero_point else None # Renamed
@@ -246,12 +303,14 @@ def _quantize_layer(
246303 # W, scales, zero_points are on target_device
247304 W_quant = torch .round (W * scales .view (- 1 , 1 ) - zero_points .view (- 1 , 1 ))
248305 W_quant = W_quant .to (torch .int8 ) # Cast to int8
306+ del W # Free memory for W as it's no longer needed
249307
250308 # Store quantized weights and parameters
251309 # quantized module and its buffers are already on target_device
252310 quantized .weight_quantized .copy_ (W_quant ) # W_quant is already on target_device and int8
253311 quantized .weight_scale .copy_ (1.0 / scales ) # scales is on target_device
254312 quantized .weight_zero_point .copy_ (zero_points ) # zero_points is on target_device
313+ del scales , zero_points # Free memory for scales and zero_points
255314
256315 # Store additional AWQ-specific information
257316 # Ensure act_scale is on the same device as the quantized layer's parameters
0 commit comments