From 0da81305a74aa012ab2e5fb47b570b3ef1ce147c Mon Sep 17 00:00:00 2001 From: wangyanbo05 Date: Wed, 24 Sep 2025 22:23:25 +0800 Subject: [PATCH] support sp + moe aux loss --- .../transformers/glm4_moe/configuration.py | 2 + .../transformers/glm4_moe/modeling.py | 70 +++++++++++-------- paddleformers/transformers/moe_gate.py | 33 ++++++--- 3 files changed, 69 insertions(+), 36 deletions(-) diff --git a/paddleformers/transformers/glm4_moe/configuration.py b/paddleformers/transformers/glm4_moe/configuration.py index e4e325a6fa..cc6df746df 100644 --- a/paddleformers/transformers/glm4_moe/configuration.py +++ b/paddleformers/transformers/glm4_moe/configuration.py @@ -159,6 +159,7 @@ def __init__( topk_method="noaux_tc", using_flex_token=True, moe_subbatch_token_num=0, + moe_aux_loss_coeff=0, **kwargs, ): self.vocab_size = vocab_size @@ -202,6 +203,7 @@ def __init__( self.using_flex_token = using_flex_token self.use_fp8 = False self.moe_subbatch_token_num = moe_subbatch_token_num + self.moe_aux_loss_coeff = moe_aux_loss_coeff self.pp_seg_method = pp_seg_method self.disable_ffn_model_parallel = disable_ffn_model_parallel diff --git a/paddleformers/transformers/glm4_moe/modeling.py b/paddleformers/transformers/glm4_moe/modeling.py index d22eb077aa..91353a9eaf 100644 --- a/paddleformers/transformers/glm4_moe/modeling.py +++ b/paddleformers/transformers/glm4_moe/modeling.py @@ -312,12 +312,12 @@ def forward(self, hidden_states): hidden_states (_type_): [batch_size * seq_len, hidden_size] """ + # _, _, h_dim = hidden_states.shape + # compute gating score with paddle.amp.auto_cast(False): hidden_states = hidden_states.cast(self.weight.dtype) - logits = F.linear(hidden_states.cast("float32"), self.weight.cast("float32").t()) - scores = self.gate_score_func(logits=logits) scores = scores.cast(paddle.float32) @@ -491,6 +491,8 @@ def __init__(self, config): moe_group=moe_group, ) if hasattr(dist, "fleet") and dist.is_initialized() and expert_parallel_degree > 1: + # for p in self.experts.parameters(): + # setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group}) self.is_mp_moe = False self.is_ep_moe = True for p in self.experts.parameters(): @@ -509,7 +511,7 @@ def __init__(self, config): ) def forward(self, hidden_states): - final_hidden_states, _, _ = super().forward(hidden_states) + final_hidden_states, aux_loss, _ = super().forward(hidden_states) final_hidden_states = final_hidden_states + self.shared_experts(hidden_states) return final_hidden_states @@ -519,6 +521,7 @@ def __init__(self, config: Glm4MoeConfig, layer_idx: int): super().__init__() self.config = config self.hidden_size = config.hidden_size + self.layer_idx = layer_idx self.self_attn = Glm4MoeAttention(config=config, layer_idx=layer_idx) @@ -591,21 +594,25 @@ def subbatch_recompute_forward( sub_seq_len = self.config.moe_subbatch_token_num seq_axis = 0 if self.config.sequence_parallel else 1 seq_len = hidden_states.shape[seq_axis] + # seq_len = sequence_length assert seq_len % sub_seq_len == 0 num_chunks = seq_len // sub_seq_len split_list = [sub_seq_len] * num_chunks input_list = paddle.split(hidden_states, split_list, axis=seq_axis) - output_list = [] + hidden_states_output_list = [] + aux_loss_output_list = [] for chunk in input_list: chunk = chunk.reshape([-1, hidden_size]) - out = recompute( + hidden_states_out, aux_loss_out = recompute( self.mlp.forward, chunk, **offload_kwargs, ) - output_list.append(out) - hidden_states = paddle.concat(output_list, axis=seq_axis) + hidden_states_output_list.append(hidden_states_out) + aux_loss_output_list.append(aux_loss_out) + hidden_states = paddle.cat(hidden_states_output_list, axis=seq_axis) + aux_loss = paddle.cat(aux_loss_output_list).sum() outputs = recompute( self.post_process, hidden_states, @@ -616,7 +623,7 @@ def subbatch_recompute_forward( present_key_value, **offload_kwargs, ) - return outputs + return outputs, aux_loss def attn( self, @@ -629,7 +636,7 @@ def attn( attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, position_embeddings: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, **kwargs, - ): + ) -> paddle.Tensor: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -693,13 +700,18 @@ def post_process( present_key_value=None, ): hidden_states = residual + hidden_states + outputs = (hidden_states,) + if output_attentions: outputs += (self_attn_weights,) + if use_cache: outputs += (present_key_value,) + if type(outputs) is tuple and len(outputs) == 1: outputs = outputs[0] + return outputs def forward( @@ -1030,24 +1042,19 @@ def __init__(self, config: Glm4MoeConfig, device=None): @paddle.no_grad() def forward(self, x, position_ids): - # NOTE: Paddle's Automatic Mixed Precision (AMP) has a default op whitelist that may automatically cast - # certain operations (like matmul) to FP16/BF16 for performance optimization. However, in scenarios where - # numerical stability is critical (e.g., RoPE init/compute), this conversion can lead to precision loss. - # Disabling auto_cast here ensures the matmul operation runs in the original precision (FP32) as intended. - with paddle.amp.auto_cast(False): - inv_freq_expanded = ( - self.inv_freq.unsqueeze(0) - .unsqueeze(-1) - .cast(paddle.float32) - .expand([position_ids.shape[0], -1, 1]) - .to(x.place) - ) - position_ids_expanded = position_ids.unsqueeze(1).cast(paddle.float32) + inv_freq_expanded = ( + self.inv_freq.unsqueeze(0) + .unsqueeze(-1) + .cast(paddle.float32) + .expand([position_ids.shape[0], -1, 1]) + .to(x.place) + ) + position_ids_expanded = position_ids.unsqueeze(1).cast(paddle.float32) - freqs = paddle.matmul(inv_freq_expanded, position_ids_expanded).transpose([0, 2, 1]) - emb = paddle.cat((freqs, freqs), axis=-1) - cos = paddle.cos(emb) * self.attention_scaling - sin = paddle.sin(emb) * self.attention_scaling + freqs = paddle.matmul(inv_freq_expanded, position_ids_expanded).transpose([0, 2, 1]) + emb = paddle.cat((freqs, freqs), axis=-1) + cos = paddle.cos(emb) * self.attention_scaling + sin = paddle.sin(emb) * self.attention_scaling return cos.cast(dtype=x.dtype), sin.cast(dtype=x.dtype) @@ -1184,6 +1191,7 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None + aux_loss = 0 moelayer_use_subbatch_recompute = ( self.config.moe_subbatch_token_num > 0 if hasattr(self.config, "moe_subbatch_token_num") else False @@ -1237,6 +1245,8 @@ def forward( hidden_states = layer_outputs[0] else: hidden_states = layer_outputs + if moelayer_use_subbatch_recompute: + aux_loss += layer_outputs[1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1252,7 +1262,7 @@ def forward( next_cache = next_decoder_cache if use_cache else None if not return_dict: - return tuple(v for v in [hidden_states, next_cache] if v is not None) + return tuple(v for v in [hidden_states, next_cache, aux_loss] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -1367,13 +1377,16 @@ def forward( return_dict=return_dict, attn_mask_startend_row_indices=attn_mask_startend_row_indices, ) - + # output: hidden_states, next_cache, aux_loss hidden_states = outputs[0] # [bs, seq_len, dim] logits = self.lm_head(hidden_states) loss = None if labels is not None: loss, _ = self.criterion(logits, labels) + if self.config.moe_aux_loss_coeff: + aux_loss = outputs[2] + loss += self.moe_aux_loss_coeff * aux_loss if not return_dict: output = (logits,) + outputs[1:] @@ -1381,6 +1394,7 @@ def forward( return CausalLMOutputWithPast( loss=loss, + aux_loss=aux_loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, diff --git a/paddleformers/transformers/moe_gate.py b/paddleformers/transformers/moe_gate.py index d666515c44..4350281de0 100644 --- a/paddleformers/transformers/moe_gate.py +++ b/paddleformers/transformers/moe_gate.py @@ -139,12 +139,29 @@ def _cal_seq_aux_loss(self, gates, top_k, topk_idx) -> paddle.Tensor: Returns: paddle.Tensor: The value of sequence auxiliary loss. """ - batch_size, seq_len, _ = gates.shape - ce = paddle.zeros([batch_size, self.num_experts]) - topk_idx = topk_idx.reshape([batch_size, -1]) - ce.put_along_axis_(indices=topk_idx, values=paddle.ones([batch_size, seq_len * top_k]), axis=1, reduce="add") - ce = ce / (seq_len * top_k / self.num_experts) - aux_loss = (ce * paddle.mean(gates, axis=1)).sum(axis=1).mean() + if self.config.sequence_parallel: + # [bs * seq_len, dim] + max_sequence_length = self.config.max_sequence_length + local_batch_seq, num_experts = gates.shape + global_batch_seq = local_batch_seq * self.config.tensor_parallel_degree + local_batch_size = global_batch_seq // max_sequence_length + ce = paddle.zeros([local_batch_size, num_experts]) + topk_idx = topk_idx.reshape([local_batch_size, -1]) + ones = paddle.ones([local_batch_size, max_sequence_length * top_k // self.config.tensor_parallel_degree]) + ce.put_along_axis_(indices=topk_idx, values=ones, axis=1, reduce="add") + ce = ce / (max_sequence_length * top_k / num_experts) + avg_gates = paddle.mean(gates, axis=0) # [num_experts] + aux_loss = (ce * avg_gates).sum(axis=1).mean() + else: + # [bs, seq_len, dim] + batch_size, seq_len, num_experts = gates.shape + ce = paddle.zeros([batch_size, self.num_experts]) + topk_idx = topk_idx.reshape([batch_size, -1]) + ce.put_along_axis_( + indices=topk_idx, values=paddle.ones([batch_size, seq_len * top_k]), axis=1, reduce="add" + ) + ce = ce / (seq_len * top_k / self.num_experts) + aux_loss = (ce * paddle.mean(gates, axis=1)).sum(axis=1).mean() return aux_loss def _cal_z_loss(self, logits) -> paddle.Tensor: @@ -473,7 +490,7 @@ def topkgating( gates: paddle.Tensor, ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: """Implements TopKGating on logits.""" - batch_size, seq_len, d_model = gates.shape + d_model = gates.shape[-1] gates_ori = gates gates = gates.reshape([-1, d_model]) @@ -553,7 +570,7 @@ def topkgating( def topkgating_nodrop(self, gates: paddle.Tensor): """Implements TopKGating on logits.""" - batch_size, seq_len, d_model = gates.shape + d_model = gates.shape[-1] gates_ori = gates gates = gates.reshape([-1, d_model])