Skip to content

Commit e479070

Browse files
authored
cherry pick from pr #2622 (#2625)
1 parent fd6dce7 commit e479070

File tree

10 files changed

+21
-33
lines changed

10 files changed

+21
-33
lines changed

paddleformers/nn/criterion/loss_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def calc_lm_head_logits(
5555
hidden_states,
5656
weight,
5757
bias=bias,
58-
transpose_y=config.get("tie_word_embeddings", False),
58+
transpose_y=True,
5959
tensor_parallel_degree=config.tensor_parallel_degree,
6060
tensor_parallel_output=tensor_parallel_output,
6161
fuse_linear=config.get("fuse_linear", False),

paddleformers/nn/lm_head.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,9 @@
2424

2525
class LMHead(nn.Layer):
2626
def __init__(self, config: PretrainedConfig):
27-
"""
28-
transpose_y (bool): Whether to transpose the lm_head weight matrix before matrix multiplication.
29-
"""
3027
super().__init__()
3128
self.config = config
3229
self.use_bias = config.get("lm_head_bias", False)
33-
self.transpose_y = config.get("tie_word_embeddings", False)
3430
self.vocab_parallel = False
3531

3632
# apply vocab tensor parallel
@@ -45,21 +41,15 @@ def __init__(self, config: PretrainedConfig):
4541
vocab_size,
4642
config.tensor_parallel_degree,
4743
)
48-
self.lm_head_shape = (
49-
[config.hidden_size, vocab_size] if not self.transpose_y else [vocab_size, config.hidden_size]
50-
)
5144

5245
self.weight = self.create_parameter(
53-
shape=self.lm_head_shape,
46+
shape=[vocab_size, config.hidden_size],
5447
dtype=paddle.get_default_dtype(),
5548
default_initializer=nn.initializer.XavierNormal(1.0),
5649
)
5750

5851
# setting distributed attr for tensor parallel
59-
self.weight.is_distributed = self.vocab_parallel
60-
61-
if self.weight.is_distributed:
62-
self.weight.split_axis = 0 if self.transpose_y else 1
52+
self._set_distributed_attr(self.weight)
6353

6454
if self.use_bias:
6555
self.bias = self.create_parameter(
@@ -69,12 +59,15 @@ def __init__(self, config: PretrainedConfig):
6959
)
7060

7161
# setting distributed attr for tensor parallel
72-
self.bias.is_distributed = self.vocab_parallel
73-
if self.bias.is_distributed:
74-
self.bias.split_axis = 0
62+
self._set_distributed_attr(self.bias)
7563
else:
7664
self.bias = None
7765

66+
def _set_distributed_attr(self, param):
67+
param.is_distributed = self.vocab_parallel
68+
if param.is_distributed:
69+
param.split_axis = 0
70+
7871
def forward(self, hidden_states, tensor_parallel_output=None):
7972
"""Project hidden states to vocabulary logits.
8073
@@ -114,5 +107,4 @@ def forward(self, hidden_states, tensor_parallel_output=None):
114107
)
115108

116109
def extra_repr(self):
117-
hidden_size, vocab_size = self.lm_head_shape if not self.transpose_y else self.lm_head_shape[::-1]
118-
return f"hidden_size={hidden_size}, vocab_size={vocab_size}, dtype={self.weight.dtype}, vocab_parallel={self.vocab_parallel}"
110+
return f"hidden_size={self.weight.shape[1]}, vocab_size={self.weight.shape[0]}, dtype={self.weight.dtype}, vocab_parallel={self.vocab_parallel}"

paddleformers/transformers/ernie4_5/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ def _get_tensor_parallel_mappings(cls, config, is_split=True):
439439

440440
def make_base_actions():
441441
actions = {
442-
"lm_head.weight": partial(fn, is_column=not config.tie_word_embeddings),
442+
"lm_head.weight": partial(fn, is_column=False),
443443
"embed_tokens.weight": partial(fn, is_column=False),
444444
}
445445
for layer_idx in range(config.num_hidden_layers):

paddleformers/transformers/ernie4_5_moe/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ def _get_tensor_parallel_mappings(cls, config, is_split=True):
503503

504504
def make_base_actions():
505505
actions = {
506-
"lm_head.weight": partial(fn, is_column=not config.tie_word_embeddings),
506+
"lm_head.weight": partial(fn, is_column=False),
507507
"embed_tokens.weight": partial(fn, is_column=False),
508508
}
509509
for layer_idx in range(config.num_hidden_layers):

paddleformers/transformers/gpt_oss/modeling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -650,7 +650,7 @@ class GptOssPreTrainedModel(PretrainedModel):
650650
config_class = GptOssConfig
651651
base_model_prefix = "model"
652652
keys_to_ignore_on_load_unexpected = [r"self_attn.rotary_emb.inv_freq"]
653-
transpose_weight_keys = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "lm_head"]
653+
transpose_weight_keys = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
654654

655655
@classmethod
656656
def _get_tensor_parallel_mappings(cls, config: GptOssConfig, is_split=True):
@@ -667,7 +667,7 @@ def get_tensor_parallel_split_mappings(num_layers, num_experts):
667667
final_actions = {}
668668

669669
base_actions = {
670-
"lm_head.weight": partial(fn, is_column=True),
670+
"lm_head.weight": partial(fn, is_column=False),
671671
# Row Linear
672672
"embed_tokens.weight": partial(fn, is_column=False),
673673
"layers.0.self_attn.o_proj.weight": partial(fn, is_column=False),

paddleformers/transformers/qwen2/modeling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ class Qwen2PretrainedModel(PretrainedModel):
348348
config_class = Qwen2Config
349349
base_model_prefix = "model"
350350
_keys_to_ignore_on_load_unexpected = [r"self_attn.rotary_emb.inv_freq"]
351-
transpose_weight_keys = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "lm_head"]
351+
transpose_weight_keys = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
352352

353353
@classmethod
354354
def _get_tensor_parallel_mappings(cls, config: Qwen2Config, is_split=True):
@@ -380,7 +380,7 @@ def _get_tensor_parallel_mappings(cls, config: Qwen2Config, is_split=True):
380380

381381
def make_base_actions():
382382
actions = {
383-
"lm_head.weight": partial(fn, is_column=not config.tie_word_embeddings),
383+
"lm_head.weight": partial(fn, is_column=False),
384384
"embed_tokens.weight": partial(fn, is_column=False),
385385
}
386386
for layer_idx in range(config.num_hidden_layers):

paddleformers/transformers/qwen2_moe/modeling.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,6 @@ class Qwen2MoePretrainedModel(PretrainedModel):
388388
"down_proj",
389389
"gate",
390390
"shared_expert_gate",
391-
"lm_head",
392391
]
393392

394393
@classmethod
@@ -433,7 +432,7 @@ def _get_tensor_parallel_mappings(cls, config: Qwen2MoeConfig, is_split=True):
433432

434433
def make_base_actions():
435434
actions = {
436-
"lm_head.weight": partial(fn, is_column=not config.tie_word_embeddings),
435+
"lm_head.weight": partial(fn, is_column=False),
437436
"embed_tokens.weight": partial(fn, is_column=False),
438437
}
439438
for layer_idx in range(config.num_hidden_layers):

paddleformers/transformers/qwen3/modeling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ class Qwen3PretrainedModel(PretrainedModel):
247247
config_class = Qwen3Config
248248
base_model_prefix = "model"
249249
_keys_to_ignore_on_load_unexpected = [r"self_attn.rotary_emb.inv_freq"]
250-
transpose_weight_keys = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "lm_head"]
250+
transpose_weight_keys = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
251251

252252
@classmethod
253253
def _get_tensor_parallel_mappings(cls, config: Qwen3Config, is_split=True):
@@ -279,7 +279,7 @@ def _get_tensor_parallel_mappings(cls, config: Qwen3Config, is_split=True):
279279

280280
def make_base_actions():
281281
actions = {
282-
"lm_head.weight": partial(fn, is_column=not config.tie_word_embeddings),
282+
"lm_head.weight": partial(fn, is_column=False),
283283
"embed_tokens.weight": partial(fn, is_column=False),
284284
}
285285
for layer_idx in range(config.num_hidden_layers):

paddleformers/transformers/qwen3_moe/modeling.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,6 @@ class Qwen3MoePretrainedModel(PretrainedModel):
273273
"up_proj",
274274
"down_proj",
275275
"gate",
276-
"lm_head",
277276
]
278277

279278
@classmethod
@@ -311,7 +310,7 @@ def _get_tensor_parallel_mappings(cls, config: Qwen3MoeConfig, is_split=True):
311310

312311
def make_base_actions():
313312
actions = {
314-
"lm_head.weight": partial(fn, is_column=not config.tie_word_embeddings),
313+
"lm_head.weight": partial(fn, is_column=False),
315314
"embed_tokens.weight": partial(fn, is_column=False),
316315
}
317316
for layer_idx in range(config.num_hidden_layers):

tests/nn/test_lm_head.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,10 @@ def test_initialization_default(self):
2727
lm_head = LMHead(config)
2828

2929
# Check weight shape and attributes
30-
self.assertEqual(lm_head.weight.shape, [config.hidden_size, config.vocab_size])
30+
self.assertEqual(lm_head.weight.shape, [config.vocab_size, config.hidden_size])
3131
self.assertFalse(lm_head.weight.is_distributed)
3232
self.assertIsNone(lm_head.bias)
3333
self.assertFalse(lm_head.vocab_parallel)
34-
self.assertFalse(lm_head.transpose_y)
3534

3635
def test_initialization_with_tie_word_embeddings(self):
3736
# Test initialization with tied embeddings
@@ -40,7 +39,6 @@ def test_initialization_with_tie_word_embeddings(self):
4039
lm_head = LMHead(config)
4140

4241
self.assertEqual(lm_head.weight.shape, [config.vocab_size, config.hidden_size])
43-
self.assertTrue(lm_head.transpose_y)
4442

4543
def test_forward_normal(self):
4644
# Test normal forward pass

0 commit comments

Comments
 (0)