@@ -51,10 +51,11 @@ def next_token(logits: torch.Tensor, temperature: float, top_p: float) -> int:
5151
5252
5353class LlamaRunner (ABC ):
54- def __init__ (self , tokenizer_path : str , model_args : ModelArgs ):
54+ def __init__ (self , tokenizer_path : str , model_args : ModelArgs , device : str = "cpu" ):
5555 self .params = model_args
5656 self .tokenizer = get_tokenizer (tokenizer_path )
5757 assert model_args .vocab_size == self .tokenizer .n_words
58+ self .device = device
5859
5960 @abstractmethod
6061 def forward (
@@ -73,9 +74,9 @@ def generate( # noqa: C901
7374 ) -> List [int ]:
7475 # prefill
7576 logits = self .forward (
76- tokens = torch .tensor ([prompt_tokens ], dtype = torch .long ),
77+ tokens = torch .tensor ([prompt_tokens ], dtype = torch .long , device = self . device ),
7778 input_pos = (
78- torch .tensor ([0 ], dtype = torch .long )
79+ torch .tensor ([0 ], dtype = torch .long , device = self . device )
7980 if self .params .use_kv_cache
8081 else None
8182 ),
@@ -87,14 +88,21 @@ def generate( # noqa: C901
8788 while len (tokens ) < self .params .max_seq_len :
8889 if self .params .use_kv_cache :
8990 logits = self .forward (
90- tokens = torch .tensor ([[current_token ]], dtype = torch .long ),
91- input_pos = torch .tensor ([len (tokens ) - 1 ], dtype = torch .long ),
91+ tokens = torch .tensor (
92+ [[current_token ]], dtype = torch .long , device = self .device
93+ ),
94+ input_pos = torch .tensor (
95+ [len (tokens ) - 1 ], dtype = torch .long , device = self .device
96+ ),
9297 )
9398 else :
94- logits = self .forward (tokens = torch .tensor ([tokens ], dtype = torch .long ))
99+ logits = self .forward (
100+ tokens = torch .tensor ([tokens ], dtype = torch .long , device = self .device ),
101+ )
95102 current_token = next_token (logits , temperature , top_p )
96103 if current_token == self .tokenizer .eos_id or (
97- hasattr (self , "stop_tokens" ) and current_token in self .stop_tokens
104+ hasattr (self .tokenizer , "stop_tokens" )
105+ and current_token in self .tokenizer .stop_tokens
98106 ):
99107 break
100108 tokens .append (current_token )
0 commit comments