Skip to content

Commit 2229778

Browse files
authored
Merge pull request #5684 from wangbluo/parallel_output
[Shardformer] Add Parallel output for shardformer models
2 parents 58954b2 + 4e50cce commit 2229778

File tree

4 files changed

+288
-10
lines changed

4 files changed

+288
-10
lines changed

colossalai/shardformer/modeling/mistral.py

Lines changed: 110 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from colossalai.pipeline.stage_manager import PipelineStageManager
1717
from colossalai.shardformer.shard import ShardConfig
1818

19-
from ..layer import ColoAttention
19+
from ..layer import ColoAttention, cross_entropy_1d
2020

2121
logger = logging.get_logger(__name__)
2222

@@ -270,11 +270,21 @@ def mistral_for_causal_lm_forward(
270270
shift_labels = labels[..., 1:].contiguous()
271271
# Flatten the tokens
272272
loss_fct = CrossEntropyLoss()
273-
shift_logits = shift_logits.view(-1, self.config.vocab_size)
274273
shift_labels = shift_labels.view(-1)
275274
# Enable model parallelism
276275
shift_labels = shift_labels.to(shift_logits.device)
277-
loss = loss_fct(shift_logits, shift_labels)
276+
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
277+
new_vocab_size = logits.shape[-1]
278+
shift_logits = shift_logits.view(-1, new_vocab_size)
279+
loss = cross_entropy_1d(
280+
shift_logits,
281+
shift_labels,
282+
process_group=shard_config.tensor_parallel_process_group,
283+
vocab_size=self.lm_head.out_features,
284+
)
285+
else:
286+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
287+
loss = loss_fct(shift_logits, shift_labels)
278288

279289
if not return_dict:
280290
output = (logits,) + outputs[1:]
@@ -609,3 +619,100 @@ def forward(
609619
return attn_output, None, past_key_value
610620

611621
return forward
622+
623+
624+
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
625+
from transformers import MistralForCausalLM
626+
627+
def forward(
628+
self: MistralForCausalLM,
629+
input_ids: torch.LongTensor = None,
630+
attention_mask: Optional[torch.Tensor] = None,
631+
position_ids: Optional[torch.LongTensor] = None,
632+
past_key_values: Optional[List[torch.FloatTensor]] = None,
633+
inputs_embeds: Optional[torch.FloatTensor] = None,
634+
labels: Optional[torch.LongTensor] = None,
635+
use_cache: Optional[bool] = None,
636+
output_attentions: Optional[bool] = None,
637+
output_hidden_states: Optional[bool] = None,
638+
return_dict: Optional[bool] = None,
639+
) -> Union[Tuple, CausalLMOutputWithPast]:
640+
r"""
641+
Args:
642+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
643+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
644+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
645+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
646+
647+
Returns:
648+
649+
Example:
650+
651+
```python
652+
>>> from transformers import AutoTokenizer, MistralForCausalLM
653+
654+
>>> model = MistralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
655+
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
656+
657+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
658+
>>> inputs = tokenizer(prompt, return_tensors="pt")
659+
660+
>>> # Generate
661+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
662+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
663+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
664+
```"""
665+
666+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
667+
output_hidden_states = (
668+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
669+
)
670+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
671+
672+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
673+
outputs = self.model(
674+
input_ids=input_ids,
675+
attention_mask=attention_mask,
676+
position_ids=position_ids,
677+
past_key_values=past_key_values,
678+
inputs_embeds=inputs_embeds,
679+
use_cache=use_cache,
680+
output_attentions=output_attentions,
681+
output_hidden_states=output_hidden_states,
682+
return_dict=return_dict,
683+
)
684+
685+
hidden_states = outputs[0]
686+
logits = self.lm_head(hidden_states)
687+
logits = logits.float()
688+
689+
loss = None
690+
if labels is not None:
691+
# Shift so that tokens < n predict n
692+
shift_logits = logits[..., :-1, :].contiguous()
693+
shift_labels = labels[..., 1:].contiguous()
694+
shift_labels = shift_labels.view(-1)
695+
# Enable model parallelism
696+
shift_labels = shift_labels.to(shift_logits.device)
697+
new_vocab_size = logits.shape[-1]
698+
shift_logits = shift_logits.view(-1, new_vocab_size)
699+
loss = cross_entropy_1d(
700+
shift_logits,
701+
shift_labels,
702+
process_group=shard_config.tensor_parallel_process_group,
703+
vocab_size=self.lm_head.out_features,
704+
)
705+
706+
if not return_dict:
707+
output = (logits,) + outputs[1:]
708+
return (loss,) + output if loss is not None else output
709+
710+
return CausalLMOutputWithPast(
711+
loss=loss,
712+
logits=logits,
713+
past_key_values=outputs.past_key_values,
714+
hidden_states=outputs.hidden_states,
715+
attentions=outputs.attentions,
716+
)
717+
718+
return forward

colossalai/shardformer/modeling/opt.py

Lines changed: 161 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from colossalai.shardformer.layer import ColoAttention
2323
from colossalai.shardformer.shard import ShardConfig
2424

25+
from ..layer import cross_entropy_1d
26+
2527
logger = logging.get_logger(__name__)
2628

2729

@@ -336,8 +338,22 @@ def opt_for_causal_lm_forward(
336338
shift_logits = logits[..., :-1, :].contiguous()
337339
shift_labels = labels[..., 1:].contiguous()
338340
# Flatten the tokens
339-
loss_fct = CrossEntropyLoss()
340-
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
341+
342+
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
343+
new_vocab_size = logits.shape[-1]
344+
shift_logits = shift_logits.view(-1, new_vocab_size)
345+
shift_labels = shift_labels.view(-1)
346+
loss = cross_entropy_1d(
347+
shift_logits,
348+
shift_labels,
349+
process_group=shard_config.tensor_parallel_process_group,
350+
vocab_size=self.lm_head.out_features,
351+
)
352+
else:
353+
loss_fct = CrossEntropyLoss()
354+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
355+
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
356+
341357
if not return_dict:
342358
output = (logits,) + outputs[1:]
343359
return (loss,) + output if loss is not None else output
@@ -844,3 +860,146 @@ def forward(
844860
return outputs
845861

846862
return forward
863+
864+
865+
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
866+
def forward(
867+
self: OPTForCausalLM,
868+
input_ids: torch.LongTensor = None,
869+
attention_mask: Optional[torch.Tensor] = None,
870+
head_mask: Optional[torch.Tensor] = None,
871+
past_key_values: Optional[List[torch.FloatTensor]] = None,
872+
inputs_embeds: Optional[torch.FloatTensor] = None,
873+
labels: Optional[torch.LongTensor] = None,
874+
use_cache: Optional[bool] = None,
875+
output_attentions: Optional[bool] = None,
876+
output_hidden_states: Optional[bool] = None,
877+
return_dict: Optional[bool] = None,
878+
) -> Union[Tuple, CausalLMOutputWithPast]:
879+
r"""
880+
Args:
881+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
882+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
883+
provide it.
884+
885+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
886+
[`PreTrainedTokenizer.__call__`] for details.
887+
888+
[What are input IDs?](../glossary#input-ids)
889+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
890+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
891+
892+
- 1 for tokens that are **not masked**,
893+
- 0 for tokens that are **masked**.
894+
895+
[What are attention masks?](../glossary#attention-mask)
896+
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
897+
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
898+
899+
- 1 indicates the head is **not masked**,
900+
- 0 indicates the head is **masked**.
901+
902+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
903+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
904+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
905+
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
906+
tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
907+
908+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
909+
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
910+
911+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
912+
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
913+
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
914+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
915+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
916+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
917+
than the model's internal embedding lookup matrix.
918+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
919+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
920+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
921+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
922+
use_cache (`bool`, *optional*):
923+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
924+
(see `past_key_values`).
925+
output_attentions (`bool`, *optional*):
926+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
927+
returned tensors for more detail.
928+
output_hidden_states (`bool`, *optional*):
929+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
930+
for more detail.
931+
return_dict (`bool`, *optional*):
932+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
933+
934+
Returns:
935+
936+
Example:
937+
938+
```python
939+
>>> from transformers import AutoTokenizer, OPTForCausalLM
940+
941+
>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
942+
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
943+
944+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
945+
>>> inputs = tokenizer(prompt, return_tensors="pt")
946+
947+
>>> # Generate
948+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
949+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
950+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo."
951+
```"""
952+
953+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
954+
output_hidden_states = (
955+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
956+
)
957+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
958+
959+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
960+
outputs = self.model.decoder(
961+
input_ids=input_ids,
962+
attention_mask=attention_mask,
963+
head_mask=head_mask,
964+
past_key_values=past_key_values,
965+
inputs_embeds=inputs_embeds,
966+
use_cache=use_cache,
967+
output_attentions=output_attentions,
968+
output_hidden_states=output_hidden_states,
969+
return_dict=return_dict,
970+
)
971+
972+
logits = self.lm_head(outputs[0]).contiguous()
973+
974+
loss = None
975+
if labels is not None:
976+
# move labels to correct device to enable model parallelism
977+
labels = labels.to(logits.device)
978+
# Shift so that tokens < n predict n
979+
shift_logits = logits[..., :-1, :].contiguous()
980+
shift_labels = labels[..., 1:].contiguous()
981+
shift_labels = shift_labels.view(-1)
982+
# Enable model parallelism
983+
shift_labels = shift_labels.to(shift_logits.device)
984+
new_vocab_size = logits.shape[-1]
985+
shift_logits = shift_logits.view(-1, new_vocab_size)
986+
loss = cross_entropy_1d(
987+
shift_logits,
988+
shift_labels,
989+
process_group=shard_config.tensor_parallel_process_group,
990+
vocab_size=self.lm_head.out_features,
991+
)
992+
993+
if not return_dict:
994+
output = (logits,) + outputs[1:]
995+
return (loss,) + output if loss is not None else output
996+
997+
return CausalLMOutputWithPast(
998+
loss=loss,
999+
logits=logits,
1000+
past_key_values=outputs.past_key_values,
1001+
hidden_states=outputs.hidden_states,
1002+
attentions=outputs.attentions,
1003+
)
1004+
1005+
return forward

colossalai/shardformer/policies/mistral.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from ..modeling.mistral import (
2020
MistralForwards,
21+
get_lm_forward_with_dist_cross_entropy,
2122
get_mistral_flash_attention_forward,
2223
get_mistral_model_forward_for_flash_attn,
2324
)
@@ -275,14 +276,18 @@ def module_policy(self):
275276
SubModuleReplacementDescription(
276277
suffix="lm_head",
277278
target_module=VocabParallelLMHead1D,
278-
kwargs=dict(
279-
gather_output=True,
280-
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
281-
),
279+
kwargs={
280+
"gather_output": not self.shard_config.parallel_output,
281+
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
282+
},
282283
)
283284
]
284285
)
285286
}
287+
if self.shard_config.parallel_output:
288+
new_item[MistralForCausalLM].method_replacement = {
289+
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
290+
}
286291
else:
287292
new_item = {
288293
MistralForCausalLM: ModulePolicyDescription(

colossalai/shardformer/policies/opt.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ..modeling.opt import (
2222
OPTPipelineForwards,
2323
get_jit_fused_opt_decoder_layer_forward,
24+
get_lm_forward_with_dist_cross_entropy,
2425
get_opt_decoder_forward_for_flash_attention,
2526
get_opt_flash_attention_forward,
2627
)
@@ -269,12 +270,18 @@ def module_policy(self):
269270
suffix="lm_head",
270271
target_module=VocabParallelLMHead1D,
271272
kwargs=dict(
272-
gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by
273+
gather_output=not self.shard_config.parallel_output,
274+
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
273275
),
274276
),
275277
policy=policy,
276278
target_key=OPTForCausalLM,
277279
)
280+
if self.shard_config.parallel_output:
281+
method_replacement = {"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}
282+
self.append_or_create_method_replacement(
283+
description=method_replacement, policy=policy, target_key=OPTForCausalLM
284+
)
278285
else:
279286
self.append_or_create_submodule_replacement(
280287
description=SubModuleReplacementDescription(

0 commit comments

Comments
 (0)