Skip to content

Commit aecd6c4

Browse files
authored
Merge pull request #185 from ModelTC/dev_fixbug
Fix dp bugs
2 parents 5eb0005 + 789ce18 commit aecd6c4

File tree

3 files changed

+169
-155
lines changed

3 files changed

+169
-155
lines changed

llmc/__main__.py

Lines changed: 142 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import time
77

88
import torch
9+
import torch.distributed as dist
910
import yaml
1011
from easydict import EasyDict
1112
from loguru import logger
@@ -31,34 +32,36 @@ def main(config):
3132
logger.info(tokenizer)
3233
logger.info(model)
3334

34-
if 'eval' in config and len(config.eval.eval_pos):
35-
eval_list = []
36-
name_list = (
37-
config.eval.name
38-
if not isinstance(config.eval.name, str)
39-
else [config.eval.name]
40-
)
41-
for name in name_list:
42-
eval_config = copy.deepcopy(config.eval)
43-
eval_config.name = name
44-
if len(name_list) != 1: # eval multi datasets
45-
eval_config.path = os.path.join(config.eval.path, name)
35+
if int(os.environ['RANK']) == 0:
36+
if 'eval' in config and len(config.eval.eval_pos):
37+
eval_list = []
38+
name_list = (
39+
config.eval.name
40+
if not isinstance(config.eval.name, str)
41+
else [config.eval.name]
42+
)
43+
for name in name_list:
44+
eval_config = copy.deepcopy(config.eval)
45+
eval_config.name = name
46+
if len(name_list) != 1: # eval multi datasets
47+
eval_config.path = os.path.join(config.eval.path, name)
48+
if config.eval.type == 'acc':
49+
acc_eval = AccuracyEval(eval_config)
50+
eval_list.append(acc_eval)
51+
else:
52+
ppl_eval = PerplexityEval(tokenizer.get_tokenizer(), eval_config)
53+
eval_list.append(ppl_eval)
54+
55+
if 'eval' in config and 'pretrain' in config.eval.eval_pos:
4656
if config.eval.type == 'acc':
47-
acc_eval = AccuracyEval(eval_config)
48-
eval_list.append(acc_eval)
57+
for acc_eval in eval_list:
58+
acc = acc_eval.eval(model)
59+
logger.info(f'{config.eval.name} acc : {acc}')
4960
else:
50-
ppl_eval = PerplexityEval(tokenizer.get_tokenizer(), eval_config)
51-
eval_list.append(ppl_eval)
52-
53-
if 'eval' in config and 'pretrain' in config.eval.eval_pos:
54-
if config.eval.type == 'acc':
55-
for acc_eval in eval_list:
56-
acc = acc_eval.eval(model)
57-
logger.info(f'{config.eval.name} acc : {acc}')
58-
else:
59-
for ppl_eval in eval_list:
60-
ppl = ppl_eval.eval(model)
61-
logger.info(f'{ppl_eval.dataset} ppl : {ppl}')
61+
for ppl_eval in eval_list:
62+
ppl = ppl_eval.eval(model)
63+
logger.info(f'{ppl_eval.dataset} ppl : {ppl}')
64+
6265
if not config.get('calib', False):
6366
blockwise_opt = ALGO_REGISTRY[config.quant.method](
6467
model,
@@ -68,6 +71,7 @@ def main(config):
6871
config=config
6972
)
7073
blockwise_opt.run_block_loop()
74+
dist.barrier()
7175
else:
7276
dataset = BaseDataset(tokenizer.get_tokenizer(), config.calib, model.batch_process)
7377
calib_data, padding_mask = dataset.get_calib_dataset()
@@ -93,121 +97,124 @@ def main(config):
9397
config
9498
)
9599
blockwise_opt.run_block_loop()
100+
dist.barrier()
96101

97-
if 'eval' in config and 'transformed' in config.eval.eval_pos:
98-
blockwise_opt.deploy('origin_float')
99-
if config.eval.type == 'acc':
100-
for acc_eval in eval_list:
101-
acc = acc_eval.eval(model)
102-
logger.info(f'{config.eval.name} acc : {acc}')
103-
else:
104-
for ppl_eval in eval_list:
105-
ppl = ppl_eval.eval(model)
106-
logger.info(f'{ppl_eval.dataset} ppl : {ppl}')
102+
if int(os.environ['RANK']) == 0:
103+
if 'eval' in config and 'transformed' in config.eval.eval_pos:
104+
blockwise_opt.deploy('origin_float')
105+
if config.eval.type == 'acc':
106+
for acc_eval in eval_list:
107+
acc = acc_eval.eval(model)
108+
logger.info(f'{config.eval.name} acc : {acc}')
109+
else:
110+
for ppl_eval in eval_list:
111+
ppl = ppl_eval.eval(model)
112+
logger.info(f'{ppl_eval.dataset} ppl : {ppl}')
107113

108-
if 'save' in config and config.save.get('save_trans', False):
109-
blockwise_opt.save_model(save_trans_path)
114+
if 'save' in config and config.save.get('save_trans', False):
115+
blockwise_opt.save_model(save_trans_path)
110116

111-
if 'save' in config and config.save.get('save_trtllm', False):
112-
blockwise_opt.save_model(save_trtllm_trans_path)
113-
from llmc.utils.export_trtllm import cvt_trtllm_engine
117+
if 'save' in config and config.save.get('save_trtllm', False):
118+
blockwise_opt.save_model(save_trtllm_trans_path)
119+
from llmc.utils.export_trtllm import cvt_trtllm_engine
114120

115-
cvt_trtllm_engine(
116-
save_trtllm_trans_path,
117-
save_trtllm_engine_path,
118-
config.save.get('trtllm_cfg'),
119-
)
121+
cvt_trtllm_engine(
122+
save_trtllm_trans_path,
123+
save_trtllm_engine_path,
124+
config.save.get('trtllm_cfg'),
125+
)
120126

121-
if 'eval' in config and 'fake_quant' in config.eval.eval_pos:
122-
blockwise_opt.deploy('fake_quant')
123-
if config.eval.type == 'acc':
124-
for acc_eval in eval_list:
125-
acc = acc_eval.eval(model)
126-
logger.info(f'{config.eval.name} acc : {acc}')
127-
else:
128-
for ppl_eval in eval_list:
129-
ppl = ppl_eval.eval(model)
130-
logger.info(f'{ppl_eval.dataset} ppl : {ppl}')
127+
if 'eval' in config and 'fake_quant' in config.eval.eval_pos:
128+
blockwise_opt.deploy('fake_quant')
129+
if config.eval.type == 'acc':
130+
for acc_eval in eval_list:
131+
acc = acc_eval.eval(model)
132+
logger.info(f'{config.eval.name} acc : {acc}')
133+
else:
134+
for ppl_eval in eval_list:
135+
ppl = ppl_eval.eval(model)
136+
logger.info(f'{ppl_eval.dataset} ppl : {ppl}')
131137

132-
if 'eval_token_consist' in config.eval and config.eval.eval_token_consist:
133-
org_model = MODEL_REGISTRY[config.model.type](
134-
config.model.path, config.model.torch_dtype
138+
if 'eval_token_consist' in config.eval and config.eval.eval_token_consist:
139+
org_model = MODEL_REGISTRY[config.model.type](
140+
config.model.path, config.model.torch_dtype
141+
)
142+
token_consist_eval = TokenConsistencyEval(tokenizer.get_tokenizer(),
143+
eval_config)
144+
consistency_ratio = token_consist_eval.eval(model, org_model)
145+
logger.info(f'Token consistency ratio: {consistency_ratio}')
146+
del org_model
147+
148+
if 'save' in config and config.save.get('save_fake', False):
149+
blockwise_opt.deploy('fake_quant')
150+
blockwise_opt.save_model(save_fake_path)
151+
152+
if 'save' in config and config.save.get('save_vllm', False):
153+
w, a = config.quant.weight, config.quant.get('act')
154+
if isinstance(w.bit, str):
155+
assert a, 'Only WA float quant is supported.'
156+
assert w.symmetric and a.symmetric, 'Only symmetric quant is supported.'
157+
assert w.bit == a.bit and w.bit in ['e4m3', 'e5m2'] and \
158+
a.bit in ['e4m3', 'e5m2'], 'Only WA FP8 quant is supported'
159+
else:
160+
assert w.symmetric, 'Only symmetric quant is supported.'
161+
assert w.bit in [4, 8], 'Supported quant: w4a16, w8a16, w8a8.'
162+
if a:
163+
assert a.symmetric, 'Only symmetric quant is supported.'
164+
assert a.bit == 8, 'Supported quant: w4a16, w8a16, w8a8.'
165+
blockwise_opt.deploy('vllm_quant')
166+
blockwise_opt.save_model(save_quant_path)
167+
update_vllm_quant_config(blockwise_opt.model, config, save_quant_path)
168+
169+
if 'save' in config and config.save.get('save_sgl', False):
170+
w, a = config.quant.weight, config.quant.get('act')
171+
if isinstance(w.bit, str):
172+
assert a, 'Only WA float quant is supported.'
173+
assert w.symmetric and a.symmetric, 'Only symmetric quant is supported.'
174+
assert w.bit == a.bit and w.bit in ['e4m3', 'e5m2'] and \
175+
a.bit in ['e4m3', 'e5m2'], 'Only WA FP8 quant is supported'
176+
else:
177+
assert w.symmetric, 'Only symmetric quant is supported.'
178+
assert w.bit in [4, 8], 'Supported quant: w4a16, w8a16, w8a8.'
179+
if a:
180+
assert a.symmetric, 'Only symmetric quant is supported.'
181+
assert a.bit == 8, 'Supported quant: w4a16, w8a16, w8a8.'
182+
blockwise_opt.deploy('sgl_quant')
183+
blockwise_opt.save_model(save_quant_path)
184+
update_vllm_quant_config(blockwise_opt.model, config, save_quant_path)
185+
186+
if 'save' in config and config.save.get('save_autoawq', False):
187+
assert config.quant.weight.bit in [4] and 'act' not in config.quant, \
188+
'AutoAWQ supports only 4-bit weight-only quantization.'
189+
assert not config.quant.weight.symmetric, 'Only asymmetric quant is supported.'
190+
191+
blockwise_opt.deploy('autoawq_quant')
192+
blockwise_opt.save_model(save_quant_path)
193+
update_autoawq_quant_config(config, save_quant_path)
194+
195+
if 'save' in config and config.save.get('save_mlcllm', False):
196+
assert config.quant.weight.bit in [4] and 'act' not in config.quant, \
197+
'MlcLLM supports only 4-bit weight-only quantization.'
198+
assert not config.quant.weight.symmetric, 'Only asymmetric quant is supported.'
199+
200+
blockwise_opt.deploy('mlcllm_quant')
201+
blockwise_opt.save_model(save_quant_path)
202+
update_autoawq_quant_config(config, save_quant_path)
203+
204+
if 'opencompass' in config:
205+
assert config.save.get('save_trans', False)
206+
cfg_path = config['opencompass']['cfg_path']
207+
output_path = config['opencompass']['output_path']
208+
eval_model_path = os.path.abspath(save_trans_path)
209+
opencompass_cmd = (
210+
f'opencompass {cfg_path} -w {output_path} '
211+
f'--llmc_cfg {args.config} '
212+
f'--llmc_eval_mode quant '
213+
f'--llmc_model_path {eval_model_path}'
135214
)
136-
token_consist_eval = TokenConsistencyEval(tokenizer.get_tokenizer(),
137-
eval_config)
138-
consistency_ratio = token_consist_eval.eval(model, org_model)
139-
logger.info(f'Token consistency ratio: {consistency_ratio}')
140-
del org_model
141-
142-
if 'save' in config and config.save.get('save_fake', False):
143-
blockwise_opt.deploy('fake_quant')
144-
blockwise_opt.save_model(save_fake_path)
145-
146-
if 'save' in config and config.save.get('save_vllm', False):
147-
w, a = config.quant.weight, config.quant.get('act')
148-
if isinstance(w.bit, str):
149-
assert a, 'Only WA float quant is supported.'
150-
assert w.symmetric and a.symmetric, 'Only symmetric quant is supported.'
151-
assert w.bit == a.bit and w.bit in ['e4m3', 'e5m2'] and \
152-
a.bit in ['e4m3', 'e5m2'], 'Only WA FP8 quant is supported'
153-
else:
154-
assert w.symmetric, 'Only symmetric quant is supported.'
155-
assert w.bit in [4, 8], 'Supported quant: w4a16, w8a16, w8a8.'
156-
if a:
157-
assert a.symmetric, 'Only symmetric quant is supported.'
158-
assert a.bit == 8, 'Supported quant: w4a16, w8a16, w8a8.'
159-
blockwise_opt.deploy('vllm_quant')
160-
blockwise_opt.save_model(save_quant_path)
161-
update_vllm_quant_config(blockwise_opt.model, config, save_quant_path)
162-
163-
if 'save' in config and config.save.get('save_sgl', False):
164-
w, a = config.quant.weight, config.quant.get('act')
165-
if isinstance(w.bit, str):
166-
assert a, 'Only WA float quant is supported.'
167-
assert w.symmetric and a.symmetric, 'Only symmetric quant is supported.'
168-
assert w.bit == a.bit and w.bit in ['e4m3', 'e5m2'] and \
169-
a.bit in ['e4m3', 'e5m2'], 'Only WA FP8 quant is supported'
170-
else:
171-
assert w.symmetric, 'Only symmetric quant is supported.'
172-
assert w.bit in [4, 8], 'Supported quant: w4a16, w8a16, w8a8.'
173-
if a:
174-
assert a.symmetric, 'Only symmetric quant is supported.'
175-
assert a.bit == 8, 'Supported quant: w4a16, w8a16, w8a8.'
176-
blockwise_opt.deploy('sgl_quant')
177-
blockwise_opt.save_model(save_quant_path)
178-
update_vllm_quant_config(blockwise_opt.model, config, save_quant_path)
179-
180-
if 'save' in config and config.save.get('save_autoawq', False):
181-
assert config.quant.weight.bit in [4] and 'act' not in config.quant, \
182-
'AutoAWQ supports only 4-bit weight-only quantization.'
183-
assert not config.quant.weight.symmetric, 'Only asymmetric quant is supported.'
184-
185-
blockwise_opt.deploy('autoawq_quant')
186-
blockwise_opt.save_model(save_quant_path)
187-
update_autoawq_quant_config(config, save_quant_path)
188-
189-
if 'save' in config and config.save.get('save_mlcllm', False):
190-
assert config.quant.weight.bit in [4] and 'act' not in config.quant, \
191-
'MlcLLM supports only 4-bit weight-only quantization.'
192-
assert not config.quant.weight.symmetric, 'Only asymmetric quant is supported.'
193-
194-
blockwise_opt.deploy('mlcllm_quant')
195-
blockwise_opt.save_model(save_quant_path)
196-
update_autoawq_quant_config(config, save_quant_path)
197-
198-
if 'opencompass' in config:
199-
assert config.save.get('save_trans', False)
200-
cfg_path = config['opencompass']['cfg_path']
201-
output_path = config['opencompass']['output_path']
202-
eval_model_path = os.path.abspath(save_trans_path)
203-
opencompass_cmd = (
204-
f'opencompass {cfg_path} -w {output_path} '
205-
f'--llmc_cfg {args.config} '
206-
f'--llmc_eval_mode quant '
207-
f'--llmc_model_path {eval_model_path}'
208-
)
209-
logger.info(f'opencompass_cmd : {opencompass_cmd}')
210-
os.system(opencompass_cmd)
215+
logger.info(f'opencompass_cmd : {opencompass_cmd}')
216+
os.system(opencompass_cmd)
217+
dist.barrier()
211218

212219

213220
if __name__ == '__main__':
@@ -270,7 +277,7 @@ def main(config):
270277
mkdirs(save_fake_path)
271278

272279
# Synchronize all processes after directory creation
273-
torch.distributed.barrier()
280+
dist.barrier()
274281

275282
main(config)
276283

llmc/compression/blockwise_optimization.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,6 @@ def run_block_loop(self):
4747
os.makedirs(self.clip_path, exist_ok=True)
4848
torch.save(self.auto_clipper.weight_clips, os.path.join(self.clip_path, 'clips.pth'))
4949

50-
@abstractmethod
51-
def block_opt(self, block):
52-
pass
53-
5450
def cache_input_hook(self, m, x, y, name, feat_dict):
5551
inputs = [i.detach().cpu() for i in x]
5652
if len(inputs) == 1:
@@ -60,3 +56,16 @@ def cache_input_hook(self, m, x, y, name, feat_dict):
6056
feat_dict[name].append(inp)
6157
else:
6258
feat_dict[name].append(tuple(inputs))
59+
60+
@abstractmethod
61+
def block_opt(self, block):
62+
pass
63+
64+
def layer_init(self, layer):
65+
pass
66+
67+
def subset_init(self, subset):
68+
pass
69+
70+
def block_init(self, block):
71+
pass

llmc/compression/quantization/base_blockwise_quantization.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -522,15 +522,6 @@ def rehook_next_subset(self, block, subset, next_subset):
522522

523523
return input_feat_subset
524524

525-
def layer_init(self, layer):
526-
pass
527-
528-
def subset_init(self, subset):
529-
pass
530-
531-
def block_init(self, block):
532-
pass
533-
534525
def collect_layers_weights(self, layers, tensor_parallelize_style=None):
535526
weights = []
536527
for _m in layers:
@@ -566,13 +557,20 @@ def register_act_qparams(self, layers_dict, act_tensors):
566557
scales_list, zeros_list, qmin_list, qmax_list = (
567558
self.aquantizer.get_batch_tensors_qparams(act_tensors)
568559
)
569-
for i in range(len(scales_list)):
570-
scales, zeros, qmin, qmax = scales_list[i], zeros_list[i], qmin_list[i], qmax_list[i]
571-
for name in layers_dict:
572-
layers_dict[name].register_buffer(f'buf_act_scales_{i}', scales)
573-
layers_dict[name].register_buffer(f'buf_act_zeros_{i}', zeros)
574-
layers_dict[name].register_buffer(f'buf_act_qmin_{i}', qmin)
575-
layers_dict[name].register_buffer(f'buf_act_qmax_{i}', qmax)
560+
world_size = int(os.environ['WORLD_SIZE'])
561+
562+
for i, (scales, zeros, qmin, qmax) in enumerate(
563+
zip(scales_list, zeros_list, qmin_list, qmax_list)
564+
):
565+
scales = scales.cuda()
566+
dist.all_reduce(scales, op=dist.ReduceOp.SUM)
567+
scales = (scales / world_size).cpu()
568+
569+
for name, layer in layers_dict.items():
570+
layer.register_buffer(f'buf_act_scales_{i}', scales)
571+
layer.register_buffer(f'buf_act_zeros_{i}', zeros)
572+
layer.register_buffer(f'buf_act_qmin_{i}', qmin)
573+
layer.register_buffer(f'buf_act_qmax_{i}', qmax)
576574

577575
@torch.no_grad()
578576
def apply_scale(self, scales, prev_op, layers):

0 commit comments

Comments
 (0)