We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent dae765e commit c58b2e4Copy full SHA for c58b2e4
examples/models/llama2/runner/generation.py
@@ -45,9 +45,9 @@ def sample_top_p(probs, p):
45
46
def next_token(logits: torch.Tensor, temperature: float, top_p: float) -> int:
47
if temperature > 0:
48
- probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
+ probs = torch.softmax(logits / temperature, dim=-1)
49
return sample_top_p(probs, top_p).item()
50
- return torch.argmax(logits[:, -1], dim=-1).item()
+ return torch.argmax(logits, dim=-1).item()
51
52
53
class LlamaRunner(ABC):
0 commit comments