1+ """Low-level API for QuantLLM - provides detailed control over model loading and quantization."""
2+
3+ import torch
4+ from typing import Optional , Dict , Any , Tuple , Union , List
5+ from transformers import PreTrainedModel , PreTrainedTokenizer , BitsAndBytesConfig
6+ from ..model .model import Model
7+ from ..config .model_config import ModelConfig
8+ from ..quant .quantization_engine import QuantizationEngine
9+ from ..quant .kernels import TritonKernelManager
10+
11+ class LowLevelQuantLLM :
12+ """Low-level interface providing fine-grained control over model loading and quantization."""
13+
14+ def __init__ (self ):
15+ self .quant_engine = QuantizationEngine ()
16+ self .kernel_manager = TritonKernelManager ()
17+
18+ def load_model_advanced (
19+ self ,
20+ model_name : str ,
21+ * ,
22+ quant_config : Optional [BitsAndBytesConfig ] = None ,
23+ device_map : Union [str , Dict [str , str ]] = "auto" ,
24+ max_memory : Optional [Dict [str , str ]] = None ,
25+ use_triton_kernels : bool = False ,
26+ optimize_layers : Optional [List [str ]] = None ,
27+ ** kwargs
28+ ) -> Tuple [PreTrainedModel , PreTrainedTokenizer ]:
29+ """
30+ Load a model with detailed quantization and optimization controls.
31+
32+ Args:
33+ model_name: Model name or path
34+ quant_config: Optional custom BitsAndBytes quantization config
35+ device_map: Device mapping strategy
36+ max_memory: Maximum memory per device
37+ use_triton_kernels: Whether to use optimized Triton kernels
38+ optimize_layers: List of layer names to optimize with Triton
39+ **kwargs: Additional arguments for model loading
40+
41+ Returns:
42+ Tuple of (model, tokenizer)
43+ """
44+ config = ModelConfig (
45+ model_name = model_name ,
46+ device_map = device_map ,
47+ max_memory = max_memory ,
48+ kwargs = kwargs
49+ )
50+
51+ if quant_config :
52+ config .quantization_config = quant_config .to_dict ()
53+
54+ model_loader = Model (config )
55+ model , tokenizer = model_loader .get_model (), model_loader .get_tokenizer ()
56+
57+ if use_triton_kernels :
58+ model = self .kernel_manager .optimize_model (
59+ model ,
60+ target_modules = optimize_layers
61+ )
62+
63+ return model , tokenizer
64+
65+ def quantize_model_weights (
66+ self ,
67+ model : PreTrainedModel ,
68+ bits : int = 4 ,
69+ group_size : int = 128 ,
70+ compute_dtype : torch .dtype = torch .bfloat16 ,
71+ use_double_quant : bool = True
72+ ) -> PreTrainedModel :
73+ """
74+ Apply quantization to an existing model's weights.
75+
76+ Args:
77+ model: Model to quantize
78+ bits: Number of bits for quantization
79+ group_size: Size of quantization groups
80+ compute_dtype: Compute dtype for operations
81+ use_double_quant: Whether to use double quantization
82+
83+ Returns:
84+ Quantized model
85+ """
86+ return self .quant_engine .quantize_weights (
87+ model ,
88+ bits = bits ,
89+ group_size = group_size ,
90+ compute_dtype = compute_dtype ,
91+ use_double_quant = use_double_quant
92+ )
93+
94+ def replace_layer_with_triton (
95+ self ,
96+ model : PreTrainedModel ,
97+ layer_name : str ,
98+ kernel_type : str = "auto"
99+ ) -> PreTrainedModel :
100+ """
101+ Replace a specific layer with its optimized Triton version.
102+
103+ Args:
104+ model: Model to modify
105+ layer_name: Name of layer to replace
106+ kernel_type: Type of Triton kernel to use
107+
108+ Returns:
109+ Model with replaced layer
110+ """
111+ return self .kernel_manager .replace_layer (
112+ model ,
113+ layer_name = layer_name ,
114+ kernel_type = kernel_type
115+ )
116+
117+ def get_memory_stats (self , model : PreTrainedModel ) -> Dict [str , Any ]:
118+ """Get detailed memory statistics for model."""
119+ return self .quant_engine .get_memory_stats (model )
0 commit comments