44# This source code is licensed under the BSD-style license found in the 
55# LICENSE file in the root directory of this source tree. 
66
7+ import  time 
78from  abc  import  ABC , abstractmethod 
89from  typing  import  List , Optional 
910
@@ -97,6 +98,7 @@ def generate(  # noqa: C901
9798        pos_base : int  =  0 ,
9899    ) ->  List [int ]:
99100        # Prefill 
101+         prefill_start  =  time .time ()
100102        logits  =  self .forward (
101103            tokens = torch .tensor ([prompt_tokens ], dtype = torch .long , device = self .device ),
102104            input_pos = (
@@ -105,11 +107,13 @@ def generate(  # noqa: C901
105107                else  None 
106108            ),
107109        )
110+         prefill_time  =  time .time () -  prefill_start 
108111
109112        current_token  =  next_token (logits , temperature , top_p )
110113        print (f"{ self .tokenizer .decode_token (current_token )}  " , end = "" , flush = True )
111114        tokens  =  prompt_tokens  +  [current_token ]
112115
116+         generate_start  =  time .time ()
113117        while  len (tokens ) <  max_seq_len :
114118            if  self .use_kv_cache :
115119                logits  =  self .forward (
@@ -140,6 +144,10 @@ def generate(  # noqa: C901
140144            print (f"{ self .tokenizer .decode_token (current_token )}  " , end = "" , flush = True )
141145        print ("\n " )
142146
147+         generate_time  =  time .time () -  generate_start 
148+         print (f"Prefill time: { prefill_time }  " )
149+         print (f"Generation tok/s: { len (tokens ) /  generate_time }  " )
150+ 
143151        return  tokens  if  echo  else  tokens [len (prompt_tokens ) :]
144152
145153    def  text_completion (
0 commit comments