1- """AWQ (Activation-Aware Weight Quantization) implementation with memory-efficient processing ."""
1+ """AWQ (Activation-Aware Weight Quantization) implementation."""
22
33import gc
44import torch
55import torch .nn as nn
66import numpy as np
77from typing import Optional , Dict , Any , List , Union , Tuple
88from transformers import PreTrainedModel
9- from .quantization_engine import QuantizationConfig , QuantizedLinear
9+ from .quantization_engine import BaseQuantizer , QuantizationConfig , QuantizedLinear
1010
11- class AWQQuantizer :
12- """AWQ quantization implementation with memory-efficient processing."""
11+ class AWQQuantizer (BaseQuantizer ):
1312 """AWQ quantization implementation with memory-efficient processing."""
1413
1514 def __init__ (
@@ -21,24 +20,21 @@ def __init__(
2120 scale_dtype : str = "fp32" ,
2221 version : str = "v2" ,
2322 enable_mnn_kernel : bool = False ,
24- batch_size : int = 2 , # Small batch size for memory efficiency
25- cpu_offload : bool = True # Enable CPU offloading
23+ batch_size : int = 2 ,
24+ device : Optional [ Union [ str , torch . device ]] = None
2625 ):
27- self .model = model
28- self .bits = bits
26+ super ().__init__ (model = model , bits = bits , device = device )
2927 self .group_size = group_size
3028 self .zero_point = zero_point
3129 self .scale_dtype = scale_dtype
3230 self .version = version
3331 self .enable_mnn_kernel = enable_mnn_kernel
32+ self .batch_size = batch_size
3433
3534 # Initialize activation statistics dictionaries
3635 self .act_scales = {}
3736 self .weight_scales = {}
3837
39- self .batch_size = batch_size
40- self .cpu_offload = cpu_offload
41-
4238 def _clear_memory (self ):
4339 """Clear GPU memory and run garbage collection."""
4440 if torch .cuda .is_available ():
@@ -54,8 +50,8 @@ def quantize(
5450 if calibration_data is None :
5551 raise ValueError ("AWQ requires calibration data for quantization" )
5652
57- # Keep model on CPU initially
58- self .model . cpu ( )
53+ # Prepare calibration data
54+ calibration_data = self .prepare_calibration_data ( calibration_data )
5955 self .model .eval ()
6056
6157 # Process calibration data in batches
@@ -68,21 +64,9 @@ def quantize(
6864 end_idx = min (step + self .batch_size , total_steps )
6965 batch = calibration_data [step :end_idx ]
7066
71- # Move batch to appropriate device
72- device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
73- batch = batch .to (device )
74-
75- # Temporarily move model to device
76- if device .type == "cuda" :
77- self .model = self .model .cuda ()
78-
7967 # Collect statistics for this batch
8068 self ._collect_activation_stats (batch )
8169
82- # Move model back to CPU if offloading enabled
83- if self .cpu_offload and device .type == "cuda" :
84- self .model = self .model .cpu ()
85-
8670 # Clean up batch
8771 del batch
8872 self ._clear_memory ()
@@ -93,35 +77,25 @@ def quantize(
9377 # Quantize the model layer by layer
9478 for name , module in self .model .named_modules ():
9579 if isinstance (module , nn .Linear ):
96- # Move layer to device temporarily for quantization
97- if device .type == "cuda" :
98- module = module .cuda ()
99-
80+ self .logger .info (f"Processing layer: { name } " )
81+
10082 # Get activation scale for this layer
10183 act_scale = self .act_scales .get (name )
102- if act_scale is not None :
103- # Quantize layer
104- quantized = self ._quantize_layer (module , act_scale )
105-
106- # Move quantized layer back to CPU if offloading
107- if self .cpu_offload :
108- quantized = quantized .cpu ()
109-
110- # Replace layer in model
111- parent_name = '.' .join (name .split ('.' )[:- 1 ])
112- child_name = name .split ('.' )[- 1 ]
113-
114- if parent_name :
115- parent = self .model .get_submodule (parent_name )
116- setattr (parent , child_name , quantized )
117- else :
118- setattr (self .model , name , quantized )
119-
120- # Clean up
121- self ._clear_memory ()
84+ quantized = self ._quantize_layer (module , act_scale )
12285
86+ # Replace layer in model
87+ parent_name = '.' .join (name .split ('.' )[:- 1 ])
88+ child_name = name .split ('.' )[- 1 ]
89+ if parent_name :
90+ parent = self .model .get_submodule (parent_name )
91+ setattr (parent , child_name , quantized )
92+ else :
93+ setattr (self .model , name , quantized )
94+
95+ self ._clear_memory ()
96+
12397 return self .model
124- def _collect_activation_stats (
98+ def _collect_activation_stats (
12599 self ,
126100 data : torch .Tensor ,
127101 num_steps : int
0 commit comments