Skip to content

Commit 3e960e0

Browse files
authored
[tf/flax] handle forced_decoder_ids deletion (#38316)
fix tf/flax, attr checks
1 parent 9eb0a37 commit 3e960e0

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

src/transformers/generation/flax_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -531,13 +531,16 @@ def _get_logits_processor(
531531
if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None)
532532
else begin_index + 1
533533
)
534-
if generation_config.forced_decoder_ids is not None and len(generation_config.forced_decoder_ids) > 0:
534+
if (
535+
getattr(generation_config, "forced_decoder_ids", None) is not None
536+
and len(generation_config.forced_decoder_ids) > 0
537+
):
535538
# generation starts after the last token that is forced
536539
begin_index += generation_config.forced_decoder_ids[-1][0]
537540
processors.append(
538541
FlaxSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)
539542
)
540-
if generation_config.forced_decoder_ids is not None:
543+
if getattr(generation_config, "forced_decoder_ids", None) is not None:
541544
forced_decoder_ids = [
542545
[input_ids_seq_length + i[0] - 1, i[1]] for i in generation_config.forced_decoder_ids
543546
]

src/transformers/generation/tf_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1490,14 +1490,14 @@ def _get_logits_processor(
14901490
if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None)
14911491
else begin_index + 1
14921492
)
1493-
if generation_config.forced_decoder_ids is not None:
1493+
if getattr(generation_config, "forced_decoder_ids", None) is not None:
14941494
begin_index += generation_config.forced_decoder_ids[-1][
14951495
0
14961496
] # generation starts after the last token that is forced
14971497
processors.append(
14981498
TFSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)
14991499
)
1500-
if generation_config.forced_decoder_ids is not None:
1500+
if getattr(generation_config, "forced_decoder_ids", None) is not None:
15011501
processors.append(TFForceTokensLogitsProcessor(generation_config.forced_decoder_ids))
15021502

15031503
processors = self._merge_criteria_processor_list(processors, logits_processor)

0 commit comments

Comments
 (0)