1- """AWQ (Activation-Aware Weight Quantization) implementation for LLM quantization ."""
1+ """AWQ (Activation-Aware Weight Quantization) implementation with memory-efficient processing ."""
22
3+ import gc
34import torch
45import torch .nn as nn
56import numpy as np
89from .quantization_engine import QuantizationConfig , QuantizedLinear
910
1011class AWQQuantizer :
11- """AWQ quantization implementation."""
12+ """AWQ quantization implementation with memory-efficient processing."""
13+ """AWQ quantization implementation with memory-efficient processing."""
1214
1315 def __init__ (
1416 self ,
@@ -18,7 +20,9 @@ def __init__(
1820 zero_point : bool = True ,
1921 scale_dtype : str = "fp32" ,
2022 version : str = "v2" ,
21- enable_mnn_kernel : bool = False
23+ enable_mnn_kernel : bool = False ,
24+ batch_size : int = 2 , # Small batch size for memory efficiency
25+ cpu_offload : bool = True # Enable CPU offloading
2226 ):
2327 self .model = model
2428 self .bits = bits
@@ -32,50 +36,90 @@ def __init__(
3236 self .act_scales = {}
3337 self .weight_scales = {}
3438
39+ self .batch_size = batch_size
40+ self .cpu_offload = cpu_offload
41+
42+ def _clear_memory (self ):
43+ """Clear GPU memory and run garbage collection."""
44+ if torch .cuda .is_available ():
45+ torch .cuda .empty_cache ()
46+ gc .collect ()
47+
3548 def quantize (
3649 self ,
3750 calibration_data : Optional [torch .Tensor ] = None ,
3851 calibration_steps : int = 100
3952 ) -> PreTrainedModel :
40- """
41- Quantize model using AWQ algorithm.
42-
43- Args:
44- calibration_data: Data used for computing activation statistics
45- calibration_steps: Number of steps for calibration
46-
47- Returns:
48- Quantized model
49- """
53+ """Memory-efficient quantization using AWQ algorithm."""
5054 if calibration_data is None :
5155 raise ValueError ("AWQ requires calibration data for quantization" )
52-
53- # Prepare model for quantization
56+
57+ # Keep model on CPU initially
58+ self .model .cpu ()
5459 self .model .eval ()
5560
56- # Collect activation statistics
57- self ._collect_activation_stats (calibration_data , calibration_steps )
61+ # Process calibration data in batches
62+ total_steps = min (calibration_steps , len (calibration_data ))
63+ for step in range (0 , total_steps , self .batch_size ):
64+ # Clear memory before processing batch
65+ self ._clear_memory ()
66+
67+ # Get batch
68+ end_idx = min (step + self .batch_size , total_steps )
69+ batch = calibration_data [step :end_idx ]
70+
71+ # Move batch to appropriate device
72+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
73+ batch = batch .to (device )
74+
75+ # Temporarily move model to device
76+ if device .type == "cuda" :
77+ self .model = self .model .cuda ()
78+
79+ # Collect statistics for this batch
80+ self ._collect_activation_stats (batch )
81+
82+ # Move model back to CPU if offloading enabled
83+ if self .cpu_offload and device .type == "cuda" :
84+ self .model = self .model .cpu ()
85+
86+ # Clean up batch
87+ del batch
88+ self ._clear_memory ()
89+
90+ # Process collected statistics
91+ self ._process_activation_stats ()
5892
59- # Convert linear layers to quantized versions
93+ # Quantize the model layer by layer
6094 for name , module in self .model .named_modules ():
6195 if isinstance (module , nn .Linear ):
96+ # Move layer to device temporarily for quantization
97+ if device .type == "cuda" :
98+ module = module .cuda ()
99+
62100 # Get activation scale for this layer
63- act_scale = self .act_scales .get (name , None )
64- if act_scale is None :
65- continue
101+ act_scale = self .act_scales .get (name )
102+ if act_scale is not None :
103+ # Quantize layer
104+ quantized = self ._quantize_layer (module , act_scale )
66105
67- # Convert to quantized layer
68- quantized = self ._quantize_layer (module , act_scale )
106+ # Move quantized layer back to CPU if offloading
107+ if self .cpu_offload :
108+ quantized = quantized .cpu ()
109+
110+ # Replace layer in model
111+ parent_name = '.' .join (name .split ('.' )[:- 1 ])
112+ child_name = name .split ('.' )[- 1 ]
113+
114+ if parent_name :
115+ parent = self .model .get_submodule (parent_name )
116+ setattr (parent , child_name , quantized )
117+ else :
118+ setattr (self .model , name , quantized )
119+
120+ # Clean up
121+ self ._clear_memory ()
69122
70- # Replace layer in model
71- parent_name = '.' .join (name .split ('.' )[:- 1 ])
72- child_name = name .split ('.' )[- 1 ]
73- if parent_name :
74- parent = self .model .get_submodule (parent_name )
75- setattr (parent , child_name , quantized )
76- else :
77- setattr (self .model , name , quantized )
78-
79123 return self .model
80124 def _collect_activation_stats (
81125 self ,
0 commit comments