Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 40 additions & 33 deletions llmc/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,43 +69,50 @@ def main(config):
for ppl_eval in eval_list:
ppl = ppl_eval.eval(model)
logger.info(f'{ppl_eval.dataset} ppl : {ppl}')

if not config.get('calib', False):
blockwise_opt = ALGO_REGISTRY[config.quant.method](
model,
quant_config=config.quant,
input=None,
padding_mask=None,
config=config
)
blockwise_opt.run_block_loop()
dist.barrier()
else:
dataset = BaseDataset(tokenizer.get_tokenizer(), config.calib, model.batch_process)
calib_data, padding_mask = dataset.get_calib_dataset()
padding_side = getattr(tokenizer.get_tokenizer(), 'padding_side', None)
model.collect_first_block_input(calib_data, padding_mask, padding_side, config.calib.type)
del calib_data
gc.collect()
torch.cuda.empty_cache()
if not config.get('sparse', False):
for modality in config.quant.get('quant_objects', ['language']):
if not config.get('calib', False):
blockwise_opt = ALGO_REGISTRY[config.quant.method](
model,
config.quant,
model.get_first_block_input(),
model.get_padding_mask(),
config
quant_config=config.quant,
input=None,
padding_mask=None,
config=config,
modality=modality,
)
blockwise_opt.run_block_loop()
dist.barrier()
else:
blockwise_opt = ALGO_REGISTRY[config.sparse.method](
model,
config.sparse,
model.get_first_block_input(),
model.get_padding_mask(),
config
)
blockwise_opt.run_block_loop()
dist.barrier()
dataset = BaseDataset(tokenizer.get_tokenizer(), config.calib, model.batch_process)
calib_data, padding_mask = dataset.get_calib_dataset()
padding_side = getattr(tokenizer.get_tokenizer(), 'padding_side', None)
model.collect_first_block_input(calib_data,
padding_mask,
padding_side,
config.calib.type,
modality)
del calib_data
gc.collect()
torch.cuda.empty_cache()
if not config.get('sparse', False):
blockwise_opt = ALGO_REGISTRY[config.quant.method](
model,
config.quant,
model.get_first_block_input(),
model.get_padding_mask(),
config,
modality
)
else:
blockwise_opt = ALGO_REGISTRY[config.sparse.method](
model,
config.sparse,
model.get_first_block_input(),
model.get_padding_mask(),
config,
modality
)
blockwise_opt.run_block_loop()
dist.barrier()

if int(os.environ['RANK']) == 0:
if 'eval' in config and 'transformed' in config.eval.eval_pos:
Expand Down
4 changes: 3 additions & 1 deletion llmc/compression/blockwise_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@


class BlockwiseOpt(metaclass=ABCMeta):
def __init__(self, model, quant_config, input, padding_mask, config):
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
self.model = model
self.modality = modality
self.model.find_blocks(modality)
self.blocks = model.get_blocks()
self.quant_config = quant_config
self.sparsity_config = quant_config
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/adadim.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

@ALGO_REGISTRY
class AdaDim(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, config):
super().__init__(model, quant_config, input, config)
def __init__(self, model, quant_config, input, config, modality='language'):
super().__init__(model, quant_config, input, config, modality)

def get_layer_out(self, x, layer):
with torch.no_grad():
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

@ALGO_REGISTRY
class Awq(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
special_config = self.quant_config.get('special', {})
self.trans = special_config.get('trans', True)
self.trans_version = special_config.get('trans_version', 'v2')
Expand Down
7 changes: 4 additions & 3 deletions llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@


class BaseBlockwiseQuantization(BlockwiseOpt):
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
self.set_quant_config()

def w_qdq(self, module, wquantizer):
Expand Down Expand Up @@ -439,7 +439,8 @@ def run(self, block, input_feat, handles):

def block_transform(self, block, input_feat, block_kwargs):
logger.info(f'Start transform the {self.block_idx}-th block')
subsets = self.model.get_subsets_in_block(block)
subsets = self.model.get_subsets_in_block(block) \
if self.modality == 'language' else self.model.get_vision_subsets_in_block(block)

if self.act_static:
self.register_non_linear_qparams(block, input_feat)
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/dgq.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@

@ALGO_REGISTRY
class DGQ(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
self.model_dtype = next(self.model.model.parameters()).dtype

def w_qdq(self, module, wquantizer):
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

@ALGO_REGISTRY
class GPTQ(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
self.dev = torch.device('cuda')
self.model_dtype = next(self.model.model.parameters()).dtype
self.add_quant_config()
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

@ALGO_REGISTRY
class HQQ(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
self.add_quant_config()

@torch.no_grad()
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/llmint8.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

@ALGO_REGISTRY
class LlmInt8(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
self.add_quant_config()

@torch.no_grad()
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/ntweak.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

@ALGO_REGISTRY
class NormTweaking(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
self.add_quant_config()

model_type = self.config['model']['type']
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/omniq.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@

@ALGO_REGISTRY
class OmniQuant(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
self.add_quant_config()

model_type = self.config['model']['type']
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/osplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

@ALGO_REGISTRY
class OsPlus(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config):
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
torch.set_grad_enabled(False)
super().__init__(model, quant_config, input, padding_mask, config)
super().__init__(model, quant_config, input, padding_mask, config, modality)

@torch.no_grad()
def filter_subset(self, layers_dict, prev_op):
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/quarot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

@ALGO_REGISTRY
class Quarot(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
self.dev = torch.device('cuda')
self.add_quant_config()
self.preprocess()
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/quik.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

@ALGO_REGISTRY
class QUIK(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
self.add_quant_config()

def add_quant_config(self):
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

@ALGO_REGISTRY
class RTN(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)

@torch.no_grad()
def block_opt(self, *opt_kwargs):
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/smoothquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

@ALGO_REGISTRY
class SmoothQuant(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
special_config = self.quant_config.get('special', {})
self.alpha = special_config.get('alpha', 0.5)

Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/spqr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

@ALGO_REGISTRY
class SpQR(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
assert (
self.wquantizer.granularity == 'per_group'
), 'SpQR only supports per_group quantization'
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/tesseraq.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

@ALGO_REGISTRY
class TesseraQ(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
self.add_quant_config()

self.attention_mask = self.input['kwargs'][0].get('attention_mask')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@


class BaseBlockwiseSparsification(BlockwiseOpt):
def __init__(self, model, sparsity_config, input, padding_mask, config):
super().__init__(model, sparsity_config, input, padding_mask, config)
def __init__(self, model, sparsity_config, input, padding_mask, config, modality='language'):
super().__init__(model, sparsity_config, input, padding_mask, config, modality)
self.set_sparsity_config()

def block_init(self, block):
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/sparsification/magnitude.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

@ALGO_REGISTRY
class Magnitude(BaseBlockwiseSparsification):
def __init__(self, model, sparsity_config, input, padding_mask, config):
super().__init__(model, sparsity_config, input, padding_mask, config)
def __init__(self, model, sparsity_config, input, padding_mask, config, modality='language'):
super().__init__(model, sparsity_config, input, padding_mask, config, modality)

@torch.no_grad()
def subset_transform(
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/sparsification/shortgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

@ALGO_REGISTRY
class ShortGPT(BaseBlockwiseSparsification):
def __init__(self, model, sparsity_config, input, padding_mask, config):
super().__init__(model, sparsity_config, input, padding_mask, config)
def __init__(self, model, sparsity_config, input, padding_mask, config, modality='language'):
super().__init__(model, sparsity_config, input, padding_mask, config, modality)

def block_opt(self, block):
block = block.cuda()
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/sparsification/wanda.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

@ALGO_REGISTRY
class Wanda(BaseBlockwiseSparsification):
def __init__(self, model, sparsity_config, input, padding_mask, config):
super().__init__(model, sparsity_config, input, padding_mask, config)
def __init__(self, model, sparsity_config, input, padding_mask, config, modality='language'):
super().__init__(model, sparsity_config, input, padding_mask, config, modality)

@torch.no_grad()
def get_row_scale(self, layer, act):
Expand Down
36 changes: 32 additions & 4 deletions llmc/models/base_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import gc
import inspect
from abc import ABCMeta, abstractmethod
from collections import defaultdict
from functools import partial
Expand Down Expand Up @@ -33,7 +34,7 @@ def __init__(self, model_path, torch_dtype, device_map=None, use_cache=False):
self.add_layernorms_class()

@abstractmethod
def find_blocks(self):
def find_blocks(self, modality='language'):
pass

def find_block_name(self):
Expand Down Expand Up @@ -87,7 +88,29 @@ def get_attention_rotary_layers(self):
def batch_process(self):
raise Exception('batch_process should not be called here.')

def get_catcher(self, first_block_input):
def get_vision_catcher(self, first_block_input):

class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
self.signature = inspect.signature(module.forward)

def forward(self, *args, **kwargs):
params = list(self.signature.parameters.keys())
for i, arg in enumerate(args):
if i > 0:
kwargs[params[i]] = arg
first_block_input['data'].append(args[0])
if 'output_router_logits' in kwargs:
assert kwargs['output_router_logits'] is False
kwargs.pop('output_router_logits')
first_block_input['kwargs'].append(kwargs)
raise ValueError

return Catcher

def get_language_catcher(self, first_block_input):

class Catcher(nn.Module):
def __init__(self, module):
Expand Down Expand Up @@ -138,10 +161,15 @@ def add_layernorms_class(self):
logger.info(f'_TRANSFORMERS_LN_TYPES_ : {_TRANSFORMERS_LN_TYPES_}')

@torch.no_grad()
def collect_first_block_input(self, calib_data, padding_mask=None, padding_side=None, data_type='txt'): # noqa
def collect_first_block_input(self, calib_data, padding_mask=None,
padding_side=None, data_type='txt', modality='language'):
first_block_input = defaultdict(list)

Catcher = self.get_catcher(first_block_input)
self.find_blocks(modality)
if modality == 'language':
Catcher = self.get_language_catcher(first_block_input)
elif modality == 'vision':
Catcher = self.get_vision_catcher(first_block_input)

self.move_embed_to_device('cuda')
if data_type == 'img_txt':
Expand Down
2 changes: 1 addition & 1 deletion llmc/models/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class Bloom(BaseModel):
def __init__(self, model_path, torch_dtype, device_map=None, use_cache=False):
super().__init__(model_path, torch_dtype, device_map, use_cache)

def find_blocks(self):
def find_blocks(self, modality='language'):
self.blocks = self.model.transformer.h

def find_embed_layers(self):
Expand Down
Loading
Loading