From 74080d88e06c773904eb38afea8522a1f9abe8b1 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Wed, 2 Apr 2025 17:21:49 +0800 Subject: [PATCH 1/5] refactor return_full_hidden_states --- .../transformers/deepseek_v2/modeling.py | 50 +++---------------- .../transformers/fused_transformer_layers.py | 47 ++++++++--------- .../transformers/llama/modeling.py | 49 +++--------------- paddlenlp/transformers/model_outputs.py | 5 -- 4 files changed, 37 insertions(+), 114 deletions(-) diff --git a/paddlenlp/experimental/transformers/deepseek_v2/modeling.py b/paddlenlp/experimental/transformers/deepseek_v2/modeling.py index be9eb634d402..a019952e28c7 100644 --- a/paddlenlp/experimental/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/experimental/transformers/deepseek_v2/modeling.py @@ -43,9 +43,6 @@ yarn_get_mscale, yarn_linear_ramp_mask, ) -from paddlenlp.transformers.model_outputs import ( - BaseModelOutputWithPastAndCrossAttentions, -) from paddlenlp.transformers.model_utils import ( dy2st_nocheck_guard_context, register_base_model, @@ -276,7 +273,6 @@ def __init__(self, config: DeepseekV2Config, base_model_prefix: str): self.weight_block_size = config.weight_block_size self.moe_quant_type = config.moe_quant_type self.rope_theta = config.rope_theta - self.return_full_hidden_states = config.get("return_full_hidden_states", False) self.use_weight_only = False self.weightonly_group_size = -1 @@ -610,7 +606,6 @@ def __init__(self, config: DeepseekV2Config, base_model_prefix: str): speculate_config = SpeculateConfig( speculate_method=config.get("speculate_method", None), speculate_max_draft_token_num=config.get("speculate_max_draft_token_num", 5), - return_full_hidden_states=config.get("return_full_hidden_states", False), ) transformer_config = FusedMultiTransformerConfig( @@ -641,9 +636,9 @@ def __init__(self, config: DeepseekV2Config, base_model_prefix: str): rotary_emb=self.rotary_emb, norm_type="rmsnorm", rank_id=config.tensor_parallel_rank, + append_attn=config.append_attn, moe_config=moe_config, mla_config=mla_config, - append_attn=config.append_attn, speculate_config=speculate_config, ) @@ -1316,7 +1311,7 @@ def forward( inputs_embeds = self.embed_tokens(ids_remove_padding) with dy2st_nocheck_guard_context(): - hidden_states, _ = self.transformer_block( + hidden_states, full_hidden_states = self.transformer_block( input_ids=input_ids, src=inputs_embeds, cum_offsets=cum_offsets, @@ -1328,13 +1323,7 @@ def forward( ) hidden_states = self.norm(hidden_states) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=None, - hidden_states=None, - attentions=None, - cum_offsets=cum_offsets, - ) + return hidden_states, full_hidden_states @register_base_model @@ -1386,7 +1375,7 @@ def forward( inputs_embeds = self.eh_proj(inputs_embeds) with dy2st_nocheck_guard_context(): - hidden_states, _ = self.transformer_block( + hidden_states, full_hidden_states = self.transformer_block( input_ids=input_ids, src=inputs_embeds, cum_offsets=cum_offsets, @@ -1399,12 +1388,7 @@ def forward( ) hidden_states = self.norm(hidden_states) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=None, - hidden_states=None, - attentions=None, - ) + return hidden_states, full_hidden_states class DeepseekV2ForCausalLMBlockInferenceModel(GenerationBlockInferenceModel, DeepseekV2PretrainedModel): @@ -1621,7 +1605,7 @@ def forward( draft_tokens=None, output_padding_offset=None, ): - outputs = self.deepseek_v2( + hidden_states, full_hidden_states = self.deepseek_v2( input_ids, src_mask=src_mask, caches=caches, @@ -1638,21 +1622,7 @@ def forward( draft_tokens=draft_tokens, output_padding_offset=output_padding_offset, ) - if self.return_full_hidden_states: - from paddlenlp_ops import rebuild_padding_v2 - - full_hidden_states = outputs[0] - cum_offsets = outputs[1] - hidden_states = rebuild_padding_v2( - full_hidden_states, - cum_offsets, - seq_lens_decoder, - seq_lens_encoder, - output_padding_offset, - self.max_seq_len, - ) - else: - hidden_states = outputs[0] + logits = self.lm_head( hidden_states, tensor_parallel_output=False, @@ -1662,8 +1632,6 @@ def forward( else: return logits - return logits - @paddle.no_grad() def set_state_dict(self, state_dict): if "lm_head.weight" in state_dict: @@ -1771,7 +1739,7 @@ def forward( output_padding_offset=None, pre_hidden_states=None, ): - outputs = self.mtp( + hidden_states, _ = self.mtp( input_ids, src_mask=src_mask, caches=caches, @@ -1790,8 +1758,6 @@ def forward( pre_hidden_states=pre_hidden_states, ) - hidden_states = outputs[0] - logits = self.lm_head( hidden_states, tensor_parallel_output=False, diff --git a/paddlenlp/experimental/transformers/fused_transformer_layers.py b/paddlenlp/experimental/transformers/fused_transformer_layers.py index a7f0d4bd6e4f..e464417469f7 100644 --- a/paddlenlp/experimental/transformers/fused_transformer_layers.py +++ b/paddlenlp/experimental/transformers/fused_transformer_layers.py @@ -158,7 +158,6 @@ class AvxConfig: class SpeculateConfig: speculate_max_draft_token_num: int = 5 speculate_method: str = None - return_full_hidden_states: bool = False @dataclass @@ -1774,7 +1773,7 @@ def forward( kwargs["input_ids"] = input_ids out = self.post_process(**kwargs) - return out, caches + return out, kwargs["multi_block_output"] class FusedMultiTransformerPostLayernorm(FusedMultiTransformerBase): @@ -3469,30 +3468,28 @@ def post_process(self, **kwargs): max_input_length = kwargs.get("max_input_length", -1) output_padding_offset = kwargs.get("output_padding_offset", None) # only used in speculative decoding - if self.config.speculate_config.return_full_hidden_states: - return multi_block_output + if paddle.is_compiled_with_xpu(): + from paddlenlp_ops import gather_next_token + + out = gather_next_token( + multi_block_output, + cum_offsets, + seq_lens_decoder, + seq_lens_encoder, + output_padding_offset, + max_input_length, + ) else: - if paddle.is_compiled_with_xpu(): - from paddlenlp_ops import gather_next_token - - out = gather_next_token( - multi_block_output, - cum_offsets, - seq_lens_decoder, - seq_lens_encoder, - output_padding_offset, - max_input_length, - ) - else: - out = rebuild_padding_v2( - multi_block_output, - cum_offsets, - seq_lens_decoder, - seq_lens_encoder, - output_padding_offset, - max_input_length, - ) - return out + out = rebuild_padding_v2( + multi_block_output, + cum_offsets, + seq_lens_decoder, + seq_lens_encoder, + output_padding_offset, + max_input_length, + ) + + return out class FusedBlockMultiTransformerWeightOnly(FusedBlockMultiTransformer, FusedMultiTransformerWeightOnly): diff --git a/paddlenlp/experimental/transformers/llama/modeling.py b/paddlenlp/experimental/transformers/llama/modeling.py index 83b3b7e858fc..94d523217717 100644 --- a/paddlenlp/experimental/transformers/llama/modeling.py +++ b/paddlenlp/experimental/transformers/llama/modeling.py @@ -253,19 +253,12 @@ def forward( # merge batch and seq_len dimension. inputs_embeds = inputs_embeds.reshape([batch * seq_len, hidden_dim]) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) hidden_states = inputs_embeds # decoder layers - all_hidden_states = () if output_hidden_states else None with dy2st_nocheck_guard_context(): hidden_states = self.transformer_block( input_ids=input_ids, @@ -276,11 +269,7 @@ def forward( ) hidden_states = self.norm(hidden_states) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, None, all_hidden_states, None] if v is not None) + return hidden_states @paddle.no_grad() # avx @@ -397,7 +386,6 @@ def __init__(self, config: LlamaConfig): self.epsilon = config.rms_norm_eps self.max_position_embeddings = config.max_position_embeddings self.quant_type = config.get("quant_type", "") - self.return_full_hidden_states = config.get("return_full_hidden_states", False) self.rope_theta = config.rope_theta self.use_neox = True @@ -613,7 +601,6 @@ def __init__(self, config: LlamaConfig): speculate_config = SpeculateConfig( speculate_method=config.get("speculate_method", None), speculate_max_draft_token_num=config.get("speculate_max_draft_token_num", 5), - return_full_hidden_states=config.get("return_full_hidden_states", False), ) transformer_config = FusedMultiTransformerConfig( embed_dim=self.hidden_size, @@ -1445,7 +1432,7 @@ def forward( inputs_embeds = self.embed_tokens(ids_remove_padding) with dy2st_nocheck_guard_context(): - hidden_states, _ = self.transformer_block( + hidden_states, full_hidden_states = self.transformer_block( input_ids=input_ids, src=inputs_embeds, cum_offsets=cum_offsets, @@ -1457,17 +1444,11 @@ def forward( ) hidden_states = self.norm(hidden_states) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=None, - hidden_states=None, - attentions=None, - cum_offsets=cum_offsets, - ) + return hidden_states, full_hidden_states @register_base_model -class EagleForLlamaInferenceModel(LlamaBlockInferenceModel): +class EagleForLlamaBlockInferenceModel(LlamaBlockInferenceModel): def __init__(self, config: LlamaConfig): self.append_attn = config.append_attn super().__init__(config) @@ -1999,7 +1980,7 @@ def forward( draft_tokens=None, output_padding_offset=None, ): - outputs = self.llama( + hidden_states, full_hidden_states = self.llama( input_ids, src_mask=src_mask, caches=caches, @@ -2016,23 +1997,7 @@ def forward( draft_tokens=draft_tokens, output_padding_offset=output_padding_offset, ) - # hidden_states = outputs[0] - if self.return_full_hidden_states: - from paddlenlp_ops import rebuild_padding_v2 - - # full_hidden_states = outputs[1] - full_hidden_states = outputs[0] - cum_offsets = outputs[1] - hidden_states = rebuild_padding_v2( - full_hidden_states, - cum_offsets, - seq_lens_decoder, - seq_lens_encoder, - output_padding_offset, - self.max_seq_len, - ) - else: - hidden_states = outputs[0] + logits = self.lm_head( hidden_states, tensor_parallel_output=False, @@ -2058,7 +2023,7 @@ def __init__(self, config): self.verify_window = config.get("speculate_verify_window", 2) self.max_seq_len = config.max_seq_len - self.eagle = EagleForLlamaInferenceModel(config) + self.eagle = EagleForLlamaBlockInferenceModel(config) if config.tie_word_embeddings: self.lm_head = LlamaLMHead(config, embedding_weights=self.llama.embed_tokens.weight, transpose_y=True) self.tie_weights() diff --git a/paddlenlp/transformers/model_outputs.py b/paddlenlp/transformers/model_outputs.py index 07c80e238b41..a987ec922608 100644 --- a/paddlenlp/transformers/model_outputs.py +++ b/paddlenlp/transformers/model_outputs.py @@ -662,10 +662,6 @@ class BaseModelOutputWithPastAndCrossAttentions(ModelOutput): Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads. - cum_offsets (`tuple(paddle.Tensor)`, *optional*, needed when `return_full_hidden_states=True`: - Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, 1)`. - - Offset of the current batch. """ last_hidden_state: paddle.Tensor = None @@ -673,7 +669,6 @@ class BaseModelOutputWithPastAndCrossAttentions(ModelOutput): hidden_states: Optional[Tuple[paddle.Tensor]] = None attentions: Optional[Tuple[paddle.Tensor]] = None cross_attentions: Optional[Tuple[paddle.Tensor]] = None - cum_offsets: Optional[Tuple[paddle.Tensor]] = None @dataclass From aec02959fd074b7a4113f82ea72867b743c30c78 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Wed, 2 Apr 2025 17:33:53 +0800 Subject: [PATCH 2/5] add ut --- .../experimental/transformers/qwen2/modeling.py | 17 +++++++---------- tests/llm/test_predictor_v1.py | 1 + 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/paddlenlp/experimental/transformers/qwen2/modeling.py b/paddlenlp/experimental/transformers/qwen2/modeling.py index 8ca063fbc599..3213be73bf17 100644 --- a/paddlenlp/experimental/transformers/qwen2/modeling.py +++ b/paddlenlp/experimental/transformers/qwen2/modeling.py @@ -1292,7 +1292,7 @@ def forward( inputs_embeds = inputs_embeds.reshape([-1, inputs_embeds.shape[2]]) with dy2st_nocheck_guard_context(): - hidden_states, _ = self.transformer_block( + hidden_states, full_hidden_states = self.transformer_block( input_ids=input_ids, src=inputs_embeds, cum_offsets=cum_offsets, @@ -1304,12 +1304,7 @@ def forward( ) hidden_states = self.norm(hidden_states) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=None, - hidden_states=None, - attentions=None, - ) + return hidden_states, full_hidden_states class Qwen2ForCausalLMBlockInferenceModel(GenerationBlockInferenceModel, Qwen2PretrainedModel): @@ -1515,7 +1510,7 @@ def forward( draft_tokens=None, output_padding_offset=None, ): - outputs = self.qwen2( + hidden_states, full_hidden_states = self.qwen2( input_ids, inputs_embeds=inputs_embeds, src_mask=src_mask, @@ -1534,13 +1529,15 @@ def forward( output_padding_offset=output_padding_offset, ) - hidden_states = outputs[0] logits = self.lm_head( hidden_states, tensor_parallel_output=False, ) - return logits + if self.return_full_hidden_states: + return logits, full_hidden_states + else: + return logits @paddle.no_grad() def set_state_dict(self, state_dict): diff --git a/tests/llm/test_predictor_v1.py b/tests/llm/test_predictor_v1.py index 98cd7b55c51d..c75ac81a729e 100644 --- a/tests/llm/test_predictor_v1.py +++ b/tests/llm/test_predictor_v1.py @@ -119,6 +119,7 @@ def setUp(self) -> None: ( { "append_attn": True, + "return_full_hidden_states": True, }, ), ] From 28eab834024f20a1938af4693a0419f0d3043794 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Wed, 2 Apr 2025 18:37:52 +0800 Subject: [PATCH 3/5] add ut --- paddlenlp/experimental/transformers/qwen2/modeling.py | 1 + 1 file changed, 1 insertion(+) diff --git a/paddlenlp/experimental/transformers/qwen2/modeling.py b/paddlenlp/experimental/transformers/qwen2/modeling.py index 3213be73bf17..fb800a81a120 100644 --- a/paddlenlp/experimental/transformers/qwen2/modeling.py +++ b/paddlenlp/experimental/transformers/qwen2/modeling.py @@ -1321,6 +1321,7 @@ def __init__(self, config: Qwen2Config, base_model_prefix: str = "qwen2"): self.max_candidate_len = config.get("speculate_max_candidate_len", 5) self.verify_window = config.get("speculate_verify_window", 2) self.max_seq_len = config.max_seq_len + self.return_full_hidden_states = config.get("return_full_hidden_states", False) self.qwen2 = Qwen2BlockInferenceModel(config, base_model_prefix) if config.tie_word_embeddings: From d5ec524a76daecba085dbdc0c0e052f97704fcec Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Wed, 2 Apr 2025 19:46:09 +0800 Subject: [PATCH 4/5] fix ut --- .../transformers/generation_utils.py | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/paddlenlp/experimental/transformers/generation_utils.py b/paddlenlp/experimental/transformers/generation_utils.py index e66ce33ac470..13f711bc74d2 100644 --- a/paddlenlp/experimental/transformers/generation_utils.py +++ b/paddlenlp/experimental/transformers/generation_utils.py @@ -294,7 +294,7 @@ def _forward_(**args): model_inputs = self.prepare_inputs_for_generation(input_ids, cache_kvs, **args) return self(**model_inputs) - def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): + def _post_process_(logits, top_p, temperature, step_idx_ori, model_kwargs): cache = model_kwargs.get("cache", None) just_decoder = model_kwargs["seq_len_encoder"] == 0 if cache is None: # first decoder @@ -313,7 +313,6 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): step_idx, model_kwargs["stop_flags"], ) - logits = outputs[0] if isinstance(outputs, tuple) else outputs logits = paddle.cast(logits, paddle.float32) logits = logits_processors(model_kwargs["all_input_ids"], logits, decoding_step=step_idx_ori) @@ -372,7 +371,7 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): outputs = _forward_(**model_kwargs) # first decoder next_tokens, model_kwargs = _post_process_( - outputs, + outputs[0] if isinstance(outputs, tuple) else outputs, top_p, temperature, step_idx_ori, @@ -388,8 +387,9 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): paddle.sum(paddle.cast(model_kwargs["stop_flags"], "int64")), model_kwargs["stop_nums"], ): + outputs = _forward_(**model_kwargs) next_tokens, model_kwargs = _post_process_( - _forward_(**model_kwargs), + outputs[0] if isinstance(outputs, tuple) else outputs, top_p, temperature, step_idx_ori, @@ -691,7 +691,7 @@ def _forward_(**args): return self(**model_inputs) def _post_process_( - outputs, + logits, top_k, top_p, penalty_score, @@ -701,7 +701,7 @@ def _post_process_( model_kwargs, ): step_idx = model_kwargs["step_idx"] - logits = paddle.cast(outputs, paddle.float32) + logits = paddle.cast(logits, paddle.float32) from paddlenlp_ops import set_preids_token_penalty_multi_scores @@ -769,7 +769,7 @@ def _post_process_( outputs = _forward_(**model_kwargs) # [bs, 1, dim_embed] # first decoder next_tokens = _post_process_( - outputs, + outputs[0] if isinstance(outputs, tuple) else outputs, top_k, top_p, penalty_score, @@ -798,7 +798,7 @@ def _forward_(**args): return self(**model_inputs) def _post_process_( - outputs, + logits, top_k, top_p, penalty_score, @@ -808,7 +808,7 @@ def _post_process_( model_kwargs, ): step_idx = model_kwargs["step_idx"] - logits = paddle.cast(outputs, paddle.float32) + logits = paddle.cast(logits, paddle.float32) from paddlenlp_ops import speculate_get_token_penalty_multi_scores @@ -951,7 +951,7 @@ def _forward_(**args): return self(**model_inputs) def _post_process_( - outputs, + logits, top_k, top_p, penalty_score, @@ -960,7 +960,7 @@ def _post_process_( temperature, model_kwargs, ): - logits = paddle.cast(outputs, paddle.float32) + logits = paddle.cast(logits, paddle.float32) probs = F.softmax(logits) @@ -1184,7 +1184,7 @@ def _forward_(**args): model_inputs = self.prepare_inputs_for_generation(input_ids, **args) return self(**model_inputs) - def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): + def _post_process_(logits, top_p, temperature, step_idx_ori, model_kwargs): cache = model_kwargs.get("cache", None) just_decoder = model_kwargs["seq_len_encoder"] == 0 if cache is None: # first decoder @@ -1204,7 +1204,6 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): step_idx, model_kwargs["stop_flags"], ) - logits = outputs[0] if isinstance(outputs, tuple) else outputs logits = paddle.cast(logits, paddle.float32) logits = logits_processors(model_kwargs["all_input_ids"], logits, decoding_step=step_idx_ori) From edec57345e5c7fe6b88ece01f98661d4a07a6dd9 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Wed, 2 Apr 2025 19:51:03 +0800 Subject: [PATCH 5/5] fix ci --- paddlenlp/experimental/transformers/qwen2/modeling.py | 1 - 1 file changed, 1 deletion(-) diff --git a/paddlenlp/experimental/transformers/qwen2/modeling.py b/paddlenlp/experimental/transformers/qwen2/modeling.py index fb800a81a120..5447f3389686 100644 --- a/paddlenlp/experimental/transformers/qwen2/modeling.py +++ b/paddlenlp/experimental/transformers/qwen2/modeling.py @@ -53,7 +53,6 @@ from paddlenlp.transformers.conversion_utils import split_param_func from paddlenlp.transformers.model_outputs import ( # CausalLMOutputWithCrossAttentions, BaseModelOutputWithPast, - BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithPast, ) from paddlenlp.transformers.model_utils import (