11# Copyright (c) Alibaba, Inc. and its affiliates.
22import os
3- from typing import List , Optional
3+ from contextlib import contextmanager
4+ from types import MethodType
5+ from typing import Dict , List , Optional
46
57import json
68import torch
9+ import torch .nn as nn
710
811from swift .llm import get_model_tokenizer , get_template
912from swift .utils import (check_json_format , get_logger , get_main , get_model_info , push_to_ms_hub , seed_everything ,
1013 show_layers )
1114from .infer import merge_lora , prepare_model_template , save_checkpoint
12- from .utils import ExportArguments , Template , get_dataset , swift_to_peft_format
15+ from .utils import ExportArguments , Template , deep_getattr , get_dataset , get_mllm_arch , swift_to_peft_format
1316
1417logger = get_logger ()
1518
1619_args : Optional [ExportArguments ] = None
1720template : Optional [Template ] = None
1821
1922
23+ def _prepare_dataset (examples : List [Dict [str , torch .LongTensor ]], batch_size : int = 1 , * args , ** kwargs ):
24+ global _args , template
25+ assert template is not None
26+ examples = [
27+ template .data_collator (examples [start :start + batch_size ]) for start in range (0 , len (examples ), batch_size )
28+ ]
29+ return examples
30+
31+
2032def _get_dataset (* args , ** kwargs ):
2133 global _args , template
2234 assert _args is not None
@@ -39,27 +51,31 @@ def _get_dataset(*args, **kwargs):
3951 samples = []
4052 n_run = 0
4153 for data in dataset :
42- input_ids = template .encode (data )[0 ].get ('input_ids' )
54+ inputs = template .encode (data )[0 ]
55+ input_ids = inputs ['input_ids' ]
4356 if input_ids is None or len (input_ids ) == 0 :
4457 continue
45- sample = torch .tensor (input_ids )
46- samples .append (sample )
58+ if _args .is_multimodal and _args .quant_method == 'gptq' :
59+ inputs .pop ('labels' , None )
60+ samples .append (inputs )
61+ else :
62+ samples += input_ids
4763 n_run += 1
4864 if n_run == n_samples :
4965 break
66+ if _args .is_multimodal and _args .quant_method == 'gptq' :
67+ return samples
5068 # now concatenate all samples and split according to block size
51- cat_samples = torch .cat (samples , dim = 0 ) # shape: [X]
52- n_split = cat_samples .shape [0 ] // block_size
69+ n_split = len (samples ) // block_size
5370 logger .info (f'Split into { n_split } blocks' )
54- if _args .quant_method == 'awq' :
55- return [cat_samples [None , i * block_size :(i + 1 ) * block_size ] for i in range (n_split )]
56- else : # gptq
57- res = []
58- for i in range (n_split ):
59- input_ids = cat_samples [None , i * block_size :(i + 1 ) * block_size ]
60- attention_mask = torch .ones_like (input_ids )
61- res .append ({'input_ids' : input_ids , 'attention_mask' : attention_mask })
62- return res
71+ res = []
72+ for i in range (n_split ):
73+ input_ids = samples [i * block_size :(i + 1 ) * block_size ]
74+ if _args .quant_method == 'awq' :
75+ res .append (torch .tensor (input_ids )[None ])
76+ else :
77+ res .append ({'input_ids' : input_ids })
78+ return res
6379
6480
6581def awq_model_quantize (awq_model , tokenizer , batch_size ) -> None :
@@ -80,22 +96,74 @@ def awq_model_quantize(awq_model, tokenizer, batch_size) -> None:
8096 bits = _args .quant_bits , group_size = group_size , zero_point = True , version = 'GEMM' )
8197
8298
99+ @contextmanager
100+ def _patch_gptq ():
101+ from optimum .gptq import quantizer
102+ _get_dataset_origin = quantizer .get_dataset
103+ _prepare_dataset_origin = quantizer .prepare_dataset
104+ quantizer .get_dataset = _get_dataset
105+ quantizer .prepare_dataset = _prepare_dataset
106+ yield
107+ quantizer .get_dataset = _get_dataset_origin
108+ quantizer .prepare_dataset = _prepare_dataset_origin
109+
110+
111+ def _patch_model_forward (module_list ):
112+
113+ def _new_forward (self , * args , ** kwargs ):
114+ if 'use_cache' in kwargs :
115+ kwargs ['use_cache' ] = False
116+ layer_ret = self .__old_forward (* args , ** kwargs )
117+ return layer_ret + args [len (layer_ret ):]
118+
119+ for module in module_list :
120+ if hasattr (module , '_old_forward' ): # device_map
121+ __old_forward = module ._old_forward
122+ module ._old_forward = MethodType (_new_forward , module )
123+ else :
124+ __old_forward = module .forward
125+ module .forward = MethodType (_new_forward , module )
126+ module .__old_forward = __old_forward
127+
128+
129+ def get_block_name_to_quantize (model : nn .Module , model_type : str ) -> Optional [str ]:
130+ mllm_arch = get_mllm_arch (model_type )
131+ prefix = ''
132+ if mllm_arch is not None :
133+ assert len (mllm_arch .language_model ) == 1 , f'mllm_arch.language_model: { mllm_arch .language_model } '
134+ prefix = mllm_arch .language_model [0 ]
135+ model = deep_getattr (model , prefix )
136+
137+ module_lists = []
138+ for n , m in model .named_modules ():
139+ if isinstance (m , nn .ModuleList ) and len (m ) >= 10 :
140+ module_lists .append ((n , m ))
141+ if module_lists :
142+ module_list = max (module_lists , key = lambda x : len (x [1 ]))
143+ _patch_model_forward (module_list [1 ])
144+ return f'{ prefix } .{ module_list [0 ]} '
145+
146+
83147def gptq_model_quantize (model , tokenizer , batch_size ):
84- from optimum .gptq import GPTQQuantizer , quantizer
148+ from optimum .gptq import GPTQQuantizer
85149 global _args
86150 logger .info (f'Quantization dataset: { _args .dataset } ' )
87- gptq_quantizer = GPTQQuantizer (bits = _args .quant_bits , dataset = ',' .join (_args .dataset ), batch_size = batch_size )
88- _origin_get_dataset = quantizer .get_dataset
89- quantizer .get_dataset = _get_dataset
90- logger .info ('Start quantizing the model...' )
91- logger .warning ('The process of packing the model takes a long time and there is no progress bar. '
92- 'Please be patient and wait...' )
93- gptq_quantizer .quantize_model (model , tokenizer )
94- quantizer .get_dataset = _origin_get_dataset # recover
151+ with _patch_gptq ():
152+ gptq_quantizer = GPTQQuantizer (
153+ bits = _args .quant_bits ,
154+ dataset = ',' .join (_args .dataset ),
155+ batch_size = batch_size ,
156+ block_name_to_quantize = get_block_name_to_quantize (model , _args .model_type ))
157+ logger .info ('Start quantizing the model...' )
158+ logger .warning ('The process of packing the model takes a long time and there is no progress bar. '
159+ 'Please be patient and wait...' )
160+ if not hasattr (model .config , 'use_cache' ):
161+ model .config .use_cache = None
162+ gptq_quantizer .quantize_model (model , tokenizer )
95163 return gptq_quantizer
96164
97165
98- def replace_and_concat (template : ' Template' , template_list : List , placeholder : str , keyword : str ):
166+ def replace_and_concat (template : Template , template_list : List , placeholder : str , keyword : str ):
99167 final_str = ''
100168 for t in template_list :
101169 if isinstance (t , str ):
0 commit comments