|
22 | 22 | from colossalai.shardformer.layer import ColoAttention
|
23 | 23 | from colossalai.shardformer.shard import ShardConfig
|
24 | 24 |
|
| 25 | +from ..layer import cross_entropy_1d |
| 26 | + |
25 | 27 | logger = logging.get_logger(__name__)
|
26 | 28 |
|
27 | 29 |
|
@@ -336,8 +338,22 @@ def opt_for_causal_lm_forward(
|
336 | 338 | shift_logits = logits[..., :-1, :].contiguous()
|
337 | 339 | shift_labels = labels[..., 1:].contiguous()
|
338 | 340 | # Flatten the tokens
|
339 |
| - loss_fct = CrossEntropyLoss() |
340 |
| - loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) |
| 341 | + |
| 342 | + if shard_config.enable_tensor_parallelism and shard_config.parallel_output: |
| 343 | + new_vocab_size = logits.shape[-1] |
| 344 | + shift_logits = shift_logits.view(-1, new_vocab_size) |
| 345 | + shift_labels = shift_labels.view(-1) |
| 346 | + loss = cross_entropy_1d( |
| 347 | + shift_logits, |
| 348 | + shift_labels, |
| 349 | + process_group=shard_config.tensor_parallel_process_group, |
| 350 | + vocab_size=self.lm_head.out_features, |
| 351 | + ) |
| 352 | + else: |
| 353 | + loss_fct = CrossEntropyLoss() |
| 354 | + shift_logits = shift_logits.view(-1, self.config.vocab_size) |
| 355 | + loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) |
| 356 | + |
341 | 357 | if not return_dict:
|
342 | 358 | output = (logits,) + outputs[1:]
|
343 | 359 | return (loss,) + output if loss is not None else output
|
@@ -844,3 +860,146 @@ def forward(
|
844 | 860 | return outputs
|
845 | 861 |
|
846 | 862 | return forward
|
| 863 | + |
| 864 | + |
| 865 | +def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): |
| 866 | + def forward( |
| 867 | + self: OPTForCausalLM, |
| 868 | + input_ids: torch.LongTensor = None, |
| 869 | + attention_mask: Optional[torch.Tensor] = None, |
| 870 | + head_mask: Optional[torch.Tensor] = None, |
| 871 | + past_key_values: Optional[List[torch.FloatTensor]] = None, |
| 872 | + inputs_embeds: Optional[torch.FloatTensor] = None, |
| 873 | + labels: Optional[torch.LongTensor] = None, |
| 874 | + use_cache: Optional[bool] = None, |
| 875 | + output_attentions: Optional[bool] = None, |
| 876 | + output_hidden_states: Optional[bool] = None, |
| 877 | + return_dict: Optional[bool] = None, |
| 878 | + ) -> Union[Tuple, CausalLMOutputWithPast]: |
| 879 | + r""" |
| 880 | + Args: |
| 881 | + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| 882 | + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you |
| 883 | + provide it. |
| 884 | +
|
| 885 | + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
| 886 | + [`PreTrainedTokenizer.__call__`] for details. |
| 887 | +
|
| 888 | + [What are input IDs?](../glossary#input-ids) |
| 889 | + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
| 890 | + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
| 891 | +
|
| 892 | + - 1 for tokens that are **not masked**, |
| 893 | + - 0 for tokens that are **masked**. |
| 894 | +
|
| 895 | + [What are attention masks?](../glossary#attention-mask) |
| 896 | + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): |
| 897 | + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: |
| 898 | +
|
| 899 | + - 1 indicates the head is **not masked**, |
| 900 | + - 0 indicates the head is **masked**. |
| 901 | +
|
| 902 | + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
| 903 | + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of |
| 904 | + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of |
| 905 | + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional |
| 906 | + tensors are only required when the model is used as a decoder in a Sequence to Sequence model. |
| 907 | +
|
| 908 | + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the |
| 909 | + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. |
| 910 | +
|
| 911 | + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those |
| 912 | + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of |
| 913 | + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. |
| 914 | + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
| 915 | + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. |
| 916 | + This is useful if you want more control over how to convert `input_ids` indices into associated vectors |
| 917 | + than the model's internal embedding lookup matrix. |
| 918 | + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| 919 | + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| 920 | + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| 921 | + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| 922 | + use_cache (`bool`, *optional*): |
| 923 | + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding |
| 924 | + (see `past_key_values`). |
| 925 | + output_attentions (`bool`, *optional*): |
| 926 | + Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
| 927 | + returned tensors for more detail. |
| 928 | + output_hidden_states (`bool`, *optional*): |
| 929 | + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
| 930 | + for more detail. |
| 931 | + return_dict (`bool`, *optional*): |
| 932 | + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| 933 | +
|
| 934 | + Returns: |
| 935 | +
|
| 936 | + Example: |
| 937 | +
|
| 938 | + ```python |
| 939 | + >>> from transformers import AutoTokenizer, OPTForCausalLM |
| 940 | +
|
| 941 | + >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m") |
| 942 | + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") |
| 943 | +
|
| 944 | + >>> prompt = "Hey, are you conscious? Can you talk to me?" |
| 945 | + >>> inputs = tokenizer(prompt, return_tensors="pt") |
| 946 | +
|
| 947 | + >>> # Generate |
| 948 | + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
| 949 | + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| 950 | + "Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo." |
| 951 | + ```""" |
| 952 | + |
| 953 | + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| 954 | + output_hidden_states = ( |
| 955 | + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| 956 | + ) |
| 957 | + return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| 958 | + |
| 959 | + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) |
| 960 | + outputs = self.model.decoder( |
| 961 | + input_ids=input_ids, |
| 962 | + attention_mask=attention_mask, |
| 963 | + head_mask=head_mask, |
| 964 | + past_key_values=past_key_values, |
| 965 | + inputs_embeds=inputs_embeds, |
| 966 | + use_cache=use_cache, |
| 967 | + output_attentions=output_attentions, |
| 968 | + output_hidden_states=output_hidden_states, |
| 969 | + return_dict=return_dict, |
| 970 | + ) |
| 971 | + |
| 972 | + logits = self.lm_head(outputs[0]).contiguous() |
| 973 | + |
| 974 | + loss = None |
| 975 | + if labels is not None: |
| 976 | + # move labels to correct device to enable model parallelism |
| 977 | + labels = labels.to(logits.device) |
| 978 | + # Shift so that tokens < n predict n |
| 979 | + shift_logits = logits[..., :-1, :].contiguous() |
| 980 | + shift_labels = labels[..., 1:].contiguous() |
| 981 | + shift_labels = shift_labels.view(-1) |
| 982 | + # Enable model parallelism |
| 983 | + shift_labels = shift_labels.to(shift_logits.device) |
| 984 | + new_vocab_size = logits.shape[-1] |
| 985 | + shift_logits = shift_logits.view(-1, new_vocab_size) |
| 986 | + loss = cross_entropy_1d( |
| 987 | + shift_logits, |
| 988 | + shift_labels, |
| 989 | + process_group=shard_config.tensor_parallel_process_group, |
| 990 | + vocab_size=self.lm_head.out_features, |
| 991 | + ) |
| 992 | + |
| 993 | + if not return_dict: |
| 994 | + output = (logits,) + outputs[1:] |
| 995 | + return (loss,) + output if loss is not None else output |
| 996 | + |
| 997 | + return CausalLMOutputWithPast( |
| 998 | + loss=loss, |
| 999 | + logits=logits, |
| 1000 | + past_key_values=outputs.past_key_values, |
| 1001 | + hidden_states=outputs.hidden_states, |
| 1002 | + attentions=outputs.attentions, |
| 1003 | + ) |
| 1004 | + |
| 1005 | + return forward |
0 commit comments