Skip to content

Commit 83e0b47

Browse files
authored
[Bug fixes] fix inference dybatch d2s (#6998)
* fix inference dybatch d2s * remove encoder-output
1 parent a43138b commit 83e0b47

File tree

2 files changed

+43
-15
lines changed

2 files changed

+43
-15
lines changed

paddlenlp/experimental/transformers/generation_utils.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -297,11 +297,14 @@ def sample(
297297
step_idx_ori = paddle.full(shape=[1], dtype="int64", fill_value=1)
298298
batch_idx = paddle.full(shape=[1], dtype="int32", fill_value=-1)
299299

300+
# fake temp next_tokens
301+
next_tokens = paddle.full(shape=[paddle.shape(input_ids).shape[0], 1], dtype="int32", fill_value=0)
302+
300303
# let inputs_embeds enter into model_kwargs.
301304
# because the code below directly use the model_kwargs as a parameter without using inputs_embeds.
302305
model_kwargs["inputs_embeds"] = inputs_embeds
303306
model_kwargs["all_input_ids"] = input_ids
304-
logits_processors = model_kwargs["logits_processors"]
307+
logits_processors = model_kwargs.pop("logits_processors")
305308

306309
def _forward_(**args):
307310
# cache_kvs is never empty because it is passed as a parameter in def sample.
@@ -367,18 +370,25 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
367370

368371
return next_tokens, model_kwargs
369372

370-
# encoder
371-
outputs = _forward_(**model_kwargs)
372-
# first decoder
373-
next_tokens, model_kwargs = _post_process_(
374-
outputs,
375-
top_p,
376-
temperature,
377-
step_idx_ori,
378-
model_kwargs,
379-
)
380-
step_idx_ori += 1
381-
encoder_output = outputs
373+
if paddle.max(model_kwargs["seq_len_encoder"]) > 0:
374+
# encoder
375+
outputs = _forward_(**model_kwargs)
376+
# first decoder
377+
next_tokens, model_kwargs = _post_process_(
378+
outputs,
379+
top_p,
380+
temperature,
381+
step_idx_ori,
382+
model_kwargs,
383+
)
384+
step_idx_ori += 1
385+
else:
386+
outputs = None
387+
# first decoder
388+
next_tokens = None
389+
model_kwargs["next_tokens"] = next_tokens
390+
step_idx_ori += 0
391+
382392
# gives it a value, means we will entered into decoder phase.
383393
model_kwargs["cache"] = 0
384394

@@ -402,5 +412,4 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
402412
paddle.cast(model_kwargs["stop_flags"], "int32"),
403413
model_kwargs["seq_len_decoder"],
404414
model_kwargs["tgt_pos"],
405-
encoder_output,
406415
)

paddlenlp/experimental/transformers/llama/modeling.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,26 @@ def forward(
286286
input_ids, position_ids, self.head_dim_shape_tensor, position_offset, True
287287
)
288288

289-
with paddle.base.framework._stride_in_no_check_dy2st_diff():
289+
if hasattr(paddle.framework, "_no_check_dy2st_diff"):
290+
# TODO(daisiming): _no_check_dy2st_diff is used to turn off the checking of behavior
291+
# inconsistency between dynamic graph and static graph. _no_check_dy2st_diff should be
292+
# removed after static graphs support inplace and stride.
293+
with paddle.framework._no_check_dy2st_diff():
294+
hidden_states, _ = self.transformer_block(
295+
input_ids,
296+
hidden_states,
297+
cum_offsets=cum_offsets,
298+
padding_offset=padding_offset,
299+
attn_mask=paddle.cast(attention_mask, dtype=hidden_states.dtype),
300+
caches=cache_kvs,
301+
pre_caches=pre_caches,
302+
pre_caches_length=position_offset,
303+
seq_lens=seq_lens,
304+
rotary_embs=new_rope,
305+
rotary_emb_dims=1,
306+
time_step=paddle.increment(paddle.shape(attention_mask)[-1], -1) if is_decoder else None,
307+
)
308+
else:
290309
hidden_states, _ = self.transformer_block(
291310
input_ids,
292311
hidden_states,

0 commit comments

Comments
 (0)