Skip to content

Commit 216a5de

Browse files
committed
Rethink the metrics
1 parent 9f330c3 commit 216a5de

File tree

3 files changed

+12
-35
lines changed

3 files changed

+12
-35
lines changed

guidance/models/_guidance_engine_metrics.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,5 @@
22

33

44
class GuidanceEngineMetrics(BaseModel):
5-
generated_tokens: NonNegativeInt = 0
6-
forced_tokens: NonNegativeInt = 0
75
model_input_tokens: NonNegativeInt = 0
6+
model_output_tokens: NonNegativeInt = 0

guidance/models/_model.py

Lines changed: 9 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
"Failed to load guidance.cpp, falling back to Python mirror implementations..."
3737
)
3838
from .. import _cpp as cpp
39+
40+
from ._guidance_engine_metrics import GuidanceEngineMetrics
3941
from .._rust.guidancerust import engine_start
4042
from .._utils import softmax, CaptureEvents
4143
from .._parser import EarleyCommitParser, Parser
@@ -52,8 +54,6 @@
5254

5355
from .. import _serialization_pb2
5456

55-
from ._guidance_engine_metrics import GuidanceEngineMetrics
56-
5757
if TYPE_CHECKING:
5858
from ..library._block import ContextBlock
5959

@@ -131,7 +131,6 @@ class EngineCallResponse:
131131
capture_groups: dict
132132
capture_group_log_probs: dict
133133
new_token_count: int
134-
last_model_token_count: int
135134

136135
def __init__(
137136
self,
@@ -141,15 +140,13 @@ def __init__(
141140
capture_groups,
142141
capture_group_log_probs,
143142
new_token_count,
144-
last_model_token_count,
145143
):
146144
self.new_bytes = new_bytes
147145
self.is_generated = is_generated
148146
self.new_bytes_prob = new_bytes_prob
149147
self.capture_groups = capture_groups
150148
self.capture_group_log_probs = capture_group_log_probs
151149
self.new_token_count = new_token_count
152-
self.last_model_token_count = last_model_token_count
153150

154151
def _to_proto(self):
155152
"""Converts an EngineCallResponse object to its Protobuf representation.
@@ -208,6 +205,8 @@ def __init__(self, tokenizer, compute_log_probs=False):
208205
self._token_trie.match = True
209206
self._token_trie.match_version = 0
210207

208+
self.metrics = GuidanceEngineMetrics()
209+
211210
def start(self, parser, grammar, ensure_bos_token=True):
212211
"""Start processing parser state executed through the grammar.
213212
@@ -542,7 +541,6 @@ def next(self, logits):
542541
# self._captured_log_prob_data.update(new_captured_log_prob_data)
543542
# yield out, self._is_generated, self._new_bytes_prob, self._captured_data, self._captured_log_prob_data, self._token_count - self._last_token_count # note that we don't capture groups until a complete parse right now...
544543

545-
self._token_count += 1 # note we only update this for tokens that emit non-hidden content
546544
response_state = (
547545
out,
548546
is_generated,
@@ -554,6 +552,7 @@ def next(self, logits):
554552

555553
self._last_token_count = self._token_count
556554
self._hidden_count = 0
555+
self._token_count += 1 # note we only update this for tokens that emit non-hidden content
557556
else:
558557
self._hidden_count -= len(new_bytes)
559558

@@ -740,9 +739,8 @@ def __call__(self, parser, grammar, ensure_bos_token=True):
740739
self.start(parser, grammar, ensure_bos_token)
741740

742741
# TODO: remove this after the next release. This verifies that calling Rust works.
743-
assert "def" == engine_start("abc", "def", 1)
742+
assert("def" == engine_start("abc", "def", 1))
744743

745-
last_model_token_count = 0
746744
logits = None
747745
while True:
748746
is_done, logits_state, response_state = self.next(logits)
@@ -758,30 +756,20 @@ def __call__(self, parser, grammar, ensure_bos_token=True):
758756
response_new_token_count,
759757
) = response_state
760758

761-
print(
762-
f"{response_is_generated=} {response_new_token_count=} {response_new_bytes=}"
763-
)
764-
765759
yield EngineCallResponse(
766760
new_bytes=response_new_bytes,
767761
is_generated=response_is_generated,
768762
new_bytes_prob=response_new_bytes_prob,
769763
capture_groups=response_capture_groups,
770764
capture_group_log_probs=response_capture_group_log_probs,
771765
new_token_count=response_new_token_count,
772-
last_model_token_count=last_model_token_count,
773766
)
774-
last_model_token_count = 0
775767

776768
if logits_state is not None:
777769
token_ids, forced_bytes, current_temp = logits_state
778-
logits, model_token_count = self.get_logits(
779-
token_ids, forced_bytes, current_temp
780-
)
781-
last_model_token_count = model_token_count
770+
logits = self.get_logits(token_ids, forced_bytes, current_temp)
782771

783772
if is_done:
784-
assert last_model_token_count == 0, "Unyielded input tokens"
785773
break
786774

787775
def _tokenize_prefix(self, byte_string):
@@ -860,7 +848,7 @@ def _cleanup_tokens(self, token_ids, token_byte_positions):
860848

861849
return token_ids, token_byte_positions
862850

863-
def get_logits(self, token_ids, forced_bytes, current_temp) -> np.ndarray:
851+
def get_logits(self, token_ids, forced_bytes, current_temp):
864852
"""A fake method designed to be overriden by subclasses."""
865853

866854
# pretend to extend the KV cache and update the log probs
@@ -937,12 +925,6 @@ def __init__(self, engine, echo=True, **kwargs):
937925
0 # used to track the last event streaming call to enable throttling
938926
)
939927

940-
# Metrics for the model
941-
self.engine_metrics = GuidanceEngineMetrics()
942-
943-
def reset_metrics(self):
944-
self.engine_metrics = GuidanceEngineMetrics()
945-
946928
@property
947929
def active_role_end(self):
948930
"""The default end patterns we should use for `gen` calls.
@@ -1399,12 +1381,6 @@ def _run_stateless(self, stateless_function, temperature=0.0, top_p=1.0, n=1):
13991381
# if not self.engine.compute_log_probs:
14001382
# chunk.new_bytes_prob = 1.0
14011383

1402-
if chunk.is_generated:
1403-
self.engine_metrics.generated_tokens += chunk.new_token_count
1404-
else:
1405-
self.engine_metrics.forced_tokens += chunk.new_token_count
1406-
self.engine_metrics.model_input_tokens += chunk.last_model_token_count
1407-
14081384
# convert the bytes to a string (delaying if we don't yet have a valid unicode string)
14091385
lm.token_count += chunk.new_token_count
14101386
chunk.new_bytes = delayed_bytes + chunk.new_bytes
@@ -1654,4 +1630,4 @@ def _check_dominated(node, parser, match_version, next_byte_mask):
16541630
parser.pos = curr_pos
16551631
if not child_dominate:
16561632
return False
1657-
return True
1633+
return True

guidance/models/transformers/_transformers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,8 @@ def get_logits(self, token_ids, forced_bytes, current_temp):
268268
self._cached_logits = (
269269
model_out.logits[0, -1, : len(self.tokenizer.tokens)].cpu().numpy()
270270
)
271+
self.metrics.model_input_tokens += len(new_token_ids)
272+
self.metrics.model_output_tokens += 1
271273

272274
return self._cached_logits, len(new_token_ids)
273275

0 commit comments

Comments
 (0)