Skip to content

Commit f630e89

Browse files
Add the Quantizations Methods.
1 parent 0d2b20a commit f630e89

File tree

5 files changed

+259
-174
lines changed

5 files changed

+259
-174
lines changed

quantllm/quant/awq.py

Lines changed: 24 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
1-
"""AWQ (Activation-Aware Weight Quantization) implementation with memory-efficient processing."""
1+
"""AWQ (Activation-Aware Weight Quantization) implementation."""
22

33
import gc
44
import torch
55
import torch.nn as nn
66
import numpy as np
77
from typing import Optional, Dict, Any, List, Union, Tuple
88
from transformers import PreTrainedModel
9-
from .quantization_engine import QuantizationConfig, QuantizedLinear
9+
from .quantization_engine import BaseQuantizer, QuantizationConfig, QuantizedLinear
1010

11-
class AWQQuantizer:
12-
"""AWQ quantization implementation with memory-efficient processing."""
11+
class AWQQuantizer(BaseQuantizer):
1312
"""AWQ quantization implementation with memory-efficient processing."""
1413

1514
def __init__(
@@ -21,24 +20,21 @@ def __init__(
2120
scale_dtype: str = "fp32",
2221
version: str = "v2",
2322
enable_mnn_kernel: bool = False,
24-
batch_size: int = 2, # Small batch size for memory efficiency
25-
cpu_offload: bool = True # Enable CPU offloading
23+
batch_size: int = 2,
24+
device: Optional[Union[str, torch.device]] = None
2625
):
27-
self.model = model
28-
self.bits = bits
26+
super().__init__(model=model, bits=bits, device=device)
2927
self.group_size = group_size
3028
self.zero_point = zero_point
3129
self.scale_dtype = scale_dtype
3230
self.version = version
3331
self.enable_mnn_kernel = enable_mnn_kernel
32+
self.batch_size = batch_size
3433

3534
# Initialize activation statistics dictionaries
3635
self.act_scales = {}
3736
self.weight_scales = {}
3837

39-
self.batch_size = batch_size
40-
self.cpu_offload = cpu_offload
41-
4238
def _clear_memory(self):
4339
"""Clear GPU memory and run garbage collection."""
4440
if torch.cuda.is_available():
@@ -54,8 +50,8 @@ def quantize(
5450
if calibration_data is None:
5551
raise ValueError("AWQ requires calibration data for quantization")
5652

57-
# Keep model on CPU initially
58-
self.model.cpu()
53+
# Prepare calibration data
54+
calibration_data = self.prepare_calibration_data(calibration_data)
5955
self.model.eval()
6056

6157
# Process calibration data in batches
@@ -68,21 +64,9 @@ def quantize(
6864
end_idx = min(step + self.batch_size, total_steps)
6965
batch = calibration_data[step:end_idx]
7066

71-
# Move batch to appropriate device
72-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73-
batch = batch.to(device)
74-
75-
# Temporarily move model to device
76-
if device.type == "cuda":
77-
self.model = self.model.cuda()
78-
7967
# Collect statistics for this batch
8068
self._collect_activation_stats(batch)
8169

82-
# Move model back to CPU if offloading enabled
83-
if self.cpu_offload and device.type == "cuda":
84-
self.model = self.model.cpu()
85-
8670
# Clean up batch
8771
del batch
8872
self._clear_memory()
@@ -93,35 +77,25 @@ def quantize(
9377
# Quantize the model layer by layer
9478
for name, module in self.model.named_modules():
9579
if isinstance(module, nn.Linear):
96-
# Move layer to device temporarily for quantization
97-
if device.type == "cuda":
98-
module = module.cuda()
99-
80+
self.logger.info(f"Processing layer: {name}")
81+
10082
# Get activation scale for this layer
10183
act_scale = self.act_scales.get(name)
102-
if act_scale is not None:
103-
# Quantize layer
104-
quantized = self._quantize_layer(module, act_scale)
105-
106-
# Move quantized layer back to CPU if offloading
107-
if self.cpu_offload:
108-
quantized = quantized.cpu()
109-
110-
# Replace layer in model
111-
parent_name = '.'.join(name.split('.')[:-1])
112-
child_name = name.split('.')[-1]
113-
114-
if parent_name:
115-
parent = self.model.get_submodule(parent_name)
116-
setattr(parent, child_name, quantized)
117-
else:
118-
setattr(self.model, name, quantized)
119-
120-
# Clean up
121-
self._clear_memory()
84+
quantized = self._quantize_layer(module, act_scale)
12285

86+
# Replace layer in model
87+
parent_name = '.'.join(name.split('.')[:-1])
88+
child_name = name.split('.')[-1]
89+
if parent_name:
90+
parent = self.model.get_submodule(parent_name)
91+
setattr(parent, child_name, quantized)
92+
else:
93+
setattr(self.model, name, quantized)
94+
95+
self._clear_memory()
96+
12397
return self.model
124-
def _collect_activation_stats(
98+
def _collect_activation_stats(
12599
self,
126100
data: torch.Tensor,
127101
num_steps: int

quantllm/quant/gguf.py

Lines changed: 13 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,16 @@
66
import torch.nn as nn
77
from typing import Optional, Dict, Any, List, Union, Tuple
88
from transformers import PreTrainedModel
9-
from .quantization_engine import QuantizationConfig, QuantizedLinear
9+
from .quantization_engine import BaseQuantizer, QuantizationConfig, QuantizedLinear
1010

1111
try:
1212
import ctransformers
1313
CT_AVAILABLE = True
1414
except ImportError:
1515
CT_AVAILABLE = False
1616

17-
class GGUFQuantizer:
18-
"""GGUF quantization implementation with CTransformers integration and memory-efficient processing."""
17+
class GGUFQuantizer(BaseQuantizer):
18+
"""GGUF quantization implementation with CTransformers integration."""
1919

2020
def __init__(
2121
self,
@@ -27,41 +27,27 @@ def __init__(
2727
use_packed: bool = True,
2828
legacy_format: bool = False,
2929
batch_size: int = 4,
30-
cpu_offload: bool = False
30+
device: Optional[Union[str, torch.device]] = None
3131
):
3232
if not CT_AVAILABLE:
3333
raise ImportError("CTransformers is required for GGUF quantization. Install with: pip install ctransformers")
34-
35-
self.model = model
36-
self.bits = bits
34+
35+
super().__init__(model=model, bits=bits, device=device)
3736
self.group_size = group_size
3837
self.desc_act = desc_act
3938
self.desc_ten = desc_ten
4039
self.use_packed = use_packed
4140
self.legacy_format = legacy_format
4241
self.batch_size = batch_size
43-
self.cpu_offload = cpu_offload
4442

45-
def _clear_memory(self):
46-
"""Clear CUDA memory and run garbage collection."""
47-
gc.collect()
48-
if torch.cuda.is_available():
49-
torch.cuda.empty_cache()
50-
5143
def quantize(
5244
self,
5345
calibration_data: Optional[torch.Tensor] = None
5446
) -> PreTrainedModel:
55-
"""
56-
Quantize model using GGUF format with memory-efficient processing.
57-
58-
Args:
59-
calibration_data: Optional tensor for computing quantization statistics
60-
61-
Returns:
62-
Quantized model
63-
"""
64-
# Prepare model for quantization
47+
"""Quantize model using GGUF format with memory-efficient processing."""
48+
# Prepare model and calibration data
49+
if calibration_data is not None:
50+
calibration_data = self.prepare_calibration_data(calibration_data)
6551
self.model.eval()
6652

6753
# Collect statistics if provided
@@ -72,19 +58,11 @@ def quantize(
7258
# Convert linear layers to quantized versions
7359
for name, module in self.model.named_modules():
7460
if isinstance(module, nn.Linear):
75-
print(f"Processing layer: {name}")
61+
self.logger.info(f"Processing layer: {name}")
7662

7763
# Create quantized layer
7864
layer_stats = stats.get(name, None)
79-
80-
# Move stats to appropriate device
81-
if layer_stats is not None and self.cpu_offload:
82-
layer_stats = {k: v.to('cpu') for k, v in layer_stats.items()}
83-
84-
quantized = self._quantize_layer(
85-
module,
86-
{k: v.to(module.weight.device) for k, v in layer_stats.items()} if self.cpu_offload and layer_stats else layer_stats
87-
)
65+
quantized = self._quantize_layer(module, layer_stats)
8866

8967
# Replace layer in model
9068
parent_name = '.'.join(name.split('.')[:-1])
@@ -95,7 +73,6 @@ def quantize(
9573
else:
9674
setattr(self.model, name, quantized)
9775

98-
# Clear memory after processing each layer
9976
self._clear_memory()
10077

10178
return self.model
@@ -325,7 +302,7 @@ def _quantize_layer(
325302
quantized.input_std.copy_(stats["std"].to(target_device))
326303

327304
return quantized
328-
def convert_to_gguf(self, output_path: str):
305+
def convert_to_gguf(self, output_path: str):
329306
"""Convert quantized model to GGUF format using CTransformers."""
330307
if not CT_AVAILABLE:
331308
raise ImportError("CTransformers is required for GGUF conversion")

quantllm/quant/gptq.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
import torch.nn as nn
77
from typing import Optional, Dict, Any, List, Union
88
from transformers import PreTrainedModel
9-
from .quantization_engine import QuantizationConfig, QuantizedLinear, DeviceManager
9+
from .quantization_engine import BaseQuantizer, QuantizationConfig, QuantizedLinear
1010

11-
class GPTQQuantizer:
11+
class GPTQQuantizer(BaseQuantizer):
1212
"""GPTQ quantization implementation with memory-efficient processing."""
1313

1414
def __init__(
@@ -24,8 +24,7 @@ def __init__(
2424
batch_size: int = 4,
2525
device: Optional[Union[str, torch.device]] = None
2626
):
27-
self.model = model
28-
self.bits = bits
27+
super().__init__(model=model, bits=bits, device=device)
2928
self.group_size = group_size
3029
self.actorder = actorder
3130
self.allow_mixed_bits = allow_mixed_bits
@@ -34,11 +33,6 @@ def __init__(
3433
self.sym = sym
3534
self.batch_size = batch_size
3635

37-
# Initialize device manager
38-
self.device_manager = DeviceManager(
39-
torch.device(device) if device else None
40-
)
41-
4236
# Initialize H matrices for each layer
4337
self.H = {}
4438

@@ -62,20 +56,14 @@ def quantize(self, calibration_data: Optional[torch.Tensor] = None) -> PreTraine
6256
if calibration_data is None:
6357
raise ValueError("GPTQ requires calibration data for quantization")
6458

65-
device = self.device_manager.primary_device
66-
self.model.to(device)
59+
# Prepare model and data
60+
calibration_data = self.prepare_for_quantization(calibration_data)
6761
self.model.eval()
6862

69-
# Convert calibration data to correct device
70-
calibration_data = calibration_data.to(device)
71-
7263
# Process layers
7364
for name, module in self.model.named_modules():
7465
if isinstance(module, nn.Linear):
75-
print(f"Processing layer: {name}")
76-
77-
# Ensure layer is on correct device
78-
module.to(device)
66+
self.logger.info(f"Processing layer: {name}")
7967

8068
# Compute Hessian approximation
8169
self.H[name] = self._compute_hessian(module, calibration_data)
@@ -92,8 +80,6 @@ def quantize(self, calibration_data: Optional[torch.Tensor] = None) -> PreTraine
9280

9381
# Clear memory after processing each layer
9482
self._clear_memory()
95-
96-
# Remove processed Hessian
9783
del self.H[name]
9884

9985
return self.model

0 commit comments

Comments
 (0)