Skip to content

Commit 71b231d

Browse files
kjohewBernardZach
authored andcommitted
Fix the memory usage issue of logits in generate() (huggingface#34813)
1 parent 9413b8b commit 71b231d

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/transformers/generation/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3246,7 +3246,7 @@ def _sample(
32463246

32473247
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
32483248
# (the clone itself is always small)
3249-
next_token_logits = outputs.logits.clone()[:, -1, :].float()
3249+
next_token_logits = outputs.logits[:, -1, :].clone().float()
32503250
next_token_logits = next_token_logits.to(input_ids.device)
32513251

32523252
# pre-process distribution

0 commit comments

Comments
 (0)