Skip to content

Commit 5021d94

Browse files
authored
[Fix] Move attention mask to the model device type (#180)
The attention mask needs to be on the same device as the rest of the model and inputs, or else there will be a device mismatch.
1 parent 3725600 commit 5021d94

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

training/generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def _forward(self, model_inputs, **generate_kwargs):
136136

137137
generated_sequence = self.model.generate(
138138
input_ids=input_ids.to(self.model.device),
139-
attention_mask=attention_mask,
139+
attention_mask=attention_mask.to(self.model.device),
140140
pad_token_id=self.tokenizer.pad_token_id,
141141
**generate_kwargs,
142142
)

0 commit comments

Comments
 (0)