22
22
from auto_round .low_cpu_mem .utils import get_layers_before_block
23
23
from auto_round .mllm .mllm_dataset import get_mllm_dataloader
24
24
from auto_round .mllm .template import Template , get_template
25
+ from auto_round .schemes import QuantizationScheme
25
26
from auto_round .special_model_handler import (
26
27
NOT_SUPPORT_ONLY_TEXT_MODELS ,
27
28
SUPPORT_ONLY_TEXT_MODELS ,
@@ -126,61 +127,56 @@ class AutoRoundMLLM(AutoRound):
126
127
127
128
"""
128
129
130
+ bits : int | None
131
+ group_size : int | None
132
+ sym : bool | None
133
+ data_type : str | None
134
+ act_bits : int | None
135
+ act_group_size : int | None
136
+ act_sym : bool | None
137
+ act_data_type : str | None
138
+ act_dynamic : bool | None
139
+ super_bits : int | None
140
+ super_group_size : int | None
141
+
129
142
def __init__ (
130
143
self ,
131
144
model : Union [torch .nn .Module , str ],
132
145
tokenizer = None ,
133
146
processor = None ,
134
147
image_processor = None ,
135
- bits : int = 4 ,
136
- group_size : int = 128 ,
137
- sym : bool = True ,
138
- layer_config : dict = None ,
139
- batch_size : int = 8 ,
140
- amp : bool = True ,
141
- device : Union [str , torch .device , int ] = 0 ,
142
- lr_scheduler = None ,
143
- dataset : Union [str , list , tuple , torch .utils .data .DataLoader ] = None ,
144
- extra_data_dir : str = None ,
145
- template : Union [str , Template ] = None ,
148
+ scheme : Union [str , dict , QuantizationScheme ] = "W4A16" ,
149
+ layer_config : dict [str , Union [str , dict , QuantizationScheme ]] = None ,
150
+ dataset : Union [str , list , tuple , torch .utils .data .DataLoader ] = "NeelNanda/pile-10k" ,
146
151
quant_nontext_module : bool = False ,
147
- enable_quanted_input : bool = True ,
148
- enable_minmax_tuning : bool = True ,
149
- lr : float = None ,
150
- minmax_lr : float = None ,
151
- low_gpu_mem_usage : bool = False ,
152
- low_cpu_mem_usage : bool = False ,
153
152
iters : int = 200 ,
154
- seqlen : int = None ,
153
+ seqlen : int = 2048 ,
155
154
nsamples : int = 128 ,
156
- sampler : str = "rand" ,
157
- seed : int = 42 ,
158
- nblocks : int = 1 ,
155
+ batch_size : int = 8 ,
159
156
gradient_accumulate_steps : int = 1 ,
160
- not_use_best_mse : bool = False ,
161
- dynamic_max_gap : int = - 1 ,
162
- data_type : str = "int" ,
163
- scale_dtype : str = "fp16" ,
164
- act_bits : int = 32 ,
165
- act_group_size : int = None ,
166
- act_sym : bool = None ,
167
- act_dynamic : bool = True ,
168
- to_quant_block_names : Union [str , list ] = None ,
169
- enable_norm_bias_tuning : bool = False ,
170
- truncation : bool = None ,
157
+ low_gpu_mem_usage : bool = False ,
158
+ device_map : Union [str , torch .device , int , dict ] = 0 ,
171
159
enable_torch_compile : bool = False ,
172
- model_kwargs : dict = None ,
160
+ seed : int = 42 ,
173
161
** kwargs ,
174
162
):
163
+ extra_data_dir = kwargs .pop ("extra_data_dir" , None )
164
+ template = kwargs .pop ("template" , None )
165
+
166
+ to_quant_block_names : Union [str , list , None ] = kwargs .pop ("to_quant_block_names" , None )
167
+ if device_map is None :
168
+ device_map = 0
169
+ self ._set_device (device_map )
170
+
175
171
if isinstance (model , str ):
176
- model , processor , tokenizer , image_processor = mllm_load_model (model , device = device )
172
+ model , processor , tokenizer , image_processor = mllm_load_model (model , device = self . device )
177
173
174
+ self .model = model
178
175
quant_nontext_module = self ._check_quant_nontext (layer_config , quant_nontext_module )
179
176
all_blocks = get_block_names (model , quant_nontext_module )
180
177
self .quant_block_list = find_matching_blocks (model , all_blocks , to_quant_block_names )
181
178
if to_quant_block_names is None :
182
179
to_quant_block_names = extract_block_names_to_str (self .quant_block_list )
183
- self .to_quant_block_names = to_quant_block_names
184
180
self .extra_data_dir = extra_data_dir
185
181
self .quant_nontext_module = quant_nontext_module
186
182
self .processor = processor
@@ -219,7 +215,7 @@ def __init__(
219
215
" switching to liuhaotian/llava_conv_58k"
220
216
)
221
217
dataset = "liuhaotian/llava_conv_58k"
222
- elif not _only_text_test (model , tokenizer , device , self .template .model_type ):
218
+ elif not _only_text_test (model , tokenizer , self . device , self .template .model_type ):
223
219
logger .warning (
224
220
f"{ model .config .model_type } does not support for { dataset } ,"
225
221
" will use liuhaotian/llava_conv_58k with default config as an alternative."
@@ -248,7 +244,7 @@ def __init__(
248
244
gradient_accumulate_steps = batch_size * gradient_accumulate_steps
249
245
batch_size = 1
250
246
seqlen = 2048 if seqlen is None else seqlen
251
- truncation = True if truncation is None else truncation
247
+ truncation = True
252
248
self .truncation = truncation
253
249
254
250
if nsamples % batch_size != 0 :
@@ -258,40 +254,20 @@ def __init__(
258
254
super (AutoRoundMLLM , self ).__init__ (
259
255
model = model ,
260
256
tokenizer = tokenizer ,
261
- bits = bits ,
262
- group_size = group_size ,
263
- sym = sym ,
257
+ scheme = scheme ,
264
258
layer_config = layer_config ,
265
- batch_size = batch_size ,
266
- amp = amp ,
267
- device = device ,
268
- lr_scheduler = lr_scheduler ,
269
259
dataset = dataset ,
270
- enable_quanted_input = enable_quanted_input ,
271
- enable_minmax_tuning = enable_minmax_tuning ,
272
- lr = lr ,
273
- minmax_lr = minmax_lr ,
274
- low_gpu_mem_usage = low_gpu_mem_usage ,
275
- low_cpu_mem_usage = low_cpu_mem_usage ,
276
260
iters = iters ,
277
261
seqlen = seqlen ,
278
262
nsamples = nsamples ,
279
- sampler = sampler ,
280
- seed = seed ,
281
- nblocks = nblocks ,
263
+ batch_size = batch_size ,
282
264
gradient_accumulate_steps = gradient_accumulate_steps ,
283
- not_use_best_mse = not_use_best_mse ,
284
- dynamic_max_gap = dynamic_max_gap ,
285
- data_type = data_type ,
286
- scale_dtype = scale_dtype ,
287
- act_bits = act_bits ,
288
- act_group_size = act_group_size ,
289
- act_sym = act_sym ,
290
- act_dynamic = act_dynamic ,
291
- to_quant_block_names = self .to_quant_block_names ,
292
- enable_norm_bias_tuning = enable_norm_bias_tuning ,
265
+ low_gpu_mem_usage = low_gpu_mem_usage ,
266
+ device_map = device_map ,
293
267
enable_torch_compile = enable_torch_compile ,
268
+ seed = seed ,
294
269
vlm = True ,
270
+ to_quant_block_names = to_quant_block_names ,
295
271
** kwargs ,
296
272
)
297
273
0 commit comments