2424
2525class LMHead (nn .Layer ):
2626 def __init__ (self , config : PretrainedConfig ):
27- """
28- transpose_y (bool): Whether to transpose the lm_head weight matrix before matrix multiplication.
29- """
3027 super ().__init__ ()
3128 self .config = config
3229 self .use_bias = config .get ("lm_head_bias" , False )
33- self .transpose_y = config .get ("tie_word_embeddings" , False )
3430 self .vocab_parallel = False
3531
3632 # apply vocab tensor parallel
@@ -45,21 +41,15 @@ def __init__(self, config: PretrainedConfig):
4541 vocab_size ,
4642 config .tensor_parallel_degree ,
4743 )
48- self .lm_head_shape = (
49- [config .hidden_size , vocab_size ] if not self .transpose_y else [vocab_size , config .hidden_size ]
50- )
5144
5245 self .weight = self .create_parameter (
53- shape = self . lm_head_shape ,
46+ shape = [ vocab_size , config . hidden_size ] ,
5447 dtype = paddle .get_default_dtype (),
5548 default_initializer = nn .initializer .XavierNormal (1.0 ),
5649 )
5750
5851 # setting distributed attr for tensor parallel
59- self .weight .is_distributed = self .vocab_parallel
60-
61- if self .weight .is_distributed :
62- self .weight .split_axis = 0 if self .transpose_y else 1
52+ self ._set_distributed_attr (self .weight )
6353
6454 if self .use_bias :
6555 self .bias = self .create_parameter (
@@ -69,12 +59,15 @@ def __init__(self, config: PretrainedConfig):
6959 )
7060
7161 # setting distributed attr for tensor parallel
72- self .bias .is_distributed = self .vocab_parallel
73- if self .bias .is_distributed :
74- self .bias .split_axis = 0
62+ self ._set_distributed_attr (self .bias )
7563 else :
7664 self .bias = None
7765
66+ def _set_distributed_attr (self , param ):
67+ param .is_distributed = self .vocab_parallel
68+ if param .is_distributed :
69+ param .split_axis = 0
70+
7871 def forward (self , hidden_states , tensor_parallel_output = None ):
7972 """Project hidden states to vocabulary logits.
8073
@@ -114,5 +107,4 @@ def forward(self, hidden_states, tensor_parallel_output=None):
114107 )
115108
116109 def extra_repr (self ):
117- hidden_size , vocab_size = self .lm_head_shape if not self .transpose_y else self .lm_head_shape [::- 1 ]
118- return f"hidden_size={ hidden_size } , vocab_size={ vocab_size } , dtype={ self .weight .dtype } , vocab_parallel={ self .vocab_parallel } "
110+ return f"hidden_size={ self .weight .shape [1 ]} , vocab_size={ self .weight .shape [0 ]} , dtype={ self .weight .dtype } , vocab_parallel={ self .vocab_parallel } "
0 commit comments