Skip to content

Commit c83684f

Browse files
authored
Fix hidden states and quant kv cache (#10854)
* fix hidden states * fix quant kv_cache * fix Lint style * fix kv_cache_reuse key error * fix kv_cache_reuse key error * remove unused code * fix kv_cache_reuse default
1 parent 887b5d2 commit c83684f

File tree

11 files changed

+472
-140
lines changed

11 files changed

+472
-140
lines changed

llm/predict/predictor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ class PredictorArgument:
196196
)
197197
dynamic_insert: bool = field(default=False, metadata={"help": "whether use dynamic insert"})
198198
total_request_num: int = field(default=None, metadata={"help": "The total number of request data"})
199+
kv_cache_reuse: int = field(default=0)
199200

200201
def __post_init__(self):
201202
if self.speculate_method is not None:
@@ -1155,6 +1156,11 @@ def init_cache_kvs(self):
11551156
for cache_k_shape, cache_v_shape in zip(self.cache_k_shapes, self.cache_v_shapes):
11561157
self.cache_kvs.append(paddle.zeros(cache_k_shape, dtype=cachekv_dtype))
11571158
self.cache_kvs.append(paddle.zeros(cache_v_shape, dtype=cachekv_dtype))
1159+
if self.config.kv_cache_reuse:
1160+
logger.warning(
1161+
f"self.config.kv_cache_reuse = {self.config.kv_cache_reuse}, break, len(self.cache_kvs) = {len(self.cache_kvs)}"
1162+
)
1163+
break
11581164
else:
11591165
# for mla's absorption
11601166
assert self.cache_v_shapes is None

paddlenlp/experimental/transformers/deepseek_v2/modeling.py

Lines changed: 8 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,6 @@
4747
yarn_get_mscale,
4848
yarn_linear_ramp_mask,
4949
)
50-
from paddlenlp.transformers.model_outputs import (
51-
BaseModelOutputWithPastAndCrossAttentions,
52-
)
5350
from paddlenlp.transformers.model_utils import (
5451
dy2st_nocheck_guard_context,
5552
register_base_model,
@@ -266,7 +263,6 @@ def __init__(self, config: DeepseekV2Config, base_model_prefix: str):
266263
self.weight_block_size = config.weight_block_size
267264
self.moe_quant_type = config.moe_quant_type
268265
self.rope_theta = config.rope_theta
269-
self.return_full_hidden_states = config.get("return_full_hidden_states", False)
270266

271267
self.use_weight_only = False
272268
self.weightonly_group_size = -1
@@ -591,7 +587,6 @@ def __init__(self, config: DeepseekV2Config, base_model_prefix: str):
591587
speculate_config = SpeculateConfig(
592588
speculate_method=config.get("speculate_method", None),
593589
speculate_max_draft_token_num=config.get("speculate_max_draft_token_num", 5),
594-
return_full_hidden_states=config.get("return_full_hidden_states", False),
595590
)
596591

597592
transformer_config = FusedMultiTransformerConfig(
@@ -622,9 +617,9 @@ def __init__(self, config: DeepseekV2Config, base_model_prefix: str):
622617
rotary_emb=self.rotary_emb,
623618
norm_type="rmsnorm",
624619
rank_id=config.tensor_parallel_rank,
620+
append_attn=config.append_attn,
625621
moe_config=moe_config,
626622
mla_config=mla_config,
627-
append_attn=config.append_attn,
628623
speculate_config=speculate_config,
629624
)
630625

@@ -1289,7 +1284,7 @@ def forward(
12891284
inputs_embeds = inputs_embeds.reshape([-1, inputs_embeds.shape[2]])
12901285

12911286
with dy2st_nocheck_guard_context():
1292-
hidden_states, _ = self.transformer_block(
1287+
hidden_states, full_hidden_states = self.transformer_block(
12931288
input_ids=input_ids,
12941289
src=inputs_embeds,
12951290
cum_offsets=cum_offsets,
@@ -1301,13 +1296,7 @@ def forward(
13011296
)
13021297
hidden_states = self.norm(hidden_states)
13031298

1304-
return BaseModelOutputWithPastAndCrossAttentions(
1305-
last_hidden_state=hidden_states,
1306-
past_key_values=None,
1307-
hidden_states=None,
1308-
attentions=None,
1309-
cum_offsets=cum_offsets,
1310-
)
1299+
return hidden_states, full_hidden_states
13111300

13121301

13131302
@register_base_model
@@ -1967,7 +1956,7 @@ def forward(
19671956
inputs_embeds = self.eh_proj(inputs_embeds)
19681957

19691958
with dy2st_nocheck_guard_context():
1970-
hidden_states, _ = self.transformer_block(
1959+
hidden_states, full_hidden_states = self.transformer_block(
19711960
input_ids=input_ids,
19721961
src=inputs_embeds,
19731962
cum_offsets=cum_offsets,
@@ -1980,12 +1969,7 @@ def forward(
19801969
)
19811970
hidden_states = self.norm(hidden_states)
19821971

1983-
return BaseModelOutputWithPastAndCrossAttentions(
1984-
last_hidden_state=hidden_states,
1985-
past_key_values=None,
1986-
hidden_states=None,
1987-
attentions=None,
1988-
)
1972+
return hidden_states, full_hidden_states
19891973

19901974

19911975
class DeepseekV2ForCausalLMBlockInferenceModel(GenerationBlockInferenceModel, DeepseekV2PretrainedModel):
@@ -2212,7 +2196,7 @@ def forward(
22122196
draft_tokens=None,
22132197
output_padding_offset=None,
22142198
):
2215-
outputs = self.deepseek_v2(
2199+
hidden_states, full_hidden_states = self.deepseek_v2(
22162200
input_ids,
22172201
inputs_embeds=inputs_embeds,
22182202
src_mask=src_mask,
@@ -2230,21 +2214,7 @@ def forward(
22302214
draft_tokens=draft_tokens,
22312215
output_padding_offset=output_padding_offset,
22322216
)
2233-
if self.return_full_hidden_states:
2234-
from paddlenlp_ops import rebuild_padding_v2
2235-
2236-
full_hidden_states = outputs[0]
2237-
cum_offsets = outputs[1]
2238-
hidden_states = rebuild_padding_v2(
2239-
full_hidden_states,
2240-
cum_offsets,
2241-
seq_lens_decoder,
2242-
seq_lens_encoder,
2243-
output_padding_offset,
2244-
self.max_seq_len,
2245-
)
2246-
else:
2247-
hidden_states = outputs[0]
2217+
22482218
logits = self.lm_head(
22492219
hidden_states,
22502220
tensor_parallel_output=False,
@@ -2254,8 +2224,6 @@ def forward(
22542224
else:
22552225
return logits
22562226

2257-
return logits
2258-
22592227
@paddle.no_grad()
22602228
def set_state_dict(self, state_dict):
22612229
if "lm_head.weight" in state_dict:
@@ -2363,7 +2331,7 @@ def forward(
23632331
output_padding_offset=None,
23642332
pre_hidden_states=None,
23652333
):
2366-
outputs = self.mtp(
2334+
hidden_states, _ = self.mtp(
23672335
input_ids,
23682336
src_mask=src_mask,
23692337
caches=caches,
@@ -2382,8 +2350,6 @@ def forward(
23822350
pre_hidden_states=pre_hidden_states,
23832351
)
23842352

2385-
hidden_states = outputs[0]
2386-
23872353
logits = self.lm_head(
23882354
hidden_states,
23892355
tensor_parallel_output=False,

paddlenlp/experimental/transformers/fused_transformer_layers.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,7 +1481,8 @@ def forward(
14811481
self.pre_process(**kwargs)
14821482
kwargs["cum_offsets"] = cum_offsets
14831483

1484-
if caches is not None:
1484+
kv_cache_reuse = kwargs.get("kv_cache_reuse", None)
1485+
if caches is not None and kv_cache_reuse is not True:
14851486
assert len(caches) == len(self.linear_weights) or len(caches) == 2 * len(self.linear_weights)
14861487

14871488
assert self.num_layers == len(self.linear_weights)
@@ -1589,7 +1590,7 @@ def forward(
15891590
kwargs["input_ids"] = input_ids
15901591

15911592
out = self.post_process(**kwargs)
1592-
return out, caches
1593+
return out, kwargs["multi_block_output"]
15931594

15941595

15951596
class FusedMultiTransformerPostLayernorm(FusedMultiTransformerBase):
@@ -3172,10 +3173,17 @@ def compute_attn(
31723173
k_dequant_scales = kwargs.get("k_dequant_scales", None)
31733174
v_dequant_scales = kwargs.get("v_dequant_scales", None)
31743175

3176+
kv_cache_reuse = kwargs.get("kv_cache_reuse", None)
3177+
if kv_cache_reuse:
3178+
k_cache_index = 0
3179+
v_cache_index = 1
3180+
else:
3181+
k_cache_index = 2 * i
3182+
v_cache_index = 2 * i + 1
31753183
fmha_out = paddle.incubate.nn.functional.block_multihead_attention(
31763184
qkv_out,
3177-
caches[2 * i],
3178-
caches[2 * i + 1],
3185+
caches[k_cache_index],
3186+
caches[v_cache_index],
31793187
kwargs.get("seq_lens_encoder", None),
31803188
kwargs.get("seq_lens_decoder", None),
31813189
kwargs.get("seq_lens_this_time", None),

paddlenlp/experimental/transformers/generation_utils.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ def _forward_(**args):
295295
model_inputs = self.prepare_inputs_for_generation(input_ids, cache_kvs, **args)
296296
return self(**model_inputs)
297297

298-
def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
298+
def _post_process_(logits, top_p, temperature, step_idx_ori, model_kwargs):
299299
cache = model_kwargs.get("cache", None)
300300
just_decoder = model_kwargs["seq_len_encoder"] == 0
301301
if cache is None: # first decoder
@@ -314,7 +314,6 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
314314
step_idx,
315315
model_kwargs["stop_flags"],
316316
)
317-
logits = outputs[0] if isinstance(outputs, tuple) else outputs
318317

319318
logits = paddle.cast(logits, paddle.float32)
320319
logits = logits_processors(model_kwargs["all_input_ids"], logits, decoding_step=step_idx_ori)
@@ -373,7 +372,7 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
373372
outputs = _forward_(**model_kwargs)
374373
# first decoder
375374
next_tokens, model_kwargs = _post_process_(
376-
outputs,
375+
outputs[0] if isinstance(outputs, tuple) else outputs,
377376
top_p,
378377
temperature,
379378
step_idx_ori,
@@ -389,8 +388,9 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
389388
paddle.sum(paddle.cast(model_kwargs["stop_flags"], "int64")),
390389
model_kwargs["stop_nums"],
391390
):
391+
outputs = _forward_(**model_kwargs)
392392
next_tokens, model_kwargs = _post_process_(
393-
_forward_(**model_kwargs),
393+
outputs[0] if isinstance(outputs, tuple) else outputs,
394394
top_p,
395395
temperature,
396396
step_idx_ori,
@@ -692,7 +692,7 @@ def _forward_(**args):
692692
return self(**model_inputs)
693693

694694
def _post_process_(
695-
outputs,
695+
logits,
696696
top_k,
697697
top_p,
698698
penalty_score,
@@ -702,7 +702,7 @@ def _post_process_(
702702
model_kwargs,
703703
):
704704
step_idx = model_kwargs["step_idx"]
705-
logits = paddle.cast(outputs, paddle.float32)
705+
logits = paddle.cast(logits, paddle.float32)
706706

707707
from paddlenlp_ops import set_preids_token_penalty_multi_scores
708708

@@ -777,7 +777,7 @@ def _post_process_(
777777
outputs = _forward_(**model_kwargs) # [bs, 1, dim_embed]
778778
# first decoder
779779
next_tokens = _post_process_(
780-
outputs,
780+
outputs[0] if isinstance(outputs, tuple) else outputs,
781781
top_k,
782782
top_p,
783783
penalty_score,
@@ -806,7 +806,7 @@ def _forward_(**args):
806806
return self(**model_inputs)
807807

808808
def _post_process_(
809-
outputs,
809+
logits,
810810
top_k,
811811
top_p,
812812
penalty_score,
@@ -816,7 +816,7 @@ def _post_process_(
816816
model_kwargs,
817817
):
818818
step_idx = model_kwargs["step_idx"]
819-
logits = paddle.cast(outputs, paddle.float32)
819+
logits = paddle.cast(logits, paddle.float32)
820820

821821
from paddlenlp_ops import speculate_get_token_penalty_multi_scores
822822

@@ -959,7 +959,7 @@ def _forward_(**args):
959959
return self(**model_inputs)
960960

961961
def _post_process_(
962-
outputs,
962+
logits,
963963
top_k,
964964
top_p,
965965
penalty_score,
@@ -968,7 +968,7 @@ def _post_process_(
968968
temperature,
969969
model_kwargs,
970970
):
971-
logits = paddle.cast(outputs, paddle.float32)
971+
logits = paddle.cast(logits, paddle.float32)
972972

973973
probs = F.softmax(logits)
974974

@@ -1191,7 +1191,7 @@ def _forward_(**args):
11911191
model_inputs = self.prepare_inputs_for_generation(input_ids, **args)
11921192
return self(**model_inputs)
11931193

1194-
def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
1194+
def _post_process_(logits, top_p, temperature, step_idx_ori, model_kwargs):
11951195
cache = model_kwargs.get("cache", None)
11961196
just_decoder = model_kwargs["seq_len_encoder"] == 0
11971197
if cache is None: # first decoder
@@ -1211,7 +1211,6 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
12111211
step_idx,
12121212
model_kwargs["stop_flags"],
12131213
)
1214-
logits = outputs[0] if isinstance(outputs, tuple) else outputs
12151214
logits = paddle.cast(logits, paddle.float32)
12161215
logits = logits_processors(model_kwargs["all_input_ids"], logits, decoding_step=step_idx_ori)
12171216

0 commit comments

Comments
 (0)