File tree Expand file tree Collapse file tree 1 file changed +10
-8
lines changed Expand file tree Collapse file tree 1 file changed +10
-8
lines changed Original file line number Diff line number Diff line change @@ -535,14 +535,16 @@ def eval(self, tokens: Sequence[int]):
535535 # Save tokens
536536 self .input_ids [n_past : n_past + n_tokens ] = batch
537537 # Save logits
538- rows = n_tokens
539- cols = self ._n_vocab
540- offset = (
541- 0 if self .context_params .logits_all else n_tokens - 1
542- ) # NOTE: Only save the last token logits if logits_all is False
543- self .scores [n_past + offset : n_past + n_tokens , :].reshape (- 1 )[
544- :
545- ] = self ._ctx .get_logits ()[offset * cols : rows * cols ]
538+ if self .context_params .logits_all :
539+ rows = n_tokens
540+ cols = self ._n_vocab
541+ logits = self ._ctx .get_logits ()[: rows * cols ]
542+ self .scores [n_past : n_past + n_tokens , :].reshape (- 1 )[: :] = logits
543+ else :
544+ rows = 1
545+ cols = self ._n_vocab
546+ logits = self ._ctx .get_logits ()[: rows * cols ]
547+ self .scores [n_past + n_tokens - 1 , :].reshape (- 1 )[: :] = logits
546548 # Update n_tokens
547549 self .n_tokens += n_tokens
548550
You can’t perform that action at this time.
0 commit comments