3636 "Failed to load guidance.cpp, falling back to Python mirror implementations..."
3737 )
3838 from .. import _cpp as cpp
39+
40+ from ._guidance_engine_metrics import GuidanceEngineMetrics
3941from .._rust .guidancerust import engine_start
4042from .._utils import softmax , CaptureEvents
4143from .._parser import EarleyCommitParser , Parser
5254
5355from .. import _serialization_pb2
5456
55- from ._guidance_engine_metrics import GuidanceEngineMetrics
56-
5757if TYPE_CHECKING :
5858 from ..library ._block import ContextBlock
5959
@@ -131,7 +131,6 @@ class EngineCallResponse:
131131 capture_groups : dict
132132 capture_group_log_probs : dict
133133 new_token_count : int
134- last_model_token_count : int
135134
136135 def __init__ (
137136 self ,
@@ -141,15 +140,13 @@ def __init__(
141140 capture_groups ,
142141 capture_group_log_probs ,
143142 new_token_count ,
144- last_model_token_count ,
145143 ):
146144 self .new_bytes = new_bytes
147145 self .is_generated = is_generated
148146 self .new_bytes_prob = new_bytes_prob
149147 self .capture_groups = capture_groups
150148 self .capture_group_log_probs = capture_group_log_probs
151149 self .new_token_count = new_token_count
152- self .last_model_token_count = last_model_token_count
153150
154151 def _to_proto (self ):
155152 """Converts an EngineCallResponse object to its Protobuf representation.
@@ -208,6 +205,8 @@ def __init__(self, tokenizer, compute_log_probs=False):
208205 self ._token_trie .match = True
209206 self ._token_trie .match_version = 0
210207
208+ self .metrics = GuidanceEngineMetrics ()
209+
211210 def start (self , parser , grammar , ensure_bos_token = True ):
212211 """Start processing parser state executed through the grammar.
213212
@@ -542,7 +541,6 @@ def next(self, logits):
542541 # self._captured_log_prob_data.update(new_captured_log_prob_data)
543542 # yield out, self._is_generated, self._new_bytes_prob, self._captured_data, self._captured_log_prob_data, self._token_count - self._last_token_count # note that we don't capture groups until a complete parse right now...
544543
545- self ._token_count += 1 # note we only update this for tokens that emit non-hidden content
546544 response_state = (
547545 out ,
548546 is_generated ,
@@ -554,6 +552,7 @@ def next(self, logits):
554552
555553 self ._last_token_count = self ._token_count
556554 self ._hidden_count = 0
555+ self ._token_count += 1 # note we only update this for tokens that emit non-hidden content
557556 else :
558557 self ._hidden_count -= len (new_bytes )
559558
@@ -740,9 +739,8 @@ def __call__(self, parser, grammar, ensure_bos_token=True):
740739 self .start (parser , grammar , ensure_bos_token )
741740
742741 # TODO: remove this after the next release. This verifies that calling Rust works.
743- assert "def" == engine_start ("abc" , "def" , 1 )
742+ assert ( "def" == engine_start ("abc" , "def" , 1 ) )
744743
745- last_model_token_count = 0
746744 logits = None
747745 while True :
748746 is_done , logits_state , response_state = self .next (logits )
@@ -758,30 +756,20 @@ def __call__(self, parser, grammar, ensure_bos_token=True):
758756 response_new_token_count ,
759757 ) = response_state
760758
761- print (
762- f"{ response_is_generated = } { response_new_token_count = } { response_new_bytes = } "
763- )
764-
765759 yield EngineCallResponse (
766760 new_bytes = response_new_bytes ,
767761 is_generated = response_is_generated ,
768762 new_bytes_prob = response_new_bytes_prob ,
769763 capture_groups = response_capture_groups ,
770764 capture_group_log_probs = response_capture_group_log_probs ,
771765 new_token_count = response_new_token_count ,
772- last_model_token_count = last_model_token_count ,
773766 )
774- last_model_token_count = 0
775767
776768 if logits_state is not None :
777769 token_ids , forced_bytes , current_temp = logits_state
778- logits , model_token_count = self .get_logits (
779- token_ids , forced_bytes , current_temp
780- )
781- last_model_token_count = model_token_count
770+ logits = self .get_logits (token_ids , forced_bytes , current_temp )
782771
783772 if is_done :
784- assert last_model_token_count == 0 , "Unyielded input tokens"
785773 break
786774
787775 def _tokenize_prefix (self , byte_string ):
@@ -860,7 +848,7 @@ def _cleanup_tokens(self, token_ids, token_byte_positions):
860848
861849 return token_ids , token_byte_positions
862850
863- def get_logits (self , token_ids , forced_bytes , current_temp ) -> np . ndarray :
851+ def get_logits (self , token_ids , forced_bytes , current_temp ):
864852 """A fake method designed to be overriden by subclasses."""
865853
866854 # pretend to extend the KV cache and update the log probs
@@ -937,12 +925,6 @@ def __init__(self, engine, echo=True, **kwargs):
937925 0 # used to track the last event streaming call to enable throttling
938926 )
939927
940- # Metrics for the model
941- self .engine_metrics = GuidanceEngineMetrics ()
942-
943- def reset_metrics (self ):
944- self .engine_metrics = GuidanceEngineMetrics ()
945-
946928 @property
947929 def active_role_end (self ):
948930 """The default end patterns we should use for `gen` calls.
@@ -1399,12 +1381,6 @@ def _run_stateless(self, stateless_function, temperature=0.0, top_p=1.0, n=1):
13991381 # if not self.engine.compute_log_probs:
14001382 # chunk.new_bytes_prob = 1.0
14011383
1402- if chunk .is_generated :
1403- self .engine_metrics .generated_tokens += chunk .new_token_count
1404- else :
1405- self .engine_metrics .forced_tokens += chunk .new_token_count
1406- self .engine_metrics .model_input_tokens += chunk .last_model_token_count
1407-
14081384 # convert the bytes to a string (delaying if we don't yet have a valid unicode string)
14091385 lm .token_count += chunk .new_token_count
14101386 chunk .new_bytes = delayed_bytes + chunk .new_bytes
@@ -1654,4 +1630,4 @@ def _check_dominated(node, parser, match_version, next_byte_mask):
16541630 parser .pos = curr_pos
16551631 if not child_dominate :
16561632 return False
1657- return True
1633+ return True
0 commit comments