Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
295 changes: 290 additions & 5 deletions llmc/compression/token_reduction/holitom.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def SigLipEncoder_forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple]:
):
output_attentions = (
output_attentions
if output_attentions is not None
Expand Down Expand Up @@ -934,16 +934,19 @@ def prepare_inputs_labels_for_multimodal(

new_input_embeds = []
new_labels = []
if os.getenv('HOLITOM_k') is not None and os.getenv('HOLITOM_r') is not None:
if (
self.pruning_paras.get('HOLITOM_k', None) is not None
and self.pruning_paras.get('HOLITOM_r', None) is not None
):
# [modified]
image_token_posi = []
prompt_len = []
cur_image_idx = 0
# rank_print("Inserting Images embedding")
for batch_idx, cur_input_ids in enumerate(input_ids):
if (
os.getenv('HOLITOM_k') is not None
and os.getenv('HOLITOM_r') is not None
self.pruning_paras.get('HOLITOM_k', None) is not None
and self.pruning_paras.get('HOLITOM_r', None) is not None
):
# [modified]
# record image position for further dropping
Expand Down Expand Up @@ -1036,7 +1039,10 @@ def prepare_inputs_labels_for_multimodal(
new_input_embeds.append(cur_new_input_embeds)
new_labels.append(cur_new_labels)

if os.getenv('HOLITOM_k') is not None and os.getenv('HOLITOM_r') is not None:
if (
self.pruning_paras.get('HOLITOM_k', None) is not None
and self.pruning_paras.get('HOLITOM_r', None) is not None
):
# [modified]
self.model.image_token_posi = image_token_posi
self.model.prompt_len = prompt_len
Expand Down Expand Up @@ -1173,6 +1179,7 @@ def __init__(self, config, model, blocks):
def add_sparse_config(self):
special_config = self.config.get('special', {})
self.model.model.pruning_paras = special_config
self.model.model.model.pruning_paras = special_config

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The line self.model.model.model.pruning_paras = special_config appears to be redundant. self.model.model.pruning_paras is set on the preceding line. Verify if this assignment is necessary and remove if it is not.


if self.model.__class__.__name__ == 'Llava_OneVision':
SigLipEncoder.forward = SigLipEncoder_forward
Expand Down Expand Up @@ -1211,5 +1218,283 @@ def add_sparse_config(self):
LlavaMetaForCausalLM_holitom.add_newline_token
)

if (
self.special_config.get('HOLITOM_k', None) is not None
and self.special_config.get('HOLITOM_r', None) is not None
):
from functools import partial

from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_flash_attention_utils import \
FlashAttentionKwargs
from transformers.modeling_outputs import \
BaseModelOutputWithPast
from transformers.processing_utils import Unpack

def qwen_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = (
use_cache if use_cache is not None else self.config.use_cache
)

if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
'You must specify exactly one of input_ids or inputs_embeds'
)

if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
'`use_cache=True` is incompatible with gradient checkpointing.' +
'Setting `use_cache=False`.'
)
use_cache = False

# TODO (joao): remove this exception in v4.56 --
# it exists for users that try to pass a legacy cache
if not isinstance(past_key_values, (type(None), Cache)):
raise ValueError(
'The `past_key_values` should be either a `Cache` object or `None`.'
)

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

if use_cache and past_key_values is None:
past_key_values = DynamicCache()

if cache_position is None:
past_seen_tokens = (
past_key_values.get_seq_length()
if past_key_values is not None
else 0
)
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
past_key_values,
output_attentions,
)

hidden_states = inputs_embeds

# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None

HOLITOM_k = self.pruning_paras.get('HOLITOM_k', 3)
HOLITOM_r = self.pruning_paras.get('HOLITOM_r', 0.5)
HOLITOM_image_token_start_index = self.image_token_posi[0]
HOLITOM_image_token_length = self.image_tokens[0]
seq_length_with_past = past_seen_tokens + inputs_embeds.shape[1]

for layer_idx, decoder_layer in enumerate(
self.layers[: self.config.num_hidden_layers]
):
if output_hidden_states:
all_hidden_states += (hidden_states,)

if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
if layer_idx < HOLITOM_k:
pass
elif layer_idx == HOLITOM_k and position_ids.size(1) > 1:
# compute pruned tokens, generate fastv sign
last_layer_attention = layer_outputs[1]
# compute average attention over different head
last_layer_attention_avg = torch.mean(
last_layer_attention, dim=1
)[0]
# generate new attention mask based on the average attention,
# sample the top ATTENTION_RANK tokens with highest attention
last_layer_attention_avg_last_tok = (
last_layer_attention_avg[-1]
)
# get the attention in image token
last_layer_attention_avg_last_tok_image = \
last_layer_attention_avg_last_tok[
HOLITOM_image_token_start_index:
HOLITOM_image_token_start_index
+ HOLITOM_image_token_length
]
# get the indexes of the top ATTENTION_RANK tokens
top_attention_rank_index = (
last_layer_attention_avg_last_tok_image.topk(
round(
HOLITOM_image_token_length * (1 - HOLITOM_r)
)
).indices
+ HOLITOM_image_token_start_index
)
# print("Before merge:", HOLITOM_image_token_length, "After merge:",
# round(HOLITOM_image_token_length*(1-HOLITOM_r)))

device = hidden_states.device
# [modified]
all_indices = torch.arange(
HOLITOM_image_token_length, device=device
)
non_topk_mask = ~torch.isin(
all_indices,
top_attention_rank_index
- HOLITOM_image_token_start_index,
)
non_topk_indices = (
all_indices[non_topk_mask]
+ HOLITOM_image_token_start_index
)
non_topk_states = hidden_states[
:, non_topk_indices, :
] # [batch_size, len(non_topk), hidden_size]
topk_states = hidden_states[
:, top_attention_rank_index, :
] # [batch_size, len(topk), hidden_size]
non_topk_norm = torch.norm(
non_topk_states, dim=-1, keepdim=True
) # [batch_size, len(non_topk), 1]
topk_norm = torch.norm(
topk_states, dim=-1, keepdim=True
) # [batch_size, len(topk), 1]
dot_product = torch.bmm(
non_topk_states, topk_states.transpose(1, 2)
) # [batch_size, len(non_topk), len(topk)]
sim_matrix = dot_product / (
non_topk_norm * topk_norm.transpose(1, 2)
)
sim_max, sim_max_index = torch.max(sim_matrix, dim=-1)

for b in range(hidden_states.size(0)):
for i in range(len(non_topk_indices)):
non_topk_idx = non_topk_indices[i]
most_similar_topk_idx = (
top_attention_rank_index[
sim_max_index[b, i]
]
)
hidden_states[b, most_similar_topk_idx, :] = (
hidden_states[b, most_similar_topk_idx, :]
+ hidden_states[b, non_topk_idx, :]
) / 2
Comment on lines +1410 to +1421

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The nested Python loops are inefficient for tensor operations on a GPU. Refactor this section to use vectorized PyTorch operations (e.g., torch.scatter_add_).

# [modified]

# keep index
keep_indexes = torch.cat(
(
torch.arange(
HOLITOM_image_token_start_index,
device=device,
),
top_attention_rank_index,
torch.arange(
HOLITOM_image_token_start_index
+ HOLITOM_image_token_length,
seq_length_with_past,
device=device,
),
)
)
# sort index
keep_indexes = keep_indexes.sort().values
# update seq length
new_seq_length = keep_indexes.shape[0]
# filter hidden states

hidden_states = hidden_states[
:, keep_indexes, :
]
# lead the cuda error in the
# second iteration of decoding layeridx 3
# update position ids
position_ids = keep_indexes.unsqueeze(0)

position_embeddings = self.rotary_emb(
hidden_states, position_ids
)

cache_position = cache_position[:new_seq_length]

if layer_idx == HOLITOM_k - 1:
output_attentions = True
else:
output_attentions = False

layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)

hidden_states = layer_outputs[0]

# if output_attentions:
# all_self_attns += (layer_outputs[1],)
Comment on lines +1460 to +1480

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for handling output_attentions incorrectly overwrites the user-provided output_attentions value in each iteration. This prevents users from getting attention outputs for all layers if they request it.

Suggested change
if layer_idx == HOLITOM_k - 1:
output_attentions = True
else:
output_attentions = False
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
# if output_attentions:
# all_self_attns += (layer_outputs[1],)
# Determine if we need attentions for this layer, either for pruning or because the user requested them.
_get_attentions = output_attentions or (layer_idx == HOLITOM_k - 1)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=_get_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)


hidden_states = self.norm(hidden_states)

# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)

return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)

from transformers.models.qwen2.modeling_qwen2 import Qwen2Model

Qwen2Model.forward = qwen_forward

def register_reduction_modules(self):
pass
5 changes: 5 additions & 0 deletions llmc/compression/token_reduction/token_reduction_module.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import time

import torch
from loguru import logger


class TokenReductionModule:
def __init__(self, config, model, blocks):
Expand Down
Loading