Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions llmc/compression/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@ def subset_transform(
prev_op = subset['prev_op']
input_name = subset['input'][0]
inspect_module = subset['inspect']
do_trans = subset.get('do_trans', True)
if not do_trans:
logger.info('do_trans is set to False. Do not transform this subset.')
return

if not check_do_quant(
self.block_idx,
Expand Down Expand Up @@ -241,6 +245,7 @@ def subset_transform(
if (
isinstance(prev_op[0], (nn.Linear, FakeQuantLinear))
and prev_op[0].out_features != layers[0].in_features * 3
and prev_op[0].out_features != layers[0].in_features * 2
and prev_op[0].out_features != layers[0].in_features
):
logger.info('Cannot apply scale. Do not transform this subset.')
Expand Down
13 changes: 12 additions & 1 deletion llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,7 @@ def apply_shift(self, shifts, prev_op, layers):
def scale_fc_fc(self, fc1, fc2, scales):
scales = scales.to(fc1.weight.device)
if fc1.out_features == fc2.in_features * 3:
logger.info('fc1.out_features == fc2.in_features * 3')
num_heads = self.model.get_num_attention_heads()
fc1.weight.t_()
org_shape = fc1.weight.shape
Expand All @@ -616,13 +617,23 @@ def scale_fc_fc(self, fc1, fc2, scales):
fc1.bias[:, 2, :].shape
)
fc1.bias.data = fc1.bias.data.reshape(-1)
else:
elif fc1.out_features == fc2.in_features * 2:
logger.info('fc1.out_features == fc2.in_features * 2')
fc1.weight.data[fc1.weight.data.shape[0] // 2:].div_(scales.view(-1, 1))
if hasattr(fc1, 'bias') and fc1.bias is not None:
fc1.bias.data[fc1.bias.data.shape[0] // 2:].div_(scales.view(-1))
elif fc1.out_features == fc2.in_features:
logger.info('fc1.out_features == fc2.in_features')
assert fc1.out_features == fc2.in_features

if hasattr(fc1, 'bias') and fc1.bias is not None:
fc1.bias.div_(scales.view(-1))

fc1.weight.div_(scales.view(-1, 1))
else:
logger.error(f'fc1.out_features: {fc1.out_features}')
logger.error(f'fc2.in_features: {fc2.in_features}')
raise Exception('Can not scale this fc-fc.')

fc2.weight.mul_(scales.view(1, -1))

Expand Down
1 change: 1 addition & 0 deletions llmc/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .bloom import Bloom
from .chatglm import ChatGLM
from .deepseekv2 import DeepseekV2
from .falcon import Falcon
from .gemma2 import Gemma2
Expand Down
26 changes: 2 additions & 24 deletions llmc/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ def get_attention_rotary_layers(self):
def batch_process(self):
raise Exception('batch_process should not be called here.')

def get_vision_catcher(self, first_block_input):

def get_catcher(self, first_block_input):
class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
Expand All @@ -125,24 +124,6 @@ def forward(self, *args, **kwargs):
kwargs.pop('output_router_logits')
first_block_input['kwargs'].append(kwargs)
raise ValueError

return Catcher

def get_language_catcher(self, first_block_input):

class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module

def forward(self, inp, **kwargs):
first_block_input['data'].append(inp)
if 'output_router_logits' in kwargs:
assert kwargs['output_router_logits'] is False
kwargs.pop('output_router_logits')
first_block_input['kwargs'].append(kwargs)
raise ValueError

return Catcher

def __str__(self):
Expand Down Expand Up @@ -184,10 +165,7 @@ def collect_first_block_input(self, calib_data, padding_mask=None,
first_block_input = defaultdict(list)

self.find_blocks(modality)
if modality == 'language':
Catcher = self.get_language_catcher(first_block_input)
elif modality == 'vision':
Catcher = self.get_vision_catcher(first_block_input)
Catcher = self.get_catcher(first_block_input)

self.move_embed_to_device('cuda')
if data_type == 'img_txt':
Expand Down
88 changes: 88 additions & 0 deletions llmc/models/chatglm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import inspect

import torch.nn as nn

from llmc.utils.registry_factory import MODEL_REGISTRY

from .base_model import BaseModel


@MODEL_REGISTRY
class ChatGLM(BaseModel):
def __init__(self, config, device_map=None, use_cache=False):
super().__init__(config, device_map, use_cache)

def find_blocks(self, modality='language'):
self.blocks = self.model.transformer.encoder.layers

def find_embed_layers(self):
self.embedding = self.model.transformer.embedding
self.rotary_pos_emb = self.model.transformer.rotary_pos_emb

def find_block_name(self):
self.block_name_prefix = 'transformer.encoder.layers'

def get_embed_layers(self):
return [self.embedding]

def get_attention_rotary_layers(self):
return [self.rotary_pos_emb]

def get_head_layers(self):
return [self.model.transformer.output_layer]

def get_pre_head_layernorm_layers(self):
return [self.model.transformer.encoder.final_layernorm]

def get_layers_except_blocks(self):
return [self.embedding, self.rotary_pos_emb, self.model.transformer.output_layer, self.model.transformer.encoder.final_layernorm] # noqa

def skip_layer_name(self):
return ['final_layernorm']

def has_bias(self):
return False

def get_layernorms_in_block(self, block):
return {
'input_layernorm': block.input_layernorm,
'post_attention_layernorm': block.post_attention_layernorm,
}

def get_subsets_in_block(self, block):
return [
{
'layers': {
'self_attention.query_key_value': block.self_attention.query_key_value
},
'prev_op': [block.input_layernorm],
'input': ['self_attention.query_key_value'],
'inspect': block.self_attention,
'has_kwargs': True,
},
{
'layers': {'self_attention.dense': block.self_attention.dense},
'prev_op': [block.self_attention.query_key_value],
'input': ['self_attention.dense'],
'inspect': block.self_attention.dense,
'has_kwargs': False,
},
{
'layers': {
'mlp.dense_h_to_4h': block.mlp.dense_h_to_4h
},
'prev_op': [block.post_attention_layernorm],
'input': ['mlp.dense_h_to_4h'],
'inspect': block.mlp,
'has_kwargs': False,
'is_mlp': True,
},
{
'layers': {'mlp.down_proj': block.mlp.dense_4h_to_h},
'prev_op': [block.mlp.dense_h_to_4h],
'input': ['mlp.dense_4h_to_h'],
'inspect': block.mlp.dense_4h_to_h,
'has_kwargs': False,
'is_mlp': True,
},
]
1 change: 1 addition & 0 deletions llmc/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,6 @@ def get_subsets_in_block(self, block):
'inspect': block.fc2,
'has_kwargs': False,
'is_mlp': True,
'do_trans': False
},
]
22 changes: 0 additions & 22 deletions llmc/models/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,28 +78,6 @@ def batch_process(self, imgs):
samples.append(sample)
return samples

def get_catcher(self, first_block_input):

class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
self.signature = inspect.signature(module.forward)

def forward(self, *args, **kwargs):
params = list(self.signature.parameters.keys())
for i, arg in enumerate(args):
if i > 0:
kwargs[params[i]] = arg
first_block_input['data'].append(args[0])
if 'output_router_logits' in kwargs:
assert kwargs['output_router_logits'] is False
kwargs.pop('output_router_logits')
first_block_input['kwargs'].append(kwargs)
raise ValueError

return Catcher

def get_subsets_in_block(self, block):
return [
{
Expand Down
Loading