diff --git a/src/petals/models/llama/speculative_model.py b/src/petals/models/llama/speculative_model.py index f8b8faea..88fab05b 100644 --- a/src/petals/models/llama/speculative_model.py +++ b/src/petals/models/llama/speculative_model.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional, Union, List import torch from transformers.generation import GenerationConfig, LogitsProcessorList, StoppingCriteriaList @@ -9,7 +9,6 @@ from petals.models.llama.config import DistributedLlamaConfig from petals.models.llama.model import DistributedLlamaForCausalLM - class DistributedLlamaForSpeculativeGeneration(DistributedLlamaForCausalLM, GenerationMixin): def __init__(self, config: DistributedLlamaConfig, small_model: LlamaForCausalLM): DistributedLlamaForCausalLM.__init__(self, config) @@ -39,7 +38,7 @@ def _sample( batch_size = input_ids.shape[0] unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) finished = False - firsts = True + first_iteration = True while not finished: speculative_inference_iteration_size = min( @@ -57,18 +56,19 @@ def _sample( assert input_ids.shape[1] + speculative_inference_iteration_size == full_sequence.shape[1] input_for_validation = full_sequence - if not firsts: + if not first_iteration: self.active_session.position = input_ids.shape[1] - 1 input_for_validation = input_for_validation[:, -speculative_inference_iteration_size - 1 :] else: - firsts = False + first_iteration = False input_for_validation = input_for_validation[:, :-1] + with torch.no_grad(): precise_model_outputs = self(input_for_validation) full_token_logits = precise_model_outputs.logits[:, -speculative_inference_iteration_size:, :].clone() - all_valid_tokens = [] - first_token = None + all_valid_tokens: List[torch.Tensor] = [] + first_token: Optional[torch.Tensor] = None for i in range(speculative_inference_iteration_size): token_logits = full_token_logits[:, i, :] token_scores = logits_processor( @@ -109,3 +109,4 @@ def _sample( streamer.end() return input_ids +