Skip to content

Fix hidden states and quant kv cache #10854

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions llm/predict/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ class PredictorArgument:
)
dynamic_insert: bool = field(default=False, metadata={"help": "whether use dynamic insert"})
total_request_num: int = field(default=None, metadata={"help": "The total number of request data"})
kv_cache_reuse: int = field(default=1)

def __post_init__(self):
if self.speculate_method is not None:
Expand Down Expand Up @@ -1155,6 +1156,11 @@ def init_cache_kvs(self):
for cache_k_shape, cache_v_shape in zip(self.cache_k_shapes, self.cache_v_shapes):
self.cache_kvs.append(paddle.zeros(cache_k_shape, dtype=cachekv_dtype))
self.cache_kvs.append(paddle.zeros(cache_v_shape, dtype=cachekv_dtype))
if self.config.kv_cache_reuse:
logger.warning(
f"self.config.kv_cache_reuse = {self.config.kv_cache_reuse}, break, len(self.cache_kvs) = {len(self.cache_kvs)}"
)
break
else:
# for mla's absorption
assert self.cache_v_shapes is None
Expand Down
50 changes: 8 additions & 42 deletions paddlenlp/experimental/transformers/deepseek_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@
yarn_get_mscale,
yarn_linear_ramp_mask,
)
from paddlenlp.transformers.model_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
)
from paddlenlp.transformers.model_utils import (
dy2st_nocheck_guard_context,
register_base_model,
Expand Down Expand Up @@ -266,7 +263,6 @@ def __init__(self, config: DeepseekV2Config, base_model_prefix: str):
self.weight_block_size = config.weight_block_size
self.moe_quant_type = config.moe_quant_type
self.rope_theta = config.rope_theta
self.return_full_hidden_states = config.get("return_full_hidden_states", False)

self.use_weight_only = False
self.weightonly_group_size = -1
Expand Down Expand Up @@ -591,7 +587,6 @@ def __init__(self, config: DeepseekV2Config, base_model_prefix: str):
speculate_config = SpeculateConfig(
speculate_method=config.get("speculate_method", None),
speculate_max_draft_token_num=config.get("speculate_max_draft_token_num", 5),
return_full_hidden_states=config.get("return_full_hidden_states", False),
)

transformer_config = FusedMultiTransformerConfig(
Expand Down Expand Up @@ -622,9 +617,9 @@ def __init__(self, config: DeepseekV2Config, base_model_prefix: str):
rotary_emb=self.rotary_emb,
norm_type="rmsnorm",
rank_id=config.tensor_parallel_rank,
append_attn=config.append_attn,
moe_config=moe_config,
mla_config=mla_config,
append_attn=config.append_attn,
speculate_config=speculate_config,
)

Expand Down Expand Up @@ -1289,7 +1284,7 @@ def forward(
inputs_embeds = inputs_embeds.reshape([-1, inputs_embeds.shape[2]])

with dy2st_nocheck_guard_context():
hidden_states, _ = self.transformer_block(
hidden_states, full_hidden_states = self.transformer_block(
input_ids=input_ids,
src=inputs_embeds,
cum_offsets=cum_offsets,
Expand All @@ -1301,13 +1296,7 @@ def forward(
)
hidden_states = self.norm(hidden_states)

return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=None,
hidden_states=None,
attentions=None,
cum_offsets=cum_offsets,
)
return hidden_states, full_hidden_states


@register_base_model
Expand Down Expand Up @@ -1967,7 +1956,7 @@ def forward(
inputs_embeds = self.eh_proj(inputs_embeds)

with dy2st_nocheck_guard_context():
hidden_states, _ = self.transformer_block(
hidden_states, full_hidden_states = self.transformer_block(
input_ids=input_ids,
src=inputs_embeds,
cum_offsets=cum_offsets,
Expand All @@ -1980,12 +1969,7 @@ def forward(
)
hidden_states = self.norm(hidden_states)

return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=None,
hidden_states=None,
attentions=None,
)
return hidden_states, full_hidden_states


class DeepseekV2ForCausalLMBlockInferenceModel(GenerationBlockInferenceModel, DeepseekV2PretrainedModel):
Expand Down Expand Up @@ -2212,7 +2196,7 @@ def forward(
draft_tokens=None,
output_padding_offset=None,
):
outputs = self.deepseek_v2(
hidden_states, full_hidden_states = self.deepseek_v2(
input_ids,
inputs_embeds=inputs_embeds,
src_mask=src_mask,
Expand All @@ -2230,21 +2214,7 @@ def forward(
draft_tokens=draft_tokens,
output_padding_offset=output_padding_offset,
)
if self.return_full_hidden_states:
from paddlenlp_ops import rebuild_padding_v2

full_hidden_states = outputs[0]
cum_offsets = outputs[1]
hidden_states = rebuild_padding_v2(
full_hidden_states,
cum_offsets,
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
self.max_seq_len,
)
else:
hidden_states = outputs[0]

logits = self.lm_head(
hidden_states,
tensor_parallel_output=False,
Expand All @@ -2254,8 +2224,6 @@ def forward(
else:
return logits

return logits

@paddle.no_grad()
def set_state_dict(self, state_dict):
if "lm_head.weight" in state_dict:
Expand Down Expand Up @@ -2363,7 +2331,7 @@ def forward(
output_padding_offset=None,
pre_hidden_states=None,
):
outputs = self.mtp(
hidden_states, _ = self.mtp(
input_ids,
src_mask=src_mask,
caches=caches,
Expand All @@ -2382,8 +2350,6 @@ def forward(
pre_hidden_states=pre_hidden_states,
)

hidden_states = outputs[0]

logits = self.lm_head(
hidden_states,
tensor_parallel_output=False,
Expand Down
16 changes: 12 additions & 4 deletions paddlenlp/experimental/transformers/fused_transformer_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1481,7 +1481,8 @@ def forward(
self.pre_process(**kwargs)
kwargs["cum_offsets"] = cum_offsets

if caches is not None:
kv_cache_reuse = kwargs.get("kv_cache_reuse", None)
if caches is not None and kv_cache_reuse is None:
assert len(caches) == len(self.linear_weights) or len(caches) == 2 * len(self.linear_weights)

assert self.num_layers == len(self.linear_weights)
Expand Down Expand Up @@ -1589,7 +1590,7 @@ def forward(
kwargs["input_ids"] = input_ids

out = self.post_process(**kwargs)
return out, caches
return out, kwargs["multi_block_output"]


class FusedMultiTransformerPostLayernorm(FusedMultiTransformerBase):
Expand Down Expand Up @@ -3172,10 +3173,17 @@ def compute_attn(
k_dequant_scales = kwargs.get("k_dequant_scales", None)
v_dequant_scales = kwargs.get("v_dequant_scales", None)

kv_cache_reuse = kwargs.get("kv_cache_reuse", None)
if kv_cache_reuse:
k_cache_index = 0
v_cache_index = 1
else:
k_cache_index = 2 * i
v_cache_index = 2 * i + 1
fmha_out = paddle.incubate.nn.functional.block_multihead_attention(
qkv_out,
caches[2 * i],
caches[2 * i + 1],
caches[k_cache_index],
caches[v_cache_index],
kwargs.get("seq_lens_encoder", None),
kwargs.get("seq_lens_decoder", None),
kwargs.get("seq_lens_this_time", None),
Expand Down
25 changes: 12 additions & 13 deletions paddlenlp/experimental/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def _forward_(**args):
model_inputs = self.prepare_inputs_for_generation(input_ids, cache_kvs, **args)
return self(**model_inputs)

def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
def _post_process_(logits, top_p, temperature, step_idx_ori, model_kwargs):
cache = model_kwargs.get("cache", None)
just_decoder = model_kwargs["seq_len_encoder"] == 0
if cache is None: # first decoder
Expand All @@ -314,7 +314,6 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
step_idx,
model_kwargs["stop_flags"],
)
logits = outputs[0] if isinstance(outputs, tuple) else outputs

logits = paddle.cast(logits, paddle.float32)
logits = logits_processors(model_kwargs["all_input_ids"], logits, decoding_step=step_idx_ori)
Expand Down Expand Up @@ -373,7 +372,7 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
outputs = _forward_(**model_kwargs)
# first decoder
next_tokens, model_kwargs = _post_process_(
outputs,
outputs[0] if isinstance(outputs, tuple) else outputs,
top_p,
temperature,
step_idx_ori,
Expand All @@ -389,8 +388,9 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
paddle.sum(paddle.cast(model_kwargs["stop_flags"], "int64")),
model_kwargs["stop_nums"],
):
outputs = _forward_(**model_kwargs)
next_tokens, model_kwargs = _post_process_(
_forward_(**model_kwargs),
outputs[0] if isinstance(outputs, tuple) else outputs,
top_p,
temperature,
step_idx_ori,
Expand Down Expand Up @@ -692,7 +692,7 @@ def _forward_(**args):
return self(**model_inputs)

def _post_process_(
outputs,
logits,
top_k,
top_p,
penalty_score,
Expand All @@ -702,7 +702,7 @@ def _post_process_(
model_kwargs,
):
step_idx = model_kwargs["step_idx"]
logits = paddle.cast(outputs, paddle.float32)
logits = paddle.cast(logits, paddle.float32)

from paddlenlp_ops import set_preids_token_penalty_multi_scores

Expand Down Expand Up @@ -777,7 +777,7 @@ def _post_process_(
outputs = _forward_(**model_kwargs) # [bs, 1, dim_embed]
# first decoder
next_tokens = _post_process_(
outputs,
outputs[0] if isinstance(outputs, tuple) else outputs,
top_k,
top_p,
penalty_score,
Expand Down Expand Up @@ -806,7 +806,7 @@ def _forward_(**args):
return self(**model_inputs)

def _post_process_(
outputs,
logits,
top_k,
top_p,
penalty_score,
Expand All @@ -816,7 +816,7 @@ def _post_process_(
model_kwargs,
):
step_idx = model_kwargs["step_idx"]
logits = paddle.cast(outputs, paddle.float32)
logits = paddle.cast(logits, paddle.float32)

from paddlenlp_ops import speculate_get_token_penalty_multi_scores

Expand Down Expand Up @@ -959,7 +959,7 @@ def _forward_(**args):
return self(**model_inputs)

def _post_process_(
outputs,
logits,
top_k,
top_p,
penalty_score,
Expand All @@ -968,7 +968,7 @@ def _post_process_(
temperature,
model_kwargs,
):
logits = paddle.cast(outputs, paddle.float32)
logits = paddle.cast(logits, paddle.float32)

probs = F.softmax(logits)

Expand Down Expand Up @@ -1191,7 +1191,7 @@ def _forward_(**args):
model_inputs = self.prepare_inputs_for_generation(input_ids, **args)
return self(**model_inputs)

def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
def _post_process_(logits, top_p, temperature, step_idx_ori, model_kwargs):
cache = model_kwargs.get("cache", None)
just_decoder = model_kwargs["seq_len_encoder"] == 0
if cache is None: # first decoder
Expand All @@ -1211,7 +1211,6 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
step_idx,
model_kwargs["stop_flags"],
)
logits = outputs[0] if isinstance(outputs, tuple) else outputs
logits = paddle.cast(logits, paddle.float32)
logits = logits_processors(model_kwargs["all_input_ids"], logits, decoding_step=step_idx_ori)

Expand Down
Loading
Loading