Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddleformers/transformers/glm4_moe/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def __init__(
topk_method="noaux_tc",
using_flex_token=True,
moe_subbatch_token_num=0,
moe_aux_loss_coeff=0,
**kwargs,
):
self.vocab_size = vocab_size
Expand Down Expand Up @@ -202,6 +203,7 @@ def __init__(
self.using_flex_token = using_flex_token
self.use_fp8 = False
self.moe_subbatch_token_num = moe_subbatch_token_num
self.moe_aux_loss_coeff = moe_aux_loss_coeff

self.pp_seg_method = pp_seg_method
self.disable_ffn_model_parallel = disable_ffn_model_parallel
Expand Down
70 changes: 42 additions & 28 deletions paddleformers/transformers/glm4_moe/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,12 +312,12 @@ def forward(self, hidden_states):
hidden_states (_type_): [batch_size * seq_len, hidden_size]
"""

# _, _, h_dim = hidden_states.shape

# compute gating score
with paddle.amp.auto_cast(False):
hidden_states = hidden_states.cast(self.weight.dtype)

logits = F.linear(hidden_states.cast("float32"), self.weight.cast("float32").t())

scores = self.gate_score_func(logits=logits)
scores = scores.cast(paddle.float32)

Expand Down Expand Up @@ -491,6 +491,8 @@ def __init__(self, config):
moe_group=moe_group,
)
if hasattr(dist, "fleet") and dist.is_initialized() and expert_parallel_degree > 1:
# for p in self.experts.parameters():
# setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group})
self.is_mp_moe = False
self.is_ep_moe = True
for p in self.experts.parameters():
Expand All @@ -509,7 +511,7 @@ def __init__(self, config):
)

def forward(self, hidden_states):
final_hidden_states, _, _ = super().forward(hidden_states)
final_hidden_states, aux_loss, _ = super().forward(hidden_states)
final_hidden_states = final_hidden_states + self.shared_experts(hidden_states)
return final_hidden_states

Expand All @@ -519,6 +521,7 @@ def __init__(self, config: Glm4MoeConfig, layer_idx: int):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.layer_idx = layer_idx

self.self_attn = Glm4MoeAttention(config=config, layer_idx=layer_idx)

Expand Down Expand Up @@ -591,21 +594,25 @@ def subbatch_recompute_forward(
sub_seq_len = self.config.moe_subbatch_token_num
seq_axis = 0 if self.config.sequence_parallel else 1
seq_len = hidden_states.shape[seq_axis]
# seq_len = sequence_length
assert seq_len % sub_seq_len == 0
num_chunks = seq_len // sub_seq_len
split_list = [sub_seq_len] * num_chunks
input_list = paddle.split(hidden_states, split_list, axis=seq_axis)
output_list = []
hidden_states_output_list = []
aux_loss_output_list = []

for chunk in input_list:
chunk = chunk.reshape([-1, hidden_size])
out = recompute(
hidden_states_out, aux_loss_out = recompute(
self.mlp.forward,
chunk,
**offload_kwargs,
)
output_list.append(out)
hidden_states = paddle.concat(output_list, axis=seq_axis)
hidden_states_output_list.append(hidden_states_out)
aux_loss_output_list.append(aux_loss_out)
hidden_states = paddle.cat(hidden_states_output_list, axis=seq_axis)
aux_loss = paddle.cat(aux_loss_output_list).sum()
outputs = recompute(
self.post_process,
hidden_states,
Expand All @@ -616,7 +623,7 @@ def subbatch_recompute_forward(
present_key_value,
**offload_kwargs,
)
return outputs
return outputs, aux_loss

def attn(
self,
Expand All @@ -629,7 +636,7 @@ def attn(
attn_mask_startend_row_indices: Optional[paddle.Tensor] = None,
position_embeddings: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None,
**kwargs,
):
) -> paddle.Tensor:
residual = hidden_states

hidden_states = self.input_layernorm(hidden_states)
Expand Down Expand Up @@ -693,13 +700,18 @@ def post_process(
present_key_value=None,
):
hidden_states = residual + hidden_states

outputs = (hidden_states,)

if output_attentions:
outputs += (self_attn_weights,)

if use_cache:
outputs += (present_key_value,)

if type(outputs) is tuple and len(outputs) == 1:
outputs = outputs[0]

return outputs

def forward(
Expand Down Expand Up @@ -1030,24 +1042,19 @@ def __init__(self, config: Glm4MoeConfig, device=None):

@paddle.no_grad()
def forward(self, x, position_ids):
# NOTE: Paddle's Automatic Mixed Precision (AMP) has a default op whitelist that may automatically cast
# certain operations (like matmul) to FP16/BF16 for performance optimization. However, in scenarios where
# numerical stability is critical (e.g., RoPE init/compute), this conversion can lead to precision loss.
# Disabling auto_cast here ensures the matmul operation runs in the original precision (FP32) as intended.
with paddle.amp.auto_cast(False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方训推会不一致,不要修改

inv_freq_expanded = (
self.inv_freq.unsqueeze(0)
.unsqueeze(-1)
.cast(paddle.float32)
.expand([position_ids.shape[0], -1, 1])
.to(x.place)
)
position_ids_expanded = position_ids.unsqueeze(1).cast(paddle.float32)
inv_freq_expanded = (
self.inv_freq.unsqueeze(0)
.unsqueeze(-1)
.cast(paddle.float32)
.expand([position_ids.shape[0], -1, 1])
.to(x.place)
)
position_ids_expanded = position_ids.unsqueeze(1).cast(paddle.float32)

freqs = paddle.matmul(inv_freq_expanded, position_ids_expanded).transpose([0, 2, 1])
emb = paddle.cat((freqs, freqs), axis=-1)
cos = paddle.cos(emb) * self.attention_scaling
sin = paddle.sin(emb) * self.attention_scaling
freqs = paddle.matmul(inv_freq_expanded, position_ids_expanded).transpose([0, 2, 1])
emb = paddle.cat((freqs, freqs), axis=-1)
cos = paddle.cos(emb) * self.attention_scaling
sin = paddle.sin(emb) * self.attention_scaling

return cos.cast(dtype=x.dtype), sin.cast(dtype=x.dtype)

Expand Down Expand Up @@ -1184,6 +1191,7 @@ def forward(
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
aux_loss = 0

moelayer_use_subbatch_recompute = (
self.config.moe_subbatch_token_num > 0 if hasattr(self.config, "moe_subbatch_token_num") else False
Expand Down Expand Up @@ -1237,6 +1245,8 @@ def forward(
hidden_states = layer_outputs[0]
else:
hidden_states = layer_outputs
if moelayer_use_subbatch_recompute:
aux_loss += layer_outputs[1]

if output_attentions:
all_self_attns += (layer_outputs[1],)
Expand All @@ -1252,7 +1262,7 @@ def forward(

next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache] if v is not None)
return tuple(v for v in [hidden_states, next_cache, aux_loss] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
Expand Down Expand Up @@ -1367,20 +1377,24 @@ def forward(
return_dict=return_dict,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
)

# output: hidden_states, next_cache, aux_loss
hidden_states = outputs[0] # [bs, seq_len, dim]
logits = self.lm_head(hidden_states)

loss = None
if labels is not None:
loss, _ = self.criterion(logits, labels)
if self.config.moe_aux_loss_coeff:
aux_loss = outputs[2]
loss += self.moe_aux_loss_coeff * aux_loss

if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output

return CausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
Expand Down
33 changes: 25 additions & 8 deletions paddleformers/transformers/moe_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,29 @@ def _cal_seq_aux_loss(self, gates, top_k, topk_idx) -> paddle.Tensor:
Returns:
paddle.Tensor: The value of sequence auxiliary loss.
"""
batch_size, seq_len, _ = gates.shape
ce = paddle.zeros([batch_size, self.num_experts])
topk_idx = topk_idx.reshape([batch_size, -1])
ce.put_along_axis_(indices=topk_idx, values=paddle.ones([batch_size, seq_len * top_k]), axis=1, reduce="add")
ce = ce / (seq_len * top_k / self.num_experts)
aux_loss = (ce * paddle.mean(gates, axis=1)).sum(axis=1).mean()
if self.config.sequence_parallel:
# [bs * seq_len, dim]
max_sequence_length = self.config.max_sequence_length
local_batch_seq, num_experts = gates.shape
global_batch_seq = local_batch_seq * self.config.tensor_parallel_degree
local_batch_size = global_batch_seq // max_sequence_length
ce = paddle.zeros([local_batch_size, num_experts])
topk_idx = topk_idx.reshape([local_batch_size, -1])
ones = paddle.ones([local_batch_size, max_sequence_length * top_k // self.config.tensor_parallel_degree])
ce.put_along_axis_(indices=topk_idx, values=ones, axis=1, reduce="add")
ce = ce / (max_sequence_length * top_k / num_experts)
avg_gates = paddle.mean(gates, axis=0) # [num_experts]
aux_loss = (ce * avg_gates).sum(axis=1).mean()
else:
# [bs, seq_len, dim]
batch_size, seq_len, num_experts = gates.shape
ce = paddle.zeros([batch_size, self.num_experts])
topk_idx = topk_idx.reshape([batch_size, -1])
ce.put_along_axis_(
indices=topk_idx, values=paddle.ones([batch_size, seq_len * top_k]), axis=1, reduce="add"
)
ce = ce / (seq_len * top_k / self.num_experts)
aux_loss = (ce * paddle.mean(gates, axis=1)).sum(axis=1).mean()
return aux_loss

def _cal_z_loss(self, logits) -> paddle.Tensor:
Expand Down Expand Up @@ -473,7 +490,7 @@ def topkgating(
gates: paddle.Tensor,
) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Implements TopKGating on logits."""
batch_size, seq_len, d_model = gates.shape
d_model = gates.shape[-1]
gates_ori = gates
gates = gates.reshape([-1, d_model])

Expand Down Expand Up @@ -553,7 +570,7 @@ def topkgating(

def topkgating_nodrop(self, gates: paddle.Tensor):
"""Implements TopKGating on logits."""
batch_size, seq_len, d_model = gates.shape
d_model = gates.shape[-1]
gates_ori = gates
gates = gates.reshape([-1, d_model])

Expand Down