@@ -204,9 +204,6 @@ def __init__(self, tokenizer, compute_log_probs=False):
204204 )
205205 self ._token_trie .match = True
206206 self ._token_trie .match_version = 0
207- # Any time get_logits is called, it should update this
208- # This does add to the list of "Thread Unsafety"
209- self .metrics = GuidanceEngineMetrics ()
210207
211208 def start (self , parser , grammar , ensure_bos_token = True ):
212209 """Start processing parser state executed through the grammar.
@@ -687,7 +684,6 @@ def next(self, logits):
687684 self ._sampled_token = self .tokenizer .tokens [self ._sampled_token_ind ]
688685 self ._new_bytes_prob = 1.0
689686 self ._was_forced = True
690- self .metrics .forced_tokens += 1
691687
692688 # we are at the end of the grammar
693689 elif next_byte_mask_sum == 0 :
@@ -758,6 +754,8 @@ def __call__(self, parser, grammar, ensure_bos_token=True):
758754 response_new_token_count ,
759755 ) = response_state
760756
757+ print (f"{ response_is_generated = } { response_new_token_count = } { response_new_bytes = } " )
758+
761759 yield EngineCallResponse (
762760 new_bytes = response_new_bytes ,
763761 is_generated = response_is_generated ,
@@ -1382,9 +1380,6 @@ def _run_stateless(self, stateless_function, temperature=0.0, top_p=1.0, n=1):
13821380 # we will return a new extended version of ourselves, which we track as `lm`
13831381 lm = self
13841382
1385- # Prepare our metrics update. This is part of our Thread Unsafety programme
1386- metrics_before = lm .engine .metrics .model_copy (deep = True )
1387-
13881383 # single generation
13891384 if n == 1 :
13901385 generated_value = ""
@@ -1398,6 +1393,11 @@ def _run_stateless(self, stateless_function, temperature=0.0, top_p=1.0, n=1):
13981393 # if not self.engine.compute_log_probs:
13991394 # chunk.new_bytes_prob = 1.0
14001395
1396+ if chunk .is_generated :
1397+ self .engine_metrics .generated_tokens += chunk .new_token_count
1398+ else :
1399+ self .engine_metrics .forced_tokens += chunk .new_token_count
1400+
14011401 # convert the bytes to a string (delaying if we don't yet have a valid unicode string)
14021402 lm .token_count += chunk .new_token_count
14031403 chunk .new_bytes = delayed_bytes + chunk .new_bytes
@@ -1466,17 +1466,6 @@ def _run_stateless(self, stateless_function, temperature=0.0, top_p=1.0, n=1):
14661466
14671467 unreplace_model_variables (replacements )
14681468
1469- # Now update our metrics while maintaining Thread Unsafety
1470- lm .engine_metrics .prompt_tokens += (
1471- self .engine .metrics .prompt_tokens - metrics_before .prompt_tokens
1472- )
1473- lm .engine_metrics .generated_tokens += (
1474- self .engine .metrics .generated_tokens - metrics_before .generated_tokens
1475- )
1476- lm .engine_metrics .forced_tokens += (
1477- self .engine .metrics .forced_tokens - metrics_before .forced_tokens
1478- )
1479-
14801469 logger .debug ("finish Model._run_stateless" )
14811470
14821471 return lm
0 commit comments