44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import time
78from abc import ABC , abstractmethod
89from typing import List , Optional
910
@@ -97,6 +98,7 @@ def generate( # noqa: C901
9798 pos_base : int = 0 ,
9899 ) -> List [int ]:
99100 # Prefill
101+ prefill_start = time .time ()
100102 logits = self .forward (
101103 tokens = torch .tensor ([prompt_tokens ], dtype = torch .long , device = self .device ),
102104 input_pos = (
@@ -105,11 +107,13 @@ def generate( # noqa: C901
105107 else None
106108 ),
107109 )
110+ prefill_time = time .time () - prefill_start
108111
109112 current_token = next_token (logits , temperature , top_p )
110113 print (f"{ self .tokenizer .decode_token (current_token )} " , end = "" , flush = True )
111114 tokens = prompt_tokens + [current_token ]
112115
116+ generate_start = time .time ()
113117 while len (tokens ) < max_seq_len :
114118 if self .use_kv_cache :
115119 logits = self .forward (
@@ -140,6 +144,10 @@ def generate( # noqa: C901
140144 print (f"{ self .tokenizer .decode_token (current_token )} " , end = "" , flush = True )
141145 print ("\n " )
142146
147+ generate_time = time .time () - generate_start
148+ print (f"Prefill time: { prefill_time } " )
149+ print (f"Generation tok/s: { len (tokens ) / generate_time } " )
150+
143151 return tokens if echo else tokens [len (prompt_tokens ) :]
144152
145153 def text_completion (
0 commit comments