From 8d84b93989f2511ab1b38ee0fb14406bb429f089 Mon Sep 17 00:00:00 2001 From: danleifeng Date: Tue, 23 Sep 2025 19:21:23 +0800 Subject: [PATCH 1/5] add ep subbatch to reduce memory --- examples/run_finetune.py | 1 + .../transformers/glm4_moe/configuration.py | 2 + .../transformers/glm4_moe/modeling.py | 264 +++++++++++++++--- paddleformers/transformers/moe_layer.py | 2 - 4 files changed, 224 insertions(+), 45 deletions(-) diff --git a/examples/run_finetune.py b/examples/run_finetune.py index f72edb0391..5e06dd6048 100644 --- a/examples/run_finetune.py +++ b/examples/run_finetune.py @@ -134,6 +134,7 @@ def main(): model_config.max_sequence_length = training_args.max_seq_len model_config.num_nextn_predict_layers = model_args.num_nextn_predict_layers model_config._attn_implementation = model_args.attn_impl + model_config.moe_subbatch_token_num = model_args.moe_subbatch_token_num logger.info(f"Final model config: {model_config}") logger.info("Creating model") diff --git a/paddleformers/transformers/glm4_moe/configuration.py b/paddleformers/transformers/glm4_moe/configuration.py index 2ad7a9f5e6..e4e325a6fa 100644 --- a/paddleformers/transformers/glm4_moe/configuration.py +++ b/paddleformers/transformers/glm4_moe/configuration.py @@ -158,6 +158,7 @@ def __init__( seq_aux=True, topk_method="noaux_tc", using_flex_token=True, + moe_subbatch_token_num=0, **kwargs, ): self.vocab_size = vocab_size @@ -200,6 +201,7 @@ def __init__( self.topk_method = topk_method self.using_flex_token = using_flex_token self.use_fp8 = False + self.moe_subbatch_token_num = moe_subbatch_token_num self.pp_seg_method = pp_seg_method self.disable_ffn_model_parallel = disable_ffn_model_parallel diff --git a/paddleformers/transformers/glm4_moe/modeling.py b/paddleformers/transformers/glm4_moe/modeling.py index 0b2b42ed29..ee018f5e42 100644 --- a/paddleformers/transformers/glm4_moe/modeling.py +++ b/paddleformers/transformers/glm4_moe/modeling.py @@ -34,7 +34,6 @@ from ...nn.norm import Norm as GeneralNorm from ...nn.pp_model import GeneralModelForCausalLMPipe from ...utils.log import logger -from ..llama.modeling import get_use_casual_mask from ..model_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ..model_utils import PretrainedModel, register_base_model from ..moe_gate import PretrainedMoEGate @@ -221,20 +220,24 @@ def forward( value_states = self.v_proj(hidden_states) if self.sequence_parallel: - target_query_shape = [batch_size, -1, self.num_heads, self.head_dim] - target_key_value_shape = [batch_size, -1, self.num_key_value_heads, self.head_dim] + max_sequence_length = self.config.max_sequence_length + bsz = hidden_states.shape[0] * self.config.tensor_parallel_degree // max_sequence_length + q_len = max_sequence_length else: - target_query_shape = [0, 0, self.num_heads, self.head_dim] - target_key_value_shape = [0, 0, self.num_key_value_heads, self.head_dim] - query_states = query_states.reshape(target_query_shape) - key_states = key_states.reshape(target_key_value_shape) - value_states = value_states.reshape(target_key_value_shape) + bsz, q_len, _ = hidden_states.shape + # if self.sequence_parallel: + # q_len, bsz, _ = hidden_states.shape + # else: + # bsz, q_len, _ = hidden_states.shape + query_states = query_states.reshape([bsz, q_len, -1, self.head_dim]) + key_states = key_states.reshape([bsz, q_len, -1, self.head_dim]) + value_states = value_states.reshape([bsz, q_len, -1, self.head_dim]) else: mix_layer = self.qkv_proj(hidden_states) if self.sequence_parallel: target_shape = [ - batch_size, - -1, + bsz, + q_len, self.num_key_value_heads, (self.num_key_value_groups + 2) * self.head_dim, ] @@ -313,14 +316,12 @@ def forward(self, hidden_states): hidden_states (_type_): [batch_size * seq_len, hidden_size] """ - _, _, h_dim = hidden_states.shape + # _, _, h_dim = hidden_states.shape # compute gating score with paddle.amp.auto_cast(False): hidden_states = hidden_states.cast(self.weight.dtype) - logits = F.linear(hidden_states.cast("float32"), self.weight.cast("float32").t()) - scores = self.gate_score_func(logits=logits) scores = scores.cast(paddle.float32) @@ -494,8 +495,18 @@ def __init__(self, config): moe_group=moe_group, ) if hasattr(dist, "fleet") and dist.is_initialized() and expert_parallel_degree > 1: + # for p in self.experts.parameters(): + # setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group}) + self.is_mp_moe = False + self.is_ep_moe = True for p in self.experts.parameters(): + setattr(p, "is_moe_param", True) setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group}) + p.no_sync = not self.is_mp_moe + p.expert = not self.is_mp_moe + logger.info(f"expert no-sync={p.no_sync}-{p.name}") + if self.is_mp_moe or self.is_ep_moe: + p.is_distributed = True self.shared_experts = Glm4MoeMLP( config=config, @@ -544,23 +555,25 @@ def __init__(self, config: Glm4MoeConfig, layer_idx: int): if not hasattr(config, "disable_ffn_model_parallel"): self.input_layernorm.enable_sequence_parallel() - def forward( + def subbatch_recompute_forward( self, hidden_states: paddle.Tensor, + batch_size: int, + hidden_size: int, position_ids: Optional[paddle.Tensor] = None, attention_mask: Optional[paddle.Tensor] = None, output_attentions: Optional[bool] = False, past_key_value: Optional[Tuple[paddle.Tensor]] = None, use_cache: Optional[bool] = False, - position_embeddings: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, - **kwargs, + position_embeddings: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, ) -> paddle.Tensor: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, + offload_kwargs = {} + offload_kwargs["offload_indices"] = [0] + assert self.config.recompute_granularity != "full_attn" + attn_outputs = recompute( + self.attn, + hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, attn_mask_startend_row_indices=attn_mask_startend_row_indices, @@ -568,25 +581,172 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, position_embeddings=position_embeddings, + **offload_kwargs, ) + + hidden_states = attn_outputs[0] + residual = attn_outputs[1] + self_attn_weights = attn_outputs[2] if output_attentions else None + present_key_value = attn_outputs[3] if use_cache else None + + if len(hidden_states.shape) != 3: + if self.config.sequence_parallel: + hidden_states = hidden_states.reshape([-1, batch_size, hidden_size]) + else: + hidden_states = hidden_states.reshape([batch_size, -1, hidden_size]) + sub_seq_len = self.config.moe_subbatch_token_num + seq_axis = 0 if self.config.sequence_parallel else 1 + seq_len = hidden_states.shape[seq_axis] + # seq_len = sequence_length + assert seq_len % sub_seq_len == 0 + num_chunks = seq_len // sub_seq_len + split_list = [sub_seq_len] * num_chunks + input_list = paddle.split(hidden_states, split_list, axis=seq_axis) + output_list = [] + + for chunk in input_list: + chunk = chunk.reshape([-1, hidden_size]) + out = recompute( + self.mlp.forward, + chunk, + **offload_kwargs, + ) + output_list.append(out) + hidden_states = paddle.concat(output_list, axis=seq_axis) + outputs = recompute( + self.post_process, + hidden_states, + residual, + output_attentions, + use_cache, + self_attn_weights, + present_key_value, + **offload_kwargs, + ) + return outputs + + def attn( + self, + hidden_states: paddle.Tensor, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, + position_embeddings: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, + **kwargs, + ): + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + has_gradient = not hidden_states.stop_gradient + if self.config.recompute and has_gradient and self.config.recompute_granularity == "full_attn": + outputs = recompute( + self.self_attn, + hidden_states=hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + past_key_value=past_key_value, + use_cache=use_cache, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + position_embeddings=position_embeddings, + **kwargs, + ) + else: + outputs = self.self_attn( + hidden_states=hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + past_key_value=past_key_value, + use_cache=use_cache, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + position_embeddings=position_embeddings, + **kwargs, + ) + if type(outputs) is tuple: + hidden_states = outputs[0] + else: + hidden_states = outputs + hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - if isinstance(hidden_states, tuple): - hidden_states, _ = hidden_states - # else: - # router_logits = None + attn_outputs = (hidden_states, residual) + + if output_attentions: + self_attn_weights = outputs[1] + attn_outputs += (self_attn_weights,) + + if use_cache: + present_key_value = outputs[2 if output_attentions else 1] + attn_outputs += (present_key_value,) + + return attn_outputs + + def post_process( + self, + hidden_states, + residual, + output_attentions=False, + use_cache=False, + self_attn_weights=None, + present_key_value=None, + ): hidden_states = residual + hidden_states + outputs = (hidden_states,) + if output_attentions: outputs += (self_attn_weights,) + if use_cache: outputs += (present_key_value,) + if type(outputs) is tuple and len(outputs) == 1: outputs = outputs[0] + + return outputs + + def forward( + self, + hidden_states: paddle.Tensor, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + position_embeddings: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, + **kwargs, + ) -> paddle.Tensor: + + attn_outputs = self.attn( + hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + position_ids=position_ids, + output_attentions=output_attentions, + use_cache=use_cache, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = attn_outputs[0] + residual = attn_outputs[1] + self_attn_weights = attn_outputs[2] if output_attentions else None + present_key_value = attn_outputs[3] if use_cache else None + + hidden_states = self.mlp(hidden_states) + outputs = self.post_process( + hidden_states, residual, output_attentions, use_cache, self_attn_weights, present_key_value + ) return outputs @@ -1002,38 +1162,39 @@ def forward( cache_length = past_key_values[0][0].shape[1] seq_length_with_past += cache_length + print(f"input_ids:{input_ids}") + cnt = 0 + for idx in input_ids.flatten().numpy(): + if idx == self.padding_idx: + cnt += 1 + numel = paddle.numel(input_ids).numpy() + print(f"input_ids len:{input_ids.shape}, numel: {numel}, padding_idx:{cnt}, percent:{cnt/numel}") if inputs_embeds is None: # [bs, seq_len, dim] inputs_embeds = self.embed_tokens(input_ids) + bs, seq_len, hidden_size = inputs_embeds.shape if self.sequence_parallel: # [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim] bs, seq_len, hidden_size = inputs_embeds.shape + # !!!! inputs_embeds = paddle.reshape_(inputs_embeds, [bs * seq_len, hidden_size]) + # inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) # [B, S, H] --> [S, B, H] # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) inputs_embeds = ScatterOp.apply(inputs_embeds) - if attn_mask_startend_row_indices is not None or get_use_casual_mask(): - attention_mask = None - else: - # [bs, seq_len] - attention_mask = ( - paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) - if attention_mask is None - else attention_mask - ) + hidden_states = inputs_embeds + if attention_mask is not None: causal_mask = self._prepare_decoder_attention_mask( - attention_mask=attention_mask, - input_shape=(batch_size, seq_length), - past_key_values_length=cache_length, - dtype=inputs_embeds.dtype, - ) # [bs, 1, seq_len, seq_len] + attention_mask, hidden_states.shape[:2], cache_length, hidden_states.dtype + ) + else: + causal_mask = None if position_ids is None: position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) - hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) # decoder layers @@ -1041,13 +1202,30 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None + moelayer_use_subbatch_recompute = ( + self.config.moe_subbatch_token_num > 0 if hasattr(self.config, "moe_subbatch_token_num") else False + ) + for idx, (decoder_layer) in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) past_key_value = past_key_values[idx] if past_key_values is not None else None has_gradient = not hidden_states.stop_gradient - if self.config.recompute and self.config.recompute_granularity == "full" and has_gradient: + if moelayer_use_subbatch_recompute: + layer_outputs = decoder_layer.subbatch_recompute_forward( + hidden_states, + bs, + hidden_size, + position_ids, + causal_mask, + output_attentions, + past_key_value, + use_cache, + attn_mask_startend_row_indices, + position_embeddings, + ) + elif self.config.recompute and self.config.recompute_granularity == "full" and has_gradient: layer_outputs = self.recompute_training_full( layer_module=decoder_layer, hidden_states=hidden_states, @@ -1070,7 +1248,6 @@ def forward( use_cache=use_cache, position_embeddings=position_embeddings, ) - # # NOTE: clear outdate cache after it has been used for memory saving # past_key_value = past_key_values[idx] = None if isinstance(layer_outputs, (tuple, list)): @@ -1205,6 +1382,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, ) hidden_states = outputs[0] # [bs, seq_len, dim] diff --git a/paddleformers/transformers/moe_layer.py b/paddleformers/transformers/moe_layer.py index 5aa433168a..5db46c4367 100644 --- a/paddleformers/transformers/moe_layer.py +++ b/paddleformers/transformers/moe_layer.py @@ -380,8 +380,6 @@ def expert_forward(self, dispatched_input, tokens_per_expert): return paddle.cat(outputs, axis=0) def forward(self, hidden_states: paddle.Tensor): - _, _, d_model = hidden_states.shape - # reshaped_input = hidden_states.reshape([-1, d_model]) probs, routing_map, l_aux, l_zloss = self.gate(hidden_states) (dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation( hidden_states, probs, routing_map From b79596d0b19e1d8bd2d883a47b76d3edae74b358 Mon Sep 17 00:00:00 2001 From: danleifeng Date: Tue, 23 Sep 2025 19:29:49 +0800 Subject: [PATCH 2/5] fix --- paddleformers/transformers/glm4_moe/modeling.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/paddleformers/transformers/glm4_moe/modeling.py b/paddleformers/transformers/glm4_moe/modeling.py index ee018f5e42..cb105e7024 100644 --- a/paddleformers/transformers/glm4_moe/modeling.py +++ b/paddleformers/transformers/glm4_moe/modeling.py @@ -225,10 +225,6 @@ def forward( q_len = max_sequence_length else: bsz, q_len, _ = hidden_states.shape - # if self.sequence_parallel: - # q_len, bsz, _ = hidden_states.shape - # else: - # bsz, q_len, _ = hidden_states.shape query_states = query_states.reshape([bsz, q_len, -1, self.head_dim]) key_states = key_states.reshape([bsz, q_len, -1, self.head_dim]) value_states = value_states.reshape([bsz, q_len, -1, self.head_dim]) @@ -1162,24 +1158,14 @@ def forward( cache_length = past_key_values[0][0].shape[1] seq_length_with_past += cache_length - print(f"input_ids:{input_ids}") - cnt = 0 - for idx in input_ids.flatten().numpy(): - if idx == self.padding_idx: - cnt += 1 - numel = paddle.numel(input_ids).numpy() - print(f"input_ids len:{input_ids.shape}, numel: {numel}, padding_idx:{cnt}, percent:{cnt/numel}") if inputs_embeds is None: # [bs, seq_len, dim] inputs_embeds = self.embed_tokens(input_ids) - bs, seq_len, hidden_size = inputs_embeds.shape if self.sequence_parallel: # [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim] bs, seq_len, hidden_size = inputs_embeds.shape - # !!!! inputs_embeds = paddle.reshape_(inputs_embeds, [bs * seq_len, hidden_size]) - # inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) # [B, S, H] --> [S, B, H] # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) inputs_embeds = ScatterOp.apply(inputs_embeds) From c851a4eaa72f3ae498c5a8cb7c350aece41f708e Mon Sep 17 00:00:00 2001 From: danleifeng Date: Tue, 23 Sep 2025 20:21:07 +0800 Subject: [PATCH 3/5] add tensorwise_offload_optimizer and pre_alloc_memory --- examples/run_finetune.py | 22 +++++ paddleformers/trainer/trainer.py | 2 +- paddleformers/trainer/training_args.py | 4 + .../trainer/utils/offload_optimizer.py | 81 +++++++++++++++++++ 4 files changed, 108 insertions(+), 1 deletion(-) create mode 100644 paddleformers/trainer/utils/offload_optimizer.py diff --git a/examples/run_finetune.py b/examples/run_finetune.py index 5e06dd6048..0d27796652 100644 --- a/examples/run_finetune.py +++ b/examples/run_finetune.py @@ -49,6 +49,19 @@ os.environ["USE_CASUAL_MASK"] = "False" +def mock_offload_optimizer(): + """ + mock offload optimizer + """ + try: + from paddleformers.trainer.utils.offload_optimizer import hack_offload_optimizer + + hack_offload_optimizer() + logger.warning("hack_offload_optimizer called.") + except ImportError: + logger.warning("hack_offload_optimizer is not imported") + + def main(): parser = PdArgumentParser((ModelConfig, DataConfig, SFTConfig)) if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"): @@ -60,9 +73,18 @@ def main(): else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() + if training_args.tensorwise_offload_optimizer: + mock_offload_optimizer() + training_args.print_config(model_args, "Model") training_args.print_config(data_args, "Data") + if training_args.pre_alloc_memory > 0: + memory_size = int(training_args.pre_alloc_memory * 1024 * 1024 * 1024) + x = paddle.empty([memory_size], dtype=paddle.uint8) + logger.info(f"pre_alloc_memory size {x.shape}") + del x + # Setup GPU & distributed training paddle.set_device(training_args.device) set_seed(seed=training_args.seed) diff --git a/paddleformers/trainer/trainer.py b/paddleformers/trainer/trainer.py index 6cd0e2d66a..29286feb18 100644 --- a/paddleformers/trainer/trainer.py +++ b/paddleformers/trainer/trainer.py @@ -2788,7 +2788,7 @@ def _save_checkpoint(self, model, metrics=None): optimizer_name = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix) saved_signal_path = os.path.join(output_dir, f"saved_signal_{dist.get_rank()}") - if self.args.unified_checkpoint and self.args.offload_optim: + if self.args.unified_checkpoint and (self.args.offload_optim or self.args.tensorwise_offload_optimizer): self._reload_optimizer() if self.args.use_hybrid_parallel: diff --git a/paddleformers/trainer/training_args.py b/paddleformers/trainer/training_args.py index 8ce3a9ffcc..640b93b9aa 100644 --- a/paddleformers/trainer/training_args.py +++ b/paddleformers/trainer/training_args.py @@ -1096,6 +1096,10 @@ class TrainingArguments: default=False, metadata={"help": "Controls the parallel execution order. False (pp first), True (sharding first)."}, ) + pre_alloc_memory: int = field( + default=0, + metadata={"help": "pre allocate memory size GB"}, + ) def __post_init__(self): world_size = paddle.distributed.get_world_size() diff --git a/paddleformers/trainer/utils/offload_optimizer.py b/paddleformers/trainer/utils/offload_optimizer.py new file mode 100644 index 0000000000..65f5b77e2e --- /dev/null +++ b/paddleformers/trainer/utils/offload_optimizer.py @@ -0,0 +1,81 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle import _C_ops +from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer import ( + HybridParallelOptimizer, +) +from paddle.optimizer import Optimizer + +from .sharding_io import to_device + + +def offload(tensor): + if paddle.is_compiled_with_cuda(): + place = paddle.CUDAPinnedPlace() + else: + place = paddle.CPUPlace() + + new_tensor = to_device(tensor, place) + assert new_tensor is tensor, "to_device must be inplace operation" + + +def reload(tensor): + new_tensor = to_device(tensor) + assert new_tensor is tensor, "to_device must be inplace operation" + + +def hack_offload_optimizer(): + # Step 1: mock _add_accumulator + origin_add_accumulator = getattr(Optimizer, "_add_accumulator") + + def new_add_accumulator(self, *args, **kwargs): + x = origin_add_accumulator(self, *args, **kwargs) + offload(x) + return x + + setattr(Optimizer, "_add_accumulator", new_add_accumulator) + + # Step 2: mock _C_ops.adamw_ and _C_ops.adamw + for name in ["adam_", "adamw_"]: + origin_op = getattr(_C_ops, name) + + def new_opt_op(*args): + for arg in args: + if isinstance(arg, paddle.Tensor): + reload(arg) + + ret = origin_op(*args) + + for i, arg in enumerate(args): + if i >= 2 and isinstance(arg, paddle.Tensor): # do not offload parameter and gradient + offload(arg) + return ret + + setattr(_C_ops, name, new_opt_op) + + # Step 3: mock _insert_sync + opt_type = HybridParallelOptimizer + origin_insert_sync = getattr(opt_type, "_insert_sync") + + def new_insert_sync(self, sync_var, *args, **kwargs): + origin_place = sync_var.place + reload(sync_var) + ret = origin_insert_sync(self, sync_var, *args, **kwargs) + new_sync_var = to_device(sync_var, origin_place) + assert new_sync_var is sync_var, "to_device must be inplace operation" + return ret + + setattr(opt_type, "_insert_sync", new_insert_sync) From efef08ef912482a04b186d018a59b09f6a12ba20 Mon Sep 17 00:00:00 2001 From: danleifeng Date: Wed, 24 Sep 2025 14:47:57 +0800 Subject: [PATCH 4/5] fix saveload bug; --- paddleformers/transformers/glm4_moe/modeling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddleformers/transformers/glm4_moe/modeling.py b/paddleformers/transformers/glm4_moe/modeling.py index cb105e7024..9582a2c3a3 100644 --- a/paddleformers/transformers/glm4_moe/modeling.py +++ b/paddleformers/transformers/glm4_moe/modeling.py @@ -292,7 +292,7 @@ def __init__(self, config, num_experts, expert_hidden_size, **kwargs): self.weight = paddle.create_parameter( shape=[num_experts, expert_hidden_size], - dtype="bfloat16", + dtype="float32", default_initializer=paddle.nn.initializer.Uniform(), ) From a05b9639ecad4cd5a032821f66a356d94b396f3e Mon Sep 17 00:00:00 2001 From: danleifeng Date: Wed, 24 Sep 2025 15:20:33 +0800 Subject: [PATCH 5/5] remove offload optimizer and useless code --- examples/run_finetune.py | 16 ---- paddleformers/trainer/trainer.py | 2 +- .../trainer/utils/offload_optimizer.py | 81 ------------------- .../transformers/glm4_moe/modeling.py | 12 +-- 4 files changed, 3 insertions(+), 108 deletions(-) delete mode 100644 paddleformers/trainer/utils/offload_optimizer.py diff --git a/examples/run_finetune.py b/examples/run_finetune.py index 0d27796652..2e147fd03c 100644 --- a/examples/run_finetune.py +++ b/examples/run_finetune.py @@ -49,19 +49,6 @@ os.environ["USE_CASUAL_MASK"] = "False" -def mock_offload_optimizer(): - """ - mock offload optimizer - """ - try: - from paddleformers.trainer.utils.offload_optimizer import hack_offload_optimizer - - hack_offload_optimizer() - logger.warning("hack_offload_optimizer called.") - except ImportError: - logger.warning("hack_offload_optimizer is not imported") - - def main(): parser = PdArgumentParser((ModelConfig, DataConfig, SFTConfig)) if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"): @@ -73,9 +60,6 @@ def main(): else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() - if training_args.tensorwise_offload_optimizer: - mock_offload_optimizer() - training_args.print_config(model_args, "Model") training_args.print_config(data_args, "Data") diff --git a/paddleformers/trainer/trainer.py b/paddleformers/trainer/trainer.py index 29286feb18..6cd0e2d66a 100644 --- a/paddleformers/trainer/trainer.py +++ b/paddleformers/trainer/trainer.py @@ -2788,7 +2788,7 @@ def _save_checkpoint(self, model, metrics=None): optimizer_name = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix) saved_signal_path = os.path.join(output_dir, f"saved_signal_{dist.get_rank()}") - if self.args.unified_checkpoint and (self.args.offload_optim or self.args.tensorwise_offload_optimizer): + if self.args.unified_checkpoint and self.args.offload_optim: self._reload_optimizer() if self.args.use_hybrid_parallel: diff --git a/paddleformers/trainer/utils/offload_optimizer.py b/paddleformers/trainer/utils/offload_optimizer.py deleted file mode 100644 index 65f5b77e2e..0000000000 --- a/paddleformers/trainer/utils/offload_optimizer.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import paddle -from paddle import _C_ops -from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer import ( - HybridParallelOptimizer, -) -from paddle.optimizer import Optimizer - -from .sharding_io import to_device - - -def offload(tensor): - if paddle.is_compiled_with_cuda(): - place = paddle.CUDAPinnedPlace() - else: - place = paddle.CPUPlace() - - new_tensor = to_device(tensor, place) - assert new_tensor is tensor, "to_device must be inplace operation" - - -def reload(tensor): - new_tensor = to_device(tensor) - assert new_tensor is tensor, "to_device must be inplace operation" - - -def hack_offload_optimizer(): - # Step 1: mock _add_accumulator - origin_add_accumulator = getattr(Optimizer, "_add_accumulator") - - def new_add_accumulator(self, *args, **kwargs): - x = origin_add_accumulator(self, *args, **kwargs) - offload(x) - return x - - setattr(Optimizer, "_add_accumulator", new_add_accumulator) - - # Step 2: mock _C_ops.adamw_ and _C_ops.adamw - for name in ["adam_", "adamw_"]: - origin_op = getattr(_C_ops, name) - - def new_opt_op(*args): - for arg in args: - if isinstance(arg, paddle.Tensor): - reload(arg) - - ret = origin_op(*args) - - for i, arg in enumerate(args): - if i >= 2 and isinstance(arg, paddle.Tensor): # do not offload parameter and gradient - offload(arg) - return ret - - setattr(_C_ops, name, new_opt_op) - - # Step 3: mock _insert_sync - opt_type = HybridParallelOptimizer - origin_insert_sync = getattr(opt_type, "_insert_sync") - - def new_insert_sync(self, sync_var, *args, **kwargs): - origin_place = sync_var.place - reload(sync_var) - ret = origin_insert_sync(self, sync_var, *args, **kwargs) - new_sync_var = to_device(sync_var, origin_place) - assert new_sync_var is sync_var, "to_device must be inplace operation" - return ret - - setattr(opt_type, "_insert_sync", new_insert_sync) diff --git a/paddleformers/transformers/glm4_moe/modeling.py b/paddleformers/transformers/glm4_moe/modeling.py index 9582a2c3a3..f6bd2443e9 100644 --- a/paddleformers/transformers/glm4_moe/modeling.py +++ b/paddleformers/transformers/glm4_moe/modeling.py @@ -312,12 +312,12 @@ def forward(self, hidden_states): hidden_states (_type_): [batch_size * seq_len, hidden_size] """ - # _, _, h_dim = hidden_states.shape - # compute gating score with paddle.amp.auto_cast(False): hidden_states = hidden_states.cast(self.weight.dtype) + logits = F.linear(hidden_states.cast("float32"), self.weight.cast("float32").t()) + scores = self.gate_score_func(logits=logits) scores = scores.cast(paddle.float32) @@ -491,8 +491,6 @@ def __init__(self, config): moe_group=moe_group, ) if hasattr(dist, "fleet") and dist.is_initialized() and expert_parallel_degree > 1: - # for p in self.experts.parameters(): - # setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group}) self.is_mp_moe = False self.is_ep_moe = True for p in self.experts.parameters(): @@ -593,7 +591,6 @@ def subbatch_recompute_forward( sub_seq_len = self.config.moe_subbatch_token_num seq_axis = 0 if self.config.sequence_parallel else 1 seq_len = hidden_states.shape[seq_axis] - # seq_len = sequence_length assert seq_len % sub_seq_len == 0 num_chunks = seq_len // sub_seq_len split_list = [sub_seq_len] * num_chunks @@ -696,18 +693,13 @@ def post_process( present_key_value=None, ): hidden_states = residual + hidden_states - outputs = (hidden_states,) - if output_attentions: outputs += (self_attn_weights,) - if use_cache: outputs += (present_key_value,) - if type(outputs) is tuple and len(outputs) == 1: outputs = outputs[0] - return outputs def forward(