Skip to content

Commit 9f330c3

Browse files
committed
Latest attempt to get consistent token results
1 parent b728b0f commit 9f330c3

File tree

4 files changed

+19
-4
lines changed

4 files changed

+19
-4
lines changed

guidance/models/_guidance_engine_metrics.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
class GuidanceEngineMetrics(BaseModel):
55
generated_tokens: NonNegativeInt = 0
66
forced_tokens: NonNegativeInt = 0
7+
model_input_tokens: NonNegativeInt = 0

guidance/models/_model.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ class EngineCallResponse:
131131
capture_groups: dict
132132
capture_group_log_probs: dict
133133
new_token_count: int
134+
last_model_token_count: int
134135

135136
def __init__(
136137
self,
@@ -140,13 +141,15 @@ def __init__(
140141
capture_groups,
141142
capture_group_log_probs,
142143
new_token_count,
144+
last_model_token_count,
143145
):
144146
self.new_bytes = new_bytes
145147
self.is_generated = is_generated
146148
self.new_bytes_prob = new_bytes_prob
147149
self.capture_groups = capture_groups
148150
self.capture_group_log_probs = capture_group_log_probs
149151
self.new_token_count = new_token_count
152+
self.last_model_token_count = last_model_token_count
150153

151154
def _to_proto(self):
152155
"""Converts an EngineCallResponse object to its Protobuf representation.
@@ -739,6 +742,7 @@ def __call__(self, parser, grammar, ensure_bos_token=True):
739742
# TODO: remove this after the next release. This verifies that calling Rust works.
740743
assert "def" == engine_start("abc", "def", 1)
741744

745+
last_model_token_count = 0
742746
logits = None
743747
while True:
744748
is_done, logits_state, response_state = self.next(logits)
@@ -765,13 +769,19 @@ def __call__(self, parser, grammar, ensure_bos_token=True):
765769
capture_groups=response_capture_groups,
766770
capture_group_log_probs=response_capture_group_log_probs,
767771
new_token_count=response_new_token_count,
772+
last_model_token_count=last_model_token_count,
768773
)
774+
last_model_token_count = 0
769775

770776
if logits_state is not None:
771777
token_ids, forced_bytes, current_temp = logits_state
772-
logits = self.get_logits(token_ids, forced_bytes, current_temp)
778+
logits, model_token_count = self.get_logits(
779+
token_ids, forced_bytes, current_temp
780+
)
781+
last_model_token_count = model_token_count
773782

774783
if is_done:
784+
assert last_model_token_count == 0, "Unyielded input tokens"
775785
break
776786

777787
def _tokenize_prefix(self, byte_string):
@@ -1393,6 +1403,7 @@ def _run_stateless(self, stateless_function, temperature=0.0, top_p=1.0, n=1):
13931403
self.engine_metrics.generated_tokens += chunk.new_token_count
13941404
else:
13951405
self.engine_metrics.forced_tokens += chunk.new_token_count
1406+
self.engine_metrics.model_input_tokens += chunk.last_model_token_count
13961407

13971408
# convert the bytes to a string (delaying if we don't yet have a valid unicode string)
13981409
lm.token_count += chunk.new_token_count

guidance/models/transformers/_transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def get_logits(self, token_ids, forced_bytes, current_temp):
269269
model_out.logits[0, -1, : len(self.tokenizer.tokens)].cpu().numpy()
270270
)
271271

272-
return self._cached_logits
272+
return self._cached_logits, len(new_token_ids)
273273

274274

275275
class Transformers(Model):

tests/library/test_gen.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,11 @@ def test_metrics_alt_expressions(selected_model: models.Model):
138138
assert str(lm) == str(lm2)
139139
assert lm.engine_metrics.generated_tokens == 10
140140
assert lm2.engine_metrics.generated_tokens == 10
141-
assert lm.engine_metrics.forced_tokens == 0
142-
assert lm2.engine_metrics.forced_tokens == 0
141+
142+
assert (
143+
lm.engine_metrics.forced_tokens + lm.engine_metrics.model_input_tokens
144+
== lm2.engine_metrics.forced_tokens + lm2.engine_metrics.model_input_tokens
145+
)
143146

144147

145148
def test_unicode(selected_model):

0 commit comments

Comments
 (0)