Skip to content

Commit 68b30a3

Browse files
authored
fix: prepare_inputs_for_inference (this also supports 4.11.3)
1 parent 3e80921 commit 68b30a3

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

donut/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def add_special_tokens(self, list_of_tokens: List[str]):
206206
if newly_added_num > 0:
207207
self.model.resize_token_embeddings(len(self.tokenizer))
208208

209-
def prepare_inputs_for_inference(self, input_ids: torch.Tensor, past=None, use_cache: bool = None, encoder_outputs: torch.Tensor = None):
209+
def prepare_inputs_for_inference(self, input_ids: torch.Tensor, encoder_outputs: torch.Tensor, past=None, use_cache: bool = None, attention_mask: torch.Tensor = None):
210210
"""
211211
Args:
212212
input_ids: (batch_size, sequence_lenth)

0 commit comments

Comments
 (0)