-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Expand file tree
/
Copy pathquant.py
More file actions
292 lines (263 loc) · 12.1 KB
/
quant.py
File metadata and controls
292 lines (263 loc) · 12.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
# Copyright (c) ModelScope Contributors. All rights reserved.
import torch
import torch.nn as nn
import transformers
from collections import defaultdict
from contextlib import contextmanager
from packaging import version
from tqdm import tqdm
from typing import Dict, List, Optional
from swift.arguments import ExportArguments
from swift.dataset import load_dataset
from swift.model import save_checkpoint
from swift.template import MaxLengthError
from swift.utils import HfConfigFactory, ProcessorMixin, deep_getattr, get_logger, get_model_parameter_info, to_device
from ..utils import prepare_model_template
logger = get_logger()
class QuantEngine(ProcessorMixin):
def __init__(self, args: ExportArguments):
self.args = args
kwargs = {}
if args.quant_method == 'awq':
from awq import AutoAWQForCausalLM
kwargs['auto_model_cls'] = AutoAWQForCausalLM
self.model, self.template = prepare_model_template(args, **kwargs)
self.template.set_mode('train')
self.model.config.use_cache = False
HfConfigFactory.set_config_attr(self.model.config, 'use_cache', False)
self.processor = self.template.processor
args.save_args()
def quantize(self):
args = self.args
if args.quant_bits is None and args.quant_method != 'fp8':
raise ValueError(f'Please set the quant_bits. args.quant_bits: {args.quant_bits}')
if args.quant_method == 'awq':
self.template.model = self.model.model
self.awq_model_quantize()
self.model.save_quantized(
args.output_dir, safetensors=args.safe_serialization, shard_size=args.max_shard_size)
elif args.quant_method in {'gptq', 'gptq_v2'}:
self.template.model = self.model
gptq_quantizer = self.gptq_model_quantize(v2=(args.quant_method == 'gptq_v2'))
if args.quant_method == 'gptq_v2':
if not getattr(self.model, '_dynamic_tied_weights_keys', None):
self.model._dynamic_tied_weights_keys = []
self.model._dynamic_tied_weights_keys += ['wf_unsqueeze_zero', 'wf_unsqueeze_neg_one']
gptq_quantizer.save(
self.model,
args.output_dir,
safe_serialization=args.safe_serialization,
max_shard_size=args.max_shard_size)
elif args.quant_method in {'bnb', 'fp8'}:
self.model.save_pretrained(
args.output_dir, safe_serialization=args.safe_serialization, max_shard_size=args.max_shard_size)
else:
raise ValueError(f'args.quant_method: {args.quant_method}')
logger.info(f'model: {self.model}')
logger.info(f'model_parameter_info: {get_model_parameter_info(self.model)}')
save_checkpoint(
None,
self.processor,
args.output_dir,
model_dirs=[args.model_dir],
additional_saved_files=self.model.model_meta.additional_saved_files)
logger.info(f'Successfully quantized the model and saved in `{args.output_dir}`.')
@torch.inference_mode()
def _prepare_gptq_dataset(self, examples: List[Dict[str, torch.LongTensor]], batch_size: int = 1, *args, **kwargs):
res = []
for start in tqdm(range(0, len(examples), batch_size)):
batched_inputs = examples[start:start + batch_size]
inputs = to_device(self.template.data_collator(batched_inputs), self.model.device)
if self.model.model_meta.is_multimodal:
_, inputs = self.template.pre_forward_hook(self.model, None, inputs)
res.append(to_device(inputs, 'cpu'))
return res
@torch.inference_mode()
def _get_quant_dataset(self, *args, **kwargs):
args = self.args
assert args.quant_method in {'awq', 'gptq', 'gptq_v2'}
template = self.template
n_samples = args.quant_n_samples
block_size = args.max_length
# only use train_dataset
dataset = load_dataset(
args.dataset, split_dataset_ratio=0, shuffle=args.dataset_shuffle, **args.get_dataset_kwargs())[0]
logger.info(f'quant_dataset: {dataset}')
dataset = dataset.shuffle()
samples = []
i = 0
prog_bar = tqdm(total=n_samples, dynamic_ncols=True)
is_multimodal = self.model.model_meta.is_multimodal
for data in dataset:
try:
inputs = template.encode(data)
except MaxLengthError:
continue
if is_multimodal and args.quant_method in {'gptq', 'gptq_v2'}:
inputs.pop('labels', None)
samples.append(inputs)
else:
input_ids = inputs['input_ids']
samples += input_ids
i += 1
prog_bar.update()
if i == n_samples:
break
prog_bar.close()
if is_multimodal and args.quant_method in {'gptq', 'gptq_v2'}:
return samples
# now concatenate all samples and split according to block size
n_split = max(len(samples) // block_size, 1)
logger.info(f'Split into {n_split} blocks')
res = []
for i in range(n_split):
input_ids = samples[i * block_size:(i + 1) * block_size]
if args.quant_method in {'gptq', 'gptq_v2'}:
res.append({'input_ids': input_ids})
else:
res.append(torch.tensor(input_ids)[None])
return res
@staticmethod
@contextmanager
def _patch_awq_move_embed(awq_model):
_origin_move_embed = awq_model.move_embed
def _move_embed(model, device: str):
if hasattr(model, '_hf_hook') and device != 'cpu':
return
_origin_move_embed(model, device)
awq_model.move_embed = _move_embed
try:
yield
finally:
awq_model.move_embed = _origin_move_embed
def awq_model_quantize(self) -> None:
from awq.quantize import quantizer
args = self.args
logger.info(f'Quantization dataset: {args.dataset}')
_origin_get_calib_dataset = quantizer.get_calib_dataset
quantizer.get_calib_dataset = self._get_quant_dataset
quant_config = {
'zero_point': True,
'q_group_size': args.group_size,
'w_bit': args.quant_bits,
'version': 'GEMM'
}
if self.model.model_info.is_moe_model:
quant_config['modules_to_not_convert'] = self.args.get_modules_to_not_convert()
logger.info(f'quant_config: {quant_config}')
logger.info('Start quantizing the model...')
with self._patch_awq_move_embed(self.model):
self.model.quantize(
self.tokenizer, quant_config=quant_config, n_parallel_calib_samples=args.quant_batch_size)
quantizer.get_calib_dataset = _origin_get_calib_dataset # recover
if self.model.quant_config.modules_to_not_convert:
model_arch = args.model_meta.model_arch
lm_head_key = getattr(model_arch, 'lm_head', None) or 'lm_head'
if lm_head_key not in self.model.quant_config.modules_to_not_convert:
self.model.quant_config.modules_to_not_convert.append(lm_head_key)
@contextmanager
def _patch_gptq(self):
from optimum.gptq import quantizer
_get_dataset_origin = quantizer.get_dataset
_prepare_dataset_origin = quantizer.prepare_dataset
quantizer.get_dataset = self._get_quant_dataset
quantizer.prepare_dataset = self._prepare_gptq_dataset
try:
yield
finally:
quantizer.get_dataset = _get_dataset_origin
quantizer.prepare_dataset = _prepare_dataset_origin
@staticmethod
def get_block_name_to_quantize(model: nn.Module) -> Optional[str]:
model_arch = model.model_meta.model_arch
prefix = ''
if hasattr(model_arch, 'language_model'):
language_model = [lm for lm in model_arch.language_model if not lm.endswith('lm_head')]
assert len(language_model) == 1, f'model_arch.language_model: {language_model}'
prefix = language_model[0]
model = deep_getattr(model, prefix)
module_lists = []
for n, m in model.named_modules():
if (isinstance(m, (nn.ModuleList, nn.Sequential)) and len(m) >= 10
and 'mlp' not in m[0].__class__.__name__.lower()): # fix moe
module_lists.append((n, m))
if module_lists:
module_list = max(module_lists, key=lambda x: len(x[1]))
return f'{prefix}.{module_list[0]}'.strip('.')
@staticmethod
def _get_experts(block):
for n, m in block.named_modules():
if isinstance(m, (nn.ModuleList, nn.Sequential)):
return n, m
@staticmethod
def get_modules_in_block_to_quantize(model, block_name: str):
if not model.model_info.is_moe_model:
return
from optimum.gptq.utils import get_layers
# Do not quantize the gate part.
block = deep_getattr(model, block_name)[-1]
prefix, experts = QuantEngine._get_experts(block)
layers = get_layers(block)
res = []
experts = defaultdict(list)
experts_idx = None
for name, layer in layers.items():
if model.model_info.model_type == 'qwen3_next' and name.startswith('self_attn.'):
# ignore attn
continue
if name.startswith(prefix):
suffix = name.rsplit('.', 1)[-1]
experts[suffix].append(name)
experts_idx = len(res)
elif 'mlp.gate' not in name:
res.append([name])
res[experts_idx:experts_idx] = experts.values()
return res
@contextmanager
def _patch_gptq_block(self, model, block_name_to_quantize):
if version.parse(transformers.__version__) < version.parse('4.54'):
yield
return
# compat transformers>=4.54
blocks = deep_getattr(model, block_name_to_quantize)
hooks = []
def _to_tuple(module, input, output):
if not isinstance(output, (list, tuple)):
output = (output, )
return output
for block in blocks:
hooks.append(block.register_forward_hook(_to_tuple))
try:
yield
finally:
for hook in hooks:
hook.remove()
def gptq_model_quantize(self, v2: bool = False):
from optimum.gptq import GPTQQuantizer
args = self.args
logger.info(f'Quantization dataset: {args.dataset}')
block_name_to_quantize = self.get_block_name_to_quantize(self.model)
modules_in_block_to_quantize = self.get_modules_in_block_to_quantize(self.model, block_name_to_quantize)
logger.info(f'block_name_to_quantize: {block_name_to_quantize}')
logger.info(f'modules_in_block_to_quantize: {modules_in_block_to_quantize}')
with self._patch_gptq():
gptq_quantizer = GPTQQuantizer(
bits=args.quant_bits,
group_size=args.group_size,
dataset=','.join(args.dataset),
batch_size=args.quant_batch_size,
block_name_to_quantize=block_name_to_quantize,
modules_in_block_to_quantize=modules_in_block_to_quantize,
checkpoint_format='gptq_v2' if v2 else 'gptq')
gptq_quantizer.serialization_keys.append('block_name_to_quantize')
logger.info('Start quantizing the model...')
logger.warning('The process of packing the model takes a long time and there is no progress bar. '
'Please be patient and wait...')
if not hasattr(self.model, 'hf_device_map'):
self.model.hf_device_map = {'': torch.device('cuda:0')}
with self._patch_gptq_block(self.model, block_name_to_quantize):
gptq_quantizer.quantize_model(self.model, self.tokenizer)
self.model.config.quantization_config.pop('dataset', None)
return gptq_quantizer
def quantize_model(args: ExportArguments):
QuantEngine(args).quantize()