Skip to content

Commit 2faa583

Browse files
committed
Trying to count forced tokens
1 parent ff46ec1 commit 2faa583

File tree

3 files changed

+18
-1
lines changed

3 files changed

+18
-1
lines changed

guidance/models/_guidance_engine_metrics.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
class GuidanceEngineMetrics(BaseModel):
55
prompt_tokens: NonNegativeInt = 0
66
generated_tokens: NonNegativeInt = 0
7+
forced_tokens: NonNegativeInt = 0

guidance/models/_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,7 @@ def next(self, logits):
687687
self._sampled_token = self.tokenizer.tokens[self._sampled_token_ind]
688688
self._new_bytes_prob = 1.0
689689
self._was_forced = True
690+
self.metrics.forced_tokens += 1
690691

691692
# we are at the end of the grammar
692693
elif next_byte_mask_sum == 0:
@@ -1472,6 +1473,9 @@ def _run_stateless(self, stateless_function, temperature=0.0, top_p=1.0, n=1):
14721473
lm.engine_metrics.generated_tokens += (
14731474
self.engine.metrics.generated_tokens - metrics_before.generated_tokens
14741475
)
1476+
lm.engine_metrics.forced_tokens += (
1477+
self.engine.metrics.forced_tokens - metrics_before.forced_tokens
1478+
)
14751479

14761480
logger.debug("finish Model._run_stateless")
14771481

tests/library/test_gen.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
from guidance import gen, models
5+
from guidance import gen, models, select
66

77

88
def test_basic():
@@ -78,17 +78,29 @@ def test_metrics_smoke(selected_model: models.Model):
7878
lm.reset_metrics()
7979

8080
lm += "abc"
81+
print(f"{lm.engine_metrics=}")
8182
lm += gen("first", max_tokens=1)
83+
print(f"{lm.engine_metrics=}")
8284
assert lm.engine_metrics.generated_tokens == 1
8385

8486
lm += "efg"
8587
lm += gen("second", max_tokens=1)
88+
print(f"{lm.engine_metrics=}")
8689
assert lm.engine_metrics.generated_tokens == 2
8790

8891
assert lm.current_token_count >= (
8992
lm.engine_metrics.prompt_tokens + lm.engine_metrics.generated_tokens
9093
)
9194

95+
def test_metrics_select(selected_model: models.Model):
96+
lm = selected_model
97+
lm.reset_metrics()
98+
99+
lm += "This is a great day to "
100+
lm += select(["ride a bike", "row a boat", "go for a swim"])
101+
print(f"lm={str(lm)}")
102+
print(f"{lm.engine_metrics=}")
103+
assert False
92104

93105
def test_unicode(selected_model):
94106
# black makes this test ugly -- easier to read with fmt: off

0 commit comments

Comments
 (0)