Skip to content

Commit 54b07e7

Browse files
Add the Quantizations Methods.
1 parent bb57f2e commit 54b07e7

File tree

7 files changed

+992
-2
lines changed

7 files changed

+992
-2
lines changed
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
Quantization Methods
2+
==================
3+
4+
QuantLLM provides three primary methods for model quantization, each with its own advantages:
5+
6+
1. GPTQ (Goyal-Pham-Tan-Quant)
7+
---------------------------------
8+
9+
GPTQ offers Hessian-based quantization with activation ordering for high accuracy:
10+
11+
.. code-block:: python
12+
13+
from quantllm.quant import GPTQQuantizer
14+
15+
# Initialize quantizer
16+
quantizer = GPTQQuantizer(
17+
model=model,
18+
bits=4, # Quantization bits (2-8)
19+
group_size=128, # Size of quantization groups
20+
actorder=True, # Enable activation ordering
21+
use_triton=True # Use Triton kernels for acceleration
22+
)
23+
24+
# Quantize model
25+
quantized_model = quantizer.quantize(calibration_data=calibration_data)
26+
27+
2. AWQ (Activation-Aware Weight Quantization)
28+
-------------------------------------------
29+
30+
AWQ adapts quantization based on activation patterns:
31+
32+
.. code-block:: python
33+
34+
from quantllm.quant import AWQQuantizer
35+
36+
quantizer = AWQQuantizer(
37+
model=model,
38+
bits=4, # Quantization bits
39+
group_size=128, # Group size for quantization
40+
zero_point=True, # Enable zero point computation
41+
version="v2" # AWQ version
42+
)
43+
44+
# Quantize with activation statistics
45+
quantized_model = quantizer.quantize(
46+
calibration_data=calibration_data,
47+
calibration_steps=100
48+
)
49+
50+
3. GGUF (GGML Universal Format)
51+
-----------------------------
52+
53+
GGUF provides an efficient format with CTransformers integration:
54+
55+
.. code-block:: python
56+
57+
from quantllm.quant import GGUFQuantizer
58+
59+
quantizer = GGUFQuantizer(
60+
model=model,
61+
bits=4, # Quantization bits
62+
group_size=32, # Group size
63+
use_packed=True # Enable weight packing
64+
)
65+
66+
# Quantize model
67+
quantized_model = quantizer.quantize()
68+
69+
# Export to GGUF format
70+
quantizer.convert_to_gguf("model-q4.gguf")
71+
72+
Choosing the Right Method
73+
------------------------
74+
75+
- **GPTQ**: Best for highest accuracy with slightly slower quantization
76+
- **AWQ**: Best balance of speed and accuracy, good for general use
77+
- **GGUF**: Best for deployment and inference with CTransformers
78+
79+
Resource Requirements
80+
------------------
81+
82+
+-------------+------------+-------------+------------+
83+
| Method | Memory | Speed | Accuracy |
84+
+=============+============+=============+============+
85+
| GPTQ | High | Slow | Highest |
86+
+-------------+------------+-------------+------------+
87+
| AWQ | Medium | Fast | High |
88+
+-------------+------------+-------------+------------+
89+
| GGUF | Low | Very Fast | Good |
90+
+-------------+------------+-------------+------------+
91+
92+
Common Parameters
93+
---------------
94+
95+
All quantizers support these common parameters:
96+
97+
- **bits**: Number of quantization bits (2-8)
98+
- **group_size**: Size of quantization groups
99+
- **calibration_data**: Data used for computing statistics
100+
101+
Example Workflow
102+
--------------
103+
104+
Here's a complete example of quantizing a model:
105+
106+
.. code-block:: python
107+
108+
import torch
109+
from quantllm import Model, ModelConfig
110+
from quantllm.quant import AWQQuantizer
111+
112+
# 1. Load model
113+
model_config = ModelConfig(model_name="facebook/opt-350m")
114+
model = Model(model_config).get_model()
115+
116+
# 2. Prepare calibration data
117+
calibration_data = prepare_calibration_data() # Your calibration data
118+
119+
# 3. Initialize quantizer
120+
quantizer = AWQQuantizer(
121+
model=model,
122+
bits=4,
123+
group_size=128
124+
)
125+
126+
# 4. Quantize model
127+
quantized_model = quantizer.quantize(
128+
calibration_data=calibration_data,
129+
calibration_steps=100
130+
)
131+
132+
# 5. Use the quantized model
133+
inputs = tokenizer("Hello, world!", return_tensors="pt")
134+
outputs = quantized_model(**inputs)
135+
136+
For more detailed examples, see the `examples/quantization_examples.py` file in the repository.

quantllm/quant/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""Quantization functionality for LLMs."""
2+
3+
from .quantization_engine import (
4+
QuantizationConfig,
5+
QuantizedLinear,
6+
QuantizationEngine
7+
)
8+
from .gptq import GPTQQuantizer
9+
from .awq import AWQQuantizer
10+
from .gguf import GGUFQuantizer
11+
12+
__all__ = [
13+
"QuantizationConfig",
14+
"QuantizedLinear",
15+
"QuantizationEngine",
16+
"GPTQQuantizer",
17+
"AWQQuantizer",
18+
"GGUFQuantizer"
19+
]

quantllm/quant/awq.py

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
"""AWQ (Activation-Aware Weight Quantization) implementation for LLM quantization."""
2+
3+
import torch
4+
import torch.nn as nn
5+
import numpy as np
6+
from typing import Optional, Dict, Any, List, Union, Tuple
7+
from transformers import PreTrainedModel
8+
from .quantization_engine import QuantizationConfig, QuantizedLinear
9+
10+
class AWQQuantizer:
11+
"""AWQ quantization implementation."""
12+
13+
def __init__(
14+
self,
15+
model: PreTrainedModel,
16+
bits: int = 4,
17+
group_size: int = 128,
18+
zero_point: bool = True,
19+
scale_dtype: str = "fp32",
20+
version: str = "v2",
21+
enable_mnn_kernel: bool = False
22+
):
23+
self.model = model
24+
self.bits = bits
25+
self.group_size = group_size
26+
self.zero_point = zero_point
27+
self.scale_dtype = scale_dtype
28+
self.version = version
29+
self.enable_mnn_kernel = enable_mnn_kernel
30+
31+
# Initialize activation statistics dictionaries
32+
self.act_scales = {}
33+
self.weight_scales = {}
34+
35+
def quantize(
36+
self,
37+
calibration_data: Optional[torch.Tensor] = None,
38+
calibration_steps: int = 100
39+
) -> PreTrainedModel:
40+
"""
41+
Quantize model using AWQ algorithm.
42+
43+
Args:
44+
calibration_data: Data used for computing activation statistics
45+
calibration_steps: Number of steps for calibration
46+
47+
Returns:
48+
Quantized model
49+
"""
50+
if calibration_data is None:
51+
raise ValueError("AWQ requires calibration data for quantization")
52+
53+
# Prepare model for quantization
54+
self.model.eval()
55+
56+
# Collect activation statistics
57+
self._collect_activation_stats(calibration_data, calibration_steps)
58+
59+
# Convert linear layers to quantized versions
60+
for name, module in self.model.named_modules():
61+
if isinstance(module, nn.Linear):
62+
# Get activation scale for this layer
63+
act_scale = self.act_scales.get(name, None)
64+
if act_scale is None:
65+
continue
66+
67+
# Convert to quantized layer
68+
quantized = self._quantize_layer(module, act_scale)
69+
70+
# Replace layer in model
71+
parent_name = '.'.join(name.split('.')[:-1])
72+
child_name = name.split('.')[-1]
73+
if parent_name:
74+
parent = self.model.get_submodule(parent_name)
75+
setattr(parent, child_name, quantized)
76+
else:
77+
setattr(self.model, name, quantized)
78+
79+
return self.model
80+
81+
def _collect_activation_stats(
82+
self,
83+
data: torch.Tensor,
84+
num_steps: int
85+
):
86+
"""Collect activation statistics for each layer."""
87+
88+
# Register hooks for all linear layers
89+
handles = []
90+
for name, module in self.model.named_modules():
91+
if isinstance(module, nn.Linear):
92+
def hook_fn(name):
93+
def fn(module, input, output):
94+
if name not in self.act_scales:
95+
self.act_scales[name] = []
96+
x = input[0].detach()
97+
scale = torch.max(torch.abs(x))
98+
self.act_scales[name].append(scale)
99+
return fn
100+
101+
handles.append(
102+
module.register_forward_hook(hook_fn(name))
103+
)
104+
105+
# Run calibration
106+
with torch.no_grad():
107+
for _ in range(num_steps):
108+
self.model(data)
109+
110+
# Remove hooks
111+
for handle in handles:
112+
handle.remove()
113+
114+
# Process collected statistics
115+
for name in self.act_scales:
116+
scales = torch.stack(self.act_scales[name])
117+
# Use 99.9th percentile for more robust statistics
118+
self.act_scales[name] = torch.quantile(scales, 0.999)
119+
120+
def _quantize_layer(
121+
self,
122+
layer: nn.Linear,
123+
act_scale: torch.Tensor
124+
) -> QuantizedLinear:
125+
"""Quantize a single layer using AWQ."""
126+
device = next(layer.parameters()).device
127+
128+
# Initialize quantized layer
129+
quantized = QuantizedLinear(
130+
layer.in_features,
131+
layer.out_features,
132+
bias=layer.bias is not None,
133+
config=QuantizationConfig(
134+
bits=self.bits,
135+
scheme="symmetric",
136+
granularity="per-channel" if self.group_size > 0 else "per-tensor",
137+
calibration="minmax",
138+
channel_wise=True,
139+
dtype=f"int{self.bits}",
140+
format="awq"
141+
)
142+
)
143+
144+
# Copy bias if exists
145+
if layer.bias is not None:
146+
quantized.bias.data.copy_(layer.bias.data)
147+
148+
# Get weight matrix
149+
W = layer.weight.data.clone()
150+
151+
# Scale weights by activation scale
152+
W = W / act_scale.view(1, -1)
153+
154+
# Compute quantization scales per group
155+
if self.group_size > 0:
156+
n_groups = W.shape[0] // self.group_size
157+
W_groups = W.view(n_groups, self.group_size, -1)
158+
159+
scales = []
160+
zero_points = [] if self.zero_point else None
161+
162+
for idx in range(n_groups):
163+
group = W_groups[idx]
164+
max_abs = torch.max(torch.abs(group))
165+
scale = (2 ** (self.bits - 1) - 1) / max_abs
166+
scales.append(scale)
167+
168+
if self.zero_point:
169+
zero_point = -(torch.max(group) + torch.min(group)) / 2 * scale
170+
zero_points.append(zero_point)
171+
172+
scales = torch.stack(scales)
173+
if self.zero_point:
174+
zero_points = torch.stack(zero_points)
175+
else:
176+
zero_points = torch.zeros_like(scales)
177+
else:
178+
max_abs = torch.max(torch.abs(W), dim=1)[0]
179+
scales = (2 ** (self.bits - 1) - 1) / max_abs
180+
if self.zero_point:
181+
max_vals = torch.max(W, dim=1)[0]
182+
min_vals = torch.min(W, dim=1)[0]
183+
zero_points = -(max_vals + min_vals) / 2 * scales
184+
else:
185+
zero_points = torch.zeros_like(scales)
186+
187+
# Quantize weights
188+
W_quant = torch.round(W * scales.view(-1, 1) - zero_points.view(-1, 1))
189+
190+
# Store quantized weights and parameters
191+
quantized.weight_quantized.copy_(W_quant.to(torch.int8))
192+
quantized.weight_scale.copy_(1.0 / scales)
193+
quantized.weight_zero_point.copy_(zero_points)
194+
195+
# Store additional AWQ-specific information
196+
if hasattr(quantized, 'act_scale'):
197+
quantized.act_scale.copy_(act_scale)
198+
199+
return quantized

0 commit comments

Comments
 (0)