Skip to content

Commit 9efc79e

Browse files
committed
add parallel output for mistral model
1 parent d3f34ee commit 9efc79e

File tree

2 files changed

+126
-7
lines changed

2 files changed

+126
-7
lines changed

colossalai/shardformer/modeling/mistral.py

Lines changed: 116 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from colossalai.pipeline.stage_manager import PipelineStageManager
1717
from colossalai.shardformer.shard import ShardConfig
1818

19-
from ..layer import ColoAttention
19+
from ..layer import ColoAttention, cross_entropy_1d
2020

2121
logger = logging.get_logger(__name__)
2222

@@ -270,11 +270,22 @@ def mistral_for_causal_lm_forward(
270270
shift_labels = labels[..., 1:].contiguous()
271271
# Flatten the tokens
272272
loss_fct = CrossEntropyLoss()
273-
shift_logits = shift_logits.view(-1, self.config.vocab_size)
273+
#shift_logits = shift_logits.view(-1, self.config.vocab_size)
274274
shift_labels = shift_labels.view(-1)
275275
# Enable model parallelism
276276
shift_labels = shift_labels.to(shift_logits.device)
277-
loss = loss_fct(shift_logits, shift_labels)
277+
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
278+
new_vocab_size = logits.shape[-1]
279+
shift_logits = shift_logits.view(-1, new_vocab_size)
280+
loss = cross_entropy_1d(
281+
shift_logits,
282+
shift_labels,
283+
process_group=shard_config.tensor_parallel_process_group,
284+
vocab_size=self.lm_head.out_features,
285+
)
286+
else:
287+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
288+
loss = loss_fct(shift_logits, shift_labels)
278289

279290
if not return_dict:
280291
output = (logits,) + outputs[1:]
@@ -609,3 +620,105 @@ def forward(
609620
return attn_output, None, past_key_value
610621

611622
return forward
623+
624+
625+
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
626+
from transformers import MistralForCausalLM
627+
628+
def forward(
629+
self: MistralForCausalLM,
630+
input_ids: torch.LongTensor = None,
631+
attention_mask: Optional[torch.Tensor] = None,
632+
position_ids: Optional[torch.LongTensor] = None,
633+
past_key_values: Optional[List[torch.FloatTensor]] = None,
634+
inputs_embeds: Optional[torch.FloatTensor] = None,
635+
labels: Optional[torch.LongTensor] = None,
636+
use_cache: Optional[bool] = None,
637+
output_attentions: Optional[bool] = None,
638+
output_hidden_states: Optional[bool] = None,
639+
return_dict: Optional[bool] = None,
640+
) -> Union[Tuple, CausalLMOutputWithPast]:
641+
r"""
642+
Args:
643+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
644+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
645+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
646+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
647+
648+
Returns:
649+
650+
Example:
651+
652+
```python
653+
>>> from transformers import AutoTokenizer, MistralForCausalLM
654+
655+
>>> model = MistralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
656+
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
657+
658+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
659+
>>> inputs = tokenizer(prompt, return_tensors="pt")
660+
661+
>>> # Generate
662+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
663+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
664+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
665+
```"""
666+
667+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
668+
output_hidden_states = (
669+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
670+
)
671+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
672+
673+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
674+
outputs = self.model(
675+
input_ids=input_ids,
676+
attention_mask=attention_mask,
677+
position_ids=position_ids,
678+
past_key_values=past_key_values,
679+
inputs_embeds=inputs_embeds,
680+
use_cache=use_cache,
681+
output_attentions=output_attentions,
682+
output_hidden_states=output_hidden_states,
683+
return_dict=return_dict,
684+
)
685+
686+
hidden_states = outputs[0]
687+
if self.config.pretraining_tp > 1:
688+
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
689+
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
690+
logits = torch.cat(logits, dim=-1)
691+
else:
692+
logits = self.lm_head(hidden_states)
693+
logits = logits.float()
694+
695+
loss = None
696+
if labels is not None:
697+
# Shift so that tokens < n predict n
698+
shift_logits = logits[..., :-1, :].contiguous()
699+
shift_labels = labels[..., 1:].contiguous()
700+
shift_labels = shift_labels.view(-1)
701+
# Enable model parallelism
702+
shift_labels = shift_labels.to(shift_logits.device)
703+
new_vocab_size = logits.shape[-1]
704+
shift_logits = shift_logits.view(-1, new_vocab_size)
705+
loss = cross_entropy_1d(
706+
shift_logits,
707+
shift_labels,
708+
process_group=shard_config.tensor_parallel_process_group,
709+
vocab_size=self.lm_head.out_features,
710+
)
711+
712+
if not return_dict:
713+
output = (logits,) + outputs[1:]
714+
return (loss,) + output if loss is not None else output
715+
716+
return CausalLMOutputWithPast(
717+
loss=loss,
718+
logits=logits,
719+
past_key_values=outputs.past_key_values,
720+
hidden_states=outputs.hidden_states,
721+
attentions=outputs.attentions,
722+
)
723+
724+
return forward

colossalai/shardformer/policies/mistral.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
MistralForwards,
2121
get_mistral_flash_attention_forward,
2222
get_mistral_model_forward_for_flash_attn,
23+
get_lm_forward_with_dist_cross_entropy,
2324
)
2425
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
2526

@@ -275,14 +276,19 @@ def module_policy(self):
275276
SubModuleReplacementDescription(
276277
suffix="lm_head",
277278
target_module=VocabParallelLMHead1D,
278-
kwargs=dict(
279-
gather_output=True,
280-
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
281-
),
279+
kwargs={
280+
#gather_output=True,
281+
"gather_output": not self.shard_config.parallel_output,
282+
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
283+
},
282284
)
283285
]
284286
)
285287
}
288+
if self.shard_config.parallel_output:
289+
new_item[MistralForCausalLM].method_replacement = {
290+
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
291+
}
286292
else:
287293
new_item = {
288294
MistralForCausalLM: ModulePolicyDescription(

0 commit comments

Comments
 (0)