Skip to content

Commit 5c50051

Browse files
committed
Tweak where stats are grabbed
1 parent 822f8e1 commit 5c50051

File tree

3 files changed

+8
-4
lines changed

3 files changed

+8
-4
lines changed

guidance/models/_guidance_engine_metrics.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,5 @@
22

33

44
class GuidanceEngineMetrics(BaseModel):
5-
prompt_tokens: NonNegativeInt = 0
65
generated_tokens: NonNegativeInt = 0
76
forced_tokens: NonNegativeInt = 0

guidance/models/_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -539,18 +539,18 @@ def next(self, logits):
539539
# self._captured_log_prob_data.update(new_captured_log_prob_data)
540540
# yield out, self._is_generated, self._new_bytes_prob, self._captured_data, self._captured_log_prob_data, self._token_count - self._last_token_count # note that we don't capture groups until a complete parse right now...
541541

542+
self._token_count += 1 # note we only update this for tokens that emit non-hidden content
542543
response_state = (
543544
out,
544545
is_generated,
545546
self._new_bytes_prob if self.compute_log_probs else 1.0,
546547
self._captured_data,
547548
self._captured_log_prob_data,
548-
self._token_count - self._last_token_count,
549+
self._token_count - self._last_token_count + 1,
549550
)
550551

551552
self._last_token_count = self._token_count
552553
self._hidden_count = 0
553-
self._token_count += 1 # note we only update this for tokens that emit non-hidden content
554554
else:
555555
self._hidden_count -= len(new_bytes)
556556

tests/library/test_gen.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,16 @@ def test_metrics_select(selected_model: models.Model):
101101
lm += select(["ride a bike", "row a boat", "go for a swim"])
102102
print(f"lm={str(lm)}")
103103
print(f"{lm.engine_metrics=}")
104+
assert lm.engine_metrics.forced_tokens > 0
105+
assert lm.engine_metrics.generated_tokens > 0
106+
assert lm.engine_metrics.forced_tokens > lm.engine_metrics.generated_tokens
107+
prev_stats = lm.engine_metrics.copy()
104108
lm += " and afterwards "
105109
lm += select(["walk to town", "walk to a show"])
106110
print(f"lm={str(lm)}")
107111
print(f"{lm.engine_metrics=}")
108-
assert False
112+
assert lm.engine_metrics.forced_tokens > prev_stats.forced_tokens
113+
assert lm.engine_metrics.generated_tokens > prev_stats.generated_tokens
109114

110115

111116
def test_unicode(selected_model):

0 commit comments

Comments
 (0)