2828from paddle .distributed import fleet
2929from paddle .distributed .fleet .utils import recompute
3030
31+ from ..segment_parallel_utils import sep_reshard_layer
32+
3133try :
3234 from paddle .incubate .nn .functional import fused_rotary_position_embedding
3335except ImportError :
@@ -200,12 +202,24 @@ def scaled_dot_product_attention(
200202 return (attn_output , attn_weights ) if output_attentions else attn_output
201203
202204
203- def get_colwise_placement (has_seq_mesh ):
204- return [dist .Replicate (), dist .Replicate (), dist .Shard (1 )] if has_seq_mesh else [dist .Replicate (), dist .Shard (1 )]
205+ def get_colwise_placement (has_seq_mesh , has_seq_parallel ):
206+ if has_seq_mesh :
207+ if has_seq_parallel : # not support mp+sep now
208+ return [dist .Replicate (), dist .Replicate (), dist .Replicate ()]
209+ else :
210+ return [dist .Replicate (), dist .Replicate (), dist .Shard (1 )]
211+ else :
212+ return [dist .Replicate (), dist .Shard (1 )]
205213
206214
207- def get_rowwise_placement (has_seq_mesh ):
208- return [dist .Replicate (), dist .Replicate (), dist .Shard (0 )] if has_seq_mesh else [dist .Replicate (), dist .Shard (0 )]
215+ def get_rowwise_placement (has_seq_mesh , has_seq_parallel ):
216+ if has_seq_mesh :
217+ if has_seq_parallel : # not support mp+sep now
218+ return [dist .Replicate (), dist .Replicate (), dist .Replicate ()]
219+ else :
220+ return [dist .Replicate (), dist .Replicate (), dist .Shard (0 )]
221+ else :
222+ return [dist .Replicate (), dist .Shard (0 )]
209223
210224
211225def get_replicate_placement (has_seq_mesh ):
@@ -266,28 +280,28 @@ def __init__(self, config, ipp: Optional[int] = None):
266280 self .gate_up_fused_proj .weight = dist .shard_tensor (
267281 self .gate_up_fused_proj .weight ,
268282 get_mesh (self .ipp ),
269- get_colwise_placement (has_seq_mesh ),
283+ get_colwise_placement (has_seq_mesh , self . config . sep_parallel_degree > 1 ),
270284 )
271285 else :
272286 self .gate_proj = nn .Linear (self .hidden_size , self .intermediate_size , bias_attr = False )
273287 self .gate_proj .weight = dist .shard_tensor (
274288 self .gate_proj .weight ,
275289 get_mesh (self .ipp ),
276- get_colwise_placement (has_seq_mesh ),
290+ get_colwise_placement (has_seq_mesh , self . config . sep_parallel_degree > 1 ),
277291 )
278292
279293 self .up_proj = nn .Linear (self .hidden_size , self .intermediate_size , bias_attr = False )
280294 self .up_proj .weight = dist .shard_tensor (
281295 self .up_proj .weight ,
282296 get_mesh (self .ipp ),
283- get_colwise_placement (has_seq_mesh ),
297+ get_colwise_placement (has_seq_mesh , self . config . sep_parallel_degree > 1 ),
284298 )
285299
286300 self .down_proj = nn .Linear (self .intermediate_size , self .hidden_size , bias_attr = False )
287301 self .down_proj .weight = dist .shard_tensor (
288302 self .down_proj .weight ,
289303 get_mesh (self .ipp ),
290- get_rowwise_placement (has_seq_mesh ),
304+ get_rowwise_placement (has_seq_mesh , self . config . sep_parallel_degree > 1 ),
291305 )
292306
293307 def forward (self , x ):
@@ -348,7 +362,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
348362 self .qkv_proj .weight = dist .shard_tensor (
349363 self .qkv_proj .weight ,
350364 get_mesh (self .ipp ),
351- get_colwise_placement (self .has_seq_mesh ),
365+ get_colwise_placement (self .has_seq_mesh , self . config . sep_parallel_degree > 1 ),
352366 )
353367
354368 else :
@@ -360,7 +374,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
360374 self .q_proj .weight = dist .shard_tensor (
361375 self .q_proj .weight ,
362376 get_mesh (self .ipp ),
363- get_colwise_placement (self .has_seq_mesh ),
377+ get_colwise_placement (self .has_seq_mesh , self . config . sep_parallel_degree > 1 ),
364378 )
365379
366380 self .k_proj = nn .Linear (
@@ -371,7 +385,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
371385 self .k_proj .weight = dist .shard_tensor (
372386 self .k_proj .weight ,
373387 get_mesh (self .ipp ),
374- get_colwise_placement (self .has_seq_mesh ),
388+ get_colwise_placement (self .has_seq_mesh , self . config . sep_parallel_degree > 1 ),
375389 )
376390
377391 self .v_proj = nn .Linear (
@@ -382,7 +396,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
382396 self .v_proj .weight = dist .shard_tensor (
383397 self .v_proj .weight ,
384398 get_mesh (self .ipp ),
385- get_colwise_placement (self .has_seq_mesh ),
399+ get_colwise_placement (self .has_seq_mesh , self . config . sep_parallel_degree > 1 ),
386400 )
387401
388402 self .o_proj = nn .Linear (
@@ -393,13 +407,16 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
393407 self .o_proj .weight = dist .shard_tensor (
394408 self .o_proj .weight ,
395409 get_mesh (self .ipp ),
396- get_rowwise_placement (self .has_seq_mesh ),
410+ get_rowwise_placement (self .has_seq_mesh , self . config . sep_parallel_degree > 1 ),
397411 )
398412
399413 if config .rope :
400414 self ._init_rope ()
401415
402416 self .config = config
417+ if config .sep_parallel_degree > 1 :
418+ assert self .num_key_value_heads % config .sep_parallel_degree == 0
419+ assert self .num_heads % config .sep_parallel_degree == 0
403420
404421 def _init_rope (self ):
405422 if self .config .rope_scaling_type is None :
@@ -456,37 +473,108 @@ def forward(
456473 )
457474
458475 if self .fuse_attention_qkv and not enable_fuse_ffn_qkv_pass ():
459- target_shape = [0 , 0 , self .num_key_value_heads , (self .num_key_value_groups + 2 ) * self .head_dim ]
460476 mix_layer = self .qkv_proj (hidden_states )
461- mix_layer = paddle .reshape_ (mix_layer , target_shape )
477+ # NOTE for GQA attention fusion (compatible with MHA and MQA):
478+ # The weight for qkv_proj is in shape like [hidden_size, hidden_size + 2 * num_kv_heads * head_dim].
479+ # After the projection, the mix_layer is in shape like [b, s, hidden_size + 2 * num_kv_heads * head_dim].
480+ # Reshape the mix_layer into a shape like [b, s, num_kv_heads, (num_groups + 2) * head_dim],
481+ # where num_groups = num_q_heads // num_kv_heads.
482+ # Split the mix_layer on the last axis into three sections [num_groups * head_dim, head_dim, head_dim]
483+ # to represent the q, k and v respectively.
484+ # The q is in the shape like [b, s, num_kv_heads, num_groups * head_dim].
485+ # The k and v are in the shape like [b, s, num_kv_heads, head_dim].
486+ # Under MHA, the q is ready for the following calculation since num_kv_heads == num_q_heads,
487+ # But for the GQA or MQA, q should be reshaped into [b, s, num_q_heads, head_dim].
488+ if self .config .sep_parallel_degree > 1 :
489+ if self .config .sequence_parallel :
490+ raise ValueError (
491+ "Sep parallel cannot be used with sequence parallel, "
492+ "because paddle auto parallel does not support "
493+ "reshard one dim twice."
494+ )
495+
496+ # [bs, seq_len / sep, num_head, head_dim] -> [bs, seq_len, num_head / sep, head_dim]
497+ mix_layer = sep_reshard_layer (
498+ mix_layer ,
499+ split_axis = 2 ,
500+ concat_axis = 1 ,
501+ )
502+ mix_layer = paddle .reshape_ (
503+ mix_layer , [0 , self .seq_length , - 1 , (self .num_key_value_groups + 2 ) * self .head_dim ]
504+ ) # [bs, seq_len, num_head/k, 3*head_dim], k is sep degree
505+ else :
506+ target_shape = [0 , 0 , self .num_key_value_heads , (self .num_key_value_groups + 2 ) * self .head_dim ]
507+ mix_layer = paddle .reshape_ (mix_layer , target_shape )
508+
462509 query_states , key_states , value_states = paddle .split (
463510 mix_layer ,
464511 num_or_sections = [self .num_key_value_groups * self .head_dim , self .head_dim , self .head_dim ],
465512 axis = - 1 ,
466513 )
467514 if self .gqa_or_mqa :
468515 query_states = paddle .reshape (query_states , [0 , 0 , self .num_heads , self .head_dim ])
516+ if self .config .sequence_parallel and self .config .sep_parallel_degree <= 1 :
517+ # [seq_len, bs, num_head * head_dim] -> [bs, seq_len, num_head * head_dim] (if sequence_parallel)
518+ # FA and rope not support sequence first
519+ query_states = paddle .transpose (query_states , [1 , 0 , 2 , 3 ])
520+ key_states = paddle .transpose (key_states , [1 , 0 , 2 , 3 ])
521+ value_states = paddle .transpose (value_states , [1 , 0 , 2 , 3 ])
469522 else :
470- target_query_shape = [0 , 0 , self .num_heads , self .head_dim ]
471- target_key_value_shape = [0 , 0 , self .num_key_value_heads , self .head_dim ]
472-
473- query_states = self .q_proj (hidden_states ).reshape (shape = target_query_shape )
474- key_states = self .k_proj (hidden_states ).reshape (shape = target_key_value_shape )
475- value_states = self .v_proj (hidden_states ).reshape (shape = target_key_value_shape )
476-
477- if self .config .sequence_parallel :
478- # [seq_len, bs, num_head * head_dim] -> [bs, seq_len, num_head * head_dim] (if sequence_parallel)
479- # FA and rope not support sequence first
480- query_states = paddle .transpose (query_states , [1 , 0 , 2 , 3 ])
481- key_states = paddle .transpose (key_states , [1 , 0 , 2 , 3 ])
482- value_states = paddle .transpose (value_states , [1 , 0 , 2 , 3 ])
523+ if self .config .sep_parallel_degree > 1 :
524+ query_states = self .q_proj (hidden_states )
525+ key_states = self .k_proj (hidden_states )
526+ value_states = self .v_proj (hidden_states )
527+ if self .config .sequence_parallel :
528+ raise ValueError (
529+ "Sep parallel cannot be used with sequence parallel, "
530+ "because paddle auto parallel does not support "
531+ "reshard one dim twice."
532+ )
483533
534+ query_states = sep_reshard_layer (
535+ query_states ,
536+ split_axis = 2 ,
537+ concat_axis = 1 ,
538+ )
539+ key_states = sep_reshard_layer (
540+ key_states ,
541+ split_axis = 2 ,
542+ concat_axis = 1 ,
543+ )
544+ value_states = sep_reshard_layer (
545+ value_states ,
546+ split_axis = 2 ,
547+ concat_axis = 1 ,
548+ )
549+ query_states = paddle .reshape (
550+ query_states , shape = [0 , self .seq_length , - 1 , self .head_dim ]
551+ ) # [bs, seq_len, num_head/k, head_dim], k is sep degree
552+ key_states = paddle .reshape (query_states , shape = [0 , self .seq_length , - 1 , self .head_dim ])
553+ value_states = paddle .reshape (value_states , shape = [0 , self .seq_length , - 1 , self .head_dim ])
554+ else :
555+ target_query_shape = [0 , 0 , self .num_heads , self .head_dim ]
556+ target_key_value_shape = [0 , 0 , self .num_key_value_heads , self .head_dim ]
557+
558+ query_states = self .q_proj (hidden_states ).reshape (shape = target_query_shape )
559+ key_states = self .k_proj (hidden_states ).reshape (shape = target_key_value_shape )
560+ value_states = self .v_proj (hidden_states ).reshape (shape = target_key_value_shape )
561+
562+ if self .config .sequence_parallel :
563+ # [seq_len, bs, num_head * head_dim] -> [bs, seq_len, num_head * head_dim] (if sequence_parallel)
564+ # FA and rope not support sequence first
565+ query_states = paddle .transpose (query_states , [1 , 0 , 2 , 3 ])
566+ key_states = paddle .transpose (key_states , [1 , 0 , 2 , 3 ])
567+ value_states = paddle .transpose (value_states , [1 , 0 , 2 , 3 ])
484568 kv_seq_len = key_states .shape [- 3 ]
485569
486570 if past_key_value is not None :
487571 kv_seq_len += past_key_value [0 ].shape [- 3 ]
488572
489573 if self .config .rope :
574+ query_placement = query_states .placements
575+ if self .config .sep_parallel_degree > 1 :
576+ batch_size , seq_length , _ , _ = query_states .shape
577+ position_ids = paddle .arange (seq_length , dtype = "int64" ).expand ((batch_size , seq_length ))
490578 if self .config .context_parallel_degree > 1 :
491579 mesh = dist .auto_parallel .get_mesh ()
492580 group = mesh ._get_group ("sep" )
@@ -516,16 +604,16 @@ def forward(
516604 self .rotary_emb ,
517605 self .config .context_parallel_degree ,
518606 )
519- if self .has_seq_mesh :
607+ if self .config . context_parallel_degree > 1 :
520608 query_states = dist .reshard (
521609 query_states ,
522610 get_mesh (self .ipp ),
523- [dist .Shard (0 ), dist .Shard (1 ), dist .Shard (2 )],
611+ query_placement , # [dist.Shard(0), dist.Shard(1), dist.Shard(2)],
524612 )
525613 key_states = dist .reshard (
526614 key_states ,
527615 get_mesh (self .ipp ),
528- [dist .Shard (0 ), dist .Shard (1 ), dist .Shard (2 )],
616+ query_placement , # [dist.Shard(0), dist.Shard(1), dist.Shard(2)],
529617 )
530618 else :
531619 cos , sin = self .rotary_emb (value_states , seq_len = kv_seq_len )
@@ -1282,7 +1370,7 @@ def __init__(self, config: LlamaConfig):
12821370 self .weight = dist .shard_tensor (
12831371 self .weight ,
12841372 get_mesh (- 1 ),
1285- get_colwise_placement (has_seq_mesh ),
1373+ get_colwise_placement (has_seq_mesh , self . config . sep_parallel_degree > 1 ),
12861374 )
12871375
12881376 def forward (self , hidden_states , tensor_parallel_output = None ):
0 commit comments