File tree Expand file tree Collapse file tree 2 files changed +7
-4
lines changed
src/transformers/generation Expand file tree Collapse file tree 2 files changed +7
-4
lines changed Original file line number Diff line number Diff line change @@ -531,13 +531,16 @@ def _get_logits_processor(
531531 if (input_ids_seq_length > 1 or generation_config .forced_bos_token_id is None )
532532 else begin_index + 1
533533 )
534- if generation_config .forced_decoder_ids is not None and len (generation_config .forced_decoder_ids ) > 0 :
534+ if (
535+ getattr (generation_config , "forced_decoder_ids" , None ) is not None
536+ and len (generation_config .forced_decoder_ids ) > 0
537+ ):
535538 # generation starts after the last token that is forced
536539 begin_index += generation_config .forced_decoder_ids [- 1 ][0 ]
537540 processors .append (
538541 FlaxSuppressTokensAtBeginLogitsProcessor (generation_config .begin_suppress_tokens , begin_index )
539542 )
540- if generation_config . forced_decoder_ids is not None :
543+ if getattr ( generation_config , " forced_decoder_ids" , None ) is not None :
541544 forced_decoder_ids = [
542545 [input_ids_seq_length + i [0 ] - 1 , i [1 ]] for i in generation_config .forced_decoder_ids
543546 ]
Original file line number Diff line number Diff line change @@ -1490,14 +1490,14 @@ def _get_logits_processor(
14901490 if (input_ids_seq_length > 1 or generation_config .forced_bos_token_id is None )
14911491 else begin_index + 1
14921492 )
1493- if generation_config . forced_decoder_ids is not None :
1493+ if getattr ( generation_config , " forced_decoder_ids" , None ) is not None :
14941494 begin_index += generation_config .forced_decoder_ids [- 1 ][
14951495 0
14961496 ] # generation starts after the last token that is forced
14971497 processors .append (
14981498 TFSuppressTokensAtBeginLogitsProcessor (generation_config .begin_suppress_tokens , begin_index )
14991499 )
1500- if generation_config . forced_decoder_ids is not None :
1500+ if getattr ( generation_config , " forced_decoder_ids" , None ) is not None :
15011501 processors .append (TFForceTokensLogitsProcessor (generation_config .forced_decoder_ids ))
15021502
15031503 processors = self ._merge_criteria_processor_list (processors , logits_processor )
You can’t perform that action at this time.
0 commit comments