diff --git a/paddlenlp/experimental/transformers/deepseek_v2/modeling.py b/paddlenlp/experimental/transformers/deepseek_v2/modeling.py index dca9623f0cb4..1db636a6ff29 100644 --- a/paddlenlp/experimental/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/experimental/transformers/deepseek_v2/modeling.py @@ -47,9 +47,6 @@ yarn_get_mscale, yarn_linear_ramp_mask, ) -from paddlenlp.transformers.model_outputs import ( - BaseModelOutputWithPastAndCrossAttentions, -) from paddlenlp.transformers.model_utils import ( dy2st_nocheck_guard_context, register_base_model, @@ -266,7 +263,6 @@ def __init__(self, config: DeepseekV2Config, base_model_prefix: str): self.weight_block_size = config.weight_block_size self.moe_quant_type = config.moe_quant_type self.rope_theta = config.rope_theta - self.return_full_hidden_states = config.get("return_full_hidden_states", False) self.use_weight_only = False self.weightonly_group_size = -1 @@ -591,7 +587,6 @@ def __init__(self, config: DeepseekV2Config, base_model_prefix: str): speculate_config = SpeculateConfig( speculate_method=config.get("speculate_method", None), speculate_max_draft_token_num=config.get("speculate_max_draft_token_num", 5), - return_full_hidden_states=config.get("return_full_hidden_states", False), ) transformer_config = FusedMultiTransformerConfig( @@ -622,9 +617,9 @@ def __init__(self, config: DeepseekV2Config, base_model_prefix: str): rotary_emb=self.rotary_emb, norm_type="rmsnorm", rank_id=config.tensor_parallel_rank, + append_attn=config.append_attn, moe_config=moe_config, mla_config=mla_config, - append_attn=config.append_attn, speculate_config=speculate_config, ) @@ -1289,7 +1284,7 @@ def forward( inputs_embeds = inputs_embeds.reshape([-1, inputs_embeds.shape[2]]) with dy2st_nocheck_guard_context(): - hidden_states, _ = self.transformer_block( + hidden_states, full_hidden_states = self.transformer_block( input_ids=input_ids, src=inputs_embeds, cum_offsets=cum_offsets, @@ -1301,13 +1296,7 @@ def forward( ) hidden_states = self.norm(hidden_states) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=None, - hidden_states=None, - attentions=None, - cum_offsets=cum_offsets, - ) + return hidden_states, full_hidden_states @register_base_model @@ -1967,7 +1956,7 @@ def forward( inputs_embeds = self.eh_proj(inputs_embeds) with dy2st_nocheck_guard_context(): - hidden_states, _ = self.transformer_block( + hidden_states, full_hidden_states = self.transformer_block( input_ids=input_ids, src=inputs_embeds, cum_offsets=cum_offsets, @@ -1980,12 +1969,7 @@ def forward( ) hidden_states = self.norm(hidden_states) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=None, - hidden_states=None, - attentions=None, - ) + return hidden_states, full_hidden_states class DeepseekV2ForCausalLMBlockInferenceModel(GenerationBlockInferenceModel, DeepseekV2PretrainedModel): @@ -2212,7 +2196,7 @@ def forward( draft_tokens=None, output_padding_offset=None, ): - outputs = self.deepseek_v2( + hidden_states, full_hidden_states = self.deepseek_v2( input_ids, inputs_embeds=inputs_embeds, src_mask=src_mask, @@ -2230,21 +2214,7 @@ def forward( draft_tokens=draft_tokens, output_padding_offset=output_padding_offset, ) - if self.return_full_hidden_states: - from paddlenlp_ops import rebuild_padding_v2 - - full_hidden_states = outputs[0] - cum_offsets = outputs[1] - hidden_states = rebuild_padding_v2( - full_hidden_states, - cum_offsets, - seq_lens_decoder, - seq_lens_encoder, - output_padding_offset, - self.max_seq_len, - ) - else: - hidden_states = outputs[0] + logits = self.lm_head( hidden_states, tensor_parallel_output=False, @@ -2254,8 +2224,6 @@ def forward( else: return logits - return logits - @paddle.no_grad() def set_state_dict(self, state_dict): if "lm_head.weight" in state_dict: @@ -2363,7 +2331,7 @@ def forward( output_padding_offset=None, pre_hidden_states=None, ): - outputs = self.mtp( + hidden_states, _ = self.mtp( input_ids, src_mask=src_mask, caches=caches, @@ -2382,8 +2350,6 @@ def forward( pre_hidden_states=pre_hidden_states, ) - hidden_states = outputs[0] - logits = self.lm_head( hidden_states, tensor_parallel_output=False, diff --git a/paddlenlp/experimental/transformers/fused_transformer_layers.py b/paddlenlp/experimental/transformers/fused_transformer_layers.py index 3276b9ac9663..b1da0deed6ba 100644 --- a/paddlenlp/experimental/transformers/fused_transformer_layers.py +++ b/paddlenlp/experimental/transformers/fused_transformer_layers.py @@ -157,7 +157,6 @@ class AvxConfig: class SpeculateConfig: speculate_max_draft_token_num: int = 5 speculate_method: str = None - return_full_hidden_states: bool = False @dataclass @@ -1589,7 +1588,7 @@ def forward( kwargs["input_ids"] = input_ids out = self.post_process(**kwargs) - return out, caches + return out, kwargs["multi_block_output"] class FusedMultiTransformerPostLayernorm(FusedMultiTransformerBase): diff --git a/paddlenlp/experimental/transformers/generation_utils.py b/paddlenlp/experimental/transformers/generation_utils.py index 0a09ceaa1f9c..f93102d0fde5 100644 --- a/paddlenlp/experimental/transformers/generation_utils.py +++ b/paddlenlp/experimental/transformers/generation_utils.py @@ -295,7 +295,7 @@ def _forward_(**args): model_inputs = self.prepare_inputs_for_generation(input_ids, cache_kvs, **args) return self(**model_inputs) - def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): + def _post_process_(logits, top_p, temperature, step_idx_ori, model_kwargs): cache = model_kwargs.get("cache", None) just_decoder = model_kwargs["seq_len_encoder"] == 0 if cache is None: # first decoder @@ -314,7 +314,6 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): step_idx, model_kwargs["stop_flags"], ) - logits = outputs[0] if isinstance(outputs, tuple) else outputs logits = paddle.cast(logits, paddle.float32) logits = logits_processors(model_kwargs["all_input_ids"], logits, decoding_step=step_idx_ori) @@ -373,7 +372,7 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): outputs = _forward_(**model_kwargs) # first decoder next_tokens, model_kwargs = _post_process_( - outputs, + outputs[0] if isinstance(outputs, tuple) else outputs, top_p, temperature, step_idx_ori, @@ -389,8 +388,9 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): paddle.sum(paddle.cast(model_kwargs["stop_flags"], "int64")), model_kwargs["stop_nums"], ): + outputs = _forward_(**model_kwargs) next_tokens, model_kwargs = _post_process_( - _forward_(**model_kwargs), + outputs[0] if isinstance(outputs, tuple) else outputs, top_p, temperature, step_idx_ori, @@ -692,7 +692,7 @@ def _forward_(**args): return self(**model_inputs) def _post_process_( - outputs, + logits, top_k, top_p, penalty_score, @@ -702,7 +702,7 @@ def _post_process_( model_kwargs, ): step_idx = model_kwargs["step_idx"] - logits = paddle.cast(outputs, paddle.float32) + logits = paddle.cast(logits, paddle.float32) from paddlenlp_ops import set_preids_token_penalty_multi_scores @@ -777,7 +777,7 @@ def _post_process_( outputs = _forward_(**model_kwargs) # [bs, 1, dim_embed] # first decoder next_tokens = _post_process_( - outputs, + outputs[0] if isinstance(outputs, tuple) else outputs, top_k, top_p, penalty_score, @@ -806,7 +806,7 @@ def _forward_(**args): return self(**model_inputs) def _post_process_( - outputs, + logits, top_k, top_p, penalty_score, @@ -816,7 +816,7 @@ def _post_process_( model_kwargs, ): step_idx = model_kwargs["step_idx"] - logits = paddle.cast(outputs, paddle.float32) + logits = paddle.cast(logits, paddle.float32) from paddlenlp_ops import speculate_get_token_penalty_multi_scores @@ -959,7 +959,7 @@ def _forward_(**args): return self(**model_inputs) def _post_process_( - outputs, + logits, top_k, top_p, penalty_score, @@ -968,7 +968,7 @@ def _post_process_( temperature, model_kwargs, ): - logits = paddle.cast(outputs, paddle.float32) + logits = paddle.cast(logits, paddle.float32) probs = F.softmax(logits) @@ -1191,7 +1191,7 @@ def _forward_(**args): model_inputs = self.prepare_inputs_for_generation(input_ids, **args) return self(**model_inputs) - def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): + def _post_process_(logits, top_p, temperature, step_idx_ori, model_kwargs): cache = model_kwargs.get("cache", None) just_decoder = model_kwargs["seq_len_encoder"] == 0 if cache is None: # first decoder @@ -1211,7 +1211,6 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): step_idx, model_kwargs["stop_flags"], ) - logits = outputs[0] if isinstance(outputs, tuple) else outputs logits = paddle.cast(logits, paddle.float32) logits = logits_processors(model_kwargs["all_input_ids"], logits, decoding_step=step_idx_ori) diff --git a/paddlenlp/experimental/transformers/llama/modeling.py b/paddlenlp/experimental/transformers/llama/modeling.py index 8ada8a726b86..da5a12b6000d 100644 --- a/paddlenlp/experimental/transformers/llama/modeling.py +++ b/paddlenlp/experimental/transformers/llama/modeling.py @@ -257,19 +257,12 @@ def forward( # merge batch and seq_len dimension. inputs_embeds = inputs_embeds.reshape([batch * seq_len, hidden_dim]) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) hidden_states = inputs_embeds # decoder layers - all_hidden_states = () if output_hidden_states else None with dy2st_nocheck_guard_context(): hidden_states = self.transformer_block( input_ids=input_ids, @@ -280,11 +273,7 @@ def forward( ) hidden_states = self.norm(hidden_states) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, None, all_hidden_states, None] if v is not None) + return hidden_states @paddle.no_grad() # avx @@ -401,7 +390,6 @@ def __init__(self, config: LlamaConfig): self.epsilon = config.rms_norm_eps self.max_position_embeddings = config.max_position_embeddings self.quant_type = config.get("quant_type", "") - self.return_full_hidden_states = config.get("return_full_hidden_states", False) self.rope_theta = config.rope_theta self.use_neox = True @@ -617,7 +605,6 @@ def __init__(self, config: LlamaConfig): speculate_config = SpeculateConfig( speculate_method=config.get("speculate_method", None), speculate_max_draft_token_num=config.get("speculate_max_draft_token_num", 5), - return_full_hidden_states=config.get("return_full_hidden_states", False), ) hpu_config = HpuConfig( @@ -1496,7 +1483,7 @@ def forward( inputs_embeds = self.embed_tokens(ids_remove_padding) with dy2st_nocheck_guard_context(): - hidden_states, _ = self.transformer_block( + hidden_states, full_hidden_states = self.transformer_block( input_ids=input_ids, src=inputs_embeds, cum_offsets=cum_offsets, @@ -1508,17 +1495,11 @@ def forward( ) hidden_states = self.norm(hidden_states) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=None, - hidden_states=None, - attentions=None, - cum_offsets=cum_offsets, - ) + return hidden_states, full_hidden_states @register_base_model -class EagleForLlamaInferenceModel(LlamaBlockInferenceModel): +class EagleForLlamaBlockInferenceModel(LlamaBlockInferenceModel): def __init__(self, config: LlamaConfig): self.append_attn = config.append_attn super().__init__(config) @@ -2079,7 +2060,7 @@ def forward( draft_tokens=None, output_padding_offset=None, ): - outputs = self.llama( + hidden_states, full_hidden_states = self.llama( input_ids, src_mask=src_mask, caches=caches, @@ -2097,23 +2078,7 @@ def forward( draft_tokens=draft_tokens, output_padding_offset=output_padding_offset, ) - # hidden_states = outputs[0] - if self.return_full_hidden_states: - from paddlenlp_ops import rebuild_padding_v2 - - # full_hidden_states = outputs[1] - full_hidden_states = outputs[0] - cum_offsets = outputs[1] - hidden_states = rebuild_padding_v2( - full_hidden_states, - cum_offsets, - seq_lens_decoder, - seq_lens_encoder, - output_padding_offset, - self.max_seq_len, - ) - else: - hidden_states = outputs[0] + logits = self.lm_head( hidden_states, tensor_parallel_output=False, @@ -2139,7 +2104,7 @@ def __init__(self, config): self.verify_window = config.get("speculate_verify_window", 2) self.max_seq_len = config.max_seq_len - self.eagle = EagleForLlamaInferenceModel(config) + self.eagle = EagleForLlamaBlockInferenceModel(config) if config.tie_word_embeddings: self.lm_head = LlamaLMHead(config, embedding_weights=self.llama.embed_tokens.weight, transpose_y=True) self.tie_weights() diff --git a/paddlenlp/experimental/transformers/qwen2/modeling.py b/paddlenlp/experimental/transformers/qwen2/modeling.py index 141bd4df1bc1..ab1a5ae04c74 100644 --- a/paddlenlp/experimental/transformers/qwen2/modeling.py +++ b/paddlenlp/experimental/transformers/qwen2/modeling.py @@ -54,7 +54,6 @@ from paddlenlp.transformers.conversion_utils import split_param_func from paddlenlp.transformers.model_outputs import ( # CausalLMOutputWithCrossAttentions, BaseModelOutputWithPast, - BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithPast, ) from paddlenlp.transformers.model_utils import ( @@ -1325,7 +1324,7 @@ def forward( inputs_embeds = inputs_embeds.reshape([-1, inputs_embeds.shape[2]]) with dy2st_nocheck_guard_context(): - hidden_states, _ = self.transformer_block( + hidden_states, full_hidden_states = self.transformer_block( input_ids=input_ids, src=inputs_embeds, cum_offsets=cum_offsets, @@ -1337,12 +1336,7 @@ def forward( ) hidden_states = self.norm(hidden_states) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=None, - hidden_states=None, - attentions=None, - ) + return hidden_states, full_hidden_states class Qwen2ForCausalLMBlockInferenceModel(GenerationBlockInferenceModel, Qwen2PretrainedModel): @@ -1359,6 +1353,7 @@ def __init__(self, config: Qwen2Config, base_model_prefix: str = "qwen2"): self.max_candidate_len = config.get("speculate_max_candidate_len", 5) self.verify_window = config.get("speculate_verify_window", 2) self.max_seq_len = config.max_seq_len + self.return_full_hidden_states = config.get("return_full_hidden_states", False) self.qwen2 = Qwen2BlockInferenceModel(config, base_model_prefix) if config.tie_word_embeddings: @@ -1555,7 +1550,7 @@ def forward( draft_tokens=None, output_padding_offset=None, ): - outputs = self.qwen2( + hidden_states, full_hidden_states = self.qwen2( input_ids, inputs_embeds=inputs_embeds, src_mask=src_mask, @@ -1575,13 +1570,15 @@ def forward( output_padding_offset=output_padding_offset, ) - hidden_states = outputs[0] logits = self.lm_head( hidden_states, tensor_parallel_output=False, ) - return logits + if self.return_full_hidden_states: + return logits, full_hidden_states + else: + return logits @paddle.no_grad() def set_state_dict(self, state_dict): diff --git a/paddlenlp/transformers/model_outputs.py b/paddlenlp/transformers/model_outputs.py index d14f40e9a215..9836aff252fe 100644 --- a/paddlenlp/transformers/model_outputs.py +++ b/paddlenlp/transformers/model_outputs.py @@ -662,10 +662,6 @@ class BaseModelOutputWithPastAndCrossAttentions(ModelOutput): Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads. - cum_offsets (`tuple(paddle.Tensor)`, *optional*, needed when `return_full_hidden_states=True`: - Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, 1)`. - - Offset of the current batch. """ last_hidden_state: paddle.Tensor = None @@ -673,7 +669,6 @@ class BaseModelOutputWithPastAndCrossAttentions(ModelOutput): hidden_states: Optional[Tuple[paddle.Tensor]] = None attentions: Optional[Tuple[paddle.Tensor]] = None cross_attentions: Optional[Tuple[paddle.Tensor]] = None - cum_offsets: Optional[Tuple[paddle.Tensor]] = None @dataclass diff --git a/tests/llm/test_predictor_v1.py b/tests/llm/test_predictor_v1.py index 98cd7b55c51d..c75ac81a729e 100644 --- a/tests/llm/test_predictor_v1.py +++ b/tests/llm/test_predictor_v1.py @@ -119,6 +119,7 @@ def setUp(self) -> None: ( { "append_attn": True, + "return_full_hidden_states": True, }, ), ]