Skip to content
Merged
6 changes: 6 additions & 0 deletions llm/predict/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ class PredictorArgument:
)
dynamic_insert: bool = field(default=False, metadata={"help": "whether use dynamic insert"})
total_request_num: int = field(default=None, metadata={"help": "The total number of request data"})
kv_cache_reuse: int = field(default=0)

def __post_init__(self):
if self.speculate_method is not None:
Expand Down Expand Up @@ -1155,6 +1156,11 @@ def init_cache_kvs(self):
for cache_k_shape, cache_v_shape in zip(self.cache_k_shapes, self.cache_v_shapes):
self.cache_kvs.append(paddle.zeros(cache_k_shape, dtype=cachekv_dtype))
self.cache_kvs.append(paddle.zeros(cache_v_shape, dtype=cachekv_dtype))
if self.config.kv_cache_reuse:
logger.warning(
f"self.config.kv_cache_reuse = {self.config.kv_cache_reuse}, break, len(self.cache_kvs) = {len(self.cache_kvs)}"
)
break
else:
# for mla's absorption
assert self.cache_v_shapes is None
Expand Down
50 changes: 8 additions & 42 deletions paddlenlp/experimental/transformers/deepseek_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@
yarn_get_mscale,
yarn_linear_ramp_mask,
)
from paddlenlp.transformers.model_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
)
from paddlenlp.transformers.model_utils import (
dy2st_nocheck_guard_context,
register_base_model,
Expand Down Expand Up @@ -266,7 +263,6 @@ def __init__(self, config: DeepseekV2Config, base_model_prefix: str):
self.weight_block_size = config.weight_block_size
self.moe_quant_type = config.moe_quant_type
self.rope_theta = config.rope_theta
self.return_full_hidden_states = config.get("return_full_hidden_states", False)

self.use_weight_only = False
self.weightonly_group_size = -1
Expand Down Expand Up @@ -591,7 +587,6 @@ def __init__(self, config: DeepseekV2Config, base_model_prefix: str):
speculate_config = SpeculateConfig(
speculate_method=config.get("speculate_method", None),
speculate_max_draft_token_num=config.get("speculate_max_draft_token_num", 5),
return_full_hidden_states=config.get("return_full_hidden_states", False),
)

transformer_config = FusedMultiTransformerConfig(
Expand Down Expand Up @@ -622,9 +617,9 @@ def __init__(self, config: DeepseekV2Config, base_model_prefix: str):
rotary_emb=self.rotary_emb,
norm_type="rmsnorm",
rank_id=config.tensor_parallel_rank,
append_attn=config.append_attn,
moe_config=moe_config,
mla_config=mla_config,
append_attn=config.append_attn,
speculate_config=speculate_config,
)

Expand Down Expand Up @@ -1289,7 +1284,7 @@ def forward(
inputs_embeds = inputs_embeds.reshape([-1, inputs_embeds.shape[2]])

with dy2st_nocheck_guard_context():
hidden_states, _ = self.transformer_block(
hidden_states, full_hidden_states = self.transformer_block(
input_ids=input_ids,
src=inputs_embeds,
cum_offsets=cum_offsets,
Expand All @@ -1301,13 +1296,7 @@ def forward(
)
hidden_states = self.norm(hidden_states)

return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=None,
hidden_states=None,
attentions=None,
cum_offsets=cum_offsets,
)
return hidden_states, full_hidden_states


@register_base_model
Expand Down Expand Up @@ -1967,7 +1956,7 @@ def forward(
inputs_embeds = self.eh_proj(inputs_embeds)

with dy2st_nocheck_guard_context():
hidden_states, _ = self.transformer_block(
hidden_states, full_hidden_states = self.transformer_block(
input_ids=input_ids,
src=inputs_embeds,
cum_offsets=cum_offsets,
Expand All @@ -1980,12 +1969,7 @@ def forward(
)
hidden_states = self.norm(hidden_states)

return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=None,
hidden_states=None,
attentions=None,
)
return hidden_states, full_hidden_states


class DeepseekV2ForCausalLMBlockInferenceModel(GenerationBlockInferenceModel, DeepseekV2PretrainedModel):
Expand Down Expand Up @@ -2212,7 +2196,7 @@ def forward(
draft_tokens=None,
output_padding_offset=None,
):
outputs = self.deepseek_v2(
hidden_states, full_hidden_states = self.deepseek_v2(
input_ids,
inputs_embeds=inputs_embeds,
src_mask=src_mask,
Expand All @@ -2230,21 +2214,7 @@ def forward(
draft_tokens=draft_tokens,
output_padding_offset=output_padding_offset,
)
if self.return_full_hidden_states:
from paddlenlp_ops import rebuild_padding_v2

full_hidden_states = outputs[0]
cum_offsets = outputs[1]
hidden_states = rebuild_padding_v2(
full_hidden_states,
cum_offsets,
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
self.max_seq_len,
)
else:
hidden_states = outputs[0]

logits = self.lm_head(
hidden_states,
tensor_parallel_output=False,
Expand All @@ -2254,8 +2224,6 @@ def forward(
else:
return logits

return logits

@paddle.no_grad()
def set_state_dict(self, state_dict):
if "lm_head.weight" in state_dict:
Expand Down Expand Up @@ -2363,7 +2331,7 @@ def forward(
output_padding_offset=None,
pre_hidden_states=None,
):
outputs = self.mtp(
hidden_states, _ = self.mtp(
input_ids,
src_mask=src_mask,
caches=caches,
Expand All @@ -2382,8 +2350,6 @@ def forward(
pre_hidden_states=pre_hidden_states,
)

hidden_states = outputs[0]

logits = self.lm_head(
hidden_states,
tensor_parallel_output=False,
Expand Down
16 changes: 12 additions & 4 deletions paddlenlp/experimental/transformers/fused_transformer_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1481,7 +1481,8 @@ def forward(
self.pre_process(**kwargs)
kwargs["cum_offsets"] = cum_offsets

if caches is not None:
kv_cache_reuse = kwargs.get("kv_cache_reuse", None)
if caches is not None and kv_cache_reuse is not True:
assert len(caches) == len(self.linear_weights) or len(caches) == 2 * len(self.linear_weights)

assert self.num_layers == len(self.linear_weights)
Expand Down Expand Up @@ -1589,7 +1590,7 @@ def forward(
kwargs["input_ids"] = input_ids

out = self.post_process(**kwargs)
return out, caches
return out, kwargs["multi_block_output"]


class FusedMultiTransformerPostLayernorm(FusedMultiTransformerBase):
Expand Down Expand Up @@ -3172,10 +3173,17 @@ def compute_attn(
k_dequant_scales = kwargs.get("k_dequant_scales", None)
v_dequant_scales = kwargs.get("v_dequant_scales", None)

kv_cache_reuse = kwargs.get("kv_cache_reuse", None)
if kv_cache_reuse:
k_cache_index = 0
v_cache_index = 1
else:
k_cache_index = 2 * i
v_cache_index = 2 * i + 1
fmha_out = paddle.incubate.nn.functional.block_multihead_attention(
qkv_out,
caches[2 * i],
caches[2 * i + 1],
caches[k_cache_index],
caches[v_cache_index],
kwargs.get("seq_lens_encoder", None),
kwargs.get("seq_lens_decoder", None),
kwargs.get("seq_lens_this_time", None),
Expand Down
25 changes: 12 additions & 13 deletions paddlenlp/experimental/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def _forward_(**args):
model_inputs = self.prepare_inputs_for_generation(input_ids, cache_kvs, **args)
return self(**model_inputs)

def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
def _post_process_(logits, top_p, temperature, step_idx_ori, model_kwargs):
cache = model_kwargs.get("cache", None)
just_decoder = model_kwargs["seq_len_encoder"] == 0
if cache is None: # first decoder
Expand All @@ -314,7 +314,6 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
step_idx,
model_kwargs["stop_flags"],
)
logits = outputs[0] if isinstance(outputs, tuple) else outputs

logits = paddle.cast(logits, paddle.float32)
logits = logits_processors(model_kwargs["all_input_ids"], logits, decoding_step=step_idx_ori)
Expand Down Expand Up @@ -373,7 +372,7 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
outputs = _forward_(**model_kwargs)
# first decoder
next_tokens, model_kwargs = _post_process_(
outputs,
outputs[0] if isinstance(outputs, tuple) else outputs,
top_p,
temperature,
step_idx_ori,
Expand All @@ -389,8 +388,9 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
paddle.sum(paddle.cast(model_kwargs["stop_flags"], "int64")),
model_kwargs["stop_nums"],
):
outputs = _forward_(**model_kwargs)
next_tokens, model_kwargs = _post_process_(
_forward_(**model_kwargs),
outputs[0] if isinstance(outputs, tuple) else outputs,
top_p,
temperature,
step_idx_ori,
Expand Down Expand Up @@ -692,7 +692,7 @@ def _forward_(**args):
return self(**model_inputs)

def _post_process_(
outputs,
logits,
top_k,
top_p,
penalty_score,
Expand All @@ -702,7 +702,7 @@ def _post_process_(
model_kwargs,
):
step_idx = model_kwargs["step_idx"]
logits = paddle.cast(outputs, paddle.float32)
logits = paddle.cast(logits, paddle.float32)

from paddlenlp_ops import set_preids_token_penalty_multi_scores

Expand Down Expand Up @@ -777,7 +777,7 @@ def _post_process_(
outputs = _forward_(**model_kwargs) # [bs, 1, dim_embed]
# first decoder
next_tokens = _post_process_(
outputs,
outputs[0] if isinstance(outputs, tuple) else outputs,
top_k,
top_p,
penalty_score,
Expand Down Expand Up @@ -806,7 +806,7 @@ def _forward_(**args):
return self(**model_inputs)

def _post_process_(
outputs,
logits,
top_k,
top_p,
penalty_score,
Expand All @@ -816,7 +816,7 @@ def _post_process_(
model_kwargs,
):
step_idx = model_kwargs["step_idx"]
logits = paddle.cast(outputs, paddle.float32)
logits = paddle.cast(logits, paddle.float32)

from paddlenlp_ops import speculate_get_token_penalty_multi_scores

Expand Down Expand Up @@ -959,7 +959,7 @@ def _forward_(**args):
return self(**model_inputs)

def _post_process_(
outputs,
logits,
top_k,
top_p,
penalty_score,
Expand All @@ -968,7 +968,7 @@ def _post_process_(
temperature,
model_kwargs,
):
logits = paddle.cast(outputs, paddle.float32)
logits = paddle.cast(logits, paddle.float32)

probs = F.softmax(logits)

Expand Down Expand Up @@ -1191,7 +1191,7 @@ def _forward_(**args):
model_inputs = self.prepare_inputs_for_generation(input_ids, **args)
return self(**model_inputs)

def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
def _post_process_(logits, top_p, temperature, step_idx_ori, model_kwargs):
cache = model_kwargs.get("cache", None)
just_decoder = model_kwargs["seq_len_encoder"] == 0
if cache is None: # first decoder
Expand All @@ -1211,7 +1211,6 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
step_idx,
model_kwargs["stop_flags"],
)
logits = outputs[0] if isinstance(outputs, tuple) else outputs
logits = paddle.cast(logits, paddle.float32)
logits = logits_processors(model_kwargs["all_input_ids"], logits, decoding_step=step_idx_ori)

Expand Down
Loading