@@ -131,6 +131,7 @@ class EngineCallResponse:
131131 capture_groups : dict
132132 capture_group_log_probs : dict
133133 new_token_count : int
134+ last_model_token_count : int
134135
135136 def __init__ (
136137 self ,
@@ -140,13 +141,15 @@ def __init__(
140141 capture_groups ,
141142 capture_group_log_probs ,
142143 new_token_count ,
144+ last_model_token_count ,
143145 ):
144146 self .new_bytes = new_bytes
145147 self .is_generated = is_generated
146148 self .new_bytes_prob = new_bytes_prob
147149 self .capture_groups = capture_groups
148150 self .capture_group_log_probs = capture_group_log_probs
149151 self .new_token_count = new_token_count
152+ self .last_model_token_count = last_model_token_count
150153
151154 def _to_proto (self ):
152155 """Converts an EngineCallResponse object to its Protobuf representation.
@@ -739,6 +742,7 @@ def __call__(self, parser, grammar, ensure_bos_token=True):
739742 # TODO: remove this after the next release. This verifies that calling Rust works.
740743 assert "def" == engine_start ("abc" , "def" , 1 )
741744
745+ last_model_token_count = 0
742746 logits = None
743747 while True :
744748 is_done , logits_state , response_state = self .next (logits )
@@ -765,13 +769,19 @@ def __call__(self, parser, grammar, ensure_bos_token=True):
765769 capture_groups = response_capture_groups ,
766770 capture_group_log_probs = response_capture_group_log_probs ,
767771 new_token_count = response_new_token_count ,
772+ last_model_token_count = last_model_token_count ,
768773 )
774+ last_model_token_count = 0
769775
770776 if logits_state is not None :
771777 token_ids , forced_bytes , current_temp = logits_state
772- logits = self .get_logits (token_ids , forced_bytes , current_temp )
778+ logits , model_token_count = self .get_logits (
779+ token_ids , forced_bytes , current_temp
780+ )
781+ last_model_token_count = model_token_count
773782
774783 if is_done :
784+ assert last_model_token_count == 0 , "Unyielded input tokens"
775785 break
776786
777787 def _tokenize_prefix (self , byte_string ):
@@ -1393,6 +1403,7 @@ def _run_stateless(self, stateless_function, temperature=0.0, top_p=1.0, n=1):
13931403 self .engine_metrics .generated_tokens += chunk .new_token_count
13941404 else :
13951405 self .engine_metrics .forced_tokens += chunk .new_token_count
1406+ self .engine_metrics .model_input_tokens += chunk .last_model_token_count
13961407
13971408 # convert the bytes to a string (delaying if we don't yet have a valid unicode string)
13981409 lm .token_count += chunk .new_token_count
0 commit comments