@@ -297,11 +297,14 @@ def sample(
297
297
step_idx_ori = paddle .full (shape = [1 ], dtype = "int64" , fill_value = 1 )
298
298
batch_idx = paddle .full (shape = [1 ], dtype = "int32" , fill_value = - 1 )
299
299
300
+ # fake temp next_tokens
301
+ next_tokens = paddle .full (shape = [paddle .shape (input_ids ).shape [0 ], 1 ], dtype = "int32" , fill_value = 0 )
302
+
300
303
# let inputs_embeds enter into model_kwargs.
301
304
# because the code below directly use the model_kwargs as a parameter without using inputs_embeds.
302
305
model_kwargs ["inputs_embeds" ] = inputs_embeds
303
306
model_kwargs ["all_input_ids" ] = input_ids
304
- logits_processors = model_kwargs [ "logits_processors" ]
307
+ logits_processors = model_kwargs . pop ( "logits_processors" )
305
308
306
309
def _forward_ (** args ):
307
310
# cache_kvs is never empty because it is passed as a parameter in def sample.
@@ -367,18 +370,25 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
367
370
368
371
return next_tokens , model_kwargs
369
372
370
- # encoder
371
- outputs = _forward_ (** model_kwargs )
372
- # first decoder
373
- next_tokens , model_kwargs = _post_process_ (
374
- outputs ,
375
- top_p ,
376
- temperature ,
377
- step_idx_ori ,
378
- model_kwargs ,
379
- )
380
- step_idx_ori += 1
381
- encoder_output = outputs
373
+ if paddle .max (model_kwargs ["seq_len_encoder" ]) > 0 :
374
+ # encoder
375
+ outputs = _forward_ (** model_kwargs )
376
+ # first decoder
377
+ next_tokens , model_kwargs = _post_process_ (
378
+ outputs ,
379
+ top_p ,
380
+ temperature ,
381
+ step_idx_ori ,
382
+ model_kwargs ,
383
+ )
384
+ step_idx_ori += 1
385
+ else :
386
+ outputs = None
387
+ # first decoder
388
+ next_tokens = None
389
+ model_kwargs ["next_tokens" ] = next_tokens
390
+ step_idx_ori += 0
391
+
382
392
# gives it a value, means we will entered into decoder phase.
383
393
model_kwargs ["cache" ] = 0
384
394
@@ -402,5 +412,4 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
402
412
paddle .cast (model_kwargs ["stop_flags" ], "int32" ),
403
413
model_kwargs ["seq_len_decoder" ],
404
414
model_kwargs ["tgt_pos" ],
405
- encoder_output ,
406
415
)
0 commit comments