11from collections .abc import Generator
2-
3- import torch
2+ from typing import Any
43
54from transformers import (
65 AutoModelForCausalLM ,
@@ -134,6 +133,8 @@ def _generate_full_stream(self, prompt: str) -> Generator[str, None, None]:
134133 Yields:
135134 str: Streaming response chunks.
136135 """
136+ import torch
137+
137138 inputs = self .tokenizer ([prompt ], return_tensors = "pt" ).to (self .model .device )
138139
139140 # Get generation parameters
@@ -200,6 +201,8 @@ def _generate_with_cache(self, query: str, kv: DynamicCache) -> str:
200201 Returns:
201202 str: Model response.
202203 """
204+ import torch
205+
203206 query_ids = self .tokenizer (
204207 query , return_tensors = "pt" , add_special_tokens = False
205208 ).input_ids .to (self .model .device )
@@ -287,10 +290,7 @@ def _generate_with_cache_stream(
287290
288291 generated .append (next_token )
289292
290- @torch .no_grad ()
291- def _prefill (
292- self , input_ids : torch .Tensor , kv : DynamicCache
293- ) -> tuple [torch .Tensor , DynamicCache ]:
293+ def _prefill (self , input_ids : Any , kv : DynamicCache ) -> tuple [Any , DynamicCache ]:
294294 """
295295 Forward the model once, returning last-step logits and updated KV cache.
296296 Args:
@@ -299,22 +299,27 @@ def _prefill(
299299 Returns:
300300 tuple[torch.Tensor, DynamicCache]: (last-step logits, updated KV cache)
301301 """
302- out = self .model (
303- input_ids = input_ids ,
304- use_cache = True ,
305- past_key_values = kv ,
306- return_dict = True ,
307- )
302+ import torch
303+
304+ with torch .no_grad ():
305+ out = self .model (
306+ input_ids = input_ids ,
307+ use_cache = True ,
308+ past_key_values = kv ,
309+ return_dict = True ,
310+ )
308311 return out .logits [:, - 1 , :], out .past_key_values
309312
310- def _select_next_token (self , logits : torch . Tensor ) -> torch . Tensor :
313+ def _select_next_token (self , logits : Any ) -> Any :
311314 """
312315 Select the next token from logits using sampling or argmax, depending on config.
313316 Args:
314317 logits (torch.Tensor): Logits for the next token.
315318 Returns:
316319 torch.Tensor: Selected token ID(s).
317320 """
321+ import torch
322+
318323 if getattr (self .config , "do_sample" , True ):
319324 batch_size , _ = logits .size ()
320325 dummy_ids = torch .zeros ((batch_size , 1 ), dtype = torch .long , device = logits .device )
@@ -323,7 +328,7 @@ def _select_next_token(self, logits: torch.Tensor) -> torch.Tensor:
323328 return torch .multinomial (probs , num_samples = 1 )
324329 return torch .argmax (logits , dim = - 1 , keepdim = True )
325330
326- def _should_stop (self , token : torch . Tensor ) -> bool :
331+ def _should_stop (self , token : Any ) -> bool :
327332 """
328333 Check if the given token is the EOS (end-of-sequence) token.
329334 Args:
@@ -347,6 +352,8 @@ def build_kv_cache(self, messages) -> DynamicCache:
347352 Returns:
348353 DynamicCache: The constructed KV cache object.
349354 """
355+ import torch
356+
350357 # Accept multiple input types and convert to standard chat messages
351358 if isinstance (messages , str ):
352359 messages = [
0 commit comments