Skip to content

Commit d250b34

Browse files
authored
[Auto Parallel] Fix sp in gpt modeling_auto (#10835)
* Update modeling_auto.py * Update pretrain-gpt3_13b_dynamic_auto.json * update * fix gpt sp model * Update loss base * update loss base * disable benchmark SP for now
1 parent b8f2101 commit d250b34

File tree

2 files changed

+44
-48
lines changed

2 files changed

+44
-48
lines changed

paddlenlp/transformers/gpt/modeling_auto.py

Lines changed: 36 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,6 @@
3232
from paddle.distributed.fleet.utils import recompute
3333
from paddle.utils import try_import
3434

35-
try:
36-
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
37-
mark_as_sequence_parallel_parameter,
38-
)
39-
except:
40-
pass
41-
4235
from ...utils.converter import StateDictNameMapping
4336
from .. import PretrainedModel, register_base_model
4437
from ..model_outputs import BaseModelOutputWithPastAndCrossAttentions
@@ -209,19 +202,19 @@ def __init__(self, config, ipp=None):
209202
)
210203

211204
def _fuse_prepare_qkv(self, query, use_cache=False, past_key_value=None):
212-
if self.config.sequence_parallel:
213-
# [bs, seq_len, num_head * head_dim] -> [bs / n, seq_len, num_head, head_dim] (n is model parallelism)
214-
target_shape = [-1, self.config.seq_length, self.num_attention_heads, 3 * self.head_dim]
215-
else:
216-
target_shape = [0, 0, self.num_attention_heads, 3 * self.head_dim]
217-
205+
target_shape = [0, 0, self.num_attention_heads, 3 * self.head_dim]
218206
# bs, seq_len, num_head * 3*head_dim
219207
mix_layer = self.qkv_proj(query)
220208
# bs, seq_len, num_head, 3*head_dim
221209
mix_layer = paddle.reshape_(mix_layer, target_shape)
222210
# query_states, key_states, value_states => bs, seq_len, num_head, head_dim
223211
query_states, key_states, value_states = paddle.split(mix_layer, num_or_sections=3, axis=-1)
224-
212+
if self.config.sequence_parallel:
213+
# [seq_len, bs, num_head * head_dim] -> [bs, seq_len, num_head * head_dim] (if sequence_parallel)
214+
# FA and rope not support sequence first
215+
query_states = paddle.transpose(query_states, [1, 0, 2, 3])
216+
key_states = paddle.transpose(key_states, [1, 0, 2, 3])
217+
value_states = paddle.transpose(value_states, [1, 0, 2, 3])
225218
# [bs, seq_len, num_head, head_dim]
226219
if past_key_value is not None:
227220
# reuse k, v, self_attention
@@ -326,6 +319,8 @@ def forward(
326319
Applies multi-head attention to map queries and a set of key-value pairs
327320
to outputs.
328321
"""
322+
if self.config.sequence_parallel:
323+
query = dist.reshard(query, get_mesh(self.ipp), [dist.Shard(1), dist.Replicate()])
329324
key = query if key is None else key
330325
value = query if value is None else value
331326
if self.config.fuse_attention_qkv:
@@ -363,11 +358,11 @@ def forward(
363358
# else their shape are [bs, q_len, num_head * head_dim / n], n is mp parallelism.
364359

365360
if self.config.sequence_parallel:
366-
bs, seq_len, dim = out.shape
367-
out = out.reshape([bs * seq_len, dim]) # [bs, seq_len, dim / n] => [bs * seq_len, dim / n]
368-
361+
out = paddle.transpose(out, [1, 0, 2])
369362
# project to output
370363
out = self.out_proj(out)
364+
if self.config.sequence_parallel:
365+
out = dist.reshard(out, get_mesh(self.ipp), [dist.Shard(1), dist.Shard(0)])
371366
# if sequence_parallel is true, out shape are [bs * seq_len / n, dim]
372367
# else their shape are [bs, seq_len, dim], n is mp parallelism.
373368
outs = [out]
@@ -390,9 +385,6 @@ def __init__(self, config, decoder_layers, norm=None, hidden_size=None):
390385
self.layers = decoder_layers
391386

392387
self.norm = GPTLayerNorm(config, config.hidden_size, epsilon=1e-5)
393-
if config.sequence_parallel:
394-
mark_as_sequence_parallel_parameter(self.norm.weight)
395-
mark_as_sequence_parallel_parameter(self.norm.bias)
396388

397389
# Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True
398390
# Enable_recompute defaults to False and is controlled by Trainer
@@ -536,11 +528,6 @@ def __init__(self, config: GPTConfig, ipp=None):
536528
self.norm1 = GPTLayerNorm(config, config.hidden_size, self.ipp, epsilon=1e-5, bias_attr=True)
537529
self.norm2 = GPTLayerNorm(config, config.hidden_size, self.ipp, epsilon=1e-5, bias_attr=True)
538530

539-
if config.sequence_parallel:
540-
mark_as_sequence_parallel_parameter(self.norm1.weight)
541-
mark_as_sequence_parallel_parameter(self.norm1.bias)
542-
mark_as_sequence_parallel_parameter(self.norm2.weight)
543-
mark_as_sequence_parallel_parameter(self.norm2.bias)
544531
if config.use_fused_dropout_add:
545532
self.fused_dropout_add1 = FusedDropoutAdd(config.attention_probs_dropout_prob, mode="upscale_in_train")
546533
self.fused_dropout_add2 = FusedDropoutAdd(config.hidden_dropout_prob, mode="upscale_in_train")
@@ -593,6 +580,12 @@ def forward(
593580

594581
# Use a ternary operator for a more concise assignment of current_seed
595582
current_seed = "local_seed" if self.config.sequence_parallel else "global_seed"
583+
if self.config.sequence_parallel:
584+
hidden_states = dist.reshard(
585+
hidden_states,
586+
get_mesh(self.ipp),
587+
[dist.Shard(1), dist.Shard(0)],
588+
)
596589

597590
# The 'with' block ensures the correct seed context is used
598591
with seed_guard_context(current_seed):
@@ -607,14 +600,17 @@ def forward(
607600
residual = hidden_states
608601
if self.config.normalize_before:
609602
hidden_states = self.norm2(hidden_states)
610-
603+
if self.config.sequence_parallel:
604+
hidden_states = dist.reshard(hidden_states, get_mesh(self.ipp), [dist.Shard(1), dist.Replicate()])
611605
# when sequence_parallel=True:
612606
# hidden_states => [bs * seq_len / n, embed_dim]
613607
with seed_guard_context(current_seed):
614608
if not self.config.use_fused_dropout_add:
615609
l_1 = self.linear1(hidden_states)
616610
act = self.activation(l_1, approximate=True)
617611
l_2 = self.linear2(act)
612+
if self.config.sequence_parallel:
613+
l_2 = dist.reshard(l_2, get_mesh(self.ipp), [dist.Shard(1), dist.Shard(0)])
618614
hidden_states = residual + self.dropout2(l_2)
619615
else:
620616
hidden_states = self.fused_dropout_add2(
@@ -658,7 +654,7 @@ def __init__(
658654
config.hidden_size,
659655
)
660656
self.word_embeddings.weight = dist.shard_tensor(
661-
self.word_embeddings.weight, get_mesh(), [dist.Replicate(), dist.Replicate()]
657+
self.word_embeddings.weight, get_mesh(), [dist.Replicate(), dist.Shard(1)]
662658
)
663659
self.position_embeddings.weight = dist.shard_tensor(
664660
self.position_embeddings.weight, get_mesh(), [dist.Replicate(), dist.Shard(1)]
@@ -685,18 +681,15 @@ def forward(self, input_ids, position_ids=None, inputs_embeddings=None):
685681
position_embeddings = self.position_embeddings(position_ids)
686682
embeddings = inputs_embeddings + position_embeddings
687683

688-
# exit()
689-
if self.config.sequence_parallel:
690-
# embeddings = dist.shard_tensor(embeddings,get_mesh(),[dist.Replicate(),dist.Replicate()])
691-
bs, seq_len, hidden_size = embeddings.shape
692-
# [bs, seq_len, dim] -> [bs * seq_len, dim]
693-
embeddings = paddle.reshape_(embeddings, [bs * seq_len, hidden_size])
694-
# [bs * seq_len / n, dim] (n is mp parallelism)
695-
# embeddings = ScatterOp.apply(embeddings)
696-
embeddings = dist.reshard(embeddings, get_mesh(), [dist.Replicate(), dist.Shard(0)])
697684
# Use a ternary operator for a more concise assignment of current_seed
698685
current_seed = "local_seed" if self.config.sequence_parallel else "global_seed"
699686
# The 'with' block ensures the correct seed context is used
687+
if self.config.sequence_parallel:
688+
# [B, S, H] -> [S, B, H]
689+
embeddings = paddle.transpose(embeddings, [1, 0, 2])
690+
embeddings = dist.reshard(embeddings, get_mesh(), [dist.Shard(1), dist.Shard(0)])
691+
else:
692+
embeddings = dist.reshard(embeddings, get_mesh(), [dist.Shard(0), dist.Replicate()])
700693
with seed_guard_context(current_seed):
701694
embeddings = self.dropout(embeddings)
702695
return embeddings
@@ -1176,13 +1169,16 @@ def __init__(self, config: GPTConfig, embedding_weights=None, ipp=None):
11761169
shape=[config.vocab_size, config.hidden_size],
11771170
dtype=paddle.get_default_dtype(),
11781171
)
1179-
self.weight = dist.shard_tensor(self.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(0)])
1172+
self.weight = dist.shard_tensor(self.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(1)])
11801173

11811174
def forward(self, hidden_states, tensor_parallel_output=None):
1182-
11831175
if self.config.sequence_parallel:
1184-
hidden_states = dist.reshard(hidden_states, get_mesh(self.ipp), [dist.Replicate(), dist.Replicate()])
1185-
hidden_states = paddle.reshape(hidden_states, [-1, self.config.seq_length, self.config.hidden_size])
1176+
hidden_states = dist.reshard(
1177+
hidden_states,
1178+
get_mesh(self.ipp),
1179+
[dist.Shard(1), dist.Shard(0)],
1180+
)
1181+
hidden_states = paddle.transpose(hidden_states, [1, 0, 2])
11861182

11871183
if tensor_parallel_output is None:
11881184
tensor_parallel_output = self.config.tensor_parallel_output

scripts/distribute/ci_case_auto.sh

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2492,11 +2492,11 @@ function llm_gpt_dygraph_auto_bs8_fp32_DP2() {
24922492
ips=-1
24932493
mem=-1
24942494
echo "result: loss=$loss ips=$ips mem=$mem loss_md5=$loss_md5"
2495-
loss_base=10.55853653 # output of dropout is different after supporting spmd
2495+
loss_base=10.55727577 # output of dropout is different after supporting spmd
24962496
ips_base=-1
24972497
mem_base=-1
24982498
if [ $IS_A100 -ne 0 ];then
2499-
loss_base=10.56019211 # after add dropout spmd
2499+
loss_base=10.56668472 # after add dropout spmd
25002500
fi
25012501
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
25022502
echo "=========== $FUNCNAME run end ==========="
@@ -2564,11 +2564,11 @@ function llm_gpt_dygraph_auto_bs8_fp32_DP2-MP2() {
25642564
ips=-1
25652565
mem=-1
25662566
echo "result: loss=$loss ips=$ips mem=$mem loss_md5=$loss_md5"
2567-
loss_base=10.5657959 # output of dropout is different after supporting spmd
2567+
loss_base=10.57985115 # output of dropout is different after supporting spmd
25682568
ips_base=-1
25692569
mem_base=-1
25702570
if [ $IS_A100 -ne 0 ];then
2571-
loss_base=10.5760107 # after add dropout spmd
2571+
loss_base=10.57280159 # after add dropout spmd
25722572
fi
25732573
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
25742574
echo "=========== $FUNCNAME run end ==========="
@@ -2637,11 +2637,11 @@ function llm_gpt_dygraph_auto_bs8_fp32_DP2-MP2-PP2() {
26372637
mem=-1
26382638
echo "result: loss=$loss ips=$ips mem=$mem loss_md5=$loss_md5"
26392639
# loss_base=10.59993172 # note: need to debug
2640-
loss_base=10.57174778 # output of dropout is different after supporting spmd
2640+
loss_base=10.57274055 # output of dropout is different after supporting spmd
26412641
ips_base=-1
26422642
mem_base=-1
26432643
if [ $IS_A100 -ne 0 ];then
2644-
loss_base=10.57701015 # after add dropout spmd
2644+
loss_base=10.57785797 # after add dropout spmd
26452645
fi
26462646
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
26472647
echo "=========== $FUNCNAME run end ==========="
@@ -2710,11 +2710,11 @@ function llm_gpt_dygraph_auto_bs8_fp16_DP2-MP2-PP2() {
27102710
mem=-1
27112711
echo "result: loss=$loss ips=$ips mem=$mem loss_md5=$loss_md5"
27122712
# loss_base=10.58456802 # note: need to debug
2713-
loss_base=10.57304478
2713+
loss_base=10.57409477
27142714
ips_base=-1
27152715
mem_base=-1
27162716
if [ $IS_A100 -ne 0 ];then
2717-
loss_base=10.57861042 # after add dropout spmd
2717+
loss_base=10.57924652 # after add dropout spmd
27182718
fi
27192719
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
27202720
echo "=========== $FUNCNAME run end ==========="

0 commit comments

Comments
 (0)