Skip to content

Commit 22ce873

Browse files
authored
[Shardformer] Add parallel output for shardformer models(bloom, falcon) (#5702)
* [pre-commit.ci] auto fixes from pre-commit.com hooks * add parallel cross entropy output for falcon model & fix some typos in bloom.py * fix module name error, self.model -> self.transformers in bloom, falcon model * Fix the overflow bug of distributed cross entropy loss function when training with fp16 * add dtype to parallel cross entropy loss function * fix dtype related typos adn prettify the loss.py * fix grad dtype and update dtype mismatch error * fix typo bugs
1 parent 9d83c6d commit 22ce873

File tree

9 files changed

+230
-17
lines changed

9 files changed

+230
-17
lines changed

colossalai/shardformer/layer/loss.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def forward(
2222
ignore_index: int,
2323
process_group: ProcessGroup,
2424
vocab_size: int,
25+
dtype=torch.float32,
2526
):
2627
r"""
2728
Calculate the cross entropy loss before gather, the origin loss function is as follows:
@@ -34,7 +35,7 @@ def forward(
3435
Args:
3536
vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is
3637
[batch_size, seq_len, vocab_size]
37-
labels (:class:`torch.Tensor`): The labels of the vocabulary, shape is
38+
target (:class:`torch.Tensor`): The labels of the vocabulary, shape is
3839
[batch_size, seq_len]
3940
4041
Returns:
@@ -86,7 +87,7 @@ def forward(
8687
dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group)
8788
exp_logits = vocab_logits
8889
torch.exp(vocab_logits, out=exp_logits)
89-
sum_exp_logits = torch.sum(exp_logits, dim=-1)
90+
sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32)
9091
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group)
9192

9293
# calculate the loss
@@ -97,9 +98,10 @@ def forward(
9798
loss = torch.sum(loss).div_(num_non_zero)
9899

99100
# calculate the softmax
100-
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
101+
exp_logits = exp_logits.div(sum_exp_logits.unsqueeze(dim=-1)).to(dtype)
101102
exp_logits[target == ignore_index] = 0.0
102103
ctx.save_for_backward(exp_logits, mask, masked_target_1d)
104+
ctx.dtype = dtype
103105

104106
return loss
105107

@@ -114,11 +116,11 @@ def backward(ctx, grad_output):
114116
partion_vocab_size = grad_logits.shape[-1]
115117
grad_logits_2d = grad_logits.view(-1, partion_vocab_size)
116118

117-
update = 1.0 - mask.view(-1).float()
119+
update = 1.0 - mask.view(-1).float().to(ctx.dtype)
118120
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update
119121

120122
grad_logits.mul_(grad_output.unsqueeze(dim=-1))
121-
return grad_logits, None, None, None, None
123+
return grad_logits, None, None, None, None, None
122124

123125

124126
def cross_entropy_1d(
@@ -127,5 +129,6 @@ def cross_entropy_1d(
127129
ignore_index: int = -100,
128130
process_group: ProcessGroup = None,
129131
vocab_size: int = None,
132+
dtype: torch.dtype = None,
130133
) -> torch.Tensor:
131-
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size)
134+
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype)

colossalai/shardformer/modeling/bloom.py

Lines changed: 95 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from transformers.modeling_outputs import (
1111
BaseModelOutputWithPastAndCrossAttentions,
1212
CausalLMOutputWithCrossAttentions,
13+
CausalLMOutputWithPast,
1314
QuestionAnsweringModelOutput,
1415
SequenceClassifierOutputWithPast,
1516
TokenClassifierOutput,
@@ -27,6 +28,8 @@
2728
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
2829
from colossalai.shardformer.shard import ShardConfig
2930

31+
from ..layer import cross_entropy_1d
32+
3033
logger = logging.get_logger(__name__)
3134

3235

@@ -354,7 +357,7 @@ def bloom_for_causal_lm_forward(
354357
past_key_values = None
355358
if stage_manager.is_last_stage():
356359
hidden_states = transformer_outputs[0]
357-
lm_logits = self.lm_head(hidden_states)
360+
lm_logits = self.lm_head(hidden_states).contiguous()
358361

359362
loss = None
360363
if labels is not None:
@@ -365,10 +368,21 @@ def bloom_for_causal_lm_forward(
365368
shift_labels = labels[..., 1:].contiguous()
366369
batch_size, seq_length, vocab_size = shift_logits.shape
367370
# Flatten the tokens
368-
loss_fct = CrossEntropyLoss()
369-
loss = loss_fct(
370-
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
371-
)
371+
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
372+
new_vocab_size = lm_logits.shape[-1]
373+
shift_logits = shift_logits.view(-1, new_vocab_size)
374+
shift_labels = shift_labels.view(-1)
375+
loss = cross_entropy_1d(
376+
shift_logits,
377+
shift_labels,
378+
process_group=shard_config.tensor_parallel_process_group,
379+
vocab_size=self.lm_head.out_features,
380+
dtype=self.transformer.dtype,
381+
)
382+
else:
383+
loss_fct = CrossEntropyLoss()
384+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
385+
loss = loss_fct(shift_logits, shift_labels.view(-1))
372386

373387
if not return_dict:
374388
output = (lm_logits,) + transformer_outputs[1:]
@@ -1065,3 +1079,79 @@ def forward(
10651079
)
10661080

10671081
return forward
1082+
1083+
1084+
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
1085+
from transformers import BloomForCausalLM
1086+
1087+
def forward(
1088+
self: BloomForCausalLM,
1089+
input_ids: Optional[torch.LongTensor] = None,
1090+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1091+
attention_mask: Optional[torch.Tensor] = None,
1092+
head_mask: Optional[torch.Tensor] = None,
1093+
inputs_embeds: Optional[torch.Tensor] = None,
1094+
labels: Optional[torch.Tensor] = None,
1095+
use_cache: Optional[bool] = None,
1096+
output_attentions: Optional[bool] = None,
1097+
output_hidden_states: Optional[bool] = None,
1098+
return_dict: Optional[bool] = None,
1099+
) -> Union[Tuple, CausalLMOutputWithPast]:
1100+
r"""
1101+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1102+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1103+
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1104+
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1105+
"""
1106+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1107+
output_hidden_states = (
1108+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1109+
)
1110+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1111+
1112+
transformer_outputs = self.transformer(
1113+
input_ids=input_ids,
1114+
past_key_values=past_key_values,
1115+
attention_mask=attention_mask,
1116+
head_mask=head_mask,
1117+
inputs_embeds=inputs_embeds,
1118+
use_cache=use_cache,
1119+
output_attentions=output_attentions,
1120+
output_hidden_states=output_hidden_states,
1121+
return_dict=return_dict,
1122+
)
1123+
past_key_values = None
1124+
hidden_states = transformer_outputs[0]
1125+
lm_logits = self.lm_head(hidden_states)
1126+
1127+
loss = None
1128+
if labels is not None:
1129+
# move labels to correct device to enable model parallelism
1130+
labels = labels.to(lm_logits.device)
1131+
# Shift so that tokens < n predict n
1132+
shift_logits = lm_logits[..., :-1, :].contiguous()
1133+
shift_labels = labels[..., 1:].contiguous()
1134+
# Flatten the tokens
1135+
new_vocab_size = lm_logits.shape[-1]
1136+
shift_logits = shift_logits.view(-1, new_vocab_size)
1137+
shift_labels = shift_labels.view(-1)
1138+
loss = cross_entropy_1d(
1139+
shift_logits,
1140+
shift_labels,
1141+
process_group=shard_config.tensor_parallel_process_group,
1142+
vocab_size=self.lm_head.out_features,
1143+
dtype=self.transformer.dtype,
1144+
)
1145+
if not return_dict:
1146+
output = (lm_logits,) + transformer_outputs[1:]
1147+
return ((loss,) + output) if loss is not None else output
1148+
1149+
return CausalLMOutputWithPast(
1150+
loss=loss,
1151+
logits=lm_logits,
1152+
past_key_values=transformer_outputs.past_key_values,
1153+
hidden_states=transformer_outputs.hidden_states,
1154+
attentions=transformer_outputs.attentions,
1155+
)
1156+
1157+
return forward

colossalai/shardformer/modeling/falcon.py

Lines changed: 96 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from transformers.modeling_outputs import (
1515
BaseModelOutputWithPastAndCrossAttentions,
1616
CausalLMOutputWithCrossAttentions,
17+
CausalLMOutputWithPast,
1718
QuestionAnsweringModelOutput,
1819
SequenceClassifierOutputWithPast,
1920
TokenClassifierOutput,
@@ -31,6 +32,8 @@
3132
from colossalai.pipeline.stage_manager import PipelineStageManager
3233
from colossalai.shardformer.shard import ShardConfig
3334

35+
from ..layer import cross_entropy_1d
36+
3437

3538
def build_falcon_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor:
3639
def build_falcon_alibi_tensor(
@@ -437,14 +440,28 @@ def falcon_for_causal_lm_forward(
437440
loss = None
438441
if labels is not None:
439442
# Shift so that tokens < n predict n
443+
labels = labels.to(lm_logits.device)
440444
shift_logits = lm_logits[..., :-1, :].contiguous()
441445
shift_labels = labels[..., 1:].contiguous()
442446
batch_size, seq_length, vocab_size = shift_logits.shape
443447
# Flatten the tokens
444448
loss_fct = CrossEntropyLoss()
445-
loss = loss_fct(
446-
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
447-
)
449+
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
450+
new_vocab_size = shift_logits.shape[-1]
451+
shift_logits = shift_logits.view(-1, new_vocab_size)
452+
shift_labels = shift_labels.view(-1)
453+
loss = cross_entropy_1d(
454+
shift_logits,
455+
shift_labels,
456+
process_group=shard_config.tensor_parallel_process_group,
457+
vocab_size=self.lm_head.out_features,
458+
dtype=self.transformer.dtype,
459+
)
460+
else:
461+
loss = loss_fct(
462+
shift_logits.view(batch_size * seq_length, vocab_size),
463+
shift_labels.view(batch_size * seq_length),
464+
)
448465

449466
if not return_dict:
450467
output = (lm_logits,) + transformer_outputs[1:]
@@ -747,3 +764,79 @@ def falcon_for_question_answering_forward(
747764
else:
748765
hidden_states = outputs.get("hidden_states")
749766
return {"hidden_states": hidden_states}
767+
768+
769+
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
770+
from transformers import FalconForCausalLM
771+
772+
def forward(
773+
self: FalconForCausalLM,
774+
input_ids: Optional[torch.LongTensor] = None,
775+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
776+
attention_mask: Optional[torch.Tensor] = None,
777+
head_mask: Optional[torch.Tensor] = None,
778+
inputs_embeds: Optional[torch.Tensor] = None,
779+
labels: Optional[torch.Tensor] = None,
780+
use_cache: Optional[bool] = None,
781+
output_attentions: Optional[bool] = None,
782+
output_hidden_states: Optional[bool] = None,
783+
return_dict: Optional[bool] = None,
784+
) -> Union[Tuple, CausalLMOutputWithPast]:
785+
r"""
786+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
787+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
788+
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
789+
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
790+
"""
791+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
792+
output_hidden_states = (
793+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
794+
)
795+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
796+
797+
transformer_outputs = self.transformer(
798+
input_ids,
799+
past_key_values=past_key_values,
800+
attention_mask=attention_mask,
801+
head_mask=head_mask,
802+
inputs_embeds=inputs_embeds,
803+
use_cache=use_cache,
804+
output_attentions=output_attentions,
805+
output_hidden_states=output_hidden_states,
806+
return_dict=return_dict,
807+
)
808+
past_key_values = None
809+
hidden_states = transformer_outputs[0]
810+
lm_logits = self.lm_head(hidden_states)
811+
loss = None
812+
if labels is not None:
813+
# Shift so that tokens < n predict n
814+
labels = labels.to(lm_logits.device)
815+
shift_logits = lm_logits[..., :-1, :].contiguous()
816+
shift_labels = labels[..., 1:].contiguous()
817+
batch_size, seq_length, vocab_size = shift_logits.shape
818+
# Flatten the tokens
819+
new_vocab_size = shift_logits.shape[-1]
820+
shift_logits = shift_logits.view(-1, new_vocab_size)
821+
shift_labels = shift_labels.view(-1)
822+
loss = cross_entropy_1d(
823+
shift_logits,
824+
shift_labels,
825+
process_group=shard_config.tensor_parallel_process_group,
826+
vocab_size=self.lm_head.out_features,
827+
dtype=self.transformer.dtype,
828+
)
829+
830+
if not return_dict:
831+
output = (lm_logits,) + transformer_outputs[1:]
832+
return ((loss,) + output) if loss is not None else output
833+
834+
return CausalLMOutputWithPast(
835+
loss=loss,
836+
logits=lm_logits,
837+
past_key_values=transformer_outputs.past_key_values,
838+
hidden_states=transformer_outputs.hidden_states,
839+
attentions=transformer_outputs.attentions,
840+
)
841+
842+
return forward

colossalai/shardformer/modeling/gpt2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,7 @@ def gpt2_lmhead_model_forward(
389389
shift_labels,
390390
process_group=shard_config.tensor_parallel_process_group,
391391
vocab_size=self.lm_head.out_features,
392+
dtype=self.transformer.dtype,
392393
)
393394
else:
394395
loss = loss_fct(shift_logits, shift_labels)
@@ -1294,6 +1295,7 @@ def forward(
12941295
shift_labels,
12951296
process_group=shard_config.tensor_parallel_process_group,
12961297
vocab_size=self.lm_head.out_features,
1298+
dtype=self.transformer.dtype,
12971299
)
12981300

12991301
if not return_dict:

colossalai/shardformer/modeling/llama.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ def llama_for_causal_lm_forward(
332332
shift_labels,
333333
process_group=shard_config.tensor_parallel_process_group,
334334
vocab_size=self.lm_head.out_features,
335+
dtype=self.model.dtype,
335336
)
336337
else:
337338
shift_logits = shift_logits.view(-1, self.config.vocab_size)
@@ -768,6 +769,7 @@ def forward(
768769
shift_labels,
769770
process_group=shard_config.tensor_parallel_process_group,
770771
vocab_size=self.lm_head.out_features,
772+
dtype=self.model.dtype,
771773
)
772774

773775
if not return_dict:

colossalai/shardformer/modeling/mistral.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ def mistral_for_causal_lm_forward(
281281
shift_labels,
282282
process_group=shard_config.tensor_parallel_process_group,
283283
vocab_size=self.lm_head.out_features,
284+
dtype=self.model.dtype,
284285
)
285286
else:
286287
shift_logits = shift_logits.view(-1, self.config.vocab_size)
@@ -701,6 +702,7 @@ def forward(
701702
shift_labels,
702703
process_group=shard_config.tensor_parallel_process_group,
703704
vocab_size=self.lm_head.out_features,
705+
dtype=self.model.dtype,
704706
)
705707

706708
if not return_dict:

colossalai/shardformer/modeling/opt.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,7 @@ def opt_for_causal_lm_forward(
348348
shift_labels,
349349
process_group=shard_config.tensor_parallel_process_group,
350350
vocab_size=self.lm_head.out_features,
351+
dtype=self.model.decoder.dtype,
351352
)
352353
else:
353354
loss_fct = CrossEntropyLoss()
@@ -988,6 +989,7 @@ def forward(
988989
shift_labels,
989990
process_group=shard_config.tensor_parallel_process_group,
990991
vocab_size=self.lm_head.out_features,
992+
dtype=self.model.decoder.dtype,
991993
)
992994

993995
if not return_dict:

0 commit comments

Comments
 (0)