Skip to content

Commit d1a6dab

Browse files
Add the Quantizations Methods.
1 parent e15d944 commit d1a6dab

File tree

6 files changed

+754
-446
lines changed

6 files changed

+754
-446
lines changed

quantllm/quant/awq.py

Lines changed: 76 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
"""AWQ (Activation-Aware Weight Quantization) implementation for LLM quantization."""
1+
"""AWQ (Activation-Aware Weight Quantization) implementation with memory-efficient processing."""
22

3+
import gc
34
import torch
45
import torch.nn as nn
56
import numpy as np
@@ -8,7 +9,8 @@
89
from .quantization_engine import QuantizationConfig, QuantizedLinear
910

1011
class AWQQuantizer:
11-
"""AWQ quantization implementation."""
12+
"""AWQ quantization implementation with memory-efficient processing."""
13+
"""AWQ quantization implementation with memory-efficient processing."""
1214

1315
def __init__(
1416
self,
@@ -18,7 +20,9 @@ def __init__(
1820
zero_point: bool = True,
1921
scale_dtype: str = "fp32",
2022
version: str = "v2",
21-
enable_mnn_kernel: bool = False
23+
enable_mnn_kernel: bool = False,
24+
batch_size: int = 2, # Small batch size for memory efficiency
25+
cpu_offload: bool = True # Enable CPU offloading
2226
):
2327
self.model = model
2428
self.bits = bits
@@ -32,50 +36,90 @@ def __init__(
3236
self.act_scales = {}
3337
self.weight_scales = {}
3438

39+
self.batch_size = batch_size
40+
self.cpu_offload = cpu_offload
41+
42+
def _clear_memory(self):
43+
"""Clear GPU memory and run garbage collection."""
44+
if torch.cuda.is_available():
45+
torch.cuda.empty_cache()
46+
gc.collect()
47+
3548
def quantize(
3649
self,
3750
calibration_data: Optional[torch.Tensor] = None,
3851
calibration_steps: int = 100
3952
) -> PreTrainedModel:
40-
"""
41-
Quantize model using AWQ algorithm.
42-
43-
Args:
44-
calibration_data: Data used for computing activation statistics
45-
calibration_steps: Number of steps for calibration
46-
47-
Returns:
48-
Quantized model
49-
"""
53+
"""Memory-efficient quantization using AWQ algorithm."""
5054
if calibration_data is None:
5155
raise ValueError("AWQ requires calibration data for quantization")
52-
53-
# Prepare model for quantization
56+
57+
# Keep model on CPU initially
58+
self.model.cpu()
5459
self.model.eval()
5560

56-
# Collect activation statistics
57-
self._collect_activation_stats(calibration_data, calibration_steps)
61+
# Process calibration data in batches
62+
total_steps = min(calibration_steps, len(calibration_data))
63+
for step in range(0, total_steps, self.batch_size):
64+
# Clear memory before processing batch
65+
self._clear_memory()
66+
67+
# Get batch
68+
end_idx = min(step + self.batch_size, total_steps)
69+
batch = calibration_data[step:end_idx]
70+
71+
# Move batch to appropriate device
72+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73+
batch = batch.to(device)
74+
75+
# Temporarily move model to device
76+
if device.type == "cuda":
77+
self.model = self.model.cuda()
78+
79+
# Collect statistics for this batch
80+
self._collect_activation_stats(batch)
81+
82+
# Move model back to CPU if offloading enabled
83+
if self.cpu_offload and device.type == "cuda":
84+
self.model = self.model.cpu()
85+
86+
# Clean up batch
87+
del batch
88+
self._clear_memory()
89+
90+
# Process collected statistics
91+
self._process_activation_stats()
5892

59-
# Convert linear layers to quantized versions
93+
# Quantize the model layer by layer
6094
for name, module in self.model.named_modules():
6195
if isinstance(module, nn.Linear):
96+
# Move layer to device temporarily for quantization
97+
if device.type == "cuda":
98+
module = module.cuda()
99+
62100
# Get activation scale for this layer
63-
act_scale = self.act_scales.get(name, None)
64-
if act_scale is None:
65-
continue
101+
act_scale = self.act_scales.get(name)
102+
if act_scale is not None:
103+
# Quantize layer
104+
quantized = self._quantize_layer(module, act_scale)
66105

67-
# Convert to quantized layer
68-
quantized = self._quantize_layer(module, act_scale)
106+
# Move quantized layer back to CPU if offloading
107+
if self.cpu_offload:
108+
quantized = quantized.cpu()
109+
110+
# Replace layer in model
111+
parent_name = '.'.join(name.split('.')[:-1])
112+
child_name = name.split('.')[-1]
113+
114+
if parent_name:
115+
parent = self.model.get_submodule(parent_name)
116+
setattr(parent, child_name, quantized)
117+
else:
118+
setattr(self.model, name, quantized)
119+
120+
# Clean up
121+
self._clear_memory()
69122

70-
# Replace layer in model
71-
parent_name = '.'.join(name.split('.')[:-1])
72-
child_name = name.split('.')[-1]
73-
if parent_name:
74-
parent = self.model.get_submodule(parent_name)
75-
setattr(parent, child_name, quantized)
76-
else:
77-
setattr(self.model, name, quantized)
78-
79123
return self.model
80124
def _collect_activation_stats(
81125
self,

0 commit comments

Comments
 (0)