Skip to content

Commit e1033aa

Browse files
committed
Update (base update)
[ghstack-poisoned]
1 parent dae765e commit e1033aa

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

examples/models/llama2/runner/generation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ def sample_top_p(probs, p):
4545

4646
def next_token(logits: torch.Tensor, temperature: float, top_p: float) -> int:
4747
if temperature > 0:
48-
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
48+
probs = torch.softmax(logits / temperature, dim=-1)
4949
return sample_top_p(probs, top_p).item()
50-
return torch.argmax(logits[:, -1], dim=-1).item()
50+
return torch.argmax(logits, dim=-1).item()
5151

5252

5353
class LlamaRunner(ABC):

0 commit comments

Comments
 (0)