Skip to content

Commit 5ff2b3a

Browse files
authored
Merge pull request #56 from SamSamhuns/fix_for_new_transformers_lib_ver
Change model_kwargs argument to encoder_outputs to support transformers>=4.22.1
2 parents 362f844 + 68b30a3 commit 5ff2b3a

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

donut/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def add_special_tokens(self, list_of_tokens: List[str]):
206206
if newly_added_num > 0:
207207
self.model.resize_token_embeddings(len(self.tokenizer))
208208

209-
def prepare_inputs_for_inference(self, input_ids: torch.Tensor, past=None, use_cache: bool = None, **model_kwargs):
209+
def prepare_inputs_for_inference(self, input_ids: torch.Tensor, encoder_outputs: torch.Tensor, past=None, use_cache: bool = None, attention_mask: torch.Tensor = None):
210210
"""
211211
Args:
212212
input_ids: (batch_size, sequence_lenth)
@@ -223,7 +223,7 @@ def prepare_inputs_for_inference(self, input_ids: torch.Tensor, past=None, use_c
223223
"attention_mask": attention_mask,
224224
"past_key_values": past,
225225
"use_cache": use_cache,
226-
"encoder_hidden_states": model_kwargs["encoder_outputs"].last_hidden_state,
226+
"encoder_hidden_states": encoder_outputs.last_hidden_state,
227227
}
228228
return output
229229

0 commit comments

Comments
 (0)