diff --git a/paddlemix/models/qwen2_vl/modeling_qwen2_vl_network.py b/paddlemix/models/qwen2_vl/modeling_qwen2_vl_network.py index 66a609fe7..63c91cbaa 100644 --- a/paddlemix/models/qwen2_vl/modeling_qwen2_vl_network.py +++ b/paddlemix/models/qwen2_vl/modeling_qwen2_vl_network.py @@ -1228,32 +1228,10 @@ def forward( return inputs_embeds -class Qwen2VLModel(Qwen2VLPreTrainedModel): - def __init__(self, config: Qwen2VLConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - self.hidden_size = config.hidden_size - # Recompute defaults to False and is controlled by Trainer - - self.embed_tokens = nn.Embedding( - self.vocab_size, - self.hidden_size, - ) - - # self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.LayerList( - [Qwen2VLDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self.norm = Qwen2RMSNorm(config, config.hidden_size, eps=config.rms_norm_eps) - - self.enamble_recompute = False - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value +class GlobalOutputNet(nn.Layer): + def __init__(self, config) -> None: + super().__init__() + self.config = config @staticmethod def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length, dtype): @@ -1283,6 +1261,68 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype) return expanded_attn_mask + def forward( + self, + position_ids, + attention_mask, + cache_position, + past_key_values, + seq_length, + batch_size, + seq_length_with_past, + cache_length, + emb_dtype, + ): + if attention_mask is None: + # [bs, seq_len] + attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) + + if flash_attn_varlen_func: + causal_mask = attention_mask + else: + causal_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), cache_length, emb_dtype + ) # [bs, 1, seq_len, seq_len] + + if cache_position is None: + past_seen_tokens = past_key_values[0][0].shape[2] if past_key_values[0] is not None else 0 + cache_position = paddle.arange(past_seen_tokens, past_seen_tokens + seq_length) + + if position_ids is None: + # the hard coded `3` is for temporal, height and width. + position_ids = cache_position.reshape([1, 1, -1]).expand([3, batch_size, -1]) + + return position_ids, causal_mask, cache_position + + +class Qwen2VLModel(Qwen2VLPreTrainedModel): + def __init__(self, config: Qwen2VLConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.hidden_size = config.hidden_size + # Recompute defaults to False and is controlled by Trainer + + self.embed_tokens = nn.Embedding( + self.vocab_size, + self.hidden_size, + ) + + # self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.global_layer = GlobalOutputNet(config=config) + self.layers = nn.LayerList( + [Qwen2VLDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Qwen2RMSNorm(config, config.hidden_size, eps=config.rms_norm_eps) + + self.enamble_recompute = False + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + @paddle.jit.not_to_static def recompute_training_full( self, @@ -1360,26 +1400,37 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # embed positions - if attention_mask is None: - # [bs, seq_len] - attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) + # # embed positions + # if attention_mask is None: + # # [bs, seq_len] + # attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) - if flash_attn_varlen_func: - causal_mask = attention_mask - else: - causal_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype - ) # [bs, 1, seq_len, seq_len] + # if flash_attn_varlen_func: + # causal_mask = attention_mask + # else: + # causal_mask = self._prepare_decoder_attention_mask( + # attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype + # ) # [bs, 1, seq_len, seq_len] - if cache_position is None: - past_seen_tokens = past_key_values[0][0].shape[2] if past_key_values[0] is not None else 0 - cache_position = paddle.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1]) + # if cache_position is None: + # past_seen_tokens = past_key_values[0][0].shape[2] if past_key_values[0] is not None else 0 + # cache_position = paddle.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1]) - if position_ids is None: - # the hard coded `3` is for temporal, height and width. - position_ids = cache_position.reshape([1, 1, -1]).expand([3, inputs_embeds.shape[0], -1]) + # if position_ids is None: + # # the hard coded `3` is for temporal, height and width. + # position_ids = cache_position.reshape([1, 1, -1]).expand([3, inputs_embeds.shape[0], -1]) + position_ids, causal_mask, cache_position = self.global_layer( + position_ids, + attention_mask, + cache_position, + past_key_values, + seq_length, + batch_size, + seq_length_with_past, + cache_length, + inputs_embeds.dtype, + ) hidden_states = inputs_embeds # decoder layers @@ -1963,3 +2014,48 @@ def prepare_inputs_for_generation( } ) return model_inputs + + def auto_dist_config(self, prefix=""): + if prefix != "": + assert prefix.endswith(".") + config = { + "sp_config": { + "parallelize_plan": { + f"{prefix}model.embed_tokens": [ + dist.RowWiseParallel(), + dist.SequenceParallelBegin(), + ], + f"{prefix}model.layers.*.self_attn.qkv_proj": dist.ColWiseParallel(), + f"{prefix}model.layers.*.self_attn.q_proj": dist.ColWiseParallel(), + f"{prefix}model.layers.*.self_attn.k_proj": dist.ColWiseParallel(), + f"{prefix}model.layers.*.self_attn.v_proj": dist.ColWiseParallel(), + f"{prefix}model.layers.*.self_attn.o_proj": dist.RowWiseParallel(), + f"{prefix}model.layers.*.self_attn": dist.SequenceParallelDisable(), + f"{prefix}model.layers.*.mlp.gate_proj": dist.ColWiseParallel(), + f"{prefix}model.layers.*.mlp.up_proj": dist.ColWiseParallel(), + f"{prefix}model.layers.*.mlp.gate_up_fused_proj": dist.ColWiseParallel(), + f"{prefix}model.layers.*.mlp.down_proj": dist.RowWiseParallel(), + f"{prefix}model.layers.*.mlp": dist.SequenceParallelDisable(need_transpose=False), + f"{prefix}lm_head.weight": dist.ColWiseParallel(), + f"{prefix}lm_head": dist.SequenceParallelEnd(), + } + }, + "mp_config": { + "parallelize_plan": { + f"{prefix}model.embed_tokens": dist.RowWiseParallel(), + f"{prefix}model.layers.*.self_attn.qkv_proj": dist.ColWiseParallel(), + f"{prefix}model.layers.*.self_attn.q_proj": dist.ColWiseParallel(), + f"{prefix}model.layers.*.self_attn.k_proj": dist.ColWiseParallel(), + f"{prefix}model.layers.*.self_attn.v_proj": dist.ColWiseParallel(), + f"{prefix}model.layers.*.self_attn.o_proj": dist.RowWiseParallel(), + f"{prefix}model.layers.*.mlp.gate_proj": dist.ColWiseParallel(), + f"{prefix}model.layers.*.mlp.up_proj": dist.ColWiseParallel(), + f"{prefix}model.layers.*.mlp.gate_up_fused_proj": dist.ColWiseParallel(), + f"{prefix}model.layers.*.mlp.down_proj": dist.RowWiseParallel(), + f"{prefix}lm_head.weight": dist.ColWiseParallel(), + } + }, + "pp_config": {"split_spec": f"{prefix}model.layers", "global_spec": f"{prefix}model.global_layer"}, + } + + return config