Skip to content

Commit bb57f2e

Browse files
Add the API and triton implementation.
1 parent ce8e976 commit bb57f2e

File tree

5 files changed

+788
-0
lines changed

5 files changed

+788
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ __pycache__/
88
*.dylib
99

1010
upcoming.md
11+
examples/
1112

1213
logs
1314
main.py

quantllm/api/high_level.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""High-level API for QuantLLM - provides simple, user-friendly interfaces."""
2+
3+
import torch
4+
from typing import Optional, Dict, Any, Tuple, Union
5+
from transformers import PreTrainedModel, PreTrainedTokenizer
6+
from ..model.model import Model
7+
from ..model.lora_config import LoraConfigManager
8+
from ..config.model_config import ModelConfig
9+
10+
class QuantLLM:
11+
"""Main interface for QuantLLM, providing simplified model loading and training."""
12+
13+
@staticmethod
14+
def from_pretrained(
15+
model_name: str,
16+
*,
17+
quant_bits: int = 4,
18+
bnb_4bit_compute_dtype: str = "bfloat16",
19+
max_seq_len: Optional[int] = None,
20+
device_map: Union[str, Dict[str, str]] = "auto",
21+
max_memory: Optional[Dict[str, str]] = None,
22+
**kwargs
23+
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
24+
"""
25+
Load a pre-trained model with optional quantization.
26+
27+
Args:
28+
model_name: Name or path of the model to load
29+
quant_bits: Number of bits for quantization (4 or 8)
30+
bnb_4bit_compute_dtype: Compute dtype for 4-bit quantization
31+
max_seq_len: Maximum sequence length
32+
device_map: Device mapping strategy or explicit mapping
33+
max_memory: Maximum memory allocation per device
34+
**kwargs: Additional arguments passed to from_pretrained
35+
36+
Returns:
37+
Tuple of (model, tokenizer)
38+
"""
39+
config = ModelConfig(
40+
model_name=model_name,
41+
load_in_4bit=(quant_bits == 4),
42+
load_in_8bit=(quant_bits == 8),
43+
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
44+
device_map=device_map,
45+
max_memory=max_memory,
46+
kwargs=kwargs
47+
)
48+
49+
model_loader = Model(config)
50+
return model_loader.get_model(), model_loader.get_tokenizer()
51+
52+
@staticmethod
53+
def get_adapter_model(
54+
base_model: PreTrainedModel,
55+
r: int = 16,
56+
target_modules: Optional[list] = None,
57+
lora_alpha: int = 16,
58+
lora_dropout: float = 0.1,
59+
bias: str = "none"
60+
) -> PreTrainedModel:
61+
"""
62+
Attach LoRA adapters to a model for efficient fine-tuning.
63+
64+
Args:
65+
base_model: Base model to attach adapters to
66+
r: LoRA attention dimension
67+
target_modules: List of module names to apply LoRA to
68+
lora_alpha: LoRA alpha parameter
69+
lora_dropout: Dropout probability for LoRA layers
70+
bias: Bias type ("none", "all", or "lora_only")
71+
72+
Returns:
73+
Model with LoRA adapters attached
74+
"""
75+
from peft import prepare_model_for_kbit_training, get_peft_model
76+
77+
lora_config = LoraConfigManager().create_custom_config(
78+
r=r,
79+
target_modules=target_modules,
80+
lora_alpha=lora_alpha,
81+
lora_dropout=lora_dropout,
82+
bias=bias
83+
)
84+
85+
model = prepare_model_for_kbit_training(base_model)
86+
return get_peft_model(model, lora_config)

quantllm/api/low_level.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
"""Low-level API for QuantLLM - provides detailed control over model loading and quantization."""
2+
3+
import torch
4+
from typing import Optional, Dict, Any, Tuple, Union, List
5+
from transformers import PreTrainedModel, PreTrainedTokenizer, BitsAndBytesConfig
6+
from ..model.model import Model
7+
from ..config.model_config import ModelConfig
8+
from ..quant.quantization_engine import QuantizationEngine
9+
from ..quant.kernels import TritonKernelManager
10+
11+
class LowLevelQuantLLM:
12+
"""Low-level interface providing fine-grained control over model loading and quantization."""
13+
14+
def __init__(self):
15+
self.quant_engine = QuantizationEngine()
16+
self.kernel_manager = TritonKernelManager()
17+
18+
def load_model_advanced(
19+
self,
20+
model_name: str,
21+
*,
22+
quant_config: Optional[BitsAndBytesConfig] = None,
23+
device_map: Union[str, Dict[str, str]] = "auto",
24+
max_memory: Optional[Dict[str, str]] = None,
25+
use_triton_kernels: bool = False,
26+
optimize_layers: Optional[List[str]] = None,
27+
**kwargs
28+
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
29+
"""
30+
Load a model with detailed quantization and optimization controls.
31+
32+
Args:
33+
model_name: Model name or path
34+
quant_config: Optional custom BitsAndBytes quantization config
35+
device_map: Device mapping strategy
36+
max_memory: Maximum memory per device
37+
use_triton_kernels: Whether to use optimized Triton kernels
38+
optimize_layers: List of layer names to optimize with Triton
39+
**kwargs: Additional arguments for model loading
40+
41+
Returns:
42+
Tuple of (model, tokenizer)
43+
"""
44+
config = ModelConfig(
45+
model_name=model_name,
46+
device_map=device_map,
47+
max_memory=max_memory,
48+
kwargs=kwargs
49+
)
50+
51+
if quant_config:
52+
config.quantization_config = quant_config.to_dict()
53+
54+
model_loader = Model(config)
55+
model, tokenizer = model_loader.get_model(), model_loader.get_tokenizer()
56+
57+
if use_triton_kernels:
58+
model = self.kernel_manager.optimize_model(
59+
model,
60+
target_modules=optimize_layers
61+
)
62+
63+
return model, tokenizer
64+
65+
def quantize_model_weights(
66+
self,
67+
model: PreTrainedModel,
68+
bits: int = 4,
69+
group_size: int = 128,
70+
compute_dtype: torch.dtype = torch.bfloat16,
71+
use_double_quant: bool = True
72+
) -> PreTrainedModel:
73+
"""
74+
Apply quantization to an existing model's weights.
75+
76+
Args:
77+
model: Model to quantize
78+
bits: Number of bits for quantization
79+
group_size: Size of quantization groups
80+
compute_dtype: Compute dtype for operations
81+
use_double_quant: Whether to use double quantization
82+
83+
Returns:
84+
Quantized model
85+
"""
86+
return self.quant_engine.quantize_weights(
87+
model,
88+
bits=bits,
89+
group_size=group_size,
90+
compute_dtype=compute_dtype,
91+
use_double_quant=use_double_quant
92+
)
93+
94+
def replace_layer_with_triton(
95+
self,
96+
model: PreTrainedModel,
97+
layer_name: str,
98+
kernel_type: str = "auto"
99+
) -> PreTrainedModel:
100+
"""
101+
Replace a specific layer with its optimized Triton version.
102+
103+
Args:
104+
model: Model to modify
105+
layer_name: Name of layer to replace
106+
kernel_type: Type of Triton kernel to use
107+
108+
Returns:
109+
Model with replaced layer
110+
"""
111+
return self.kernel_manager.replace_layer(
112+
model,
113+
layer_name=layer_name,
114+
kernel_type=kernel_type
115+
)
116+
117+
def get_memory_stats(self, model: PreTrainedModel) -> Dict[str, Any]:
118+
"""Get detailed memory statistics for model."""
119+
return self.quant_engine.get_memory_stats(model)

0 commit comments

Comments
 (0)