Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 19 additions & 42 deletions llmc/compression/token_reduction/fastv.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,25 +48,24 @@ def hook_prepare_inputs_labels_for_multimodal(
past_key_values,
labels,
images,
image_sizes
modalities=['image'],
image_sizes=None,
):
if 'image_token_start_index' not in pruning_paras:
token_indices = input_ids[0][attention_mask[0]] == IMAGE_TOKEN_INDEX
pruning_paras['image_token_start_index'] = torch.where(token_indices)[0].item()
return self._original_prepare_inputs_labels_for_multimodal(
input_ids, position_ids, attention_mask,
past_key_values, labels, images, image_sizes
past_key_values, labels, images, modalities, image_sizes
)
return hook_prepare_inputs_labels_for_multimodal

def update_output_attentions_hook(module, args, kwargs):
def update_output_attentions_hook(module, args, kwargs, pruning_paras):
kwargs['output_attentions'] = True
pruning_paras['attn_scores'] = module.__class__.forward(module, *args, **kwargs)[1]
kwargs['output_attentions'] = False
return args, kwargs
Comment on lines +63 to 67

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The update_output_attentions_hook is registered as a pre_hook, but it manually executes the module's forward pass via module.__class__.forward. This will cause the forward pass for this layer to be executed twice: once inside this hook, and a second time by the PyTorch framework after the hook returns. This is a major performance bottleneck.

The correct way to get attention scores is to use a forward_hook which runs after the forward method and receives its output. The deleted store_attention_hook was the right pattern.

If the goal was to get attention scores without affecting the output signature of the layer (which might break subsequent layers), the forward_hook should be modified to strip the attention scores from the output tuple before returning.

I strongly recommend reverting to the pre-hook/forward-hook pattern to fix this critical performance issue.


def store_attention_hook(m, x, layer_outputs, pruning_paras):
layer_attention = layer_outputs[1]
pruning_paras['attn_scores'] = layer_attention

@prefill_wrapper
def fastv_pruning_hook(module, args, kwargs, pruning_paras):

Expand All @@ -76,7 +75,6 @@ def fastv_pruning_hook(module, args, kwargs, pruning_paras):

hidden_states = args[0]
causal_mask = kwargs['attention_mask']
cache_position = kwargs['cache_position']

device = hidden_states.device
# last_layer_attention = layer_outputs[1]
Expand Down Expand Up @@ -106,37 +104,26 @@ def fastv_pruning_hook(module, args, kwargs, pruning_paras):

# sort index
keep_indexs = keep_indexs.sort().values
# update seq length
new_seq_length = keep_indexs.shape[0]
# filter hidden states &

hidden_states = hidden_states[:, keep_indexs, :]
# update position ids
position_ids = keep_indexs.unsqueeze(0)
# update attention mask
causal_mask = _update_causal_mask(
causal_mask, None, hidden_states, 0
) if causal_mask is not None else None
kwargs['attention_mask'] = causal_mask
kwargs['cache_position'] = cache_position[:new_seq_length]
kwargs['position_ids'] = position_ids
kwargs['position_embeddings'] = None
pruning_paras['attention_mask'] = causal_mask
pruning_paras['cache_position'] = cache_position[:new_seq_length]
pruning_paras['position_ids'] = position_ids
pruning_paras['position_embeddings'] = None
if causal_mask is not None:
causal_mask = causal_mask[:, :, :hidden_states.shape[1], :hidden_states.shape[1]]
kwargs['attention_mask'].resize_as_(causal_mask).copy_(causal_mask.clone())
kwargs['cache_position'].resize_as_(position_ids.squeeze(0)).copy_(
position_ids.squeeze(0).clone())
kwargs['position_ids'].resize_as_(position_ids).copy_(position_ids.clone())

position_embeddings = kwargs['position_embeddings']
new_pe0 = position_embeddings[0][:, keep_indexs, :].clone()
new_pe1 = position_embeddings[1][:, keep_indexs, :].clone()
position_embeddings[0].resize_as_(new_pe0).copy_(new_pe0)
position_embeddings[1].resize_as_(new_pe0).copy_(new_pe1)

return (hidden_states,), kwargs

@prefill_wrapper
def read_parameter_hook(module, args, kwargs, pruning_paras):
kwargs['attention_mask'] = pruning_paras['attention_mask']
kwargs['cache_position'] = pruning_paras['cache_position']
kwargs['position_ids'] = pruning_paras['position_ids']
kwargs['position_embeddings'] = pruning_paras['position_embeddings']

return args, kwargs

if self.model.__class__.__name__ == 'LlavaHf':
self.model.embed_tokens.register_forward_pre_hook(
functools.partial(input_hook, pruning_paras=self.pruning_paras)
Expand All @@ -151,21 +138,11 @@ def read_parameter_hook(module, args, kwargs, pruning_paras):
)

self.blocks[self.pruning_loc - 1].register_forward_pre_hook(
update_output_attentions_hook,
functools.partial(update_output_attentions_hook, pruning_paras=self.pruning_paras),
with_kwargs=True
)

self.blocks[self.pruning_loc - 1].register_forward_hook(
functools.partial(store_attention_hook, pruning_paras=self.pruning_paras),
)

self.blocks[self.pruning_loc].register_forward_pre_hook(
functools.partial(fastv_pruning_hook, pruning_paras=self.pruning_paras),
with_kwargs=True
)

for idx in range(self.pruning_loc + 1, len(self.blocks)):
self.blocks[idx].register_forward_pre_hook(
functools.partial(read_parameter_hook, pruning_paras=self.pruning_paras),
with_kwargs=True
)
39 changes: 3 additions & 36 deletions llmc/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,46 +44,13 @@ def build_model(self):
self.llava_config.use_cache = True
self.vlm_model_config.use_cache = True
logger.info(f'self.vlm_model_config : {self.vlm_model_config}')

self.tokenizer, self.vlm_model, image_processor, context_len = load_pretrained_model(
self.model_path,
None,
get_model_name_from_path(self.model_path),
load_8bit=False,
load_4bit=False,
device='cpu',
torch_dtype=self.torch_dtype,
config=self.llava_config,
)

# llava forward not support "cache_position"
ori_forward = self.vlm_model.forward

def safe_forward(*args, **kwargs):
kwargs.pop('cache_position', None)
return ori_forward(*args, **kwargs)
self.vlm_model.forward = safe_forward

# llava generate use "inputs" instead of "input_ids"
ori_generate = self.vlm_model.generate

def safe_generate(*args, **kwargs):
if 'input_ids' in kwargs:
kwargs['inputs'] = kwargs.pop('input_ids')
return ori_generate(*args, **kwargs)
self.vlm_model.generate = safe_generate

# "attention_mask" is passed via kwargs rather than as an explicit keyword argument.
ori_prepare_inputs_for_generation = self.vlm_model.prepare_inputs_for_generation

def safe_prepare_inputs_for_generation(
self, input_ids, past_key_values=None,
inputs_embeds=None, attention_mask=None, **kwargs):
if attention_mask is not None:
kwargs['attention_mask'] = attention_mask
return ori_prepare_inputs_for_generation(
input_ids, past_key_values, inputs_embeds, **kwargs)
self.vlm_model.prepare_inputs_for_generation = types.MethodType(
safe_prepare_inputs_for_generation, self.vlm_model
device_map='cpu',
attn_implementation='sdpa'
)
Comment on lines +47 to 54

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The removal of monkey-patching for cache_position, input_ids, and attention_mask suggests significant updates in the LLaVA library. Verify that the new attn_implementation='sdpa' and device_map='cpu' configurations are thoroughly tested across different hardware setups to ensure consistent performance and compatibility.

device_map='cpu',
attn_implementation='sdpa'


self.eval_name = 'LlavaEval'
Expand Down