|
27 | 27 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
28 | 28 |
|
29 | 29 | from ...activations import ACT2FN, get_activation |
30 | | -from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache |
| 30 | +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache |
31 | 31 | from ...generation import GenerationMixin |
32 | | -from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_attention_mask_for_sdpa |
| 32 | +from ...masking_utils import create_causal_mask |
| 33 | +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa |
33 | 34 | from ...modeling_layers import GradientCheckpointingLayer |
34 | 35 | from ...modeling_outputs import ( |
35 | 36 | BaseModelOutputWithPastAndCrossAttentions, |
@@ -278,53 +279,62 @@ def forward( |
278 | 279 | **kwargs, |
279 | 280 | ) -> tuple[Union[torch.Tensor, tuple[torch.Tensor]], ...]: |
280 | 281 | is_cross_attention = encoder_hidden_states is not None |
| 282 | + if past_key_value is not None: |
| 283 | + if isinstance(past_key_value, EncoderDecoderCache): |
| 284 | + is_updated = past_key_value.is_updated.get(self.layer_idx) |
| 285 | + if is_cross_attention: |
| 286 | + # after the first generated id, we can subsequently re-use all key/value_layer from cache |
| 287 | + curr_past_key_value = past_key_value.cross_attention_cache |
| 288 | + else: |
| 289 | + curr_past_key_value = past_key_value.self_attention_cache |
| 290 | + else: |
| 291 | + curr_past_key_value = past_key_value |
| 292 | + |
281 | 293 | if is_cross_attention: |
282 | 294 | if not hasattr(self, "q_attn"): |
283 | 295 | raise ValueError( |
284 | 296 | "If class is used as cross attention, the weights `q_attn` have to be defined. " |
285 | 297 | "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." |
286 | 298 | ) |
287 | | - |
288 | 299 | query_states = self.q_attn(hidden_states) |
289 | | - key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) |
290 | 300 | attention_mask = encoder_attention_mask |
| 301 | + |
| 302 | + # Try to get key/value states from cache if possible |
| 303 | + if past_key_value is not None and is_updated: |
| 304 | + key_states = curr_past_key_value.layers[self.layer_idx].keys |
| 305 | + value_states = curr_past_key_value.layers[self.layer_idx].values |
| 306 | + else: |
| 307 | + key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) |
| 308 | + shape_kv = (*key_states.shape[:-1], -1, self.head_dim) |
| 309 | + key_states = key_states.view(shape_kv).transpose(1, 2) |
| 310 | + value_states = value_states.view(shape_kv).transpose(1, 2) |
291 | 311 | else: |
292 | 312 | query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2) |
| 313 | + shape_kv = (*key_states.shape[:-1], -1, self.head_dim) |
| 314 | + key_states = key_states.view(shape_kv).transpose(1, 2) |
| 315 | + value_states = value_states.view(shape_kv).transpose(1, 2) |
293 | 316 |
|
294 | 317 | shape_q = (*query_states.shape[:-1], -1, self.head_dim) |
295 | | - shape_kv = (*key_states.shape[:-1], -1, self.head_dim) |
296 | | - |
297 | 318 | query_states = query_states.view(shape_q).transpose(1, 2) |
298 | | - key_states = key_states.view(shape_kv).transpose(1, 2) |
299 | | - value_states = value_states.view(shape_kv).transpose(1, 2) |
300 | 319 |
|
301 | | - if past_key_value is not None: |
302 | | - if isinstance(past_key_value, EncoderDecoderCache): |
303 | | - if is_cross_attention: |
304 | | - past_key_value = past_key_value.cross_attention_cache |
305 | | - else: |
306 | | - past_key_value = past_key_value.self_attention_cache |
307 | | - cache_kwargs = {"cache_position": cache_position} |
308 | | - key_states, value_states = past_key_value.update( |
309 | | - key_states, value_states, self.layer_idx, cache_kwargs=cache_kwargs |
| 320 | + if (past_key_value is not None and not is_cross_attention) or ( |
| 321 | + past_key_value is not None and is_cross_attention and not is_updated |
| 322 | + ): |
| 323 | + # save all key/value_layer to cache to be re-used for fast auto-regressive generation |
| 324 | + cache_position = cache_position if not is_cross_attention else None |
| 325 | + key_states, value_states = curr_past_key_value.update( |
| 326 | + key_states, value_states, self.layer_idx, {"cache_position": cache_position} |
310 | 327 | ) |
| 328 | + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls |
| 329 | + if is_cross_attention: |
| 330 | + past_key_value.is_updated[self.layer_idx] = True |
311 | 331 |
|
312 | 332 | is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention |
313 | 333 |
|
314 | 334 | using_eager = self.config._attn_implementation == "eager" |
315 | 335 | attention_interface: Callable = eager_attention_forward |
316 | 336 | if self.config._attn_implementation != "eager": |
317 | | - if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None): |
318 | | - using_eager = True |
319 | | - logger.warning_once( |
320 | | - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " |
321 | | - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' |
322 | | - ) |
323 | | - else: |
324 | | - # Attention functions are consistent with previous equivalent attention classes, however they do not support some options |
325 | | - # (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but |
326 | | - # not necessarily to eager (if mentioned options are provided). |
327 | | - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
| 337 | + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
328 | 338 |
|
329 | 339 | if using_eager and self.reorder_and_upcast_attn: |
330 | 340 | attn_output, attn_weights = self._upcast_and_reordered_attn( |
@@ -861,8 +871,14 @@ def forward( |
861 | 871 | # ._update_causal_mask() and ._prepare_4d_causal_attention_mask_with_cache_position() copied from LlamaModel |
862 | 872 | if attention_mask is not None and attention_mask.ndim < 4: |
863 | 873 | attention_mask = attention_mask.view(batch_size, -1) |
864 | | - causal_mask = self._update_causal_mask( |
865 | | - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions |
| 874 | + |
| 875 | + causal_mask = create_causal_mask( |
| 876 | + config=self.config, |
| 877 | + input_embeds=inputs_embeds, |
| 878 | + attention_mask=attention_mask, |
| 879 | + cache_position=cache_position, |
| 880 | + past_key_values=past_key_values, |
| 881 | + position_ids=position_ids, |
866 | 882 | ) |
867 | 883 |
|
868 | 884 | # If a 2D or 3D attention mask is provided for the cross-attention |
@@ -903,9 +919,6 @@ def forward( |
903 | 919 | # Model parallel |
904 | 920 | if self.model_parallel: |
905 | 921 | torch.cuda.set_device(hidden_states.device) |
906 | | - # Ensure that attention_mask is always on the same device as hidden_states |
907 | | - if attention_mask is not None: |
908 | | - attention_mask = attention_mask.to(hidden_states.device) |
909 | 922 | if isinstance(head_mask, torch.Tensor): |
910 | 923 | head_mask = head_mask.to(hidden_states.device) |
911 | 924 | if output_hidden_states: |
@@ -966,123 +979,6 @@ def forward( |
966 | 979 | cross_attentions=all_cross_attentions, |
967 | 980 | ) |
968 | 981 |
|
969 | | - def _update_causal_mask( |
970 | | - self, |
971 | | - attention_mask: torch.Tensor, |
972 | | - input_tensor: torch.Tensor, |
973 | | - cache_position: torch.Tensor, |
974 | | - past_key_values: Cache, |
975 | | - output_attentions: bool, |
976 | | - ): |
977 | | - if self.config._attn_implementation == "flash_attention_2": |
978 | | - if attention_mask is not None and 0.0 in attention_mask: |
979 | | - return attention_mask |
980 | | - return None |
981 | | - |
982 | | - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in |
983 | | - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail |
984 | | - # to infer the attention mask. |
985 | | - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
986 | | - using_static_cache = isinstance(past_key_values, StaticCache) |
987 | | - |
988 | | - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward |
989 | | - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: |
990 | | - if AttentionMaskConverter._ignore_causal_mask_sdpa( |
991 | | - attention_mask, |
992 | | - inputs_embeds=input_tensor, |
993 | | - past_key_values_length=past_seen_tokens, |
994 | | - is_training=self.training, |
995 | | - ): |
996 | | - return None |
997 | | - |
998 | | - dtype = input_tensor.dtype |
999 | | - sequence_length = input_tensor.shape[1] |
1000 | | - if using_static_cache: |
1001 | | - target_length = past_key_values.get_max_cache_shape() |
1002 | | - else: |
1003 | | - target_length = ( |
1004 | | - attention_mask.shape[-1] |
1005 | | - if isinstance(attention_mask, torch.Tensor) |
1006 | | - else past_seen_tokens + sequence_length + 1 |
1007 | | - ) |
1008 | | - |
1009 | | - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). |
1010 | | - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( |
1011 | | - attention_mask, |
1012 | | - sequence_length=sequence_length, |
1013 | | - target_length=target_length, |
1014 | | - dtype=dtype, |
1015 | | - cache_position=cache_position, |
1016 | | - batch_size=input_tensor.shape[0], |
1017 | | - ) |
1018 | | - |
1019 | | - if ( |
1020 | | - self.config._attn_implementation == "sdpa" |
1021 | | - and attention_mask is not None |
1022 | | - and attention_mask.device.type == "cuda" |
1023 | | - and not output_attentions |
1024 | | - ): |
1025 | | - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when |
1026 | | - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. |
1027 | | - # Details: https://github.com/pytorch/pytorch/issues/110213 |
1028 | | - min_dtype = torch.finfo(dtype).min |
1029 | | - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) |
1030 | | - |
1031 | | - return causal_mask |
1032 | | - |
1033 | | - @staticmethod |
1034 | | - def _prepare_4d_causal_attention_mask_with_cache_position( |
1035 | | - attention_mask: torch.Tensor, |
1036 | | - sequence_length: int, |
1037 | | - target_length: int, |
1038 | | - dtype: torch.dtype, |
1039 | | - cache_position: torch.Tensor, |
1040 | | - batch_size: int, |
1041 | | - **kwargs, |
1042 | | - ): |
1043 | | - """ |
1044 | | - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape |
1045 | | - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. |
1046 | | -
|
1047 | | - Args: |
1048 | | - attention_mask (`torch.Tensor`): |
1049 | | - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape |
1050 | | - `(batch_size, 1, query_length, key_value_length)`. |
1051 | | - sequence_length (`int`): |
1052 | | - The sequence length being processed. |
1053 | | - target_length (`int`): |
1054 | | - The target length: when generating with static cache, the mask should be as long as the static cache, |
1055 | | - to account for the 0 padding, the part of the cache that is not filled yet. |
1056 | | - dtype (`torch.dtype`): |
1057 | | - The dtype to use for the 4D attention mask. |
1058 | | - cache_position (`torch.Tensor`): |
1059 | | - Indices depicting the position of the input sequence tokens in the sequence. |
1060 | | - batch_size (`torch.Tensor`): |
1061 | | - Batch size. |
1062 | | - """ |
1063 | | - if attention_mask is not None and attention_mask.dim() == 4: |
1064 | | - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. |
1065 | | - causal_mask = attention_mask |
1066 | | - else: |
1067 | | - min_dtype = torch.finfo(dtype).min |
1068 | | - causal_mask = torch.full( |
1069 | | - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device |
1070 | | - ) |
1071 | | - if sequence_length != 1: |
1072 | | - causal_mask = torch.triu(causal_mask, diagonal=1) |
1073 | | - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) |
1074 | | - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
1075 | | - if attention_mask is not None: |
1076 | | - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit |
1077 | | - mask_length = attention_mask.shape[-1] |
1078 | | - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] |
1079 | | - padding_mask = padding_mask == 0 |
1080 | | - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
1081 | | - padding_mask, min_dtype |
1082 | | - ) |
1083 | | - |
1084 | | - return causal_mask |
1085 | | - |
1086 | 982 |
|
1087 | 983 | @auto_docstring( |
1088 | 984 | custom_intro=""" |
|
0 commit comments