32
32
from paddle .distributed .fleet .utils import recompute
33
33
from paddle .utils import try_import
34
34
35
- try :
36
- from paddle .distributed .fleet .utils .sequence_parallel_utils import (
37
- mark_as_sequence_parallel_parameter ,
38
- )
39
- except :
40
- pass
41
-
42
35
from ...utils .converter import StateDictNameMapping
43
36
from .. import PretrainedModel , register_base_model
44
37
from ..model_outputs import BaseModelOutputWithPastAndCrossAttentions
@@ -209,19 +202,19 @@ def __init__(self, config, ipp=None):
209
202
)
210
203
211
204
def _fuse_prepare_qkv (self , query , use_cache = False , past_key_value = None ):
212
- if self .config .sequence_parallel :
213
- # [bs, seq_len, num_head * head_dim] -> [bs / n, seq_len, num_head, head_dim] (n is model parallelism)
214
- target_shape = [- 1 , self .config .seq_length , self .num_attention_heads , 3 * self .head_dim ]
215
- else :
216
- target_shape = [0 , 0 , self .num_attention_heads , 3 * self .head_dim ]
217
-
205
+ target_shape = [0 , 0 , self .num_attention_heads , 3 * self .head_dim ]
218
206
# bs, seq_len, num_head * 3*head_dim
219
207
mix_layer = self .qkv_proj (query )
220
208
# bs, seq_len, num_head, 3*head_dim
221
209
mix_layer = paddle .reshape_ (mix_layer , target_shape )
222
210
# query_states, key_states, value_states => bs, seq_len, num_head, head_dim
223
211
query_states , key_states , value_states = paddle .split (mix_layer , num_or_sections = 3 , axis = - 1 )
224
-
212
+ if self .config .sequence_parallel :
213
+ # [seq_len, bs, num_head * head_dim] -> [bs, seq_len, num_head * head_dim] (if sequence_parallel)
214
+ # FA and rope not support sequence first
215
+ query_states = paddle .transpose (query_states , [1 , 0 , 2 , 3 ])
216
+ key_states = paddle .transpose (key_states , [1 , 0 , 2 , 3 ])
217
+ value_states = paddle .transpose (value_states , [1 , 0 , 2 , 3 ])
225
218
# [bs, seq_len, num_head, head_dim]
226
219
if past_key_value is not None :
227
220
# reuse k, v, self_attention
@@ -326,6 +319,8 @@ def forward(
326
319
Applies multi-head attention to map queries and a set of key-value pairs
327
320
to outputs.
328
321
"""
322
+ if self .config .sequence_parallel :
323
+ query = dist .reshard (query , get_mesh (self .ipp ), [dist .Shard (1 ), dist .Replicate ()])
329
324
key = query if key is None else key
330
325
value = query if value is None else value
331
326
if self .config .fuse_attention_qkv :
@@ -363,11 +358,11 @@ def forward(
363
358
# else their shape are [bs, q_len, num_head * head_dim / n], n is mp parallelism.
364
359
365
360
if self .config .sequence_parallel :
366
- bs , seq_len , dim = out .shape
367
- out = out .reshape ([bs * seq_len , dim ]) # [bs, seq_len, dim / n] => [bs * seq_len, dim / n]
368
-
361
+ out = paddle .transpose (out , [1 , 0 , 2 ])
369
362
# project to output
370
363
out = self .out_proj (out )
364
+ if self .config .sequence_parallel :
365
+ out = dist .reshard (out , get_mesh (self .ipp ), [dist .Shard (1 ), dist .Shard (0 )])
371
366
# if sequence_parallel is true, out shape are [bs * seq_len / n, dim]
372
367
# else their shape are [bs, seq_len, dim], n is mp parallelism.
373
368
outs = [out ]
@@ -390,9 +385,6 @@ def __init__(self, config, decoder_layers, norm=None, hidden_size=None):
390
385
self .layers = decoder_layers
391
386
392
387
self .norm = GPTLayerNorm (config , config .hidden_size , epsilon = 1e-5 )
393
- if config .sequence_parallel :
394
- mark_as_sequence_parallel_parameter (self .norm .weight )
395
- mark_as_sequence_parallel_parameter (self .norm .bias )
396
388
397
389
# Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True
398
390
# Enable_recompute defaults to False and is controlled by Trainer
@@ -536,11 +528,6 @@ def __init__(self, config: GPTConfig, ipp=None):
536
528
self .norm1 = GPTLayerNorm (config , config .hidden_size , self .ipp , epsilon = 1e-5 , bias_attr = True )
537
529
self .norm2 = GPTLayerNorm (config , config .hidden_size , self .ipp , epsilon = 1e-5 , bias_attr = True )
538
530
539
- if config .sequence_parallel :
540
- mark_as_sequence_parallel_parameter (self .norm1 .weight )
541
- mark_as_sequence_parallel_parameter (self .norm1 .bias )
542
- mark_as_sequence_parallel_parameter (self .norm2 .weight )
543
- mark_as_sequence_parallel_parameter (self .norm2 .bias )
544
531
if config .use_fused_dropout_add :
545
532
self .fused_dropout_add1 = FusedDropoutAdd (config .attention_probs_dropout_prob , mode = "upscale_in_train" )
546
533
self .fused_dropout_add2 = FusedDropoutAdd (config .hidden_dropout_prob , mode = "upscale_in_train" )
@@ -593,6 +580,12 @@ def forward(
593
580
594
581
# Use a ternary operator for a more concise assignment of current_seed
595
582
current_seed = "local_seed" if self .config .sequence_parallel else "global_seed"
583
+ if self .config .sequence_parallel :
584
+ hidden_states = dist .reshard (
585
+ hidden_states ,
586
+ get_mesh (self .ipp ),
587
+ [dist .Shard (1 ), dist .Shard (0 )],
588
+ )
596
589
597
590
# The 'with' block ensures the correct seed context is used
598
591
with seed_guard_context (current_seed ):
@@ -607,14 +600,17 @@ def forward(
607
600
residual = hidden_states
608
601
if self .config .normalize_before :
609
602
hidden_states = self .norm2 (hidden_states )
610
-
603
+ if self .config .sequence_parallel :
604
+ hidden_states = dist .reshard (hidden_states , get_mesh (self .ipp ), [dist .Shard (1 ), dist .Replicate ()])
611
605
# when sequence_parallel=True:
612
606
# hidden_states => [bs * seq_len / n, embed_dim]
613
607
with seed_guard_context (current_seed ):
614
608
if not self .config .use_fused_dropout_add :
615
609
l_1 = self .linear1 (hidden_states )
616
610
act = self .activation (l_1 , approximate = True )
617
611
l_2 = self .linear2 (act )
612
+ if self .config .sequence_parallel :
613
+ l_2 = dist .reshard (l_2 , get_mesh (self .ipp ), [dist .Shard (1 ), dist .Shard (0 )])
618
614
hidden_states = residual + self .dropout2 (l_2 )
619
615
else :
620
616
hidden_states = self .fused_dropout_add2 (
@@ -658,7 +654,7 @@ def __init__(
658
654
config .hidden_size ,
659
655
)
660
656
self .word_embeddings .weight = dist .shard_tensor (
661
- self .word_embeddings .weight , get_mesh (), [dist .Replicate (), dist .Replicate ( )]
657
+ self .word_embeddings .weight , get_mesh (), [dist .Replicate (), dist .Shard ( 1 )]
662
658
)
663
659
self .position_embeddings .weight = dist .shard_tensor (
664
660
self .position_embeddings .weight , get_mesh (), [dist .Replicate (), dist .Shard (1 )]
@@ -685,18 +681,15 @@ def forward(self, input_ids, position_ids=None, inputs_embeddings=None):
685
681
position_embeddings = self .position_embeddings (position_ids )
686
682
embeddings = inputs_embeddings + position_embeddings
687
683
688
- # exit()
689
- if self .config .sequence_parallel :
690
- # embeddings = dist.shard_tensor(embeddings,get_mesh(),[dist.Replicate(),dist.Replicate()])
691
- bs , seq_len , hidden_size = embeddings .shape
692
- # [bs, seq_len, dim] -> [bs * seq_len, dim]
693
- embeddings = paddle .reshape_ (embeddings , [bs * seq_len , hidden_size ])
694
- # [bs * seq_len / n, dim] (n is mp parallelism)
695
- # embeddings = ScatterOp.apply(embeddings)
696
- embeddings = dist .reshard (embeddings , get_mesh (), [dist .Replicate (), dist .Shard (0 )])
697
684
# Use a ternary operator for a more concise assignment of current_seed
698
685
current_seed = "local_seed" if self .config .sequence_parallel else "global_seed"
699
686
# The 'with' block ensures the correct seed context is used
687
+ if self .config .sequence_parallel :
688
+ # [B, S, H] -> [S, B, H]
689
+ embeddings = paddle .transpose (embeddings , [1 , 0 , 2 ])
690
+ embeddings = dist .reshard (embeddings , get_mesh (), [dist .Shard (1 ), dist .Shard (0 )])
691
+ else :
692
+ embeddings = dist .reshard (embeddings , get_mesh (), [dist .Shard (0 ), dist .Replicate ()])
700
693
with seed_guard_context (current_seed ):
701
694
embeddings = self .dropout (embeddings )
702
695
return embeddings
@@ -1176,13 +1169,16 @@ def __init__(self, config: GPTConfig, embedding_weights=None, ipp=None):
1176
1169
shape = [config .vocab_size , config .hidden_size ],
1177
1170
dtype = paddle .get_default_dtype (),
1178
1171
)
1179
- self .weight = dist .shard_tensor (self .weight , get_mesh (self .ipp ), [dist .Replicate (), dist .Shard (0 )])
1172
+ self .weight = dist .shard_tensor (self .weight , get_mesh (self .ipp ), [dist .Replicate (), dist .Shard (1 )])
1180
1173
1181
1174
def forward (self , hidden_states , tensor_parallel_output = None ):
1182
-
1183
1175
if self .config .sequence_parallel :
1184
- hidden_states = dist .reshard (hidden_states , get_mesh (self .ipp ), [dist .Replicate (), dist .Replicate ()])
1185
- hidden_states = paddle .reshape (hidden_states , [- 1 , self .config .seq_length , self .config .hidden_size ])
1176
+ hidden_states = dist .reshard (
1177
+ hidden_states ,
1178
+ get_mesh (self .ipp ),
1179
+ [dist .Shard (1 ), dist .Shard (0 )],
1180
+ )
1181
+ hidden_states = paddle .transpose (hidden_states , [1 , 0 , 2 ])
1186
1182
1187
1183
if tensor_parallel_output is None :
1188
1184
tensor_parallel_output = self .config .tensor_parallel_output
0 commit comments