-
Notifications
You must be signed in to change notification settings - Fork 66
holitom #395
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
holitom #395
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @chengtao-lv, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces the HoliTom module, a novel token reduction strategy aimed at optimizing multimodal model performance, especially for video inputs. It achieves this by intelligently identifying and merging redundant visual tokens, thereby reducing computational overhead. The implementation involves dynamically injecting custom processing logic into existing LLaVA model components to handle visual feature extraction and token preparation more efficiently.
Highlights
- New Token Reduction Module: Introduced a new token reduction module named
HoliTomwithin thellmc.compression.token_reductionpackage. This module is designed to optimize the processing of visual tokens in multimodal models. - Dynamic Method Patching for LLaVA: Implemented dynamic patching of key methods in
LlavaMetaForCausalLM,SigLipEncoder, andSigLipVisionTowerclasses. This allows theHoliTomlogic to intercept and modify the multimodal input preparation and vision tower forwarding processes, specifically for models likeLlava_OneVision. - Advanced Visual Token Processing: The
HoliTommodule incorporates sophisticated algorithms for visual token reduction, including density peak clustering (cluster_dpc_knn), selection of static windows (select_static_windows), and merging tokens based on attention density and general density (merge_tokens_by_attention_density,merge_tokens_by_density,merge_tokens_by_clustering). This enables intelligent pruning of redundant visual information, particularly useful for video inputs. - Support for Multimodal Input Optimization: The core
holitomfunction orchestrates the token reduction process by separating static and dynamic features, applying different merging strategies based on window size, and handling newline token insertion for grid-based or frame-wise processing. This aims to reduce the sequence length of visual tokens while preserving critical information.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new token reduction method for multimodal models. The implementation uses monkey-patching to add static and dynamic token pruning strategies based on feature similarity and attention density. A critical bug was identified that will cause a runtime error, and a core function is over 500 lines long and should be refactored for maintainability.
| output_hidden_states=True, | ||
| ) | ||
| image_feature = image_forward_out.hidden_states[-1].to(image.dtype) | ||
| assert image_features.shape[-2] == 729 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def prepare_inputs_labels_for_multimodal( | ||
| self, | ||
| input_ids, | ||
| position_ids, | ||
| attention_mask, | ||
| past_key_values, | ||
| labels, | ||
| images, | ||
| modalities=['image'], | ||
| image_sizes=None, | ||
| ): | ||
| import os | ||
|
|
||
| vision_tower = self.get_vision_tower() | ||
| # rank_print(modalities) | ||
| if vision_tower is None or images is None or input_ids.shape[1] == 1: | ||
| return ( | ||
| input_ids, | ||
| position_ids, | ||
| attention_mask, | ||
| past_key_values, | ||
| None, | ||
| labels, | ||
| ) | ||
|
|
||
| if isinstance(modalities, str): | ||
| modalities = [modalities] | ||
|
|
||
| # import pdb; pdb.set_trace() | ||
| if type(images) is list or images.ndim == 5: | ||
| mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat') | ||
| image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square') | ||
| mm_newline_position = getattr( | ||
| self.config, 'mm_newline_position', 'one_token' | ||
| ) | ||
|
|
||
| if type(images) is list: | ||
| images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] | ||
|
|
||
| video_idx_in_batch = [] | ||
| for _ in range(len(modalities)): | ||
| if modalities[_] == 'video': | ||
| video_idx_in_batch.append(_) | ||
|
|
||
| images_list = [] | ||
| for image in images: | ||
| if image.ndim == 4: | ||
| images_list.append(image) | ||
| else: | ||
| images_list.append(image.unsqueeze(0)) | ||
|
|
||
| concat_images = torch.cat([image for image in images_list], dim=0) | ||
| split_sizes = [image.shape[0] for image in images_list] | ||
| encoded_image_features, attn_weights, _, images_dtype = ( | ||
| self.encode_images_multi(concat_images) | ||
| ) | ||
| retain_ratio = self.pruning_paras.get('RETAIN_RATIO', 0.1) | ||
| # C = int(os.environ.get("C", 8)) | ||
| # tau = float(os.environ.get("T", 0.8)) | ||
| tau = self.pruning_paras.get('T', 0.1) | ||
| # P = int(os.environ.get("P", 4)) | ||
| Beta = float(os.environ.get('BETA', 0.6)) | ||
| D = float(os.environ.get('D', 0)) | ||
| K = int(os.environ.get('K', 7)) | ||
| max_window_size = int(os.environ.get('MAX_WINDOW_SIZE', 1024)) | ||
| # NO_BETA = os.environ.get('NO_BETA', '1') | ||
| # rank0_print(f"retain_ratio: {retain_ratio}, | ||
| # tau: {tau}, Beta: {Beta}, D: {D}, K: {K}, | ||
| # max_window_size: {max_window_size}, NO_BETA: {NO_BETA}") | ||
| # image_features,all_faster_video_features = | ||
| # self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes) | ||
|
|
||
| # This is a list, each element is [num_images, patch * patch, dim] | ||
| # rank_print(f"Concat images : {concat_images.shape}") | ||
| encoded_image_features = torch.split(encoded_image_features, split_sizes) | ||
| image_features = [] | ||
| for idx, image_feat in enumerate(encoded_image_features): | ||
| if idx in video_idx_in_batch: | ||
| # [modify] | ||
| # image_features.append(self.get_2dPool(image_feat)) | ||
| # image_feat: (batch_size, seq_len, embed_dim) | ||
| # attn_weights: (batch_size, seq_len) | ||
| pooled_image_feat = self.get_2dPool( | ||
| image_feat | ||
| ) # (batch_size, seq_len', embed_dim) | ||
| attn_weights = attn_weights.unsqueeze(-1) | ||
| attn_weights = self.get_2dPool(attn_weights) | ||
| attn_weights = attn_weights.squeeze(-1) # (batch_size, seq_len') | ||
|
|
||
| batch_size, seq_len, embed_dim = pooled_image_feat.shape | ||
|
|
||
| pooled_image_feat_normed = torch.nn.functional.normalize( | ||
| pooled_image_feat, p=2, dim=-1 | ||
| ) | ||
| feature_sim = torch.nn.functional.cosine_similarity( | ||
| pooled_image_feat_normed[:-1], | ||
| pooled_image_feat_normed[1:], | ||
| dim=-1, | ||
| ) # (batch_size-1, seq_len') | ||
|
|
||
| selected_frames, total_reduced = self.select_static_windows( | ||
| feature_sim, batch_size, tau, max_window_size | ||
| ) | ||
| # rank0_print(f"Selected frames: {selected_frames}") | ||
| # rank0_print(f"Total reduced features: {total_reduced}") | ||
|
|
||
| total_tokens = batch_size * seq_len | ||
| retain_ratio = min( | ||
| retain_ratio / ((total_tokens - total_reduced) / total_tokens), | ||
| 1, | ||
| ) | ||
| # rank0_print(f"After static pruning, retain ratio: {retain_ratio}") | ||
|
|
||
| ( | ||
| static_feat, | ||
| dynamic_feat, | ||
| _, | ||
| dynamic_attn, | ||
| static_pos, | ||
| dynamic_pos, | ||
| ) = self.get_static_dynamic_features( | ||
| pooled_image_feat, | ||
| attn_weights, | ||
| selected_frames, | ||
| feature_sim, | ||
| tau, | ||
| ) | ||
|
|
||
| segment_features = [] | ||
| for idx, (start, end) in enumerate(selected_frames): | ||
| window_size = end - start + 1 | ||
| segment_features.append( | ||
| self.holitom( | ||
| static_feat[idx], | ||
| dynamic_feat[idx], | ||
| dynamic_attn[idx], | ||
| static_pos[idx], | ||
| dynamic_pos[idx], | ||
| window_size, | ||
| retain_ratio, | ||
| D, | ||
| Beta, | ||
| K, | ||
| images_dtype, | ||
| mm_newline_position, | ||
| ) | ||
| ) | ||
| image_features.append(torch.cat(segment_features, dim=0)) | ||
|
|
||
| else: | ||
| image_features.append(image_feat) | ||
| # image_features = | ||
| # self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes) | ||
| # rank_print(f"Encoded image feats : {[x.shape for x in image_features]}") | ||
| # image_features = torch.split(image_features, split_sizes, dim=0) | ||
|
|
||
| if mm_patch_merge_type == 'flat': | ||
| image_features = [x.flatten(0, 1) for x in image_features] | ||
|
|
||
| elif mm_patch_merge_type.startswith('spatial'): | ||
| new_image_features = [] | ||
| for image_idx, image_feature in enumerate(image_features): | ||
| # FIXME: now assume the image is square, and split to 2x2 patches | ||
| # num_patches = h * w, where h = w = sqrt(num_patches) | ||
| # currently image_feature is a tensor of shape (4, num_patches, hidden_size) | ||
| # we want to first unflatten it to (2, 2, h, w, hidden_size) | ||
| # rank0_print("At least we are reaching here") | ||
| # import pdb; pdb.set_trace() | ||
| if image_idx in video_idx_in_batch: # video operations | ||
| # rank0_print("Video") | ||
| if mm_newline_position == 'grid': | ||
| new_image_features.append(image_feature) | ||
| elif mm_newline_position == 'frame': | ||
| # Frame-wise | ||
| image_feature = self.add_token_per_frame(image_feature) | ||
|
|
||
| new_image_features.append(image_feature.flatten(0, 1)) | ||
|
|
||
| elif mm_newline_position == 'one_token': | ||
| # one-token | ||
| # image_feature = image_feature.flatten(0, 1) | ||
| if 'unpad' in mm_patch_merge_type: | ||
| image_feature = torch.cat( | ||
| ( | ||
| image_feature, | ||
| self.model.image_newline[None].to( | ||
| image_feature.device | ||
| ), | ||
| ), | ||
| dim=0, | ||
| ) | ||
| new_image_features.append(image_feature) | ||
| elif mm_newline_position == 'no_token': | ||
| new_image_features.append(image_feature.flatten(0, 1)) | ||
| else: | ||
| raise ValueError( | ||
| f'Unexpected mm_newline_position: {mm_newline_position}' | ||
| ) | ||
| elif ( | ||
| image_feature.shape[0] > 1 | ||
| ): # multi patches and multi images operations | ||
| # rank0_print("Single-images") | ||
| base_image_feature = image_feature[0] | ||
| image_feature = image_feature[1:] | ||
| height = width = self.get_vision_tower().num_patches_per_side | ||
| assert height * width == base_image_feature.shape[0] | ||
|
|
||
| if 'anyres_max' in image_aspect_ratio: | ||
| matched_anyres_max_num_patches = re.match( | ||
| r'anyres_max_(\d+)', image_aspect_ratio | ||
| ) | ||
| if matched_anyres_max_num_patches: | ||
| max_num_patches = int( | ||
| matched_anyres_max_num_patches.group(1) | ||
| ) | ||
|
|
||
| if ( | ||
| image_aspect_ratio == 'anyres' | ||
| or 'anyres_max' in image_aspect_ratio | ||
| ): | ||
| if hasattr(self.get_vision_tower(), 'image_size'): | ||
| vision_tower_image_size = ( | ||
| self.get_vision_tower().image_size | ||
| ) | ||
| else: | ||
| raise ValueError( | ||
| 'vision_tower_image_size is not found in the vision tower.' | ||
| ) | ||
| try: | ||
| num_patch_width, num_patch_height = ( | ||
| get_anyres_image_grid_shape( | ||
| image_sizes[image_idx], | ||
| self.config.image_grid_pinpoints, | ||
| vision_tower_image_size, | ||
| ) | ||
| ) | ||
| except Exception as e: | ||
| rank0_print(f'Error: {e}') | ||
| num_patch_width, num_patch_height = 2, 2 | ||
| image_feature = image_feature.view( | ||
| num_patch_height, num_patch_width, height, width, -1 | ||
| ) | ||
| else: | ||
| image_feature = image_feature.view(2, 2, height, width, -1) | ||
|
|
||
| if 'maxpool2x2' in mm_patch_merge_type: | ||
| image_feature = image_feature.permute( | ||
| 4, 0, 2, 1, 3 | ||
| ).contiguous() | ||
| image_feature = image_feature.flatten(1, 2).flatten(2, 3) | ||
| image_feature = nn.functional.max_pool2d(image_feature, 2) | ||
| image_feature = image_feature.flatten(1, 2).transpose(0, 1) | ||
| elif ( | ||
| 'unpad' in mm_patch_merge_type | ||
| and 'anyres_max' in image_aspect_ratio | ||
| and matched_anyres_max_num_patches | ||
| ): | ||
| unit = image_feature.shape[2] | ||
| image_feature = image_feature.permute( | ||
| 4, 0, 2, 1, 3 | ||
| ).contiguous() | ||
| image_feature = image_feature.flatten(1, 2).flatten(2, 3) | ||
| image_feature = unpad_image( | ||
| image_feature, image_sizes[image_idx] | ||
| ) | ||
| c, h, w = image_feature.shape | ||
| times = math.sqrt(h * w / (max_num_patches * unit**2)) | ||
| if times > 1.1: | ||
| image_feature = image_feature[None] | ||
| image_feature = nn.functional.interpolate( | ||
| image_feature, | ||
| [int(h // times), int(w // times)], | ||
| mode='bilinear', | ||
| )[0] | ||
| image_feature = torch.cat( | ||
| ( | ||
| image_feature, | ||
| self.model.image_newline[:, None, None] | ||
| .expand(*image_feature.shape[:-1], 1) | ||
| .to(image_feature.device), | ||
| ), | ||
| dim=-1, | ||
| ) | ||
| image_feature = image_feature.flatten(1, 2).transpose(0, 1) | ||
| elif 'unpad' in mm_patch_merge_type: | ||
| image_feature = image_feature.permute( | ||
| 4, 0, 2, 1, 3 | ||
| ).contiguous() | ||
| image_feature = image_feature.flatten(1, 2).flatten(2, 3) | ||
| image_feature = unpad_image( | ||
| image_feature, image_sizes[image_idx] | ||
| ) | ||
| image_feature = torch.cat( | ||
| ( | ||
| image_feature, | ||
| self.model.image_newline[:, None, None] | ||
| .expand(*image_feature.shape[:-1], 1) | ||
| .to(image_feature.device), | ||
| ), | ||
| dim=-1, | ||
| ) | ||
| image_feature = image_feature.flatten(1, 2).transpose(0, 1) | ||
| else: | ||
| image_feature = image_feature.permute( | ||
| 0, 2, 1, 3, 4 | ||
| ).contiguous() | ||
| image_feature = image_feature.flatten(0, 3) | ||
| if 'nobase' in mm_patch_merge_type: | ||
| pass | ||
| else: | ||
| image_feature = torch.cat( | ||
| (base_image_feature, image_feature), dim=0 | ||
| ) | ||
| new_image_features.append(image_feature) | ||
| else: # single image operations | ||
| image_feature = image_feature[0] | ||
| if 'unpad' in mm_patch_merge_type: | ||
| image_feature = torch.cat( | ||
| (image_feature, self.model.image_newline[None]), dim=0 | ||
| ) | ||
|
|
||
| new_image_features.append(image_feature) | ||
| image_features = new_image_features | ||
| else: | ||
| raise ValueError( | ||
| f'Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}' | ||
| ) | ||
| else: | ||
| image_features = self.encode_images(images) | ||
|
|
||
| # TODO: image start / end is not implemented here to support pretraining. | ||
| if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr( | ||
| self.config, 'mm_use_im_start_end', False | ||
| ): | ||
| raise NotImplementedError | ||
| # rank_print(f"Total images : {len(image_features)}") | ||
|
|
||
| # Let's just add dummy tensors if they do not exist, | ||
| # it is a headache to deal with None all the time. | ||
| # But it is not ideal, and if you have a better idea, | ||
| # please open an issue / submit a PR, thanks. | ||
| _labels = labels | ||
| _position_ids = position_ids | ||
| _attention_mask = attention_mask | ||
| if attention_mask is None: | ||
| attention_mask = torch.ones_like(input_ids, dtype=torch.bool) | ||
| else: | ||
| attention_mask = attention_mask.bool() | ||
| if position_ids is None: | ||
| position_ids = torch.arange( | ||
| 0, input_ids.shape[1], dtype=torch.long, device=input_ids.device | ||
| ) | ||
| if labels is None: | ||
| labels = torch.full_like(input_ids, IGNORE_INDEX) | ||
|
|
||
| # remove the padding using attention_mask -- FIXME | ||
| # _input_ids = input_ids | ||
| input_ids = [ | ||
| cur_input_ids[cur_attention_mask] | ||
| for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) | ||
| ] | ||
| labels = [ | ||
| cur_labels[cur_attention_mask] | ||
| for cur_labels, cur_attention_mask in zip(labels, attention_mask) | ||
| ] | ||
|
|
||
| new_input_embeds = [] | ||
| new_labels = [] | ||
| if os.getenv('HOLITOM_k') is not None and os.getenv('HOLITOM_r') is not None: | ||
| # [modified] | ||
| image_token_posi = [] | ||
| prompt_len = [] | ||
| cur_image_idx = 0 | ||
| # rank_print("Inserting Images embedding") | ||
| for batch_idx, cur_input_ids in enumerate(input_ids): | ||
| if ( | ||
| os.getenv('HOLITOM_k') is not None | ||
| and os.getenv('HOLITOM_r') is not None | ||
| ): | ||
| # [modified] | ||
| # record image position for further dropping | ||
| image_index = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[ | ||
| 0 | ||
| ].tolist() | ||
| if image_index == []: | ||
| image_token_posi.append(-1) | ||
| else: | ||
| image_token_posi.append(image_index[0]) | ||
|
|
||
| # record input instruction length in inference mode | ||
| if not self.training: | ||
| if image_index == []: | ||
| prompt_len.append(cur_input_ids.shape[0]) | ||
| else: | ||
| prompt_len.append( | ||
| cur_input_ids.shape[0] - 1 | ||
| ) # consider image place holder | ||
|
|
||
| num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() | ||
| # rank0_print(num_images) | ||
| if num_images == 0: | ||
| cur_image_features = image_features[cur_image_idx] | ||
| cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) | ||
| cur_input_embeds = torch.cat( | ||
| [cur_input_embeds_1, cur_image_features[0:0]], dim=0 | ||
| ) | ||
| new_input_embeds.append(cur_input_embeds) | ||
| new_labels.append(labels[batch_idx]) | ||
| cur_image_idx += 1 | ||
| continue | ||
|
|
||
| image_token_indices = ( | ||
| [-1] | ||
| + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() | ||
| + [cur_input_ids.shape[0]] | ||
| ) | ||
| cur_input_ids_noim = [] | ||
| cur_labels = labels[batch_idx] | ||
| cur_labels_noim = [] | ||
| for i in range(len(image_token_indices) - 1): | ||
| cur_input_ids_noim.append( | ||
| cur_input_ids[ | ||
| image_token_indices[i] + 1: image_token_indices[i + 1] | ||
| ] | ||
| ) | ||
| cur_labels_noim.append( | ||
| cur_labels[image_token_indices[i] + 1: image_token_indices[i + 1]] | ||
| ) | ||
| # [modify] | ||
| # text_token_count = sum([x.shape[0] for x in cur_labels_noim]) | ||
| # vision_token_count = len(image_features[cur_image_idx]) | ||
| # rank0_print(f"Batch {batch_idx}: | ||
| # Text tokens: {text_token_count} Original Vision tokens: {vision_token_count}") | ||
|
|
||
| split_sizes = [x.shape[0] for x in cur_labels_noim] | ||
| cur_input_embeds = self.get_model().embed_tokens( | ||
| torch.cat(cur_input_ids_noim) | ||
| ) | ||
| cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) | ||
| cur_new_input_embeds = [] | ||
| cur_new_labels = [] | ||
|
|
||
| for i in range(num_images + 1): | ||
| cur_new_input_embeds.append(cur_input_embeds_no_im[i]) | ||
| cur_new_labels.append(cur_labels_noim[i]) | ||
| if i < num_images: | ||
| try: | ||
| cur_image_features = image_features[cur_image_idx] | ||
| except IndexError: | ||
| cur_image_features = image_features[cur_image_idx - 1] | ||
| cur_image_idx += 1 | ||
| cur_new_input_embeds.append(cur_image_features) | ||
| cur_new_labels.append( | ||
| torch.full( | ||
| (cur_image_features.shape[0],), | ||
| IGNORE_INDEX, | ||
| device=cur_labels.device, | ||
| dtype=cur_labels.dtype, | ||
| ) | ||
| ) | ||
|
|
||
| cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] | ||
|
|
||
| # import pdb; pdb.set_trace() | ||
| cur_new_input_embeds = torch.cat(cur_new_input_embeds) | ||
| cur_new_labels = torch.cat(cur_new_labels) | ||
|
|
||
| new_input_embeds.append(cur_new_input_embeds) | ||
| new_labels.append(cur_new_labels) | ||
|
|
||
| if os.getenv('HOLITOM_k') is not None and os.getenv('HOLITOM_r') is not None: | ||
| # [modified] | ||
| self.model.image_token_posi = image_token_posi | ||
| self.model.prompt_len = prompt_len | ||
| self.model.image_tokens = [ | ||
| image_feature.shape[0] for image_feature in image_features | ||
| ] | ||
|
|
||
| # Truncate sequences to max length as image embeddings can make the sequence longer | ||
| tokenizer_model_max_length = getattr( | ||
| self.config, 'tokenizer_model_max_length', None | ||
| ) | ||
| # rank_print("Finishing Inserting") | ||
|
|
||
| new_input_embeds = [ | ||
| x[:tokenizer_model_max_length] | ||
| for x, modality in zip(new_input_embeds, modalities) | ||
| ] | ||
| new_labels = [ | ||
| x[:tokenizer_model_max_length] | ||
| for x, modality in zip(new_labels, modalities) | ||
| ] | ||
|
|
||
| # Combine them | ||
| max_len = max(x.shape[0] for x in new_input_embeds) | ||
| batch_size = len(new_input_embeds) | ||
|
|
||
| new_input_embeds_padded = [] | ||
| new_labels_padded = torch.full( | ||
| (batch_size, max_len), | ||
| IGNORE_INDEX, | ||
| dtype=new_labels[0].dtype, | ||
| device=new_labels[0].device, | ||
| ) | ||
| attention_mask = torch.zeros( | ||
| (batch_size, max_len), | ||
| dtype=attention_mask.dtype, | ||
| device=attention_mask.device, | ||
| ) | ||
| position_ids = torch.zeros( | ||
| (batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device | ||
| ) | ||
| # rank0_print("Prepare pos id") | ||
|
|
||
| for i, (cur_new_embed, cur_new_labels) in enumerate( | ||
| zip(new_input_embeds, new_labels) | ||
| ): | ||
| cur_len = cur_new_embed.shape[0] | ||
| if getattr(self.config, 'tokenizer_padding_side', 'right') == 'left': | ||
| new_input_embeds_padded.append( | ||
| torch.cat( | ||
| ( | ||
| torch.zeros( | ||
| (max_len - cur_len, cur_new_embed.shape[1]), | ||
| dtype=cur_new_embed.dtype, | ||
| device=cur_new_embed.device, | ||
| ), | ||
| cur_new_embed, | ||
| ), | ||
| dim=0, | ||
| ) | ||
| ) | ||
| if cur_len > 0: | ||
| new_labels_padded[i, -cur_len:] = cur_new_labels | ||
| attention_mask[i, -cur_len:] = True | ||
| position_ids[i, -cur_len:] = torch.arange( | ||
| 0, cur_len, dtype=position_ids.dtype, device=position_ids.device | ||
| ) | ||
| else: | ||
| new_input_embeds_padded.append( | ||
| torch.cat( | ||
| ( | ||
| cur_new_embed, | ||
| torch.zeros( | ||
| (max_len - cur_len, cur_new_embed.shape[1]), | ||
| dtype=cur_new_embed.dtype, | ||
| device=cur_new_embed.device, | ||
| ), | ||
| ), | ||
| dim=0, | ||
| ) | ||
| ) | ||
| if cur_len > 0: | ||
| new_labels_padded[i, :cur_len] = cur_new_labels | ||
| attention_mask[i, :cur_len] = True | ||
| position_ids[i, :cur_len] = torch.arange( | ||
| 0, cur_len, dtype=position_ids.dtype, device=position_ids.device | ||
| ) | ||
|
|
||
| new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) | ||
| # rank0_print("tokenizer padding") | ||
|
|
||
| if _labels is None: | ||
| new_labels = None | ||
| else: | ||
| new_labels = new_labels_padded | ||
|
|
||
| if _attention_mask is None: | ||
| attention_mask = None | ||
| else: | ||
| attention_mask = attention_mask.to(dtype=_attention_mask.dtype) | ||
|
|
||
| if _position_ids is None: | ||
| position_ids = None | ||
| if getattr(self.config, 'use_pos_skipping', False) and self.training: | ||
| position_ids = ( | ||
| torch.arange(new_input_embeds.size(1), device=new_input_embeds.device) | ||
| .unsqueeze(0) | ||
| .to(new_input_embeds.device) | ||
| ) | ||
| split_position = random.randint(0, new_input_embeds.size(1)) | ||
| left_add = random.randint(0, self.config.pos_skipping_range) | ||
| right_add = random.randint(left_add, self.config.pos_skipping_range) | ||
| position_ids[:, :split_position] += left_add | ||
| position_ids[:, split_position:] += right_add | ||
| # import pdb; pdb.set_trace() | ||
| # rank0_print("Finish preparing") | ||
| return ( | ||
| None, | ||
| position_ids, | ||
| attention_mask, | ||
| past_key_values, | ||
| new_input_embeds, | ||
| new_labels, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function prepare_inputs_labels_for_multimodal is excessively long (over 500 lines). Its size and complexity make it very difficult to read, debug, and maintain. It handles multiple, distinct responsibilities, such as image encoding, complex token pruning logic, and preparing embeddings. To improve maintainability, please refactor this monolithic function into several smaller, well-defined helper functions, each with a single responsibility.
| Beta = float(os.environ.get('BETA', 0.6)) | ||
| D = float(os.environ.get('D', 0)) | ||
| K = int(os.environ.get('K', 7)) | ||
| max_window_size = int(os.environ.get('MAX_WINDOW_SIZE', 1024)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function reads several hyperparameters (BETA, D, K, MAX_WINDOW_SIZE) from environment variables. This approach is not ideal for production code as it makes configuration implicit and harder to manage. It would be more robust to pass these values through the model's configuration, similar to how retain_ratio and tau are handled.
| try: | ||
| cur_image_features = image_features[cur_image_idx] | ||
| except IndexError: | ||
| cur_image_features = image_features[cur_image_idx - 1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This try-except IndexError block appears to be masking a potential bug. Silently catching the exception and reusing the previous image features could lead to subtle errors. It would be much safer to investigate the root cause of why cur_image_idx might go out of bounds and fix the underlying logic that increments it.
No description provided.