Skip to content

Commit f4fb949

Browse files
add do_trans in config & remove language catcher & support chaglm (#219)
1 parent 135fe9c commit f4fb949

File tree

7 files changed

+109
-47
lines changed

7 files changed

+109
-47
lines changed

llmc/compression/quantization/awq.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,10 @@ def subset_transform(
203203
prev_op = subset['prev_op']
204204
input_name = subset['input'][0]
205205
inspect_module = subset['inspect']
206+
do_trans = subset.get('do_trans', True)
207+
if not do_trans:
208+
logger.info('do_trans is set to False. Do not transform this subset.')
209+
return
206210

207211
if not check_do_quant(
208212
self.block_idx,
@@ -241,6 +245,7 @@ def subset_transform(
241245
if (
242246
isinstance(prev_op[0], (nn.Linear, FakeQuantLinear))
243247
and prev_op[0].out_features != layers[0].in_features * 3
248+
and prev_op[0].out_features != layers[0].in_features * 2
244249
and prev_op[0].out_features != layers[0].in_features
245250
):
246251
logger.info('Cannot apply scale. Do not transform this subset.')

llmc/compression/quantization/base_blockwise_quantization.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,7 @@ def apply_shift(self, shifts, prev_op, layers):
598598
def scale_fc_fc(self, fc1, fc2, scales):
599599
scales = scales.to(fc1.weight.device)
600600
if fc1.out_features == fc2.in_features * 3:
601+
logger.info('fc1.out_features == fc2.in_features * 3')
601602
num_heads = self.model.get_num_attention_heads()
602603
fc1.weight.t_()
603604
org_shape = fc1.weight.shape
@@ -616,13 +617,23 @@ def scale_fc_fc(self, fc1, fc2, scales):
616617
fc1.bias[:, 2, :].shape
617618
)
618619
fc1.bias.data = fc1.bias.data.reshape(-1)
619-
else:
620+
elif fc1.out_features == fc2.in_features * 2:
621+
logger.info('fc1.out_features == fc2.in_features * 2')
622+
fc1.weight.data[fc1.weight.data.shape[0] // 2:].div_(scales.view(-1, 1))
623+
if hasattr(fc1, 'bias') and fc1.bias is not None:
624+
fc1.bias.data[fc1.bias.data.shape[0] // 2:].div_(scales.view(-1))
625+
elif fc1.out_features == fc2.in_features:
626+
logger.info('fc1.out_features == fc2.in_features')
620627
assert fc1.out_features == fc2.in_features
621628

622629
if hasattr(fc1, 'bias') and fc1.bias is not None:
623630
fc1.bias.div_(scales.view(-1))
624631

625632
fc1.weight.div_(scales.view(-1, 1))
633+
else:
634+
logger.error(f'fc1.out_features: {fc1.out_features}')
635+
logger.error(f'fc2.in_features: {fc2.in_features}')
636+
raise Exception('Can not scale this fc-fc.')
626637

627638
fc2.weight.mul_(scales.view(1, -1))
628639

llmc/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .bloom import Bloom
2+
from .chatglm import ChatGLM
23
from .deepseekv2 import DeepseekV2
34
from .falcon import Falcon
45
from .gemma2 import Gemma2

llmc/models/base_model.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,7 @@ def get_attention_rotary_layers(self):
106106
def batch_process(self):
107107
raise Exception('batch_process should not be called here.')
108108

109-
def get_vision_catcher(self, first_block_input):
110-
109+
def get_catcher(self, first_block_input):
111110
class Catcher(nn.Module):
112111
def __init__(self, module):
113112
super().__init__()
@@ -125,24 +124,6 @@ def forward(self, *args, **kwargs):
125124
kwargs.pop('output_router_logits')
126125
first_block_input['kwargs'].append(kwargs)
127126
raise ValueError
128-
129-
return Catcher
130-
131-
def get_language_catcher(self, first_block_input):
132-
133-
class Catcher(nn.Module):
134-
def __init__(self, module):
135-
super().__init__()
136-
self.module = module
137-
138-
def forward(self, inp, **kwargs):
139-
first_block_input['data'].append(inp)
140-
if 'output_router_logits' in kwargs:
141-
assert kwargs['output_router_logits'] is False
142-
kwargs.pop('output_router_logits')
143-
first_block_input['kwargs'].append(kwargs)
144-
raise ValueError
145-
146127
return Catcher
147128

148129
def __str__(self):
@@ -184,10 +165,7 @@ def collect_first_block_input(self, calib_data, padding_mask=None,
184165
first_block_input = defaultdict(list)
185166

186167
self.find_blocks(modality)
187-
if modality == 'language':
188-
Catcher = self.get_language_catcher(first_block_input)
189-
elif modality == 'vision':
190-
Catcher = self.get_vision_catcher(first_block_input)
168+
Catcher = self.get_catcher(first_block_input)
191169

192170
self.move_embed_to_device('cuda')
193171
if data_type == 'img_txt':

llmc/models/chatglm.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import inspect
2+
3+
import torch.nn as nn
4+
5+
from llmc.utils.registry_factory import MODEL_REGISTRY
6+
7+
from .base_model import BaseModel
8+
9+
10+
@MODEL_REGISTRY
11+
class ChatGLM(BaseModel):
12+
def __init__(self, config, device_map=None, use_cache=False):
13+
super().__init__(config, device_map, use_cache)
14+
15+
def find_blocks(self, modality='language'):
16+
self.blocks = self.model.transformer.encoder.layers
17+
18+
def find_embed_layers(self):
19+
self.embedding = self.model.transformer.embedding
20+
self.rotary_pos_emb = self.model.transformer.rotary_pos_emb
21+
22+
def find_block_name(self):
23+
self.block_name_prefix = 'transformer.encoder.layers'
24+
25+
def get_embed_layers(self):
26+
return [self.embedding]
27+
28+
def get_attention_rotary_layers(self):
29+
return [self.rotary_pos_emb]
30+
31+
def get_head_layers(self):
32+
return [self.model.transformer.output_layer]
33+
34+
def get_pre_head_layernorm_layers(self):
35+
return [self.model.transformer.encoder.final_layernorm]
36+
37+
def get_layers_except_blocks(self):
38+
return [self.embedding, self.rotary_pos_emb, self.model.transformer.output_layer, self.model.transformer.encoder.final_layernorm] # noqa
39+
40+
def skip_layer_name(self):
41+
return ['final_layernorm']
42+
43+
def has_bias(self):
44+
return False
45+
46+
def get_layernorms_in_block(self, block):
47+
return {
48+
'input_layernorm': block.input_layernorm,
49+
'post_attention_layernorm': block.post_attention_layernorm,
50+
}
51+
52+
def get_subsets_in_block(self, block):
53+
return [
54+
{
55+
'layers': {
56+
'self_attention.query_key_value': block.self_attention.query_key_value
57+
},
58+
'prev_op': [block.input_layernorm],
59+
'input': ['self_attention.query_key_value'],
60+
'inspect': block.self_attention,
61+
'has_kwargs': True,
62+
},
63+
{
64+
'layers': {'self_attention.dense': block.self_attention.dense},
65+
'prev_op': [block.self_attention.query_key_value],
66+
'input': ['self_attention.dense'],
67+
'inspect': block.self_attention.dense,
68+
'has_kwargs': False,
69+
},
70+
{
71+
'layers': {
72+
'mlp.dense_h_to_4h': block.mlp.dense_h_to_4h
73+
},
74+
'prev_op': [block.post_attention_layernorm],
75+
'input': ['mlp.dense_h_to_4h'],
76+
'inspect': block.mlp,
77+
'has_kwargs': False,
78+
'is_mlp': True,
79+
},
80+
{
81+
'layers': {'mlp.down_proj': block.mlp.dense_4h_to_h},
82+
'prev_op': [block.mlp.dense_h_to_4h],
83+
'input': ['mlp.dense_4h_to_h'],
84+
'inspect': block.mlp.dense_4h_to_h,
85+
'has_kwargs': False,
86+
'is_mlp': True,
87+
},
88+
]

llmc/models/opt.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,5 +85,6 @@ def get_subsets_in_block(self, block):
8585
'inspect': block.fc2,
8686
'has_kwargs': False,
8787
'is_mlp': True,
88+
'do_trans': False
8889
},
8990
]

llmc/models/vit.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -78,28 +78,6 @@ def batch_process(self, imgs):
7878
samples.append(sample)
7979
return samples
8080

81-
def get_catcher(self, first_block_input):
82-
83-
class Catcher(nn.Module):
84-
def __init__(self, module):
85-
super().__init__()
86-
self.module = module
87-
self.signature = inspect.signature(module.forward)
88-
89-
def forward(self, *args, **kwargs):
90-
params = list(self.signature.parameters.keys())
91-
for i, arg in enumerate(args):
92-
if i > 0:
93-
kwargs[params[i]] = arg
94-
first_block_input['data'].append(args[0])
95-
if 'output_router_logits' in kwargs:
96-
assert kwargs['output_router_logits'] is False
97-
kwargs.pop('output_router_logits')
98-
first_block_input['kwargs'].append(kwargs)
99-
raise ValueError
100-
101-
return Catcher
102-
10381
def get_subsets_in_block(self, block):
10482
return [
10583
{

0 commit comments

Comments
 (0)