21
21
from colossalai .pipeline .stage_manager import PipelineStageManager
22
22
from colossalai .shardformer .layer import ColoAttention
23
23
from colossalai .shardformer .shard import ShardConfig
24
-
24
+ from .. layer import cross_entropy_1d
25
25
logger = logging .get_logger (__name__ )
26
26
27
27
@@ -336,8 +336,22 @@ def opt_for_causal_lm_forward(
336
336
shift_logits = logits [..., :- 1 , :].contiguous ()
337
337
shift_labels = labels [..., 1 :].contiguous ()
338
338
# Flatten the tokens
339
- loss_fct = CrossEntropyLoss ()
340
- loss = loss_fct (shift_logits .view (- 1 , self .config .vocab_size ), shift_labels .view (- 1 ))
339
+
340
+ if shard_config .enable_tensor_parallelism and shard_config .parallel_output :
341
+ new_vocab_size = logits .shape [- 1 ]
342
+ shift_logits = shift_logits .view (- 1 , new_vocab_size )
343
+ shift_labels = shift_labels .view (- 1 )
344
+ loss = cross_entropy_1d (
345
+ shift_logits ,
346
+ shift_labels ,
347
+ process_group = shard_config .tensor_parallel_process_group ,
348
+ vocab_size = self .lm_head .out_features ,
349
+ )
350
+ else :
351
+ loss_fct = CrossEntropyLoss ()
352
+ shift_logits = shift_logits .view (- 1 , self .config .vocab_size )
353
+ loss = loss_fct (shift_logits .view (- 1 , self .config .vocab_size ), shift_labels .view (- 1 ))
354
+
341
355
if not return_dict :
342
356
output = (logits ,) + outputs [1 :]
343
357
return (loss ,) + output if loss is not None else output
@@ -844,3 +858,148 @@ def forward(
844
858
return outputs
845
859
846
860
return forward
861
+
862
+
863
+ def get_lm_forward_with_dist_cross_entropy (shard_config : ShardConfig ):
864
+ def forward (
865
+ self : OPTForCausalLM ,
866
+ input_ids : torch .LongTensor = None ,
867
+ attention_mask : Optional [torch .Tensor ] = None ,
868
+ head_mask : Optional [torch .Tensor ] = None ,
869
+ past_key_values : Optional [List [torch .FloatTensor ]] = None ,
870
+ inputs_embeds : Optional [torch .FloatTensor ] = None ,
871
+ labels : Optional [torch .LongTensor ] = None ,
872
+ use_cache : Optional [bool ] = None ,
873
+ output_attentions : Optional [bool ] = None ,
874
+ output_hidden_states : Optional [bool ] = None ,
875
+ return_dict : Optional [bool ] = None ,
876
+ ) -> Union [Tuple , CausalLMOutputWithPast ]:
877
+ r"""
878
+ Args:
879
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
880
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
881
+ provide it.
882
+
883
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
884
+ [`PreTrainedTokenizer.__call__`] for details.
885
+
886
+ [What are input IDs?](../glossary#input-ids)
887
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
888
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
889
+
890
+ - 1 for tokens that are **not masked**,
891
+ - 0 for tokens that are **masked**.
892
+
893
+ [What are attention masks?](../glossary#attention-mask)
894
+ head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
895
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
896
+
897
+ - 1 indicates the head is **not masked**,
898
+ - 0 indicates the head is **masked**.
899
+
900
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
901
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
902
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
903
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
904
+ tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
905
+
906
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
907
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
908
+
909
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
910
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
911
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
912
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
913
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
914
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
915
+ than the model's internal embedding lookup matrix.
916
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
917
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
918
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
919
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
920
+ use_cache (`bool`, *optional*):
921
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
922
+ (see `past_key_values`).
923
+ output_attentions (`bool`, *optional*):
924
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
925
+ returned tensors for more detail.
926
+ output_hidden_states (`bool`, *optional*):
927
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
928
+ for more detail.
929
+ return_dict (`bool`, *optional*):
930
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
931
+
932
+ Returns:
933
+
934
+ Example:
935
+
936
+ ```python
937
+ >>> from transformers import AutoTokenizer, OPTForCausalLM
938
+
939
+ >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
940
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
941
+
942
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
943
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
944
+
945
+ >>> # Generate
946
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
947
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
948
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo."
949
+ ```"""
950
+
951
+ output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
952
+ output_hidden_states = (
953
+ output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
954
+ )
955
+ return_dict = return_dict if return_dict is not None else self .config .use_return_dict
956
+
957
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
958
+ outputs = self .model .decoder (
959
+ input_ids = input_ids ,
960
+ attention_mask = attention_mask ,
961
+ head_mask = head_mask ,
962
+ past_key_values = past_key_values ,
963
+ inputs_embeds = inputs_embeds ,
964
+ use_cache = use_cache ,
965
+ output_attentions = output_attentions ,
966
+ output_hidden_states = output_hidden_states ,
967
+ return_dict = return_dict ,
968
+ )
969
+
970
+ logits = self .lm_head (outputs [0 ]).contiguous ()
971
+
972
+ loss = None
973
+ if labels is not None :
974
+ # move labels to correct device to enable model parallelism
975
+ labels = labels .to (logits .device )
976
+ # Shift so that tokens < n predict n
977
+ shift_logits = logits [..., :- 1 , :].contiguous ()
978
+ shift_labels = labels [..., 1 :].contiguous ()
979
+ shift_labels = shift_labels .view (- 1 )
980
+ # Enable model parallelism
981
+ shift_labels = shift_labels .to (shift_logits .device )
982
+ new_vocab_size = logits .shape [- 1 ]
983
+ shift_logits = shift_logits .view (- 1 , new_vocab_size )
984
+ loss = cross_entropy_1d (
985
+ shift_logits ,
986
+ shift_labels ,
987
+ process_group = shard_config .tensor_parallel_process_group ,
988
+ vocab_size = self .lm_head .out_features ,
989
+ )
990
+ #loss_fct = CrossEntropyLoss()
991
+ #loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
992
+
993
+ if not return_dict :
994
+ output = (logits ,) + outputs [1 :]
995
+ return (loss ,) + output if loss is not None else output
996
+
997
+ return CausalLMOutputWithPast (
998
+ loss = loss ,
999
+ logits = logits ,
1000
+ past_key_values = outputs .past_key_values ,
1001
+ hidden_states = outputs .hidden_states ,
1002
+ attentions = outputs .attentions ,
1003
+ )
1004
+
1005
+ return forward
0 commit comments