1- """GPTQ (Goyal-Pham-Tan-Quant) implementation for LLM quantization ."""
1+ """GPTQ (Goyal-Pham-Tan-Quant) implementation."""
22
3- import math
43import gc
4+ import math
55import torch
66import torch .nn as nn
77from typing import Optional , Dict , Any , List , Union
88from transformers import PreTrainedModel
9- from .quantization_engine import QuantizationConfig , QuantizedLinear
9+ from .quantization_engine import QuantizationConfig , QuantizedLinear , DeviceManager
1010
1111class GPTQQuantizer :
1212 """GPTQ quantization implementation with memory-efficient processing."""
@@ -22,18 +22,22 @@ def __init__(
2222 percdamp : float = 0.01 ,
2323 sym : bool = True ,
2424 batch_size : int = 4 ,
25- cpu_offload : bool = False
25+ device : Optional [ Union [ str , torch . device ]] = None
2626 ):
2727 self .model = model
2828 self .bits = bits
2929 self .group_size = group_size
3030 self .actorder = actorder
3131 self .allow_mixed_bits = allow_mixed_bits
32- self .use_triton = use_triton
32+ self .use_triton = use_triton and torch . cuda . is_available ()
3333 self .percdamp = percdamp
3434 self .sym = sym
3535 self .batch_size = batch_size
36- self .cpu_offload = cpu_offload
36+
37+ # Initialize device manager
38+ self .device_manager = DeviceManager (
39+ torch .device (device ) if device else None
40+ )
3741
3842 # Initialize H matrices for each layer
3943 self .H = {}
@@ -43,6 +47,7 @@ def _clear_memory(self):
4347 gc .collect ()
4448 if torch .cuda .is_available ():
4549 torch .cuda .empty_cache ()
50+ self .device_manager .sync ()
4651
4752 def quantize (self , calibration_data : Optional [torch .Tensor ] = None ) -> PreTrainedModel :
4853 """
@@ -56,24 +61,25 @@ def quantize(self, calibration_data: Optional[torch.Tensor] = None) -> PreTraine
5661 """
5762 if calibration_data is None :
5863 raise ValueError ("GPTQ requires calibration data for quantization" )
59-
60- # Prepare model for quantization
64+
65+ device = self .device_manager .primary_device
66+ self .model .to (device )
6167 self .model .eval ()
6268
63- # Convert all linear layers to quantizable versions
69+ # Convert calibration data to correct device
70+ calibration_data = calibration_data .to (device )
71+
72+ # Process layers
6473 for name , module in self .model .named_modules ():
6574 if isinstance (module , nn .Linear ):
6675 print (f"Processing layer: { name } " )
6776
68- # Compute Hessian approximation for this layer
69- self .H [name ] = self ._compute_hessian (module , calibration_data )
70-
71- # Move Hessian to CPU if offloading is enabled
72- if self .cpu_offload :
73- self .H [name ] = self .H [name ].cpu ()
77+ # Ensure layer is on correct device
78+ module .to (device )
7479
75- # Convert to quantized layer
76- quantized = self ._quantize_layer (module , self .H [name ].to (module .weight .device ) if self .cpu_offload else self .H [name ])
80+ # Compute Hessian approximation
81+ self .H [name ] = self ._compute_hessian (module , calibration_data )
82+ quantized = self ._quantize_layer (module , self .H [name ])
7783
7884 # Replace layer in model
7985 parent_name = '.' .join (name .split ('.' )[:- 1 ])
@@ -87,65 +93,57 @@ def quantize(self, calibration_data: Optional[torch.Tensor] = None) -> PreTraine
8793 # Clear memory after processing each layer
8894 self ._clear_memory ()
8995
90- # Remove processed Hessian to free memory
96+ # Remove processed Hessian
9197 del self .H [name ]
9298
9399 return self .model
94-
100+
95101 def _compute_hessian (self , layer : nn .Linear , data : torch .Tensor ) -> torch .Tensor :
96- """Compute Hessian approximation for a layer with memory-efficient batch processing."""
97- device = next (layer .parameters ()).device
98-
99- # Initialize accumulator on CPU if offloading is enabled
102+ """Compute Hessian approximation for a layer with memory-efficient processing."""
103+ device = self .device_manager .primary_device
100104 n = layer .in_features
101- H = torch .zeros ((n , n ), device = 'cpu' if self . cpu_offload else device )
105+ H = torch .zeros ((n , n ), device = device )
102106
103107 def hook_fn (module , input , output ):
104108 x = input [0 ].detach ()
105- # Reshape input if needed (batch_size * seq_len, hidden_size)
106109 if len (x .shape ) == 3 :
107110 x = x .view (- 1 , x .size (- 1 ))
108111
109112 with torch .no_grad ():
110- # Process in smaller chunks to save memory
111- chunk_size = 1024 # Adjust based on available memory
113+ chunk_size = 1024
112114 num_chunks = math .ceil (x .size (0 ) / chunk_size )
113115
114116 for i in range (num_chunks ):
115117 chunk = x [i * chunk_size :(i + 1 ) * chunk_size ]
116- # Compute contribution to Hessian
117- if self .cpu_offload :
118- chunk_H = torch .matmul (chunk .t (), chunk ).cpu ()
119- else :
120- chunk_H = torch .matmul (chunk .t (), chunk )
118+ chunk_H = torch .matmul (chunk .t (), chunk )
121119 H .add_ (chunk_H )
122120
123- # Clear intermediate tensors
124121 del chunk_H
125- if i % 10 == 0 : # Periodic memory cleanup
122+ if i % 10 == 0 :
126123 self ._clear_memory ()
127124
128125 # Register forward hook
129126 handle = layer .register_forward_hook (hook_fn )
130127
131- # Run calibration data through model in batches
128+ # Process calibration data in batches
132129 with torch .no_grad ():
133130 for i in range (0 , len (data ), self .batch_size ):
134- batch = data [i :i + self .batch_size ]
131+ batch = data [i :i + self .batch_size ]. to ( device )
135132 self .model (batch )
136133
137- # Periodic memory cleanup
138134 if i % (self .batch_size * 10 ) == 0 :
139135 self ._clear_memory ()
140136
141- # Remove hook
142137 handle .remove ()
143-
144138 return H
145139
146140 def _quantize_layer (self , layer : nn .Linear , H : torch .Tensor ) -> QuantizedLinear :
147- """Quantize a single layer using GPTQ."""
148- device = next (layer .parameters ()).device
141+ """Quantize a single layer using GPTQ with memory management."""
142+ device = self .device_manager .primary_device
143+
144+ # Ensure tensors are on the correct device
145+ H = H .to (device )
146+ W = layer .weight .data .to (device )
149147
150148 # Initialize quantized layer
151149 quantized = QuantizedLinear (
@@ -155,65 +153,42 @@ def _quantize_layer(self, layer: nn.Linear, H: torch.Tensor) -> QuantizedLinear:
155153 config = QuantizationConfig (
156154 bits = self .bits ,
157155 scheme = "symmetric" if self .sym else "asymmetric" ,
158- granularity = "per-tensor" ,
159- calibration = "minmax" ,
160- channel_wise = False ,
161- dtype = f"{ 'u' if not self .sym else '' } int{ self .bits } " ,
162- format = "gptq"
156+ granularity = "per-channel" ,
157+ calibration = "gptq"
163158 )
164- )
159+ ). to ( device )
165160
166- # Copy bias if exists
167161 if layer .bias is not None :
168162 quantized .bias .data .copy_ (layer .bias .data )
169163
170- # Get weight matrix
171- W = layer .weight .data .clone ()
172-
173- # Compute optimal scales and zero points
174- if self .group_size > 0 :
175- n_groups = W .shape [0 ] // self .group_size
176- W_groups = W .view (n_groups , self .group_size , - 1 )
177- scales = []
178- zero_points = []
164+ # Process in chunks to save memory
165+ chunk_size = min (1024 , layer .out_features )
166+ for i in range (0 , layer .out_features , chunk_size ):
167+ chunk_end = min (i + chunk_size , layer .out_features )
168+ W_chunk = W [i :chunk_end ]
179169
180- for idx in range (n_groups ):
181- group = W_groups [idx ]
182- if self .sym :
183- scale = (2 ** (self .bits - 1 ) - 1 ) / torch .max (torch .abs (group ))
184- zero_point = 0
185- else :
186- min_val = torch .min (group )
187- max_val = torch .max (group )
188- scale = (2 ** self .bits - 1 ) / (max_val - min_val )
189- zero_point = - min_val * scale
190-
191- scales .append (scale )
192- zero_points .append (zero_point )
193-
194- scales = torch .stack (scales )
195- zero_points = torch .stack (zero_points )
196- else :
170+ # Compute optimal scaling factors for this chunk
197171 if self .sym :
198- scales = ( 2 ** ( self . bits - 1 ) - 1 ) / torch .max (torch . abs ( W ), dim = 1 )[0 ]
199- zero_points = torch . zeros_like ( scales )
172+ max_val = W_chunk . abs () .max (dim = 1 )[0 ]
173+ scale = ( 2 ** ( self . bits - 1 ) - 1 ) / max_val
200174 else :
201- min_vals = torch .min (W , dim = 1 )[0 ]
202- max_vals = torch .max (W , dim = 1 )[0 ]
203- scales = (2 ** self .bits - 1 ) / (max_vals - min_vals )
204- zero_points = - min_vals * scales
205-
206- # Quantize weights
207- W_quant = torch .round (W * scales .view (- 1 , 1 ) - zero_points .view (- 1 , 1 ))
208-
209- # Apply GPTQ optimization
210- recon_loss = torch .sum ((W - (W_quant + zero_points .view (- 1 , 1 )) / scales .view (- 1 , 1 )).pow (2 ))
211- if H is not None :
212- recon_loss = recon_loss * torch .trace (H )
213-
214- # Store quantized weights and parameters
215- quantized .weight_quantized .copy_ (W_quant .to (torch .int8 ))
216- quantized .weight_scale .copy_ (1.0 / scales )
217- quantized .weight_zero_point .copy_ (zero_points )
175+ min_val = W_chunk .min (dim = 1 )[0 ]
176+ max_val = W_chunk .max (dim = 1 )[0 ]
177+ scale = (2 ** self .bits - 1 ) / (max_val - min_val )
178+
179+ # Quantize chunk
180+ W_quant = torch .round (W_chunk * scale .unsqueeze (1 ))
181+ W_quant = torch .clamp (
182+ W_quant ,
183+ - (2 ** (self .bits - 1 )),
184+ 2 ** (self .bits - 1 ) - 1
185+ )
186+
187+ # Store quantized weights and scale
188+ quantized .weight_quantized .data [i :chunk_end ] = W_quant .to (torch .int8 )
189+ quantized .weight_scale .data [i :chunk_end ] = 1.0 / scale
190+
191+ del W_chunk , W_quant
192+ self ._clear_memory ()
218193
219194 return quantized
0 commit comments