Skip to content

Commit 108ddfb

Browse files
committed
add parallel_output for the opt model
1 parent 88f057c commit 108ddfb

File tree

2 files changed

+174
-4
lines changed

2 files changed

+174
-4
lines changed

colossalai/shardformer/modeling/opt.py

Lines changed: 162 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from colossalai.pipeline.stage_manager import PipelineStageManager
2222
from colossalai.shardformer.layer import ColoAttention
2323
from colossalai.shardformer.shard import ShardConfig
24-
24+
from ..layer import cross_entropy_1d
2525
logger = logging.get_logger(__name__)
2626

2727

@@ -336,8 +336,22 @@ def opt_for_causal_lm_forward(
336336
shift_logits = logits[..., :-1, :].contiguous()
337337
shift_labels = labels[..., 1:].contiguous()
338338
# Flatten the tokens
339-
loss_fct = CrossEntropyLoss()
340-
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
339+
340+
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
341+
new_vocab_size = logits.shape[-1]
342+
shift_logits = shift_logits.view(-1, new_vocab_size)
343+
shift_labels = shift_labels.view(-1)
344+
loss = cross_entropy_1d(
345+
shift_logits,
346+
shift_labels,
347+
process_group=shard_config.tensor_parallel_process_group,
348+
vocab_size=self.lm_head.out_features,
349+
)
350+
else:
351+
loss_fct = CrossEntropyLoss()
352+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
353+
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
354+
341355
if not return_dict:
342356
output = (logits,) + outputs[1:]
343357
return (loss,) + output if loss is not None else output
@@ -844,3 +858,148 @@ def forward(
844858
return outputs
845859

846860
return forward
861+
862+
863+
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
864+
def forward(
865+
self: OPTForCausalLM,
866+
input_ids: torch.LongTensor = None,
867+
attention_mask: Optional[torch.Tensor] = None,
868+
head_mask: Optional[torch.Tensor] = None,
869+
past_key_values: Optional[List[torch.FloatTensor]] = None,
870+
inputs_embeds: Optional[torch.FloatTensor] = None,
871+
labels: Optional[torch.LongTensor] = None,
872+
use_cache: Optional[bool] = None,
873+
output_attentions: Optional[bool] = None,
874+
output_hidden_states: Optional[bool] = None,
875+
return_dict: Optional[bool] = None,
876+
) -> Union[Tuple, CausalLMOutputWithPast]:
877+
r"""
878+
Args:
879+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
880+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
881+
provide it.
882+
883+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
884+
[`PreTrainedTokenizer.__call__`] for details.
885+
886+
[What are input IDs?](../glossary#input-ids)
887+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
888+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
889+
890+
- 1 for tokens that are **not masked**,
891+
- 0 for tokens that are **masked**.
892+
893+
[What are attention masks?](../glossary#attention-mask)
894+
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
895+
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
896+
897+
- 1 indicates the head is **not masked**,
898+
- 0 indicates the head is **masked**.
899+
900+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
901+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
902+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
903+
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
904+
tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
905+
906+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
907+
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
908+
909+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
910+
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
911+
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
912+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
913+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
914+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
915+
than the model's internal embedding lookup matrix.
916+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
917+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
918+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
919+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
920+
use_cache (`bool`, *optional*):
921+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
922+
(see `past_key_values`).
923+
output_attentions (`bool`, *optional*):
924+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
925+
returned tensors for more detail.
926+
output_hidden_states (`bool`, *optional*):
927+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
928+
for more detail.
929+
return_dict (`bool`, *optional*):
930+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
931+
932+
Returns:
933+
934+
Example:
935+
936+
```python
937+
>>> from transformers import AutoTokenizer, OPTForCausalLM
938+
939+
>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
940+
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
941+
942+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
943+
>>> inputs = tokenizer(prompt, return_tensors="pt")
944+
945+
>>> # Generate
946+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
947+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
948+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo."
949+
```"""
950+
951+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
952+
output_hidden_states = (
953+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
954+
)
955+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
956+
957+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
958+
outputs = self.model.decoder(
959+
input_ids=input_ids,
960+
attention_mask=attention_mask,
961+
head_mask=head_mask,
962+
past_key_values=past_key_values,
963+
inputs_embeds=inputs_embeds,
964+
use_cache=use_cache,
965+
output_attentions=output_attentions,
966+
output_hidden_states=output_hidden_states,
967+
return_dict=return_dict,
968+
)
969+
970+
logits = self.lm_head(outputs[0]).contiguous()
971+
972+
loss = None
973+
if labels is not None:
974+
# move labels to correct device to enable model parallelism
975+
labels = labels.to(logits.device)
976+
# Shift so that tokens < n predict n
977+
shift_logits = logits[..., :-1, :].contiguous()
978+
shift_labels = labels[..., 1:].contiguous()
979+
shift_labels = shift_labels.view(-1)
980+
# Enable model parallelism
981+
shift_labels = shift_labels.to(shift_logits.device)
982+
new_vocab_size = logits.shape[-1]
983+
shift_logits = shift_logits.view(-1, new_vocab_size)
984+
loss = cross_entropy_1d(
985+
shift_logits,
986+
shift_labels,
987+
process_group=shard_config.tensor_parallel_process_group,
988+
vocab_size=self.lm_head.out_features,
989+
)
990+
#loss_fct = CrossEntropyLoss()
991+
#loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
992+
993+
if not return_dict:
994+
output = (logits,) + outputs[1:]
995+
return (loss,) + output if loss is not None else output
996+
997+
return CausalLMOutputWithPast(
998+
loss=loss,
999+
logits=logits,
1000+
past_key_values=outputs.past_key_values,
1001+
hidden_states=outputs.hidden_states,
1002+
attentions=outputs.attentions,
1003+
)
1004+
1005+
return forward

colossalai/shardformer/policies/opt.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
get_jit_fused_opt_decoder_layer_forward,
2424
get_opt_decoder_forward_for_flash_attention,
2525
get_opt_flash_attention_forward,
26+
get_lm_forward_with_dist_cross_entropy
2627
)
2728
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
2829

@@ -269,12 +270,22 @@ def module_policy(self):
269270
suffix="lm_head",
270271
target_module=VocabParallelLMHead1D,
271272
kwargs=dict(
272-
gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by
273+
gather_output=not self.shard_config.parallel_output,
274+
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by
273275
),
274276
),
275277
policy=policy,
276278
target_key=OPTForCausalLM,
277279
)
280+
if self.shard_config.parallel_output:
281+
method_replacement = {
282+
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
283+
}
284+
self.append_or_create_method_replacement(
285+
description=method_replacement,
286+
policy=policy,
287+
target_key=OPTForCausalLM
288+
)
278289
else:
279290
self.append_or_create_submodule_replacement(
280291
description=SubModuleReplacementDescription(

0 commit comments

Comments
 (0)