diff --git a/examples/run_finetune.py b/examples/run_finetune.py index f72edb0391..2e147fd03c 100644 --- a/examples/run_finetune.py +++ b/examples/run_finetune.py @@ -63,6 +63,12 @@ def main(): training_args.print_config(model_args, "Model") training_args.print_config(data_args, "Data") + if training_args.pre_alloc_memory > 0: + memory_size = int(training_args.pre_alloc_memory * 1024 * 1024 * 1024) + x = paddle.empty([memory_size], dtype=paddle.uint8) + logger.info(f"pre_alloc_memory size {x.shape}") + del x + # Setup GPU & distributed training paddle.set_device(training_args.device) set_seed(seed=training_args.seed) @@ -134,6 +140,7 @@ def main(): model_config.max_sequence_length = training_args.max_seq_len model_config.num_nextn_predict_layers = model_args.num_nextn_predict_layers model_config._attn_implementation = model_args.attn_impl + model_config.moe_subbatch_token_num = model_args.moe_subbatch_token_num logger.info(f"Final model config: {model_config}") logger.info("Creating model") diff --git a/paddleformers/trainer/training_args.py b/paddleformers/trainer/training_args.py index 8ce3a9ffcc..640b93b9aa 100644 --- a/paddleformers/trainer/training_args.py +++ b/paddleformers/trainer/training_args.py @@ -1096,6 +1096,10 @@ class TrainingArguments: default=False, metadata={"help": "Controls the parallel execution order. False (pp first), True (sharding first)."}, ) + pre_alloc_memory: int = field( + default=0, + metadata={"help": "pre allocate memory size GB"}, + ) def __post_init__(self): world_size = paddle.distributed.get_world_size() diff --git a/paddleformers/transformers/glm4_moe/configuration.py b/paddleformers/transformers/glm4_moe/configuration.py index 2ad7a9f5e6..e4e325a6fa 100644 --- a/paddleformers/transformers/glm4_moe/configuration.py +++ b/paddleformers/transformers/glm4_moe/configuration.py @@ -158,6 +158,7 @@ def __init__( seq_aux=True, topk_method="noaux_tc", using_flex_token=True, + moe_subbatch_token_num=0, **kwargs, ): self.vocab_size = vocab_size @@ -200,6 +201,7 @@ def __init__( self.topk_method = topk_method self.using_flex_token = using_flex_token self.use_fp8 = False + self.moe_subbatch_token_num = moe_subbatch_token_num self.pp_seg_method = pp_seg_method self.disable_ffn_model_parallel = disable_ffn_model_parallel diff --git a/paddleformers/transformers/glm4_moe/modeling.py b/paddleformers/transformers/glm4_moe/modeling.py index 0b2b42ed29..f6bd2443e9 100644 --- a/paddleformers/transformers/glm4_moe/modeling.py +++ b/paddleformers/transformers/glm4_moe/modeling.py @@ -34,7 +34,6 @@ from ...nn.norm import Norm as GeneralNorm from ...nn.pp_model import GeneralModelForCausalLMPipe from ...utils.log import logger -from ..llama.modeling import get_use_casual_mask from ..model_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ..model_utils import PretrainedModel, register_base_model from ..moe_gate import PretrainedMoEGate @@ -221,20 +220,20 @@ def forward( value_states = self.v_proj(hidden_states) if self.sequence_parallel: - target_query_shape = [batch_size, -1, self.num_heads, self.head_dim] - target_key_value_shape = [batch_size, -1, self.num_key_value_heads, self.head_dim] + max_sequence_length = self.config.max_sequence_length + bsz = hidden_states.shape[0] * self.config.tensor_parallel_degree // max_sequence_length + q_len = max_sequence_length else: - target_query_shape = [0, 0, self.num_heads, self.head_dim] - target_key_value_shape = [0, 0, self.num_key_value_heads, self.head_dim] - query_states = query_states.reshape(target_query_shape) - key_states = key_states.reshape(target_key_value_shape) - value_states = value_states.reshape(target_key_value_shape) + bsz, q_len, _ = hidden_states.shape + query_states = query_states.reshape([bsz, q_len, -1, self.head_dim]) + key_states = key_states.reshape([bsz, q_len, -1, self.head_dim]) + value_states = value_states.reshape([bsz, q_len, -1, self.head_dim]) else: mix_layer = self.qkv_proj(hidden_states) if self.sequence_parallel: target_shape = [ - batch_size, - -1, + bsz, + q_len, self.num_key_value_heads, (self.num_key_value_groups + 2) * self.head_dim, ] @@ -293,7 +292,7 @@ def __init__(self, config, num_experts, expert_hidden_size, **kwargs): self.weight = paddle.create_parameter( shape=[num_experts, expert_hidden_size], - dtype="bfloat16", + dtype="float32", default_initializer=paddle.nn.initializer.Uniform(), ) @@ -313,8 +312,6 @@ def forward(self, hidden_states): hidden_states (_type_): [batch_size * seq_len, hidden_size] """ - _, _, h_dim = hidden_states.shape - # compute gating score with paddle.amp.auto_cast(False): hidden_states = hidden_states.cast(self.weight.dtype) @@ -494,8 +491,16 @@ def __init__(self, config): moe_group=moe_group, ) if hasattr(dist, "fleet") and dist.is_initialized() and expert_parallel_degree > 1: + self.is_mp_moe = False + self.is_ep_moe = True for p in self.experts.parameters(): + setattr(p, "is_moe_param", True) setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group}) + p.no_sync = not self.is_mp_moe + p.expert = not self.is_mp_moe + logger.info(f"expert no-sync={p.no_sync}-{p.name}") + if self.is_mp_moe or self.is_ep_moe: + p.is_distributed = True self.shared_experts = Glm4MoeMLP( config=config, @@ -544,23 +549,25 @@ def __init__(self, config: Glm4MoeConfig, layer_idx: int): if not hasattr(config, "disable_ffn_model_parallel"): self.input_layernorm.enable_sequence_parallel() - def forward( + def subbatch_recompute_forward( self, hidden_states: paddle.Tensor, + batch_size: int, + hidden_size: int, position_ids: Optional[paddle.Tensor] = None, attention_mask: Optional[paddle.Tensor] = None, output_attentions: Optional[bool] = False, past_key_value: Optional[Tuple[paddle.Tensor]] = None, use_cache: Optional[bool] = False, - position_embeddings: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, - **kwargs, + position_embeddings: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, ) -> paddle.Tensor: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, + offload_kwargs = {} + offload_kwargs["offload_indices"] = [0] + assert self.config.recompute_granularity != "full_attn" + attn_outputs = recompute( + self.attn, + hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, attn_mask_startend_row_indices=attn_mask_startend_row_indices, @@ -568,17 +575,123 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, position_embeddings=position_embeddings, + **offload_kwargs, + ) + + hidden_states = attn_outputs[0] + residual = attn_outputs[1] + self_attn_weights = attn_outputs[2] if output_attentions else None + present_key_value = attn_outputs[3] if use_cache else None + + if len(hidden_states.shape) != 3: + if self.config.sequence_parallel: + hidden_states = hidden_states.reshape([-1, batch_size, hidden_size]) + else: + hidden_states = hidden_states.reshape([batch_size, -1, hidden_size]) + sub_seq_len = self.config.moe_subbatch_token_num + seq_axis = 0 if self.config.sequence_parallel else 1 + seq_len = hidden_states.shape[seq_axis] + assert seq_len % sub_seq_len == 0 + num_chunks = seq_len // sub_seq_len + split_list = [sub_seq_len] * num_chunks + input_list = paddle.split(hidden_states, split_list, axis=seq_axis) + output_list = [] + + for chunk in input_list: + chunk = chunk.reshape([-1, hidden_size]) + out = recompute( + self.mlp.forward, + chunk, + **offload_kwargs, + ) + output_list.append(out) + hidden_states = paddle.concat(output_list, axis=seq_axis) + outputs = recompute( + self.post_process, + hidden_states, + residual, + output_attentions, + use_cache, + self_attn_weights, + present_key_value, + **offload_kwargs, ) + return outputs + + def attn( + self, + hidden_states: paddle.Tensor, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, + position_embeddings: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, + **kwargs, + ): + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + has_gradient = not hidden_states.stop_gradient + if self.config.recompute and has_gradient and self.config.recompute_granularity == "full_attn": + outputs = recompute( + self.self_attn, + hidden_states=hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + past_key_value=past_key_value, + use_cache=use_cache, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + position_embeddings=position_embeddings, + **kwargs, + ) + else: + outputs = self.self_attn( + hidden_states=hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + past_key_value=past_key_value, + use_cache=use_cache, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + position_embeddings=position_embeddings, + **kwargs, + ) + if type(outputs) is tuple: + hidden_states = outputs[0] + else: + hidden_states = outputs + hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - if isinstance(hidden_states, tuple): - hidden_states, _ = hidden_states - # else: - # router_logits = None + attn_outputs = (hidden_states, residual) + + if output_attentions: + self_attn_weights = outputs[1] + attn_outputs += (self_attn_weights,) + + if use_cache: + present_key_value = outputs[2 if output_attentions else 1] + attn_outputs += (present_key_value,) + + return attn_outputs + + def post_process( + self, + hidden_states, + residual, + output_attentions=False, + use_cache=False, + self_attn_weights=None, + present_key_value=None, + ): hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: @@ -589,6 +702,41 @@ def forward( outputs = outputs[0] return outputs + def forward( + self, + hidden_states: paddle.Tensor, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + position_embeddings: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, + **kwargs, + ) -> paddle.Tensor: + + attn_outputs = self.attn( + hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + position_ids=position_ids, + output_attentions=output_attentions, + use_cache=use_cache, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = attn_outputs[0] + residual = attn_outputs[1] + self_attn_weights = attn_outputs[2] if output_attentions else None + present_key_value = attn_outputs[3] if use_cache else None + + hidden_states = self.mlp(hidden_states) + outputs = self.post_process( + hidden_states, residual, output_attentions, use_cache, self_attn_weights, present_key_value + ) + return outputs + class Glm4MoePreTrainedModel(PretrainedModel): config: Glm4MoeConfig @@ -1013,27 +1161,18 @@ def forward( # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) inputs_embeds = ScatterOp.apply(inputs_embeds) - if attn_mask_startend_row_indices is not None or get_use_casual_mask(): - attention_mask = None - else: - # [bs, seq_len] - attention_mask = ( - paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) - if attention_mask is None - else attention_mask - ) + hidden_states = inputs_embeds + if attention_mask is not None: causal_mask = self._prepare_decoder_attention_mask( - attention_mask=attention_mask, - input_shape=(batch_size, seq_length), - past_key_values_length=cache_length, - dtype=inputs_embeds.dtype, - ) # [bs, 1, seq_len, seq_len] + attention_mask, hidden_states.shape[:2], cache_length, hidden_states.dtype + ) + else: + causal_mask = None if position_ids is None: position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) - hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) # decoder layers @@ -1041,13 +1180,30 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None + moelayer_use_subbatch_recompute = ( + self.config.moe_subbatch_token_num > 0 if hasattr(self.config, "moe_subbatch_token_num") else False + ) + for idx, (decoder_layer) in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) past_key_value = past_key_values[idx] if past_key_values is not None else None has_gradient = not hidden_states.stop_gradient - if self.config.recompute and self.config.recompute_granularity == "full" and has_gradient: + if moelayer_use_subbatch_recompute: + layer_outputs = decoder_layer.subbatch_recompute_forward( + hidden_states, + bs, + hidden_size, + position_ids, + causal_mask, + output_attentions, + past_key_value, + use_cache, + attn_mask_startend_row_indices, + position_embeddings, + ) + elif self.config.recompute and self.config.recompute_granularity == "full" and has_gradient: layer_outputs = self.recompute_training_full( layer_module=decoder_layer, hidden_states=hidden_states, @@ -1070,7 +1226,6 @@ def forward( use_cache=use_cache, position_embeddings=position_embeddings, ) - # # NOTE: clear outdate cache after it has been used for memory saving # past_key_value = past_key_values[idx] = None if isinstance(layer_outputs, (tuple, list)): @@ -1205,6 +1360,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, ) hidden_states = outputs[0] # [bs, seq_len, dim] diff --git a/paddleformers/transformers/moe_layer.py b/paddleformers/transformers/moe_layer.py index 5aa433168a..5db46c4367 100644 --- a/paddleformers/transformers/moe_layer.py +++ b/paddleformers/transformers/moe_layer.py @@ -380,8 +380,6 @@ def expert_forward(self, dispatched_input, tokens_per_expert): return paddle.cat(outputs, axis=0) def forward(self, hidden_states: paddle.Tensor): - _, _, d_model = hidden_states.shape - # reshaped_input = hidden_states.reshape([-1, d_model]) probs, routing_map, l_aux, l_zloss = self.gate(hidden_states) (dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation( hidden_states, probs, routing_map