|
| 1 | +"""Llama CCE patch. Adapted from transformers v4.51.2""" |
| 2 | + |
| 3 | +# pylint: disable=duplicate-code |
| 4 | + |
| 5 | + |
| 6 | +from types import MethodType |
| 7 | +from typing import Optional, Union |
| 8 | + |
| 9 | +import torch |
| 10 | +import transformers |
| 11 | +from cut_cross_entropy.transformers.utils import ( |
| 12 | + PatchOptions, |
| 13 | + TransformersModelT, |
| 14 | + apply_lce, |
| 15 | +) |
| 16 | +from transformers.cache_utils import Cache |
| 17 | +from transformers.modeling_outputs import ( |
| 18 | + BaseModelOutputWithPast, |
| 19 | + CausalLMOutputWithPast, |
| 20 | +) |
| 21 | +from transformers.models.llama.modeling_llama import ( |
| 22 | + _CONFIG_FOR_DOC, |
| 23 | + LLAMA_INPUTS_DOCSTRING, |
| 24 | + KwargsForCausalLM, |
| 25 | +) |
| 26 | +from transformers.processing_utils import Unpack |
| 27 | +from transformers.utils import ( |
| 28 | + add_start_docstrings_to_model_forward, |
| 29 | + replace_return_docstrings, |
| 30 | +) |
| 31 | +from transformers.utils.deprecation import deprecate_kwarg |
| 32 | +from transformers.utils.generic import can_return_tuple |
| 33 | + |
| 34 | +_PATCH_OPTS: PatchOptions | None = None |
| 35 | + |
| 36 | + |
| 37 | +@can_return_tuple |
| 38 | +@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") |
| 39 | +@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) |
| 40 | +@replace_return_docstrings( |
| 41 | + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC |
| 42 | +) |
| 43 | +def cce_forward( |
| 44 | + self, |
| 45 | + input_ids: Optional[torch.LongTensor] = None, |
| 46 | + attention_mask: Optional[torch.Tensor] = None, |
| 47 | + position_ids: Optional[torch.LongTensor] = None, |
| 48 | + past_key_values: Optional[Cache] = None, |
| 49 | + inputs_embeds: Optional[torch.FloatTensor] = None, |
| 50 | + labels: Optional[torch.LongTensor] = None, |
| 51 | + use_cache: Optional[bool] = None, |
| 52 | + output_attentions: Optional[bool] = None, |
| 53 | + output_hidden_states: Optional[bool] = None, |
| 54 | + cache_position: Optional[torch.LongTensor] = None, |
| 55 | + logits_to_keep: Union[int, torch.Tensor] = 0, |
| 56 | + **kwargs: Unpack[KwargsForCausalLM], |
| 57 | +) -> CausalLMOutputWithPast: |
| 58 | + r""" |
| 59 | + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| 60 | + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| 61 | + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| 62 | + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| 63 | +
|
| 64 | + logits_to_keep (`int` or `torch.Tensor`, *optional*): |
| 65 | + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all |
| 66 | + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that |
| 67 | + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. |
| 68 | + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. |
| 69 | + This is useful when using packed tensor format (single dimension for batch and sequence length). |
| 70 | +
|
| 71 | + Returns: |
| 72 | +
|
| 73 | + Example: |
| 74 | +
|
| 75 | + ```python |
| 76 | + >>> from transformers import AutoTokenizer, LlamaForCausalLM |
| 77 | +
|
| 78 | + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") |
| 79 | + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") |
| 80 | +
|
| 81 | + >>> prompt = "Hey, are you conscious? Can you talk to me?" |
| 82 | + >>> inputs = tokenizer(prompt, return_tensors="pt") |
| 83 | +
|
| 84 | + >>> # Generate |
| 85 | + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
| 86 | + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| 87 | + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
| 88 | + ```""" |
| 89 | + output_attentions = ( |
| 90 | + output_attentions |
| 91 | + if output_attentions is not None |
| 92 | + else self.config.output_attentions |
| 93 | + ) |
| 94 | + output_hidden_states = ( |
| 95 | + output_hidden_states |
| 96 | + if output_hidden_states is not None |
| 97 | + else self.config.output_hidden_states |
| 98 | + ) |
| 99 | + |
| 100 | + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) |
| 101 | + outputs: BaseModelOutputWithPast = self.model( |
| 102 | + input_ids=input_ids, |
| 103 | + attention_mask=attention_mask, |
| 104 | + position_ids=position_ids, |
| 105 | + past_key_values=past_key_values, |
| 106 | + inputs_embeds=inputs_embeds, |
| 107 | + use_cache=use_cache, |
| 108 | + output_attentions=output_attentions, |
| 109 | + output_hidden_states=output_hidden_states, |
| 110 | + cache_position=cache_position, |
| 111 | + **kwargs, |
| 112 | + ) |
| 113 | + |
| 114 | + hidden_states = outputs.last_hidden_state |
| 115 | + if hidden_states is None: |
| 116 | + raise ValueError("hidden_states is None") |
| 117 | + |
| 118 | + loss = None |
| 119 | + logits = None |
| 120 | + |
| 121 | + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss |
| 122 | + slice_indices = ( |
| 123 | + slice(-logits_to_keep, None) |
| 124 | + if isinstance(logits_to_keep, int) |
| 125 | + else logits_to_keep |
| 126 | + ) |
| 127 | + if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): |
| 128 | + assert labels is not None |
| 129 | + loss = apply_lce( |
| 130 | + hidden_states[:, slice_indices, :], |
| 131 | + self.lm_head.weight, |
| 132 | + labels, |
| 133 | + _PATCH_OPTS, |
| 134 | + **kwargs, |
| 135 | + ) |
| 136 | + else: |
| 137 | + logits = self.lm_head(hidden_states[:, slice_indices, :]) |
| 138 | + |
| 139 | + if labels is not None: |
| 140 | + loss = self.loss_function( |
| 141 | + logits=logits, |
| 142 | + labels=labels, |
| 143 | + vocab_size=self.config.vocab_size, |
| 144 | + **kwargs, |
| 145 | + ) |
| 146 | + |
| 147 | + return CausalLMOutputWithPast( |
| 148 | + loss=loss, |
| 149 | + logits=logits, |
| 150 | + past_key_values=outputs.past_key_values, |
| 151 | + hidden_states=outputs.hidden_states, |
| 152 | + attentions=outputs.attentions, |
| 153 | + ) |
| 154 | + |
| 155 | + |
| 156 | +def patch_llama( |
| 157 | + maybe_model: TransformersModelT | str | transformers.PretrainedConfig, |
| 158 | + patch_options: PatchOptions, |
| 159 | +) -> TransformersModelT | None: |
| 160 | + """Patch Llama for CCE.""" |
| 161 | + global _PATCH_OPTS # pylint: disable=global-statement |
| 162 | + from transformers.models.llama import modeling_llama |
| 163 | + |
| 164 | + _PATCH_OPTS = patch_options |
| 165 | + |
| 166 | + if isinstance(maybe_model, transformers.PreTrainedModel): |
| 167 | + assert isinstance( |
| 168 | + maybe_model, modeling_llama.LlamaForCausalLM |
| 169 | + ), f"Expected a LlamaForCausalLM model. Got {type(maybe_model)}." |
| 170 | + maybe_model.forward = MethodType(cce_forward, maybe_model) |
| 171 | + return maybe_model |
| 172 | + |
| 173 | + modeling_llama.LlamaForCausalLM.forward = cce_forward |
| 174 | + return None |
0 commit comments