Skip to content

Commit 0d2b20a

Browse files
Add the Quantizations Methods.
1 parent 03c1fe9 commit 0d2b20a

File tree

3 files changed

+208
-144
lines changed

3 files changed

+208
-144
lines changed

quantllm/quant/gptq.py

Lines changed: 68 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
"""GPTQ (Goyal-Pham-Tan-Quant) implementation for LLM quantization."""
1+
"""GPTQ (Goyal-Pham-Tan-Quant) implementation."""
22

3-
import math
43
import gc
4+
import math
55
import torch
66
import torch.nn as nn
77
from typing import Optional, Dict, Any, List, Union
88
from transformers import PreTrainedModel
9-
from .quantization_engine import QuantizationConfig, QuantizedLinear
9+
from .quantization_engine import QuantizationConfig, QuantizedLinear, DeviceManager
1010

1111
class GPTQQuantizer:
1212
"""GPTQ quantization implementation with memory-efficient processing."""
@@ -22,18 +22,22 @@ def __init__(
2222
percdamp: float = 0.01,
2323
sym: bool = True,
2424
batch_size: int = 4,
25-
cpu_offload: bool = False
25+
device: Optional[Union[str, torch.device]] = None
2626
):
2727
self.model = model
2828
self.bits = bits
2929
self.group_size = group_size
3030
self.actorder = actorder
3131
self.allow_mixed_bits = allow_mixed_bits
32-
self.use_triton = use_triton
32+
self.use_triton = use_triton and torch.cuda.is_available()
3333
self.percdamp = percdamp
3434
self.sym = sym
3535
self.batch_size = batch_size
36-
self.cpu_offload = cpu_offload
36+
37+
# Initialize device manager
38+
self.device_manager = DeviceManager(
39+
torch.device(device) if device else None
40+
)
3741

3842
# Initialize H matrices for each layer
3943
self.H = {}
@@ -43,6 +47,7 @@ def _clear_memory(self):
4347
gc.collect()
4448
if torch.cuda.is_available():
4549
torch.cuda.empty_cache()
50+
self.device_manager.sync()
4651

4752
def quantize(self, calibration_data: Optional[torch.Tensor] = None) -> PreTrainedModel:
4853
"""
@@ -56,24 +61,25 @@ def quantize(self, calibration_data: Optional[torch.Tensor] = None) -> PreTraine
5661
"""
5762
if calibration_data is None:
5863
raise ValueError("GPTQ requires calibration data for quantization")
59-
60-
# Prepare model for quantization
64+
65+
device = self.device_manager.primary_device
66+
self.model.to(device)
6167
self.model.eval()
6268

63-
# Convert all linear layers to quantizable versions
69+
# Convert calibration data to correct device
70+
calibration_data = calibration_data.to(device)
71+
72+
# Process layers
6473
for name, module in self.model.named_modules():
6574
if isinstance(module, nn.Linear):
6675
print(f"Processing layer: {name}")
6776

68-
# Compute Hessian approximation for this layer
69-
self.H[name] = self._compute_hessian(module, calibration_data)
70-
71-
# Move Hessian to CPU if offloading is enabled
72-
if self.cpu_offload:
73-
self.H[name] = self.H[name].cpu()
77+
# Ensure layer is on correct device
78+
module.to(device)
7479

75-
# Convert to quantized layer
76-
quantized = self._quantize_layer(module, self.H[name].to(module.weight.device) if self.cpu_offload else self.H[name])
80+
# Compute Hessian approximation
81+
self.H[name] = self._compute_hessian(module, calibration_data)
82+
quantized = self._quantize_layer(module, self.H[name])
7783

7884
# Replace layer in model
7985
parent_name = '.'.join(name.split('.')[:-1])
@@ -87,65 +93,57 @@ def quantize(self, calibration_data: Optional[torch.Tensor] = None) -> PreTraine
8793
# Clear memory after processing each layer
8894
self._clear_memory()
8995

90-
# Remove processed Hessian to free memory
96+
# Remove processed Hessian
9197
del self.H[name]
9298

9399
return self.model
94-
100+
95101
def _compute_hessian(self, layer: nn.Linear, data: torch.Tensor) -> torch.Tensor:
96-
"""Compute Hessian approximation for a layer with memory-efficient batch processing."""
97-
device = next(layer.parameters()).device
98-
99-
# Initialize accumulator on CPU if offloading is enabled
102+
"""Compute Hessian approximation for a layer with memory-efficient processing."""
103+
device = self.device_manager.primary_device
100104
n = layer.in_features
101-
H = torch.zeros((n, n), device='cpu' if self.cpu_offload else device)
105+
H = torch.zeros((n, n), device=device)
102106

103107
def hook_fn(module, input, output):
104108
x = input[0].detach()
105-
# Reshape input if needed (batch_size * seq_len, hidden_size)
106109
if len(x.shape) == 3:
107110
x = x.view(-1, x.size(-1))
108111

109112
with torch.no_grad():
110-
# Process in smaller chunks to save memory
111-
chunk_size = 1024 # Adjust based on available memory
113+
chunk_size = 1024
112114
num_chunks = math.ceil(x.size(0) / chunk_size)
113115

114116
for i in range(num_chunks):
115117
chunk = x[i * chunk_size:(i + 1) * chunk_size]
116-
# Compute contribution to Hessian
117-
if self.cpu_offload:
118-
chunk_H = torch.matmul(chunk.t(), chunk).cpu()
119-
else:
120-
chunk_H = torch.matmul(chunk.t(), chunk)
118+
chunk_H = torch.matmul(chunk.t(), chunk)
121119
H.add_(chunk_H)
122120

123-
# Clear intermediate tensors
124121
del chunk_H
125-
if i % 10 == 0: # Periodic memory cleanup
122+
if i % 10 == 0:
126123
self._clear_memory()
127124

128125
# Register forward hook
129126
handle = layer.register_forward_hook(hook_fn)
130127

131-
# Run calibration data through model in batches
128+
# Process calibration data in batches
132129
with torch.no_grad():
133130
for i in range(0, len(data), self.batch_size):
134-
batch = data[i:i+self.batch_size]
131+
batch = data[i:i+self.batch_size].to(device)
135132
self.model(batch)
136133

137-
# Periodic memory cleanup
138134
if i % (self.batch_size * 10) == 0:
139135
self._clear_memory()
140136

141-
# Remove hook
142137
handle.remove()
143-
144138
return H
145139

146140
def _quantize_layer(self, layer: nn.Linear, H: torch.Tensor) -> QuantizedLinear:
147-
"""Quantize a single layer using GPTQ."""
148-
device = next(layer.parameters()).device
141+
"""Quantize a single layer using GPTQ with memory management."""
142+
device = self.device_manager.primary_device
143+
144+
# Ensure tensors are on the correct device
145+
H = H.to(device)
146+
W = layer.weight.data.to(device)
149147

150148
# Initialize quantized layer
151149
quantized = QuantizedLinear(
@@ -155,65 +153,42 @@ def _quantize_layer(self, layer: nn.Linear, H: torch.Tensor) -> QuantizedLinear:
155153
config=QuantizationConfig(
156154
bits=self.bits,
157155
scheme="symmetric" if self.sym else "asymmetric",
158-
granularity="per-tensor",
159-
calibration="minmax",
160-
channel_wise=False,
161-
dtype=f"{'u' if not self.sym else ''}int{self.bits}",
162-
format="gptq"
156+
granularity="per-channel",
157+
calibration="gptq"
163158
)
164-
)
159+
).to(device)
165160

166-
# Copy bias if exists
167161
if layer.bias is not None:
168162
quantized.bias.data.copy_(layer.bias.data)
169163

170-
# Get weight matrix
171-
W = layer.weight.data.clone()
172-
173-
# Compute optimal scales and zero points
174-
if self.group_size > 0:
175-
n_groups = W.shape[0] // self.group_size
176-
W_groups = W.view(n_groups, self.group_size, -1)
177-
scales = []
178-
zero_points = []
164+
# Process in chunks to save memory
165+
chunk_size = min(1024, layer.out_features)
166+
for i in range(0, layer.out_features, chunk_size):
167+
chunk_end = min(i + chunk_size, layer.out_features)
168+
W_chunk = W[i:chunk_end]
179169

180-
for idx in range(n_groups):
181-
group = W_groups[idx]
182-
if self.sym:
183-
scale = (2 ** (self.bits - 1) - 1) / torch.max(torch.abs(group))
184-
zero_point = 0
185-
else:
186-
min_val = torch.min(group)
187-
max_val = torch.max(group)
188-
scale = (2 ** self.bits - 1) / (max_val - min_val)
189-
zero_point = -min_val * scale
190-
191-
scales.append(scale)
192-
zero_points.append(zero_point)
193-
194-
scales = torch.stack(scales)
195-
zero_points = torch.stack(zero_points)
196-
else:
170+
# Compute optimal scaling factors for this chunk
197171
if self.sym:
198-
scales = (2 ** (self.bits - 1) - 1) / torch.max(torch.abs(W), dim=1)[0]
199-
zero_points = torch.zeros_like(scales)
172+
max_val = W_chunk.abs().max(dim=1)[0]
173+
scale = (2 ** (self.bits - 1) - 1) / max_val
200174
else:
201-
min_vals = torch.min(W, dim=1)[0]
202-
max_vals = torch.max(W, dim=1)[0]
203-
scales = (2 ** self.bits - 1) / (max_vals - min_vals)
204-
zero_points = -min_vals * scales
205-
206-
# Quantize weights
207-
W_quant = torch.round(W * scales.view(-1, 1) - zero_points.view(-1, 1))
208-
209-
# Apply GPTQ optimization
210-
recon_loss = torch.sum((W - (W_quant + zero_points.view(-1, 1)) / scales.view(-1, 1)).pow(2))
211-
if H is not None:
212-
recon_loss = recon_loss * torch.trace(H)
213-
214-
# Store quantized weights and parameters
215-
quantized.weight_quantized.copy_(W_quant.to(torch.int8))
216-
quantized.weight_scale.copy_(1.0 / scales)
217-
quantized.weight_zero_point.copy_(zero_points)
175+
min_val = W_chunk.min(dim=1)[0]
176+
max_val = W_chunk.max(dim=1)[0]
177+
scale = (2 ** self.bits - 1) / (max_val - min_val)
178+
179+
# Quantize chunk
180+
W_quant = torch.round(W_chunk * scale.unsqueeze(1))
181+
W_quant = torch.clamp(
182+
W_quant,
183+
-(2 ** (self.bits - 1)),
184+
2 ** (self.bits - 1) - 1
185+
)
186+
187+
# Store quantized weights and scale
188+
quantized.weight_quantized.data[i:chunk_end] = W_quant.to(torch.int8)
189+
quantized.weight_scale.data[i:chunk_end] = 1.0 / scale
190+
191+
del W_chunk, W_quant
192+
self._clear_memory()
218193

219194
return quantized

quantllm/quant/quantization_engine.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,62 @@
77
import numpy as np
88
from ..trainer.logger import TrainingLogger
99

10+
def get_device_map(model: PreTrainedModel) -> Dict[str, torch.device]:
11+
"""Get device mapping for model parameters."""
12+
device_map = {}
13+
for name, param in model.named_parameters():
14+
device_map[name] = param.device
15+
return device_map
16+
17+
def move_to_device(
18+
tensor: torch.Tensor,
19+
device: torch.device,
20+
force_copy: bool = False
21+
) -> torch.Tensor:
22+
"""Safely move tensor to device with proper error handling."""
23+
try:
24+
if force_copy:
25+
return tensor.to(device, copy=True)
26+
if tensor.device == device:
27+
return tensor
28+
return tensor.to(device)
29+
except Exception as e:
30+
raise RuntimeError(f"Failed to move tensor to {device}: {str(e)}")
31+
32+
class DeviceManager:
33+
"""Manage device placement and synchronization."""
34+
35+
def __init__(self, primary_device: Optional[torch.device] = None):
36+
self.primary_device = primary_device or self._get_default_device()
37+
self.device_maps = {}
38+
39+
def _get_default_device(self) -> torch.device:
40+
"""Get the best available device."""
41+
if torch.cuda.is_available():
42+
# Automatically select GPU with most free memory
43+
max_free = 0
44+
best_device = 0
45+
for i in range(torch.cuda.device_count()):
46+
free_mem = torch.cuda.get_device_properties(i).total_memory - torch.cuda.memory_allocated(i)
47+
if free_mem > max_free:
48+
max_free = free_mem
49+
best_device = i
50+
return torch.device(f'cuda:{best_device}')
51+
return torch.device('cpu')
52+
53+
def sync(self):
54+
"""Synchronize all CUDA devices."""
55+
if torch.cuda.is_available():
56+
for i in range(torch.cuda.device_count()):
57+
torch.cuda.synchronize(i)
58+
59+
def ensure_same_device(self, *tensors: torch.Tensor) -> List[torch.Tensor]:
60+
"""Ensure all tensors are on the same device."""
61+
if not tensors:
62+
return []
63+
target_device = tensors[0].device
64+
return [move_to_device(t, target_device) for t in tensors]
65+
1066
class QuantizationConfig:
1167
"""Configuration for quantization parameters."""
1268

@@ -108,10 +164,14 @@ class QuantizationEngine:
108164
def __init__(
109165
self,
110166
config: QuantizationConfig,
111-
logger: Optional[TrainingLogger] = None
167+
logger: Optional[TrainingLogger] = None,
168+
device: Optional[Union[str, torch.device]] = None
112169
):
113170
self.config = config
114171
self.logger = logger or TrainingLogger()
172+
self.device_manager = DeviceManager(
173+
torch.device(device) if device else None
174+
)
115175

116176
def quantize_model(
117177
self,

0 commit comments

Comments
 (0)