66import time
77
88import torch
9+ import torch .distributed as dist
910import yaml
1011from easydict import EasyDict
1112from loguru import logger
@@ -31,34 +32,36 @@ def main(config):
3132 logger .info (tokenizer )
3233 logger .info (model )
3334
34- if 'eval' in config and len (config .eval .eval_pos ):
35- eval_list = []
36- name_list = (
37- config .eval .name
38- if not isinstance (config .eval .name , str )
39- else [config .eval .name ]
40- )
41- for name in name_list :
42- eval_config = copy .deepcopy (config .eval )
43- eval_config .name = name
44- if len (name_list ) != 1 : # eval multi datasets
45- eval_config .path = os .path .join (config .eval .path , name )
35+ if int (os .environ ['RANK' ]) == 0 :
36+ if 'eval' in config and len (config .eval .eval_pos ):
37+ eval_list = []
38+ name_list = (
39+ config .eval .name
40+ if not isinstance (config .eval .name , str )
41+ else [config .eval .name ]
42+ )
43+ for name in name_list :
44+ eval_config = copy .deepcopy (config .eval )
45+ eval_config .name = name
46+ if len (name_list ) != 1 : # eval multi datasets
47+ eval_config .path = os .path .join (config .eval .path , name )
48+ if config .eval .type == 'acc' :
49+ acc_eval = AccuracyEval (eval_config )
50+ eval_list .append (acc_eval )
51+ else :
52+ ppl_eval = PerplexityEval (tokenizer .get_tokenizer (), eval_config )
53+ eval_list .append (ppl_eval )
54+
55+ if 'eval' in config and 'pretrain' in config .eval .eval_pos :
4656 if config .eval .type == 'acc' :
47- acc_eval = AccuracyEval (eval_config )
48- eval_list .append (acc_eval )
57+ for acc_eval in eval_list :
58+ acc = acc_eval .eval (model )
59+ logger .info (f'{ config .eval .name } acc : { acc } ' )
4960 else :
50- ppl_eval = PerplexityEval (tokenizer .get_tokenizer (), eval_config )
51- eval_list .append (ppl_eval )
52-
53- if 'eval' in config and 'pretrain' in config .eval .eval_pos :
54- if config .eval .type == 'acc' :
55- for acc_eval in eval_list :
56- acc = acc_eval .eval (model )
57- logger .info (f'{ config .eval .name } acc : { acc } ' )
58- else :
59- for ppl_eval in eval_list :
60- ppl = ppl_eval .eval (model )
61- logger .info (f'{ ppl_eval .dataset } ppl : { ppl } ' )
61+ for ppl_eval in eval_list :
62+ ppl = ppl_eval .eval (model )
63+ logger .info (f'{ ppl_eval .dataset } ppl : { ppl } ' )
64+
6265 if not config .get ('calib' , False ):
6366 blockwise_opt = ALGO_REGISTRY [config .quant .method ](
6467 model ,
@@ -68,6 +71,7 @@ def main(config):
6871 config = config
6972 )
7073 blockwise_opt .run_block_loop ()
74+ dist .barrier ()
7175 else :
7276 dataset = BaseDataset (tokenizer .get_tokenizer (), config .calib , model .batch_process )
7377 calib_data , padding_mask = dataset .get_calib_dataset ()
@@ -93,121 +97,124 @@ def main(config):
9397 config
9498 )
9599 blockwise_opt .run_block_loop ()
100+ dist .barrier ()
96101
97- if 'eval' in config and 'transformed' in config .eval .eval_pos :
98- blockwise_opt .deploy ('origin_float' )
99- if config .eval .type == 'acc' :
100- for acc_eval in eval_list :
101- acc = acc_eval .eval (model )
102- logger .info (f'{ config .eval .name } acc : { acc } ' )
103- else :
104- for ppl_eval in eval_list :
105- ppl = ppl_eval .eval (model )
106- logger .info (f'{ ppl_eval .dataset } ppl : { ppl } ' )
102+ if int (os .environ ['RANK' ]) == 0 :
103+ if 'eval' in config and 'transformed' in config .eval .eval_pos :
104+ blockwise_opt .deploy ('origin_float' )
105+ if config .eval .type == 'acc' :
106+ for acc_eval in eval_list :
107+ acc = acc_eval .eval (model )
108+ logger .info (f'{ config .eval .name } acc : { acc } ' )
109+ else :
110+ for ppl_eval in eval_list :
111+ ppl = ppl_eval .eval (model )
112+ logger .info (f'{ ppl_eval .dataset } ppl : { ppl } ' )
107113
108- if 'save' in config and config .save .get ('save_trans' , False ):
109- blockwise_opt .save_model (save_trans_path )
114+ if 'save' in config and config .save .get ('save_trans' , False ):
115+ blockwise_opt .save_model (save_trans_path )
110116
111- if 'save' in config and config .save .get ('save_trtllm' , False ):
112- blockwise_opt .save_model (save_trtllm_trans_path )
113- from llmc .utils .export_trtllm import cvt_trtllm_engine
117+ if 'save' in config and config .save .get ('save_trtllm' , False ):
118+ blockwise_opt .save_model (save_trtllm_trans_path )
119+ from llmc .utils .export_trtllm import cvt_trtllm_engine
114120
115- cvt_trtllm_engine (
116- save_trtllm_trans_path ,
117- save_trtllm_engine_path ,
118- config .save .get ('trtllm_cfg' ),
119- )
121+ cvt_trtllm_engine (
122+ save_trtllm_trans_path ,
123+ save_trtllm_engine_path ,
124+ config .save .get ('trtllm_cfg' ),
125+ )
120126
121- if 'eval' in config and 'fake_quant' in config .eval .eval_pos :
122- blockwise_opt .deploy ('fake_quant' )
123- if config .eval .type == 'acc' :
124- for acc_eval in eval_list :
125- acc = acc_eval .eval (model )
126- logger .info (f'{ config .eval .name } acc : { acc } ' )
127- else :
128- for ppl_eval in eval_list :
129- ppl = ppl_eval .eval (model )
130- logger .info (f'{ ppl_eval .dataset } ppl : { ppl } ' )
127+ if 'eval' in config and 'fake_quant' in config .eval .eval_pos :
128+ blockwise_opt .deploy ('fake_quant' )
129+ if config .eval .type == 'acc' :
130+ for acc_eval in eval_list :
131+ acc = acc_eval .eval (model )
132+ logger .info (f'{ config .eval .name } acc : { acc } ' )
133+ else :
134+ for ppl_eval in eval_list :
135+ ppl = ppl_eval .eval (model )
136+ logger .info (f'{ ppl_eval .dataset } ppl : { ppl } ' )
131137
132- if 'eval_token_consist' in config .eval and config .eval .eval_token_consist :
133- org_model = MODEL_REGISTRY [config .model .type ](
134- config .model .path , config .model .torch_dtype
138+ if 'eval_token_consist' in config .eval and config .eval .eval_token_consist :
139+ org_model = MODEL_REGISTRY [config .model .type ](
140+ config .model .path , config .model .torch_dtype
141+ )
142+ token_consist_eval = TokenConsistencyEval (tokenizer .get_tokenizer (),
143+ eval_config )
144+ consistency_ratio = token_consist_eval .eval (model , org_model )
145+ logger .info (f'Token consistency ratio: { consistency_ratio } ' )
146+ del org_model
147+
148+ if 'save' in config and config .save .get ('save_fake' , False ):
149+ blockwise_opt .deploy ('fake_quant' )
150+ blockwise_opt .save_model (save_fake_path )
151+
152+ if 'save' in config and config .save .get ('save_vllm' , False ):
153+ w , a = config .quant .weight , config .quant .get ('act' )
154+ if isinstance (w .bit , str ):
155+ assert a , 'Only WA float quant is supported.'
156+ assert w .symmetric and a .symmetric , 'Only symmetric quant is supported.'
157+ assert w .bit == a .bit and w .bit in ['e4m3' , 'e5m2' ] and \
158+ a .bit in ['e4m3' , 'e5m2' ], 'Only WA FP8 quant is supported'
159+ else :
160+ assert w .symmetric , 'Only symmetric quant is supported.'
161+ assert w .bit in [4 , 8 ], 'Supported quant: w4a16, w8a16, w8a8.'
162+ if a :
163+ assert a .symmetric , 'Only symmetric quant is supported.'
164+ assert a .bit == 8 , 'Supported quant: w4a16, w8a16, w8a8.'
165+ blockwise_opt .deploy ('vllm_quant' )
166+ blockwise_opt .save_model (save_quant_path )
167+ update_vllm_quant_config (blockwise_opt .model , config , save_quant_path )
168+
169+ if 'save' in config and config .save .get ('save_sgl' , False ):
170+ w , a = config .quant .weight , config .quant .get ('act' )
171+ if isinstance (w .bit , str ):
172+ assert a , 'Only WA float quant is supported.'
173+ assert w .symmetric and a .symmetric , 'Only symmetric quant is supported.'
174+ assert w .bit == a .bit and w .bit in ['e4m3' , 'e5m2' ] and \
175+ a .bit in ['e4m3' , 'e5m2' ], 'Only WA FP8 quant is supported'
176+ else :
177+ assert w .symmetric , 'Only symmetric quant is supported.'
178+ assert w .bit in [4 , 8 ], 'Supported quant: w4a16, w8a16, w8a8.'
179+ if a :
180+ assert a .symmetric , 'Only symmetric quant is supported.'
181+ assert a .bit == 8 , 'Supported quant: w4a16, w8a16, w8a8.'
182+ blockwise_opt .deploy ('sgl_quant' )
183+ blockwise_opt .save_model (save_quant_path )
184+ update_vllm_quant_config (blockwise_opt .model , config , save_quant_path )
185+
186+ if 'save' in config and config .save .get ('save_autoawq' , False ):
187+ assert config .quant .weight .bit in [4 ] and 'act' not in config .quant , \
188+ 'AutoAWQ supports only 4-bit weight-only quantization.'
189+ assert not config .quant .weight .symmetric , 'Only asymmetric quant is supported.'
190+
191+ blockwise_opt .deploy ('autoawq_quant' )
192+ blockwise_opt .save_model (save_quant_path )
193+ update_autoawq_quant_config (config , save_quant_path )
194+
195+ if 'save' in config and config .save .get ('save_mlcllm' , False ):
196+ assert config .quant .weight .bit in [4 ] and 'act' not in config .quant , \
197+ 'MlcLLM supports only 4-bit weight-only quantization.'
198+ assert not config .quant .weight .symmetric , 'Only asymmetric quant is supported.'
199+
200+ blockwise_opt .deploy ('mlcllm_quant' )
201+ blockwise_opt .save_model (save_quant_path )
202+ update_autoawq_quant_config (config , save_quant_path )
203+
204+ if 'opencompass' in config :
205+ assert config .save .get ('save_trans' , False )
206+ cfg_path = config ['opencompass' ]['cfg_path' ]
207+ output_path = config ['opencompass' ]['output_path' ]
208+ eval_model_path = os .path .abspath (save_trans_path )
209+ opencompass_cmd = (
210+ f'opencompass { cfg_path } -w { output_path } '
211+ f'--llmc_cfg { args .config } '
212+ f'--llmc_eval_mode quant '
213+ f'--llmc_model_path { eval_model_path } '
135214 )
136- token_consist_eval = TokenConsistencyEval (tokenizer .get_tokenizer (),
137- eval_config )
138- consistency_ratio = token_consist_eval .eval (model , org_model )
139- logger .info (f'Token consistency ratio: { consistency_ratio } ' )
140- del org_model
141-
142- if 'save' in config and config .save .get ('save_fake' , False ):
143- blockwise_opt .deploy ('fake_quant' )
144- blockwise_opt .save_model (save_fake_path )
145-
146- if 'save' in config and config .save .get ('save_vllm' , False ):
147- w , a = config .quant .weight , config .quant .get ('act' )
148- if isinstance (w .bit , str ):
149- assert a , 'Only WA float quant is supported.'
150- assert w .symmetric and a .symmetric , 'Only symmetric quant is supported.'
151- assert w .bit == a .bit and w .bit in ['e4m3' , 'e5m2' ] and \
152- a .bit in ['e4m3' , 'e5m2' ], 'Only WA FP8 quant is supported'
153- else :
154- assert w .symmetric , 'Only symmetric quant is supported.'
155- assert w .bit in [4 , 8 ], 'Supported quant: w4a16, w8a16, w8a8.'
156- if a :
157- assert a .symmetric , 'Only symmetric quant is supported.'
158- assert a .bit == 8 , 'Supported quant: w4a16, w8a16, w8a8.'
159- blockwise_opt .deploy ('vllm_quant' )
160- blockwise_opt .save_model (save_quant_path )
161- update_vllm_quant_config (blockwise_opt .model , config , save_quant_path )
162-
163- if 'save' in config and config .save .get ('save_sgl' , False ):
164- w , a = config .quant .weight , config .quant .get ('act' )
165- if isinstance (w .bit , str ):
166- assert a , 'Only WA float quant is supported.'
167- assert w .symmetric and a .symmetric , 'Only symmetric quant is supported.'
168- assert w .bit == a .bit and w .bit in ['e4m3' , 'e5m2' ] and \
169- a .bit in ['e4m3' , 'e5m2' ], 'Only WA FP8 quant is supported'
170- else :
171- assert w .symmetric , 'Only symmetric quant is supported.'
172- assert w .bit in [4 , 8 ], 'Supported quant: w4a16, w8a16, w8a8.'
173- if a :
174- assert a .symmetric , 'Only symmetric quant is supported.'
175- assert a .bit == 8 , 'Supported quant: w4a16, w8a16, w8a8.'
176- blockwise_opt .deploy ('sgl_quant' )
177- blockwise_opt .save_model (save_quant_path )
178- update_vllm_quant_config (blockwise_opt .model , config , save_quant_path )
179-
180- if 'save' in config and config .save .get ('save_autoawq' , False ):
181- assert config .quant .weight .bit in [4 ] and 'act' not in config .quant , \
182- 'AutoAWQ supports only 4-bit weight-only quantization.'
183- assert not config .quant .weight .symmetric , 'Only asymmetric quant is supported.'
184-
185- blockwise_opt .deploy ('autoawq_quant' )
186- blockwise_opt .save_model (save_quant_path )
187- update_autoawq_quant_config (config , save_quant_path )
188-
189- if 'save' in config and config .save .get ('save_mlcllm' , False ):
190- assert config .quant .weight .bit in [4 ] and 'act' not in config .quant , \
191- 'MlcLLM supports only 4-bit weight-only quantization.'
192- assert not config .quant .weight .symmetric , 'Only asymmetric quant is supported.'
193-
194- blockwise_opt .deploy ('mlcllm_quant' )
195- blockwise_opt .save_model (save_quant_path )
196- update_autoawq_quant_config (config , save_quant_path )
197-
198- if 'opencompass' in config :
199- assert config .save .get ('save_trans' , False )
200- cfg_path = config ['opencompass' ]['cfg_path' ]
201- output_path = config ['opencompass' ]['output_path' ]
202- eval_model_path = os .path .abspath (save_trans_path )
203- opencompass_cmd = (
204- f'opencompass { cfg_path } -w { output_path } '
205- f'--llmc_cfg { args .config } '
206- f'--llmc_eval_mode quant '
207- f'--llmc_model_path { eval_model_path } '
208- )
209- logger .info (f'opencompass_cmd : { opencompass_cmd } ' )
210- os .system (opencompass_cmd )
215+ logger .info (f'opencompass_cmd : { opencompass_cmd } ' )
216+ os .system (opencompass_cmd )
217+ dist .barrier ()
211218
212219
213220if __name__ == '__main__' :
@@ -270,7 +277,7 @@ def main(config):
270277 mkdirs (save_fake_path )
271278
272279 # Synchronize all processes after directory creation
273- torch . distributed .barrier ()
280+ dist .barrier ()
274281
275282 main (config )
276283
0 commit comments