Skip to content

Conversation

@chengtao-lv
Copy link
Collaborator

No description provided.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @chengtao-lv, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces the HoliTom module, a novel token reduction strategy aimed at optimizing multimodal model performance, especially for video inputs. It achieves this by intelligently identifying and merging redundant visual tokens, thereby reducing computational overhead. The implementation involves dynamically injecting custom processing logic into existing LLaVA model components to handle visual feature extraction and token preparation more efficiently.

Highlights

  • New Token Reduction Module: Introduced a new token reduction module named HoliTom within the llmc.compression.token_reduction package. This module is designed to optimize the processing of visual tokens in multimodal models.
  • Dynamic Method Patching for LLaVA: Implemented dynamic patching of key methods in LlavaMetaForCausalLM, SigLipEncoder, and SigLipVisionTower classes. This allows the HoliTom logic to intercept and modify the multimodal input preparation and vision tower forwarding processes, specifically for models like Llava_OneVision.
  • Advanced Visual Token Processing: The HoliTom module incorporates sophisticated algorithms for visual token reduction, including density peak clustering (cluster_dpc_knn), selection of static windows (select_static_windows), and merging tokens based on attention density and general density (merge_tokens_by_attention_density, merge_tokens_by_density, merge_tokens_by_clustering). This enables intelligent pruning of redundant visual information, particularly useful for video inputs.
  • Support for Multimodal Input Optimization: The core holitom function orchestrates the token reduction process by separating static and dynamic features, applying different merging strategies based on window size, and handling newline token insertion for grid-based or frame-wise processing. This aims to reduce the sequence length of visual tokens while preserving critical information.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new token reduction method for multimodal models. The implementation uses monkey-patching to add static and dynamic token pruning strategies based on feature similarity and attention density. A critical bug was identified that will cause a runtime error, and a core function is over 500 lines long and should be refactored for maintainability.

output_hidden_states=True,
)
image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
assert image_features.shape[-2] == 729

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This line will raise an AttributeError because image_features is an empty list at this point in the loop. You likely intended to check the shape of image_feature, which was defined on the previous line.

Suggested change
assert image_features.shape[-2] == 729
assert image_feature.shape[-2] == 729

Comment on lines +569 to +1163
def prepare_inputs_labels_for_multimodal(
self,
input_ids,
position_ids,
attention_mask,
past_key_values,
labels,
images,
modalities=['image'],
image_sizes=None,
):
import os

vision_tower = self.get_vision_tower()
# rank_print(modalities)
if vision_tower is None or images is None or input_ids.shape[1] == 1:
return (
input_ids,
position_ids,
attention_mask,
past_key_values,
None,
labels,
)

if isinstance(modalities, str):
modalities = [modalities]

# import pdb; pdb.set_trace()
if type(images) is list or images.ndim == 5:
mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat')
image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square')
mm_newline_position = getattr(
self.config, 'mm_newline_position', 'one_token'
)

if type(images) is list:
images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]

video_idx_in_batch = []
for _ in range(len(modalities)):
if modalities[_] == 'video':
video_idx_in_batch.append(_)

images_list = []
for image in images:
if image.ndim == 4:
images_list.append(image)
else:
images_list.append(image.unsqueeze(0))

concat_images = torch.cat([image for image in images_list], dim=0)
split_sizes = [image.shape[0] for image in images_list]
encoded_image_features, attn_weights, _, images_dtype = (
self.encode_images_multi(concat_images)
)
retain_ratio = self.pruning_paras.get('RETAIN_RATIO', 0.1)
# C = int(os.environ.get("C", 8))
# tau = float(os.environ.get("T", 0.8))
tau = self.pruning_paras.get('T', 0.1)
# P = int(os.environ.get("P", 4))
Beta = float(os.environ.get('BETA', 0.6))
D = float(os.environ.get('D', 0))
K = int(os.environ.get('K', 7))
max_window_size = int(os.environ.get('MAX_WINDOW_SIZE', 1024))
# NO_BETA = os.environ.get('NO_BETA', '1')
# rank0_print(f"retain_ratio: {retain_ratio},
# tau: {tau}, Beta: {Beta}, D: {D}, K: {K},
# max_window_size: {max_window_size}, NO_BETA: {NO_BETA}")
# image_features,all_faster_video_features =
# self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes)

# This is a list, each element is [num_images, patch * patch, dim]
# rank_print(f"Concat images : {concat_images.shape}")
encoded_image_features = torch.split(encoded_image_features, split_sizes)
image_features = []
for idx, image_feat in enumerate(encoded_image_features):
if idx in video_idx_in_batch:
# [modify]
# image_features.append(self.get_2dPool(image_feat))
# image_feat: (batch_size, seq_len, embed_dim)
# attn_weights: (batch_size, seq_len)
pooled_image_feat = self.get_2dPool(
image_feat
) # (batch_size, seq_len', embed_dim)
attn_weights = attn_weights.unsqueeze(-1)
attn_weights = self.get_2dPool(attn_weights)
attn_weights = attn_weights.squeeze(-1) # (batch_size, seq_len')

batch_size, seq_len, embed_dim = pooled_image_feat.shape

pooled_image_feat_normed = torch.nn.functional.normalize(
pooled_image_feat, p=2, dim=-1
)
feature_sim = torch.nn.functional.cosine_similarity(
pooled_image_feat_normed[:-1],
pooled_image_feat_normed[1:],
dim=-1,
) # (batch_size-1, seq_len')

selected_frames, total_reduced = self.select_static_windows(
feature_sim, batch_size, tau, max_window_size
)
# rank0_print(f"Selected frames: {selected_frames}")
# rank0_print(f"Total reduced features: {total_reduced}")

total_tokens = batch_size * seq_len
retain_ratio = min(
retain_ratio / ((total_tokens - total_reduced) / total_tokens),
1,
)
# rank0_print(f"After static pruning, retain ratio: {retain_ratio}")

(
static_feat,
dynamic_feat,
_,
dynamic_attn,
static_pos,
dynamic_pos,
) = self.get_static_dynamic_features(
pooled_image_feat,
attn_weights,
selected_frames,
feature_sim,
tau,
)

segment_features = []
for idx, (start, end) in enumerate(selected_frames):
window_size = end - start + 1
segment_features.append(
self.holitom(
static_feat[idx],
dynamic_feat[idx],
dynamic_attn[idx],
static_pos[idx],
dynamic_pos[idx],
window_size,
retain_ratio,
D,
Beta,
K,
images_dtype,
mm_newline_position,
)
)
image_features.append(torch.cat(segment_features, dim=0))

else:
image_features.append(image_feat)
# image_features =
# self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes)
# rank_print(f"Encoded image feats : {[x.shape for x in image_features]}")
# image_features = torch.split(image_features, split_sizes, dim=0)

if mm_patch_merge_type == 'flat':
image_features = [x.flatten(0, 1) for x in image_features]

elif mm_patch_merge_type.startswith('spatial'):
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
# FIXME: now assume the image is square, and split to 2x2 patches
# num_patches = h * w, where h = w = sqrt(num_patches)
# currently image_feature is a tensor of shape (4, num_patches, hidden_size)
# we want to first unflatten it to (2, 2, h, w, hidden_size)
# rank0_print("At least we are reaching here")
# import pdb; pdb.set_trace()
if image_idx in video_idx_in_batch: # video operations
# rank0_print("Video")
if mm_newline_position == 'grid':
new_image_features.append(image_feature)
elif mm_newline_position == 'frame':
# Frame-wise
image_feature = self.add_token_per_frame(image_feature)

new_image_features.append(image_feature.flatten(0, 1))

elif mm_newline_position == 'one_token':
# one-token
# image_feature = image_feature.flatten(0, 1)
if 'unpad' in mm_patch_merge_type:
image_feature = torch.cat(
(
image_feature,
self.model.image_newline[None].to(
image_feature.device
),
),
dim=0,
)
new_image_features.append(image_feature)
elif mm_newline_position == 'no_token':
new_image_features.append(image_feature.flatten(0, 1))
else:
raise ValueError(
f'Unexpected mm_newline_position: {mm_newline_position}'
)
elif (
image_feature.shape[0] > 1
): # multi patches and multi images operations
# rank0_print("Single-images")
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.get_vision_tower().num_patches_per_side
assert height * width == base_image_feature.shape[0]

if 'anyres_max' in image_aspect_ratio:
matched_anyres_max_num_patches = re.match(
r'anyres_max_(\d+)', image_aspect_ratio
)
if matched_anyres_max_num_patches:
max_num_patches = int(
matched_anyres_max_num_patches.group(1)
)

if (
image_aspect_ratio == 'anyres'
or 'anyres_max' in image_aspect_ratio
):
if hasattr(self.get_vision_tower(), 'image_size'):
vision_tower_image_size = (
self.get_vision_tower().image_size
)
else:
raise ValueError(
'vision_tower_image_size is not found in the vision tower.'
)
try:
num_patch_width, num_patch_height = (
get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
vision_tower_image_size,
)
)
except Exception as e:
rank0_print(f'Error: {e}')
num_patch_width, num_patch_height = 2, 2
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
else:
image_feature = image_feature.view(2, 2, height, width, -1)

if 'maxpool2x2' in mm_patch_merge_type:
image_feature = image_feature.permute(
4, 0, 2, 1, 3
).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = nn.functional.max_pool2d(image_feature, 2)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
elif (
'unpad' in mm_patch_merge_type
and 'anyres_max' in image_aspect_ratio
and matched_anyres_max_num_patches
):
unit = image_feature.shape[2]
image_feature = image_feature.permute(
4, 0, 2, 1, 3
).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(
image_feature, image_sizes[image_idx]
)
c, h, w = image_feature.shape
times = math.sqrt(h * w / (max_num_patches * unit**2))
if times > 1.1:
image_feature = image_feature[None]
image_feature = nn.functional.interpolate(
image_feature,
[int(h // times), int(w // times)],
mode='bilinear',
)[0]
image_feature = torch.cat(
(
image_feature,
self.model.image_newline[:, None, None]
.expand(*image_feature.shape[:-1], 1)
.to(image_feature.device),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
elif 'unpad' in mm_patch_merge_type:
image_feature = image_feature.permute(
4, 0, 2, 1, 3
).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(
image_feature, image_sizes[image_idx]
)
image_feature = torch.cat(
(
image_feature,
self.model.image_newline[:, None, None]
.expand(*image_feature.shape[:-1], 1)
.to(image_feature.device),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
else:
image_feature = image_feature.permute(
0, 2, 1, 3, 4
).contiguous()
image_feature = image_feature.flatten(0, 3)
if 'nobase' in mm_patch_merge_type:
pass
else:
image_feature = torch.cat(
(base_image_feature, image_feature), dim=0
)
new_image_features.append(image_feature)
else: # single image operations
image_feature = image_feature[0]
if 'unpad' in mm_patch_merge_type:
image_feature = torch.cat(
(image_feature, self.model.image_newline[None]), dim=0
)

new_image_features.append(image_feature)
image_features = new_image_features
else:
raise ValueError(
f'Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}'
)
else:
image_features = self.encode_images(images)

# TODO: image start / end is not implemented here to support pretraining.
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(
self.config, 'mm_use_im_start_end', False
):
raise NotImplementedError
# rank_print(f"Total images : {len(image_features)}")

# Let's just add dummy tensors if they do not exist,
# it is a headache to deal with None all the time.
# But it is not ideal, and if you have a better idea,
# please open an issue / submit a PR, thanks.
_labels = labels
_position_ids = position_ids
_attention_mask = attention_mask
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
else:
attention_mask = attention_mask.bool()
if position_ids is None:
position_ids = torch.arange(
0, input_ids.shape[1], dtype=torch.long, device=input_ids.device
)
if labels is None:
labels = torch.full_like(input_ids, IGNORE_INDEX)

# remove the padding using attention_mask -- FIXME
# _input_ids = input_ids
input_ids = [
cur_input_ids[cur_attention_mask]
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
]
labels = [
cur_labels[cur_attention_mask]
for cur_labels, cur_attention_mask in zip(labels, attention_mask)
]

new_input_embeds = []
new_labels = []
if os.getenv('HOLITOM_k') is not None and os.getenv('HOLITOM_r') is not None:
# [modified]
image_token_posi = []
prompt_len = []
cur_image_idx = 0
# rank_print("Inserting Images embedding")
for batch_idx, cur_input_ids in enumerate(input_ids):
if (
os.getenv('HOLITOM_k') is not None
and os.getenv('HOLITOM_r') is not None
):
# [modified]
# record image position for further dropping
image_index = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[
0
].tolist()
if image_index == []:
image_token_posi.append(-1)
else:
image_token_posi.append(image_index[0])

# record input instruction length in inference mode
if not self.training:
if image_index == []:
prompt_len.append(cur_input_ids.shape[0])
else:
prompt_len.append(
cur_input_ids.shape[0] - 1
) # consider image place holder

num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
# rank0_print(num_images)
if num_images == 0:
cur_image_features = image_features[cur_image_idx]
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
cur_input_embeds = torch.cat(
[cur_input_embeds_1, cur_image_features[0:0]], dim=0
)
new_input_embeds.append(cur_input_embeds)
new_labels.append(labels[batch_idx])
cur_image_idx += 1
continue

image_token_indices = (
[-1]
+ torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist()
+ [cur_input_ids.shape[0]]
)
cur_input_ids_noim = []
cur_labels = labels[batch_idx]
cur_labels_noim = []
for i in range(len(image_token_indices) - 1):
cur_input_ids_noim.append(
cur_input_ids[
image_token_indices[i] + 1: image_token_indices[i + 1]
]
)
cur_labels_noim.append(
cur_labels[image_token_indices[i] + 1: image_token_indices[i + 1]]
)
# [modify]
# text_token_count = sum([x.shape[0] for x in cur_labels_noim])
# vision_token_count = len(image_features[cur_image_idx])
# rank0_print(f"Batch {batch_idx}:
# Text tokens: {text_token_count} Original Vision tokens: {vision_token_count}")

split_sizes = [x.shape[0] for x in cur_labels_noim]
cur_input_embeds = self.get_model().embed_tokens(
torch.cat(cur_input_ids_noim)
)
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
cur_new_input_embeds = []
cur_new_labels = []

for i in range(num_images + 1):
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
cur_new_labels.append(cur_labels_noim[i])
if i < num_images:
try:
cur_image_features = image_features[cur_image_idx]
except IndexError:
cur_image_features = image_features[cur_image_idx - 1]
cur_image_idx += 1
cur_new_input_embeds.append(cur_image_features)
cur_new_labels.append(
torch.full(
(cur_image_features.shape[0],),
IGNORE_INDEX,
device=cur_labels.device,
dtype=cur_labels.dtype,
)
)

cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]

# import pdb; pdb.set_trace()
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
cur_new_labels = torch.cat(cur_new_labels)

new_input_embeds.append(cur_new_input_embeds)
new_labels.append(cur_new_labels)

if os.getenv('HOLITOM_k') is not None and os.getenv('HOLITOM_r') is not None:
# [modified]
self.model.image_token_posi = image_token_posi
self.model.prompt_len = prompt_len
self.model.image_tokens = [
image_feature.shape[0] for image_feature in image_features
]

# Truncate sequences to max length as image embeddings can make the sequence longer
tokenizer_model_max_length = getattr(
self.config, 'tokenizer_model_max_length', None
)
# rank_print("Finishing Inserting")

new_input_embeds = [
x[:tokenizer_model_max_length]
for x, modality in zip(new_input_embeds, modalities)
]
new_labels = [
x[:tokenizer_model_max_length]
for x, modality in zip(new_labels, modalities)
]

# Combine them
max_len = max(x.shape[0] for x in new_input_embeds)
batch_size = len(new_input_embeds)

new_input_embeds_padded = []
new_labels_padded = torch.full(
(batch_size, max_len),
IGNORE_INDEX,
dtype=new_labels[0].dtype,
device=new_labels[0].device,
)
attention_mask = torch.zeros(
(batch_size, max_len),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
position_ids = torch.zeros(
(batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device
)
# rank0_print("Prepare pos id")

for i, (cur_new_embed, cur_new_labels) in enumerate(
zip(new_input_embeds, new_labels)
):
cur_len = cur_new_embed.shape[0]
if getattr(self.config, 'tokenizer_padding_side', 'right') == 'left':
new_input_embeds_padded.append(
torch.cat(
(
torch.zeros(
(max_len - cur_len, cur_new_embed.shape[1]),
dtype=cur_new_embed.dtype,
device=cur_new_embed.device,
),
cur_new_embed,
),
dim=0,
)
)
if cur_len > 0:
new_labels_padded[i, -cur_len:] = cur_new_labels
attention_mask[i, -cur_len:] = True
position_ids[i, -cur_len:] = torch.arange(
0, cur_len, dtype=position_ids.dtype, device=position_ids.device
)
else:
new_input_embeds_padded.append(
torch.cat(
(
cur_new_embed,
torch.zeros(
(max_len - cur_len, cur_new_embed.shape[1]),
dtype=cur_new_embed.dtype,
device=cur_new_embed.device,
),
),
dim=0,
)
)
if cur_len > 0:
new_labels_padded[i, :cur_len] = cur_new_labels
attention_mask[i, :cur_len] = True
position_ids[i, :cur_len] = torch.arange(
0, cur_len, dtype=position_ids.dtype, device=position_ids.device
)

new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
# rank0_print("tokenizer padding")

if _labels is None:
new_labels = None
else:
new_labels = new_labels_padded

if _attention_mask is None:
attention_mask = None
else:
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)

if _position_ids is None:
position_ids = None
if getattr(self.config, 'use_pos_skipping', False) and self.training:
position_ids = (
torch.arange(new_input_embeds.size(1), device=new_input_embeds.device)
.unsqueeze(0)
.to(new_input_embeds.device)
)
split_position = random.randint(0, new_input_embeds.size(1))
left_add = random.randint(0, self.config.pos_skipping_range)
right_add = random.randint(left_add, self.config.pos_skipping_range)
position_ids[:, :split_position] += left_add
position_ids[:, split_position:] += right_add
# import pdb; pdb.set_trace()
# rank0_print("Finish preparing")
return (
None,
position_ids,
attention_mask,
past_key_values,
new_input_embeds,
new_labels,
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The function prepare_inputs_labels_for_multimodal is excessively long (over 500 lines). Its size and complexity make it very difficult to read, debug, and maintain. It handles multiple, distinct responsibilities, such as image encoding, complex token pruning logic, and preparing embeddings. To improve maintainability, please refactor this monolithic function into several smaller, well-defined helper functions, each with a single responsibility.

Comment on lines +630 to +633
Beta = float(os.environ.get('BETA', 0.6))
D = float(os.environ.get('D', 0))
K = int(os.environ.get('K', 7))
max_window_size = int(os.environ.get('MAX_WINDOW_SIZE', 1024))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This function reads several hyperparameters (BETA, D, K, MAX_WINDOW_SIZE) from environment variables. This approach is not ideal for production code as it makes configuration implicit and harder to manage. It would be more robust to pass these values through the model's configuration, similar to how retain_ratio and tau are handled.

Comment on lines +1015 to +1018
try:
cur_image_features = image_features[cur_image_idx]
except IndexError:
cur_image_features = image_features[cur_image_idx - 1]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This try-except IndexError block appears to be masking a potential bug. Silently catching the exception and reusing the previous image features could lead to subtle errors. It would be much safer to investigate the root cause of why cur_image_idx might go out of bounds and fix the underlying logic that increments it.

@helloyongyang helloyongyang merged commit 948c6ed into main Jun 27, 2025
1 of 2 checks passed
@helloyongyang helloyongyang deleted the vlm branch June 27, 2025 11:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants