Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 138 additions & 42 deletions paddlemix/models/qwen2_vl/modeling_qwen2_vl_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,32 +1228,10 @@ def forward(
return inputs_embeds


class Qwen2VLModel(Qwen2VLPreTrainedModel):
def __init__(self, config: Qwen2VLConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.hidden_size = config.hidden_size
# Recompute defaults to False and is controlled by Trainer

self.embed_tokens = nn.Embedding(
self.vocab_size,
self.hidden_size,
)

# self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.LayerList(
[Qwen2VLDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Qwen2RMSNorm(config, config.hidden_size, eps=config.rms_norm_eps)

self.enamble_recompute = False

def get_input_embeddings(self):
return self.embed_tokens

def set_input_embeddings(self, value):
self.embed_tokens = value
class GlobalOutputNet(nn.Layer):
def __init__(self, config) -> None:
super().__init__()
self.config = config

@staticmethod
def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length, dtype):
Expand Down Expand Up @@ -1283,6 +1261,68 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype)
return expanded_attn_mask

def forward(
self,
position_ids,
attention_mask,
cache_position,
past_key_values,
seq_length,
batch_size,
seq_length_with_past,
cache_length,
emb_dtype,
):
if attention_mask is None:
# [bs, seq_len]
attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool)

if flash_attn_varlen_func:
causal_mask = attention_mask
else:
causal_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), cache_length, emb_dtype
) # [bs, 1, seq_len, seq_len]

if cache_position is None:
past_seen_tokens = past_key_values[0][0].shape[2] if past_key_values[0] is not None else 0
cache_position = paddle.arange(past_seen_tokens, past_seen_tokens + seq_length)

if position_ids is None:
# the hard coded `3` is for temporal, height and width.
position_ids = cache_position.reshape([1, 1, -1]).expand([3, batch_size, -1])

return position_ids, causal_mask, cache_position


class Qwen2VLModel(Qwen2VLPreTrainedModel):
def __init__(self, config: Qwen2VLConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.hidden_size = config.hidden_size
# Recompute defaults to False and is controlled by Trainer

self.embed_tokens = nn.Embedding(
self.vocab_size,
self.hidden_size,
)

# self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.global_layer = GlobalOutputNet(config=config)
self.layers = nn.LayerList(
[Qwen2VLDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Qwen2RMSNorm(config, config.hidden_size, eps=config.rms_norm_eps)

self.enamble_recompute = False

def get_input_embeddings(self):
return self.embed_tokens

def set_input_embeddings(self, value):
self.embed_tokens = value

@paddle.jit.not_to_static
def recompute_training_full(
self,
Expand Down Expand Up @@ -1360,26 +1400,37 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

# embed positions
if attention_mask is None:
# [bs, seq_len]
attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool)
# # embed positions
# if attention_mask is None:
# # [bs, seq_len]
# attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool)

if flash_attn_varlen_func:
causal_mask = attention_mask
else:
causal_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
) # [bs, 1, seq_len, seq_len]
# if flash_attn_varlen_func:
# causal_mask = attention_mask
# else:
# causal_mask = self._prepare_decoder_attention_mask(
# attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
# ) # [bs, 1, seq_len, seq_len]

if cache_position is None:
past_seen_tokens = past_key_values[0][0].shape[2] if past_key_values[0] is not None else 0
cache_position = paddle.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1])
# if cache_position is None:
# past_seen_tokens = past_key_values[0][0].shape[2] if past_key_values[0] is not None else 0
# cache_position = paddle.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1])

if position_ids is None:
# the hard coded `3` is for temporal, height and width.
position_ids = cache_position.reshape([1, 1, -1]).expand([3, inputs_embeds.shape[0], -1])
# if position_ids is None:
# # the hard coded `3` is for temporal, height and width.
# position_ids = cache_position.reshape([1, 1, -1]).expand([3, inputs_embeds.shape[0], -1])

position_ids, causal_mask, cache_position = self.global_layer(
position_ids,
attention_mask,
cache_position,
past_key_values,
seq_length,
batch_size,
seq_length_with_past,
cache_length,
inputs_embeds.dtype,
)
hidden_states = inputs_embeds

# decoder layers
Expand Down Expand Up @@ -1963,3 +2014,48 @@ def prepare_inputs_for_generation(
}
)
return model_inputs

def auto_dist_config(self, prefix=""):
if prefix != "":
assert prefix.endswith(".")
config = {
"sp_config": {
"parallelize_plan": {
f"{prefix}model.embed_tokens": [
dist.RowWiseParallel(),
dist.SequenceParallelBegin(),
],
f"{prefix}model.layers.*.self_attn.qkv_proj": dist.ColWiseParallel(),
f"{prefix}model.layers.*.self_attn.q_proj": dist.ColWiseParallel(),
f"{prefix}model.layers.*.self_attn.k_proj": dist.ColWiseParallel(),
f"{prefix}model.layers.*.self_attn.v_proj": dist.ColWiseParallel(),
f"{prefix}model.layers.*.self_attn.o_proj": dist.RowWiseParallel(),
f"{prefix}model.layers.*.self_attn": dist.SequenceParallelDisable(),
f"{prefix}model.layers.*.mlp.gate_proj": dist.ColWiseParallel(),
f"{prefix}model.layers.*.mlp.up_proj": dist.ColWiseParallel(),
f"{prefix}model.layers.*.mlp.gate_up_fused_proj": dist.ColWiseParallel(),
f"{prefix}model.layers.*.mlp.down_proj": dist.RowWiseParallel(),
f"{prefix}model.layers.*.mlp": dist.SequenceParallelDisable(need_transpose=False),
f"{prefix}lm_head.weight": dist.ColWiseParallel(),
f"{prefix}lm_head": dist.SequenceParallelEnd(),
}
},
"mp_config": {
"parallelize_plan": {
f"{prefix}model.embed_tokens": dist.RowWiseParallel(),
f"{prefix}model.layers.*.self_attn.qkv_proj": dist.ColWiseParallel(),
f"{prefix}model.layers.*.self_attn.q_proj": dist.ColWiseParallel(),
f"{prefix}model.layers.*.self_attn.k_proj": dist.ColWiseParallel(),
f"{prefix}model.layers.*.self_attn.v_proj": dist.ColWiseParallel(),
f"{prefix}model.layers.*.self_attn.o_proj": dist.RowWiseParallel(),
f"{prefix}model.layers.*.mlp.gate_proj": dist.ColWiseParallel(),
f"{prefix}model.layers.*.mlp.up_proj": dist.ColWiseParallel(),
f"{prefix}model.layers.*.mlp.gate_up_fused_proj": dist.ColWiseParallel(),
f"{prefix}model.layers.*.mlp.down_proj": dist.RowWiseParallel(),
f"{prefix}lm_head.weight": dist.ColWiseParallel(),
}
},
"pp_config": {"split_spec": f"{prefix}model.layers", "global_spec": f"{prefix}model.global_layer"},
}

return config