Skip to content

Commit 1e1db12

Browse files
authored
(small) fix conditional for input_ids and input_embeds in marian (#40045)
* (small) fix conditional for input_ids and input_embeds in marian * address comment
1 parent 7f2f534 commit 1e1db12

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

src/transformers/models/marian/modeling_marian.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -979,16 +979,14 @@ def forward(
979979

980980
# retrieve input_ids and inputs_embeds
981981
if (input_ids is None) ^ (inputs_embeds is not None):
982-
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
982+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
983983
elif input_ids is not None:
984984
input = input_ids
985985
input_shape = input.shape
986986
input_ids = input_ids.view(-1, input_shape[-1])
987987
elif inputs_embeds is not None:
988988
input_shape = inputs_embeds.size()[:-1]
989989
input = inputs_embeds[:, :, -1]
990-
else:
991-
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
992990

993991
if inputs_embeds is None:
994992
inputs_embeds = self.embed_tokens(input)

0 commit comments

Comments
 (0)