Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 8 additions & 42 deletions paddlenlp/experimental/transformers/deepseek_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@
yarn_get_mscale,
yarn_linear_ramp_mask,
)
from paddlenlp.transformers.model_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
)
from paddlenlp.transformers.model_utils import (
dy2st_nocheck_guard_context,
register_base_model,
Expand Down Expand Up @@ -266,7 +263,6 @@ def __init__(self, config: DeepseekV2Config, base_model_prefix: str):
self.weight_block_size = config.weight_block_size
self.moe_quant_type = config.moe_quant_type
self.rope_theta = config.rope_theta
self.return_full_hidden_states = config.get("return_full_hidden_states", False)

self.use_weight_only = False
self.weightonly_group_size = -1
Expand Down Expand Up @@ -591,7 +587,6 @@ def __init__(self, config: DeepseekV2Config, base_model_prefix: str):
speculate_config = SpeculateConfig(
speculate_method=config.get("speculate_method", None),
speculate_max_draft_token_num=config.get("speculate_max_draft_token_num", 5),
return_full_hidden_states=config.get("return_full_hidden_states", False),
)

transformer_config = FusedMultiTransformerConfig(
Expand Down Expand Up @@ -622,9 +617,9 @@ def __init__(self, config: DeepseekV2Config, base_model_prefix: str):
rotary_emb=self.rotary_emb,
norm_type="rmsnorm",
rank_id=config.tensor_parallel_rank,
append_attn=config.append_attn,
moe_config=moe_config,
mla_config=mla_config,
append_attn=config.append_attn,
speculate_config=speculate_config,
)

Expand Down Expand Up @@ -1289,7 +1284,7 @@ def forward(
inputs_embeds = inputs_embeds.reshape([-1, inputs_embeds.shape[2]])

with dy2st_nocheck_guard_context():
hidden_states, _ = self.transformer_block(
hidden_states, full_hidden_states = self.transformer_block(
input_ids=input_ids,
src=inputs_embeds,
cum_offsets=cum_offsets,
Expand All @@ -1301,13 +1296,7 @@ def forward(
)
hidden_states = self.norm(hidden_states)

return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=None,
hidden_states=None,
attentions=None,
cum_offsets=cum_offsets,
)
return hidden_states, full_hidden_states


@register_base_model
Expand Down Expand Up @@ -1967,7 +1956,7 @@ def forward(
inputs_embeds = self.eh_proj(inputs_embeds)

with dy2st_nocheck_guard_context():
hidden_states, _ = self.transformer_block(
hidden_states, full_hidden_states = self.transformer_block(
input_ids=input_ids,
src=inputs_embeds,
cum_offsets=cum_offsets,
Expand All @@ -1980,12 +1969,7 @@ def forward(
)
hidden_states = self.norm(hidden_states)

return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=None,
hidden_states=None,
attentions=None,
)
return hidden_states, full_hidden_states


class DeepseekV2ForCausalLMBlockInferenceModel(GenerationBlockInferenceModel, DeepseekV2PretrainedModel):
Expand Down Expand Up @@ -2212,7 +2196,7 @@ def forward(
draft_tokens=None,
output_padding_offset=None,
):
outputs = self.deepseek_v2(
hidden_states, full_hidden_states = self.deepseek_v2(
input_ids,
inputs_embeds=inputs_embeds,
src_mask=src_mask,
Expand All @@ -2230,21 +2214,7 @@ def forward(
draft_tokens=draft_tokens,
output_padding_offset=output_padding_offset,
)
if self.return_full_hidden_states:
from paddlenlp_ops import rebuild_padding_v2

full_hidden_states = outputs[0]
cum_offsets = outputs[1]
hidden_states = rebuild_padding_v2(
full_hidden_states,
cum_offsets,
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
self.max_seq_len,
)
else:
hidden_states = outputs[0]

logits = self.lm_head(
hidden_states,
tensor_parallel_output=False,
Expand All @@ -2254,8 +2224,6 @@ def forward(
else:
return logits

return logits

@paddle.no_grad()
def set_state_dict(self, state_dict):
if "lm_head.weight" in state_dict:
Expand Down Expand Up @@ -2363,7 +2331,7 @@ def forward(
output_padding_offset=None,
pre_hidden_states=None,
):
outputs = self.mtp(
hidden_states, _ = self.mtp(
input_ids,
src_mask=src_mask,
caches=caches,
Expand All @@ -2382,8 +2350,6 @@ def forward(
pre_hidden_states=pre_hidden_states,
)

hidden_states = outputs[0]

logits = self.lm_head(
hidden_states,
tensor_parallel_output=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ class AvxConfig:
class SpeculateConfig:
speculate_max_draft_token_num: int = 5
speculate_method: str = None
return_full_hidden_states: bool = False


@dataclass
Expand Down Expand Up @@ -1589,7 +1588,7 @@ def forward(
kwargs["input_ids"] = input_ids

out = self.post_process(**kwargs)
return out, caches
return out, kwargs["multi_block_output"]


class FusedMultiTransformerPostLayernorm(FusedMultiTransformerBase):
Expand Down
25 changes: 12 additions & 13 deletions paddlenlp/experimental/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def _forward_(**args):
model_inputs = self.prepare_inputs_for_generation(input_ids, cache_kvs, **args)
return self(**model_inputs)

def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
def _post_process_(logits, top_p, temperature, step_idx_ori, model_kwargs):
cache = model_kwargs.get("cache", None)
just_decoder = model_kwargs["seq_len_encoder"] == 0
if cache is None: # first decoder
Expand All @@ -314,7 +314,6 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
step_idx,
model_kwargs["stop_flags"],
)
logits = outputs[0] if isinstance(outputs, tuple) else outputs

logits = paddle.cast(logits, paddle.float32)
logits = logits_processors(model_kwargs["all_input_ids"], logits, decoding_step=step_idx_ori)
Expand Down Expand Up @@ -373,7 +372,7 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
outputs = _forward_(**model_kwargs)
# first decoder
next_tokens, model_kwargs = _post_process_(
outputs,
outputs[0] if isinstance(outputs, tuple) else outputs,
top_p,
temperature,
step_idx_ori,
Expand All @@ -389,8 +388,9 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
paddle.sum(paddle.cast(model_kwargs["stop_flags"], "int64")),
model_kwargs["stop_nums"],
):
outputs = _forward_(**model_kwargs)
next_tokens, model_kwargs = _post_process_(
_forward_(**model_kwargs),
outputs[0] if isinstance(outputs, tuple) else outputs,
top_p,
temperature,
step_idx_ori,
Expand Down Expand Up @@ -692,7 +692,7 @@ def _forward_(**args):
return self(**model_inputs)

def _post_process_(
outputs,
logits,
top_k,
top_p,
penalty_score,
Expand All @@ -702,7 +702,7 @@ def _post_process_(
model_kwargs,
):
step_idx = model_kwargs["step_idx"]
logits = paddle.cast(outputs, paddle.float32)
logits = paddle.cast(logits, paddle.float32)

from paddlenlp_ops import set_preids_token_penalty_multi_scores

Expand Down Expand Up @@ -777,7 +777,7 @@ def _post_process_(
outputs = _forward_(**model_kwargs) # [bs, 1, dim_embed]
# first decoder
next_tokens = _post_process_(
outputs,
outputs[0] if isinstance(outputs, tuple) else outputs,
top_k,
top_p,
penalty_score,
Expand Down Expand Up @@ -806,7 +806,7 @@ def _forward_(**args):
return self(**model_inputs)

def _post_process_(
outputs,
logits,
top_k,
top_p,
penalty_score,
Expand All @@ -816,7 +816,7 @@ def _post_process_(
model_kwargs,
):
step_idx = model_kwargs["step_idx"]
logits = paddle.cast(outputs, paddle.float32)
logits = paddle.cast(logits, paddle.float32)

from paddlenlp_ops import speculate_get_token_penalty_multi_scores

Expand Down Expand Up @@ -959,7 +959,7 @@ def _forward_(**args):
return self(**model_inputs)

def _post_process_(
outputs,
logits,
top_k,
top_p,
penalty_score,
Expand All @@ -968,7 +968,7 @@ def _post_process_(
temperature,
model_kwargs,
):
logits = paddle.cast(outputs, paddle.float32)
logits = paddle.cast(logits, paddle.float32)

probs = F.softmax(logits)

Expand Down Expand Up @@ -1191,7 +1191,7 @@ def _forward_(**args):
model_inputs = self.prepare_inputs_for_generation(input_ids, **args)
return self(**model_inputs)

def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
def _post_process_(logits, top_p, temperature, step_idx_ori, model_kwargs):
cache = model_kwargs.get("cache", None)
just_decoder = model_kwargs["seq_len_encoder"] == 0
if cache is None: # first decoder
Expand All @@ -1211,7 +1211,6 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
step_idx,
model_kwargs["stop_flags"],
)
logits = outputs[0] if isinstance(outputs, tuple) else outputs
logits = paddle.cast(logits, paddle.float32)
logits = logits_processors(model_kwargs["all_input_ids"], logits, decoding_step=step_idx_ori)

Expand Down
49 changes: 7 additions & 42 deletions paddlenlp/experimental/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,19 +257,12 @@ def forward(
# merge batch and seq_len dimension.
inputs_embeds = inputs_embeds.reshape([batch * seq_len, hidden_dim])

output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

hidden_states = inputs_embeds

# decoder layers
all_hidden_states = () if output_hidden_states else None
with dy2st_nocheck_guard_context():
hidden_states = self.transformer_block(
input_ids=input_ids,
Expand All @@ -280,11 +273,7 @@ def forward(
)
hidden_states = self.norm(hidden_states)

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

if not return_dict:
return tuple(v for v in [hidden_states, None, all_hidden_states, None] if v is not None)
return hidden_states

@paddle.no_grad()
# avx
Expand Down Expand Up @@ -401,7 +390,6 @@ def __init__(self, config: LlamaConfig):
self.epsilon = config.rms_norm_eps
self.max_position_embeddings = config.max_position_embeddings
self.quant_type = config.get("quant_type", "")
self.return_full_hidden_states = config.get("return_full_hidden_states", False)

self.rope_theta = config.rope_theta
self.use_neox = True
Expand Down Expand Up @@ -617,7 +605,6 @@ def __init__(self, config: LlamaConfig):
speculate_config = SpeculateConfig(
speculate_method=config.get("speculate_method", None),
speculate_max_draft_token_num=config.get("speculate_max_draft_token_num", 5),
return_full_hidden_states=config.get("return_full_hidden_states", False),
)

hpu_config = HpuConfig(
Expand Down Expand Up @@ -1496,7 +1483,7 @@ def forward(
inputs_embeds = self.embed_tokens(ids_remove_padding)

with dy2st_nocheck_guard_context():
hidden_states, _ = self.transformer_block(
hidden_states, full_hidden_states = self.transformer_block(
input_ids=input_ids,
src=inputs_embeds,
cum_offsets=cum_offsets,
Expand All @@ -1508,17 +1495,11 @@ def forward(
)
hidden_states = self.norm(hidden_states)

return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=None,
hidden_states=None,
attentions=None,
cum_offsets=cum_offsets,
)
return hidden_states, full_hidden_states


@register_base_model
class EagleForLlamaInferenceModel(LlamaBlockInferenceModel):
class EagleForLlamaBlockInferenceModel(LlamaBlockInferenceModel):
def __init__(self, config: LlamaConfig):
self.append_attn = config.append_attn
super().__init__(config)
Expand Down Expand Up @@ -2079,7 +2060,7 @@ def forward(
draft_tokens=None,
output_padding_offset=None,
):
outputs = self.llama(
hidden_states, full_hidden_states = self.llama(
input_ids,
src_mask=src_mask,
caches=caches,
Expand All @@ -2097,23 +2078,7 @@ def forward(
draft_tokens=draft_tokens,
output_padding_offset=output_padding_offset,
)
# hidden_states = outputs[0]
if self.return_full_hidden_states:
from paddlenlp_ops import rebuild_padding_v2

# full_hidden_states = outputs[1]
full_hidden_states = outputs[0]
cum_offsets = outputs[1]
hidden_states = rebuild_padding_v2(
full_hidden_states,
cum_offsets,
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
self.max_seq_len,
)
else:
hidden_states = outputs[0]

logits = self.lm_head(
hidden_states,
tensor_parallel_output=False,
Expand All @@ -2139,7 +2104,7 @@ def __init__(self, config):
self.verify_window = config.get("speculate_verify_window", 2)
self.max_seq_len = config.max_seq_len

self.eagle = EagleForLlamaInferenceModel(config)
self.eagle = EagleForLlamaBlockInferenceModel(config)
if config.tie_word_embeddings:
self.lm_head = LlamaLMHead(config, embedding_weights=self.llama.embed_tokens.weight, transpose_y=True)
self.tie_weights()
Expand Down
Loading
Loading