|
| 1 | +import functools |
| 2 | +import math |
| 3 | + |
| 4 | +import torch |
| 5 | +from torch import nn |
| 6 | +from transformers.modeling_attn_mask_utils import \ |
| 7 | + _prepare_4d_causal_attention_mask |
| 8 | +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb |
| 9 | + |
| 10 | +from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY |
| 11 | + |
| 12 | +from .token_reduction_module import TokenReductionModule |
| 13 | + |
| 14 | + |
| 15 | +@TOKEN_REDUCTION_REGISTRY.register('PyramidDrop') |
| 16 | +class PyramidDrop(TokenReductionModule): |
| 17 | + def __init__(self, config, model, blocks): |
| 18 | + super().__init__(config, model, blocks) |
| 19 | + self.add_sparse_config() |
| 20 | + self.register_reduction_modules() |
| 21 | + |
| 22 | + def add_sparse_config(self): |
| 23 | + special_config = self.config.get('special', {}) |
| 24 | + self.pruning_loc = special_config['layer_list'] |
| 25 | + image_token_ratio_list = special_config['image_token_ratio_list'] |
| 26 | + image_token_ratio_list.insert(0, 1.0) |
| 27 | + special_config['image_token_ratio_list'] = image_token_ratio_list |
| 28 | + special_config['tokenizer_padding_side'] = getattr( |
| 29 | + self.model.vlm_model.language_model.model.config, |
| 30 | + 'tokenizer_padding_side', |
| 31 | + 'right', |
| 32 | + ) |
| 33 | + special_config['image_token_index'] = self.model.pruning_config[ |
| 34 | + 'image_token_index' |
| 35 | + ] |
| 36 | + self.model.model.parameters = special_config |
| 37 | + |
| 38 | + def register_reduction_modules(self): |
| 39 | + |
| 40 | + def pruning_hook(module, args, kwargs, pruning_pars, cur_num, layer_idx): |
| 41 | + |
| 42 | + if layer_idx == self.pruning_loc[0]: |
| 43 | + position_ids = kwargs['position_ids'] |
| 44 | + attention_mask = kwargs['attention_mask'] |
| 45 | + position_embeddings = kwargs['position_embeddings'] |
| 46 | + else: |
| 47 | + attention_mask = pruning_pars['attention_mask'] |
| 48 | + position_ids = pruning_pars['position_ids'] |
| 49 | + position_embeddings = pruning_pars['position_embeddings'] |
| 50 | + |
| 51 | + features = args[0] |
| 52 | + _position_ids = position_ids |
| 53 | + _attention_mask = attention_mask |
| 54 | + prompt_len = pruning_pars['prompt_len'] |
| 55 | + image_tokens_list = pruning_pars['image_tokens'] |
| 56 | + image_token_posi = pruning_pars['image_token_posi'] |
| 57 | + image_token_ratio_list = pruning_pars['image_token_ratio_list'] |
| 58 | + |
| 59 | + if position_ids is None: |
| 60 | + position_ids = torch.arange( |
| 61 | + 0, features.shape[1], dtype=torch.long, device=features.device |
| 62 | + ).unsqueeze(0) |
| 63 | + |
| 64 | + if pruning_pars['tokenizer_padding_side'] == 'right': |
| 65 | + |
| 66 | + batch_size = features.shape[0] |
| 67 | + image_tokens = [ |
| 68 | + int(cur_image_token * image_token_ratio_list[cur_num]) |
| 69 | + for cur_image_token in image_tokens_list |
| 70 | + ] |
| 71 | + keep_length = [ |
| 72 | + int(cur_image_token * image_token_ratio_list[cur_num + 1]) |
| 73 | + for cur_image_token in image_tokens_list |
| 74 | + ] |
| 75 | + |
| 76 | + features_list = [] |
| 77 | + attention_mask_list = [] |
| 78 | + |
| 79 | + if attention_mask is None: |
| 80 | + attention_mask = torch.ones( |
| 81 | + (batch_size, features.shape[1]), |
| 82 | + dtype=torch.bool, |
| 83 | + device=features.device, |
| 84 | + ) |
| 85 | + else: |
| 86 | + attention_mask = attention_mask.bool() |
| 87 | + |
| 88 | + # obtain query_states and key_states to calculate attention map |
| 89 | + hidden_states = features.clone().detach() |
| 90 | + self_attn = module.self_attn |
| 91 | + hidden_states = module.input_layernorm(hidden_states) |
| 92 | + |
| 93 | + num_heads = self_attn.num_heads |
| 94 | + num_key_value_heads = self_attn.num_key_value_heads |
| 95 | + head_dim = self_attn.head_dim |
| 96 | + |
| 97 | + bsz, q_len, _ = hidden_states.size() |
| 98 | + |
| 99 | + query_states = self_attn.q_proj(hidden_states) |
| 100 | + key_states = self_attn.k_proj(hidden_states) |
| 101 | + value_states = self_attn.v_proj(hidden_states) |
| 102 | + |
| 103 | + query_states = query_states.view( |
| 104 | + bsz, q_len, num_heads, head_dim |
| 105 | + ).transpose(1, 2) |
| 106 | + key_states = key_states.view( |
| 107 | + bsz, q_len, num_key_value_heads, head_dim |
| 108 | + ).transpose(1, 2) |
| 109 | + value_states = value_states.view( |
| 110 | + bsz, q_len, num_key_value_heads, head_dim |
| 111 | + ).transpose(1, 2) |
| 112 | + |
| 113 | + if position_embeddings is None: |
| 114 | + cos, sin = self_attn.rotary_emb(value_states, position_ids) |
| 115 | + else: |
| 116 | + cos, sin = position_embeddings |
| 117 | + |
| 118 | + query_states, key_states = apply_rotary_pos_emb( |
| 119 | + query_states, key_states, cos, sin |
| 120 | + ) |
| 121 | + |
| 122 | + # attention_mask |
| 123 | + eager_attention_mask = _prepare_4d_causal_attention_mask( |
| 124 | + attention_mask, |
| 125 | + (batch_size, q_len), |
| 126 | + hidden_states, |
| 127 | + past_key_values_length=0, |
| 128 | + ).to(device=query_states.device) |
| 129 | + |
| 130 | + # take valid features |
| 131 | + features = [ |
| 132 | + cur_features[cur_attention_mask] |
| 133 | + for cur_features, cur_attention_mask in zip( |
| 134 | + features, attention_mask |
| 135 | + ) |
| 136 | + ] |
| 137 | + attention_mask = [ |
| 138 | + cur_attention_mask[cur_attention_mask] |
| 139 | + for cur_attention_mask, cur_attention_mask in zip( |
| 140 | + attention_mask, attention_mask |
| 141 | + ) |
| 142 | + ] |
| 143 | + |
| 144 | + # rank & drop |
| 145 | + for i in range(batch_size): |
| 146 | + image_index = image_token_posi[i] |
| 147 | + if image_index == -1: |
| 148 | + cur_input_embeds = features[i] |
| 149 | + features_list.append(cur_input_embeds) |
| 150 | + attention_mask_list.append(attention_mask[i]) |
| 151 | + continue |
| 152 | + |
| 153 | + # obtain current states |
| 154 | + cur_key_states = key_states[i] |
| 155 | + cur_query_states = query_states[i] |
| 156 | + cur_eager_attention_mask = eager_attention_mask[i] |
| 157 | + |
| 158 | + prompt_total_len = prompt_len[i] + image_tokens[i] |
| 159 | + text_query_states = cur_query_states[ |
| 160 | + :, prompt_total_len - 1, : |
| 161 | + ].unsqueeze(1) |
| 162 | + text_eager_attention_mask = cur_eager_attention_mask[ |
| 163 | + :, prompt_total_len - 1, : |
| 164 | + ].unsqueeze(1) |
| 165 | + |
| 166 | + # calculate attention map |
| 167 | + attn_weights = torch.matmul( |
| 168 | + text_query_states, cur_key_states.transpose(1, 2) |
| 169 | + ) / math.sqrt( |
| 170 | + head_dim |
| 171 | + ) # (num_head, text_token,seq_len) |
| 172 | + attn_weights = attn_weights + text_eager_attention_mask |
| 173 | + attn_weights = nn.functional.softmax( |
| 174 | + attn_weights, dim=-1, dtype=torch.float32 |
| 175 | + ).to( |
| 176 | + query_states.dtype |
| 177 | + ) # (num_head, text_token,seq_len) |
| 178 | + |
| 179 | + attention_avg_head = torch.mean( |
| 180 | + attn_weights, dim=0 |
| 181 | + ) # ave across heads |
| 182 | + attention_avg_head = attention_avg_head[ |
| 183 | + :, image_index: image_index + image_tokens[i] |
| 184 | + ] # select image token as keys |
| 185 | + attention_avg_text = torch.mean(attention_avg_head, dim=0) # (576) |
| 186 | + |
| 187 | + # rank and drop by attention score |
| 188 | + top_rank_index = attention_avg_text.topk(keep_length[i]).indices |
| 189 | + top_rank_index = top_rank_index + image_index |
| 190 | + top_rank_index = top_rank_index.sort().values |
| 191 | + |
| 192 | + start_index = image_index + image_tokens[i] |
| 193 | + new_input_embeds = torch.cat( |
| 194 | + [ |
| 195 | + features[i][:image_index, :], |
| 196 | + features[i][top_rank_index, :], |
| 197 | + features[i][start_index:, :], |
| 198 | + ], |
| 199 | + dim=0, |
| 200 | + ) |
| 201 | + new_attention_mask = torch.cat( |
| 202 | + [ |
| 203 | + attention_mask[i][:image_index], |
| 204 | + attention_mask[i][top_rank_index], |
| 205 | + attention_mask[i][start_index:], |
| 206 | + ], |
| 207 | + dim=0, |
| 208 | + ) |
| 209 | + |
| 210 | + features_list.append(new_input_embeds) |
| 211 | + attention_mask_list.append(new_attention_mask) |
| 212 | + |
| 213 | + # Truncate sequences to max length as image embeddings can make the sequence longer |
| 214 | + tokenizer_model_max_length = getattr( |
| 215 | + self.model.vlm_model.language_model.model.config, |
| 216 | + 'tokenizer_model_max_length', |
| 217 | + 2048, |
| 218 | + ) |
| 219 | + if tokenizer_model_max_length is not None: |
| 220 | + new_input_embeds = [ |
| 221 | + x[:tokenizer_model_max_length] for x in features_list |
| 222 | + ] |
| 223 | + new_attention_mask = [ |
| 224 | + x[:tokenizer_model_max_length] for x in attention_mask_list |
| 225 | + ] |
| 226 | + |
| 227 | + max_len = max(x.shape[0] for x in new_input_embeds) |
| 228 | + |
| 229 | + # padding the sequences to form batch |
| 230 | + embeds_padded = [] |
| 231 | + attention_mask_padded = [] |
| 232 | + position_ids = torch.zeros( |
| 233 | + (batch_size, max_len), |
| 234 | + dtype=position_ids.dtype, |
| 235 | + device=position_ids.device, |
| 236 | + ) |
| 237 | + for i, cur_new_embed in enumerate(new_input_embeds): |
| 238 | + cur_len_emb = cur_new_embed.shape[0] |
| 239 | + dif = max_len - cur_len_emb # padding to longest seq |
| 240 | + |
| 241 | + cur_new_embed = torch.cat( |
| 242 | + [ |
| 243 | + cur_new_embed, |
| 244 | + torch.zeros( |
| 245 | + (dif, cur_new_embed.shape[1]), |
| 246 | + dtype=cur_new_embed.dtype, |
| 247 | + device=cur_new_embed.device, |
| 248 | + ), |
| 249 | + ], |
| 250 | + dim=0, |
| 251 | + ) |
| 252 | + cur_attention_mask = new_attention_mask[i] |
| 253 | + cur_attention_mask = torch.cat( |
| 254 | + [ |
| 255 | + cur_attention_mask, |
| 256 | + torch.full( |
| 257 | + (dif,), |
| 258 | + False, |
| 259 | + dtype=cur_attention_mask.dtype, |
| 260 | + device=cur_attention_mask.device, |
| 261 | + ), |
| 262 | + ], |
| 263 | + dim=0, |
| 264 | + ) |
| 265 | + |
| 266 | + embeds_padded.append(cur_new_embed) |
| 267 | + |
| 268 | + attention_mask_padded.append(cur_attention_mask) |
| 269 | + |
| 270 | + cur_len = new_attention_mask[i].sum().item() |
| 271 | + position_ids[i, :cur_len] = torch.arange( |
| 272 | + 0, cur_len, dtype=position_ids.dtype, device=position_ids.device |
| 273 | + ) |
| 274 | + |
| 275 | + new_input_embeds = torch.stack(embeds_padded, dim=0) |
| 276 | + new_input_embeds = new_input_embeds.to(features[0].dtype) |
| 277 | + |
| 278 | + new_attention_mask = torch.stack(attention_mask_padded, dim=0) |
| 279 | + |
| 280 | + if _position_ids is None: |
| 281 | + position_ids = None |
| 282 | + |
| 283 | + if _attention_mask is None: |
| 284 | + new_attention_mask = None |
| 285 | + else: |
| 286 | + new_attention_mask = new_attention_mask.to( |
| 287 | + dtype=_attention_mask.dtype |
| 288 | + ) |
| 289 | + |
| 290 | + kwargs['attention_mask'] = new_attention_mask |
| 291 | + kwargs['position_ids'] = position_ids |
| 292 | + kwargs['position_embeddings'] = None |
| 293 | + pruning_pars['attention_mask'] = new_attention_mask |
| 294 | + pruning_pars['position_ids'] = position_ids |
| 295 | + pruning_pars['position_embeddings'] = None |
| 296 | + |
| 297 | + return (new_input_embeds,), kwargs |
| 298 | + |
| 299 | + def input_hook(module, input_args, pruning_pars): |
| 300 | + input_ids = input_args[0] |
| 301 | + pre_prompt_length_list = [] |
| 302 | + image_token_posi = [] |
| 303 | + image_tokens = [] |
| 304 | + IMAGE_TOKEN_INDEX = pruning_pars['image_token_index'] |
| 305 | + |
| 306 | + # find the position of the first image token |
| 307 | + for seq in input_ids: |
| 308 | + image_token_idxs = (seq == IMAGE_TOKEN_INDEX).nonzero(as_tuple=True)[0] |
| 309 | + image_tokens.append(image_token_idxs.shape[0]) |
| 310 | + image_token_posi.append(image_token_idxs[0].item()) |
| 311 | + pre_prompt_length_list.append(seq.shape[0] - image_token_idxs.shape[0]) |
| 312 | + |
| 313 | + pruning_pars['prompt_len'] = pre_prompt_length_list |
| 314 | + pruning_pars['image_token_posi'] = image_token_posi |
| 315 | + pruning_pars['image_tokens'] = image_tokens |
| 316 | + |
| 317 | + return input_args |
| 318 | + |
| 319 | + def read_parameter_hook(module, args, kwargs, pruning_pars): |
| 320 | + kwargs['attention_mask'] = pruning_pars['attention_mask'] |
| 321 | + # kwargs['cache_position'] = pruning_pars['cache_position'] |
| 322 | + kwargs['position_ids'] = pruning_pars['position_ids'] |
| 323 | + kwargs['position_embeddings'] = pruning_pars['position_embeddings'] |
| 324 | + |
| 325 | + return args, kwargs |
| 326 | + |
| 327 | + self.model.embed_tokens.register_forward_pre_hook( |
| 328 | + functools.partial(input_hook, pruning_pars=self.model.model.parameters) |
| 329 | + ) |
| 330 | + |
| 331 | + for layer_idx in range(self.pruning_loc[0], len(self.blocks)): |
| 332 | + if layer_idx in self.pruning_loc: |
| 333 | + stage = self.pruning_loc.index(layer_idx) |
| 334 | + self.blocks[layer_idx].register_forward_pre_hook( |
| 335 | + functools.partial( |
| 336 | + pruning_hook, |
| 337 | + pruning_pars=self.model.model.parameters, |
| 338 | + cur_num=stage, |
| 339 | + layer_idx=layer_idx, |
| 340 | + ), |
| 341 | + with_kwargs=True, |
| 342 | + ) |
| 343 | + else: |
| 344 | + self.blocks[layer_idx].register_forward_pre_hook( |
| 345 | + functools.partial( |
| 346 | + read_parameter_hook, pruning_pars=self.model.model.parameters |
| 347 | + ), |
| 348 | + with_kwargs=True, |
| 349 | + ) |
0 commit comments