@@ -299,8 +299,9 @@ class LlamaRMSNorm(nn.Layer):
299299 def __init__ (self , config ):
300300 super ().__init__ ()
301301 self .hidden_size = config .hidden_size
302+ self .embedding_output_size = config .embedding_output_size
302303 self .weight = paddle .create_parameter (
303- shape = [self .hidden_size ],
304+ shape = [self .embedding_output_size ],
304305 dtype = paddle .get_default_dtype (),
305306 default_initializer = nn .initializer .Constant (1.0 ),
306307 )
@@ -465,6 +466,7 @@ class LlamaMLP(nn.Layer):
465466 def __init__ (self , config ):
466467 super ().__init__ ()
467468 self .hidden_size = config .hidden_size
469+ self .embedding_output_size = config .embedding_output_size
468470 self .intermediate_size = config .intermediate_size
469471 self .tensor_parallel_degree = config .tensor_parallel_degree
470472 self .fuse_attention_ffn = config .fuse_attention_ffn
@@ -479,39 +481,41 @@ def __init__(self, config):
479481 if config .tensor_parallel_degree > 1 :
480482 if config .fuse_attention_ffn :
481483 self .gate_up_fused_proj = ColumnParallelLinear (
482- self .hidden_size ,
484+ self .embedding_output_size ,
483485 self .intermediate_size * 2 ,
484486 gather_output = False ,
485487 has_bias = False ,
486488 )
487489 else :
488490 self .gate_proj = ColumnParallelLinear (
489- self .hidden_size ,
491+ self .embedding_output_size ,
490492 self .intermediate_size ,
491493 gather_output = False ,
492494 has_bias = False ,
493495 )
494496 self .up_proj = ColumnParallelLinear (
495- self .hidden_size ,
497+ self .embedding_output_size ,
496498 self .intermediate_size ,
497499 gather_output = False ,
498500 has_bias = False ,
499501 )
500502
501503 self .down_proj = RowParallelLinear (
502504 self .intermediate_size ,
503- self .hidden_size ,
505+ self .embedding_output_size ,
504506 input_is_parallel = True ,
505507 has_bias = False ,
506508 )
507509 else :
508510 if config .fuse_attention_ffn :
509- self .gate_up_fused_proj = nn .Linear (self .hidden_size , self .intermediate_size * 2 , bias_attr = False )
511+ self .gate_up_fused_proj = nn .Linear (
512+ self .embedding_output_size , self .intermediate_size * 2 , bias_attr = False
513+ )
510514 else :
511- self .gate_proj = nn .Linear (self .hidden_size , self .intermediate_size , bias_attr = False )
512- self .up_proj = nn .Linear (self .hidden_size , self .intermediate_size , bias_attr = False )
515+ self .gate_proj = nn .Linear (self .embedding_output_size , self .intermediate_size , bias_attr = False )
516+ self .up_proj = nn .Linear (self .embedding_output_size , self .intermediate_size , bias_attr = False )
513517
514- self .down_proj = nn .Linear (self .intermediate_size , self .hidden_size , bias_attr = False )
518+ self .down_proj = nn .Linear (self .intermediate_size , self .embedding_output_size , bias_attr = False )
515519
516520 def forward (self , x ):
517521 if self .fuse_attention_ffn :
@@ -530,6 +534,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
530534
531535 self .config = config
532536 self .hidden_size = config .hidden_size
537+ self .embedding_output_size = config .embedding_output_size
533538 self .num_heads = config .num_attention_heads
534539
535540 self .head_dim = self .hidden_size // config .num_attention_heads
@@ -590,78 +595,78 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
590595 if config .tensor_parallel_degree > 1 :
591596 if self .fuse_attention_qkv :
592597 self .qkv_proj = ColumnParallelLinear (
593- self .hidden_size ,
598+ self .embedding_output_size ,
594599 3 * self .hidden_size ,
595600 has_bias = False ,
596601 gather_output = False ,
597602 )
598603 else :
599604 self .q_proj = ColumnParallelLinear (
600- self .hidden_size ,
605+ self .embedding_output_size ,
601606 self .hidden_size ,
602607 has_bias = False ,
603608 gather_output = False ,
604609 )
605610 if self .kv_indices is None :
606611 self .k_proj = ColumnParallelLinear (
607- self .hidden_size ,
612+ self .embedding_output_size ,
608613 self .config .num_key_value_heads * self .head_dim ,
609614 has_bias = False ,
610615 gather_output = False ,
611616 )
612617 self .v_proj = ColumnParallelLinear (
613- self .hidden_size ,
618+ self .embedding_output_size ,
614619 self .config .num_key_value_heads * self .head_dim ,
615620 has_bias = False ,
616621 gather_output = False ,
617622 )
618623 else :
619624 self .k_proj = nn .Linear (
620- self .hidden_size ,
625+ self .embedding_output_size ,
621626 self .config .num_key_value_heads * self .head_dim ,
622627 bias_attr = False ,
623628 )
624629 self .v_proj = nn .Linear (
625- self .hidden_size ,
630+ self .embedding_output_size ,
626631 self .config .num_key_value_heads * self .head_dim ,
627632 bias_attr = False ,
628633 )
629634
630635 else :
631636 if self .fuse_attention_qkv :
632637 self .qkv_proj = nn .Linear (
633- self .hidden_size ,
638+ self .embedding_output_size ,
634639 3 * self .hidden_size ,
635640 bias_attr = False ,
636641 )
637642 else :
638643 self .q_proj = nn .Linear (
639- self .hidden_size ,
644+ self .embedding_output_size ,
640645 self .hidden_size ,
641646 bias_attr = False ,
642647 )
643648 self .k_proj = nn .Linear (
644- self .hidden_size ,
649+ self .embedding_output_size ,
645650 self .config .num_key_value_heads * self .head_dim ,
646651 bias_attr = False ,
647652 )
648653 self .v_proj = nn .Linear (
649- self .hidden_size ,
654+ self .embedding_output_size ,
650655 self .config .num_key_value_heads * self .head_dim ,
651656 bias_attr = False ,
652657 )
653658
654659 if config .tensor_parallel_degree > 1 :
655660 self .o_proj = RowParallelLinear (
656661 self .hidden_size ,
657- self .hidden_size ,
662+ self .embedding_output_size ,
658663 has_bias = False ,
659664 input_is_parallel = True ,
660665 )
661666 else :
662667 self .o_proj = nn .Linear (
663668 self .hidden_size ,
664- self .hidden_size ,
669+ self .embedding_output_size ,
665670 bias_attr = False ,
666671 )
667672
@@ -1078,6 +1083,7 @@ def __init__(self, config: LlamaConfig):
10781083 super ().__init__ (config )
10791084 self .vocab_size = config .vocab_size
10801085 self .hidden_size = config .hidden_size
1086+ self .embedding_output_size = config .embedding_output_size
10811087 self .sequence_parallel = config .sequence_parallel
10821088 self .recompute_granularity = config .recompute_granularity
10831089 self .no_recompute_layers = config .no_recompute_layers if config .no_recompute_layers is not None else []
@@ -1087,13 +1093,13 @@ def __init__(self, config: LlamaConfig):
10871093 if config .tensor_parallel_degree > 1 :
10881094 self .embed_tokens = mpu .VocabParallelEmbedding (
10891095 self .vocab_size ,
1090- self .hidden_size ,
1096+ self .embedding_output_size ,
10911097 weight_attr = paddle .ParamAttr (initializer = nn .initializer .XavierNormal ()),
10921098 )
10931099 else :
10941100 self .embed_tokens = nn .Embedding (
10951101 self .vocab_size ,
1096- self .hidden_size ,
1102+ self .embedding_output_size ,
10971103 )
10981104
10991105 self .layers = nn .LayerList (
@@ -1115,12 +1121,10 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
11151121 # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
11161122 if len (attention_mask .shape ) == 2 :
11171123 expanded_attn_mask = _expand_2d_mask (attention_mask , dtype , tgt_length = input_shape [- 1 ])
1118- # For decoding phase in generation, seq_length = 1, we don't need to add causal mask
1119- if input_shape [- 1 ] > 1 :
1120- combined_attention_mask = _make_causal_mask (
1121- input_shape , past_key_values_length = past_key_values_length
1122- )
1123- expanded_attn_mask = expanded_attn_mask & combined_attention_mask
1124+ # For decoding phase in generation, seq_length = 1, we don't need to add causal mask. for we run pretrain, temporarily delete if
1125+ # if input_shape[-1] > 1:
1126+ combined_attention_mask = _make_causal_mask (input_shape , past_key_values_length = past_key_values_length )
1127+ expanded_attn_mask = expanded_attn_mask & combined_attention_mask
11241128 # [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len]
11251129 elif len (attention_mask .shape ) == 3 :
11261130 expanded_attn_mask = attention_mask .unsqueeze (1 ).astype ("bool" )
@@ -1359,7 +1363,7 @@ def __init__(self, config: LlamaConfig):
13591363 vocab_size = config .vocab_size
13601364
13611365 self .weight = self .create_parameter (
1362- shape = [config .hidden_size , vocab_size ],
1366+ shape = [config .embedding_output_size , vocab_size ],
13631367 dtype = paddle .get_default_dtype (),
13641368 )
13651369 # Must set distributed attr for Tensor Parallel !
0 commit comments