Skip to content

Commit 437cb87

Browse files
Merge pull request #4 from codewithdark-git/quant-optim-docs-tests
Feat: Introduce QuantizerFactory API and Refactor Quantization Workflow
2 parents 31938b5 + 082196c commit 437cb87

File tree

11 files changed

+702
-656
lines changed

11 files changed

+702
-656
lines changed

docs/api_reference/quantization.rst

Lines changed: 201 additions & 97 deletions
Large diffs are not rendered by default.

quantllm/api.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
from typing import Optional, Dict, Any, Tuple
2+
from transformers import PreTrainedModel
3+
from .quant.awq import AWQQuantizer
4+
from .quant.gptq import GPTQQuantizer
5+
from .quant.gguf import GGUFQuantizer
6+
from .trainer.logger import TrainingLogger
7+
8+
class QuantizerFactory:
9+
@staticmethod
10+
def quantize_from_pretrained(
11+
model_name_or_path: str,
12+
method: str,
13+
quant_config_dict: Optional[Dict[str, Any]] = None,
14+
calibration_data: Optional[Any] = None, # Typically torch.Tensor or similar
15+
calibration_steps: Optional[int] = 100, # Specific to AWQ's quantize method
16+
device: Optional[str] = None # Explicit device control
17+
) -> Tuple[PreTrainedModel, Any]: # Returns (quantized_model, tokenizer)
18+
"""
19+
Loads a model from Hugging Face, quantizes it using the specified method,
20+
and returns the quantized model and its tokenizer.
21+
22+
Args:
23+
model_name_or_path (str): Hugging Face model ID or local path.
24+
method (str): Quantization method to use ('awq', 'gptq', 'gguf').
25+
quant_config_dict (Optional[Dict[str, Any]]): Dictionary with quantization parameters.
26+
Common keys: 'bits', 'group_size', 'batch_size' (for quantizer init).
27+
AWQ specific: 'zero_point', 'awq_version' (maps to 'version' in AWQQuantizer).
28+
GPTQ specific: 'actorder', 'percdamp', 'sym'.
29+
GGUF specific: 'use_packed', 'cpu_offload', 'desc_act', 'desc_ten', 'legacy_format'.
30+
calibration_data (Optional[Any]): Calibration data required for quantization.
31+
calibration_steps (Optional[int]): Number of calibration steps, primarily for AWQ's
32+
quantize() method. Defaults to 100.
33+
device (Optional[str]): Device to run quantization on ('cpu', 'cuda', 'cuda:x').
34+
If None, default device selection logic in BaseQuantizer is used.
35+
36+
Returns:
37+
Tuple[PreTrainedModel, Any]: The quantized model and its associated tokenizer.
38+
39+
Raises:
40+
ValueError: If an unsupported quantization method is specified or essential parameters are missing.
41+
RuntimeError: If quantization fails for some reason.
42+
"""
43+
logger = TrainingLogger()
44+
if quant_config_dict is None:
45+
quant_config_dict = {}
46+
47+
method_lower = method.lower()
48+
logger.log_info(f"Attempting to quantize model '{model_name_or_path}' using method: {method_lower}")
49+
50+
bits = quant_config_dict.get('bits', 4)
51+
group_size = quant_config_dict.get('group_size', 128)
52+
quantizer_batch_size = quant_config_dict.get('batch_size', 4)
53+
54+
quantizer = None
55+
56+
if method_lower == 'awq':
57+
awq_zero_point = quant_config_dict.get('zero_point', True)
58+
awq_version = quant_config_dict.get('awq_version', 'v2')
59+
60+
quantizer = AWQQuantizer(
61+
model_or_model_name_or_path=model_name_or_path,
62+
bits=bits,
63+
group_size=group_size,
64+
zero_point=awq_zero_point,
65+
version=awq_version,
66+
batch_size=quantizer_batch_size,
67+
device=device
68+
)
69+
logger.log_info(f"Quantizing with AWQ... Bits: {bits}, Group Size: {group_size}, Zero Point: {awq_zero_point}, Version: {awq_version}")
70+
quantizer.quantize( # Call quantize, model is updated in place
71+
calibration_data=calibration_data,
72+
calibration_steps=calibration_steps
73+
)
74+
75+
elif method_lower == 'gptq':
76+
gptq_actorder = quant_config_dict.get('actorder', True)
77+
gptq_percdamp = quant_config_dict.get('percdamp', 0.01)
78+
gptq_sym = quant_config_dict.get('sym', True)
79+
80+
quantizer = GPTQQuantizer(
81+
model_or_model_name_or_path=model_name_or_path,
82+
bits=bits,
83+
group_size=group_size,
84+
actorder=gptq_actorder,
85+
percdamp=gptq_percdamp,
86+
sym=gptq_sym,
87+
batch_size=quantizer_batch_size,
88+
device=device
89+
)
90+
logger.log_info(f"Quantizing with GPTQ... Bits: {bits}, Group Size: {group_size}, ActOrder: {gptq_actorder}, Sym: {gptq_sym}")
91+
quantizer.quantize(calibration_data=calibration_data) # Model updated in place
92+
93+
elif method_lower == 'gguf':
94+
gguf_use_packed = quant_config_dict.get('use_packed', True)
95+
gguf_cpu_offload = quant_config_dict.get('cpu_offload', False)
96+
gguf_desc_act = quant_config_dict.get('desc_act', False)
97+
gguf_desc_ten = quant_config_dict.get('desc_ten', False)
98+
gguf_legacy_format = quant_config_dict.get('legacy_format', False)
99+
100+
quantizer = GGUFQuantizer(
101+
model_or_model_name_or_path=model_name_or_path,
102+
bits=bits,
103+
group_size=group_size,
104+
use_packed=gguf_use_packed,
105+
cpu_offload=gguf_cpu_offload,
106+
desc_act=gguf_desc_act,
107+
desc_ten=gguf_desc_ten,
108+
legacy_format=gguf_legacy_format,
109+
batch_size=quantizer_batch_size,
110+
device=device
111+
)
112+
logger.log_info(f"Quantizing with GGUF... Bits: {bits}, Group Size: {group_size}, Packed: {gguf_use_packed}, CPU Offload: {gguf_cpu_offload}")
113+
quantizer.quantize(calibration_data=calibration_data) # Model updated in place
114+
115+
else:
116+
logger.log_error(f"Unsupported quantization method: {method}")
117+
raise ValueError(f"Unsupported quantization method: {method}. Supported methods are 'awq', 'gptq', 'gguf'.")
118+
119+
if quantizer is None or quantizer.model is None:
120+
logger.log_error(f"Failed to initialize quantizer or obtain quantized model for method: {method}")
121+
raise RuntimeError(f"Quantization failed for method: {method}. Quantizer or model is None.")
122+
123+
logger.log_info(f"Successfully quantized model with method: {method_lower}")
124+
return quantizer.model, quantizer.tokenizer

quantllm/quant/awq.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
import numpy as np
77
from typing import Optional, Dict, Any, List, Union, Tuple
88
from transformers import PreTrainedModel
9-
from .quantization_engine import BaseQuantizer, QuantizationConfig, QuantizedLinear
9+
from .quantization_engine import move_to_device, BaseQuantizer, QuantizationConfig, QuantizedLinear
1010

1111
class AWQQuantizer(BaseQuantizer):
1212
"""AWQ quantization implementation with memory-efficient processing."""
1313

1414
def __init__(
1515
self,
16-
model: PreTrainedModel,
16+
model_or_model_name_or_path: Union[str, PreTrainedModel], # Changed parameter name
1717
bits: int = 4,
1818
group_size: int = 128,
1919
zero_point: bool = True,
@@ -27,7 +27,8 @@ def __init__(
2727
Initializes the AWQQuantizer.
2828
2929
Args:
30-
model (PreTrainedModel): The model to be quantized.
30+
model_or_model_name_or_path (Union[str, PreTrainedModel]):
31+
The Hugging Face model name/path or a PreTrainedModel instance to be quantized.
3132
bits (int, optional): Number of bits for quantization. Defaults to 4.
3233
group_size (int, optional): Size of the quantization group. Defaults to 128.
3334
zero_point (bool, optional): Whether to use zero-point quantization for activations. Defaults to True.
@@ -39,7 +40,9 @@ def __init__(
3940
The device for quantization operations ('cpu', 'cuda', etc.).
4041
Inherited from BaseQuantizer. Defaults to None (auto-detection).
4142
"""
42-
super().__init__(model=model, bits=bits, device=device)
43+
# Pass all relevant kwargs to BaseQuantizer
44+
# AWQQuantizer specific args are handled here.
45+
super().__init__(model_or_model_name_or_path=model_or_model_name_or_path, bits=bits, device=device)
4346
self.group_size = group_size
4447
self.zero_point = zero_point
4548
self.scale_dtype = scale_dtype
@@ -101,7 +104,18 @@ def quantize(
101104

102105
self._clear_memory()
103106

107+
# Update model config with quantization parameters
108+
awq_specific_params = {
109+
"zero_point": self.zero_point,
110+
"version": self.version,
111+
"scale_dtype": self.scale_dtype, # Added from __init__
112+
"enable_mnn_kernel": self.enable_mnn_kernel # Added from __init__
113+
# batch_size is more of a process param, not a model config param usually
114+
}
115+
self._update_model_config_with_quant_params("awq", awq_specific_params)
116+
104117
return self.model
118+
105119
def _collect_activation_stats(
106120
self,
107121
data: torch.Tensor # Removed num_steps parameter

quantllm/quant/gguf.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch.nn as nn
77
from typing import Optional, Dict, Any, List, Union, Tuple
88
from transformers import PreTrainedModel
9-
from .quantization_engine import BaseQuantizer, QuantizationConfig, QuantizedLinear
9+
from .quantization_engine import move_to_device, BaseQuantizer, QuantizationConfig, QuantizedLinear
1010

1111
try:
1212
import ctransformers
@@ -19,7 +19,7 @@ class GGUFQuantizer(BaseQuantizer):
1919

2020
def __init__(
2121
self,
22-
model: PreTrainedModel,
22+
model_or_model_name_or_path: Union[str, PreTrainedModel], # Changed parameter name
2323
bits: int = 4,
2424
group_size: int = 32,
2525
desc_act: bool = False,
@@ -34,7 +34,8 @@ def __init__(
3434
Initializes the GGUFQuantizer.
3535
3636
Args:
37-
model (PreTrainedModel): The model to be quantized.
37+
model_or_model_name_or_path (Union[str, PreTrainedModel]):
38+
The Hugging Face model name/path or a PreTrainedModel instance to be quantized.
3839
bits (int, optional): Number of bits for quantization. Defaults to 4.
3940
group_size (int, optional): Size of the quantization group. Defaults to 32.
4041
desc_act (bool, optional): Whether to describe activations in GGUF metadata. Defaults to False.
@@ -52,7 +53,7 @@ def __init__(
5253
if not CT_AVAILABLE:
5354
raise ImportError("CTransformers is required for GGUF quantization. Install with: pip install ctransformers")
5455

55-
super().__init__(model=model, bits=bits, device=device)
56+
super().__init__(model_or_model_name_or_path=model_or_model_name_or_path, bits=bits, device=device)
5657
self.group_size = group_size
5758
self.desc_act = desc_act
5859
self.desc_ten = desc_ten
@@ -94,9 +95,20 @@ def quantize(
9495
setattr(self.model, name, quantized)
9596

9697
self._clear_memory()
98+
99+
# Update model config with quantization parameters
100+
gguf_specific_params = {
101+
"use_packed": self.use_packed,
102+
"cpu_offload": self.cpu_offload,
103+
"desc_act": self.desc_act,
104+
"desc_ten": self.desc_ten,
105+
"legacy_format": self.legacy_format
106+
# group_size is handled by BaseQuantizer if present as self.group_size
107+
}
108+
self._update_model_config_with_quant_params("gguf", gguf_specific_params)
97109

98110
return self.model
99-
111+
100112
def _collect_stats(self, data: torch.Tensor) -> Dict[str, Dict[str, torch.Tensor]]:
101113
"""Collect statistics for quantization with memory-efficient batch processing."""
102114
stats = {}

quantllm/quant/gptq.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
import torch.nn as nn
77
from typing import Optional, Dict, Any, List, Union
88
from transformers import PreTrainedModel
9-
from .quantization_engine import BaseQuantizer, QuantizationConfig, QuantizedLinear
9+
from .quantization_engine import move_to_device, BaseQuantizer, QuantizationConfig, QuantizedLinear
1010

1111
class GPTQQuantizer(BaseQuantizer):
1212
"""GPTQ quantization implementation with memory-efficient processing."""
1313

1414
def __init__(
1515
self,
16-
model: PreTrainedModel,
16+
model_or_model_name_or_path: Union[str, PreTrainedModel], # Changed parameter name
1717
bits: int = 4,
1818
group_size: int = 128,
1919
actorder: bool = False,
@@ -28,7 +28,8 @@ def __init__(
2828
Initializes the GPTQQuantizer.
2929
3030
Args:
31-
model (PreTrainedModel): The model to be quantized.
31+
model_or_model_name_or_path (Union[str, PreTrainedModel]):
32+
The Hugging Face model name/path or a PreTrainedModel instance to be quantized.
3233
bits (int, optional): Number of bits for quantization. Defaults to 4.
3334
group_size (int, optional): Size of the quantization group. Defaults to 128.
3435
actorder (bool, optional): Whether to use activation order for columns. Defaults to False.
@@ -43,7 +44,7 @@ def __init__(
4344
The device for quantization operations ('cpu', 'cuda', etc.).
4445
Inherited from BaseQuantizer. Defaults to None (auto-detection).
4546
"""
46-
super().__init__(model=model, bits=bits, device=device)
47+
super().__init__(model_or_model_name_or_path=model_or_model_name_or_path, bits=bits, device=device)
4748
self.group_size = group_size
4849
self.actorder = actorder
4950
self.allow_mixed_bits = allow_mixed_bits
@@ -101,8 +102,18 @@ def quantize(self, calibration_data: Optional[torch.Tensor] = None) -> PreTraine
101102
self._clear_memory()
102103
del self.H[name]
103104

105+
# Update model config with quantization parameters
106+
gptq_specific_params = {
107+
"actorder": self.actorder,
108+
"sym": self.sym,
109+
"percdamp": self.percdamp,
110+
"allow_mixed_bits": self.allow_mixed_bits # Added from __init__
111+
# use_triton is more of a runtime/environment flag, might not be essential in model config
112+
}
113+
self._update_model_config_with_quant_params("gptq", gptq_specific_params)
114+
104115
return self.model
105-
116+
106117
def _compute_hessian(self, layer: nn.Linear, data: torch.Tensor) -> torch.Tensor:
107118
"""Compute Hessian approximation for a layer with memory-efficient processing."""
108119
n = layer.in_features

0 commit comments

Comments
 (0)