Skip to content

Commit 5ad2304

Browse files
authored
Merge pull request #806 from riedgar-ms/riedgar-ms/model-metrics-01
[Feature] Monitor token consumption
2 parents 7b4d85f + d860cb2 commit 5ad2304

File tree

5 files changed

+83
-4
lines changed

5 files changed

+83
-4
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from pydantic import BaseModel, NonNegativeInt
2+
3+
4+
class GuidanceEngineMetrics(BaseModel):
5+
engine_input_tokens: NonNegativeInt = 0
6+
engine_output_tokens: NonNegativeInt = 0

guidance/models/_model.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
"Failed to load guidance.cpp, falling back to Python mirror implementations..."
3737
)
3838
from .. import _cpp as cpp
39+
40+
from ._guidance_engine_metrics import GuidanceEngineMetrics
3941
from .._rust.guidancerust import engine_start
4042
from .._utils import softmax, CaptureEvents
4143
from .._parser import EarleyCommitParser, Parser
@@ -203,6 +205,11 @@ def __init__(self, tokenizer, compute_log_probs=False):
203205
self._token_trie.match = True
204206
self._token_trie.match_version = 0
205207

208+
self.metrics = GuidanceEngineMetrics()
209+
210+
def reset_metrics(self):
211+
self.metrics = GuidanceEngineMetrics()
212+
206213
def start(self, parser, grammar, ensure_bos_token=True):
207214
"""Start processing parser state executed through the grammar.
208215
@@ -1626,4 +1633,4 @@ def _check_dominated(node, parser, match_version, next_byte_mask):
16261633
parser.pos = curr_pos
16271634
if not child_dominate:
16281635
return False
1629-
return True
1636+
return True

guidance/models/llama_cpp/_llama_cpp.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,12 @@ def get_logits(self, token_ids, forced_bytes, current_temp):
193193
batch.logits[n_tokens - 1] = True
194194

195195
ret = llama_cpp.llama_decode(self.model_obj.ctx, batch)
196+
self.metrics.engine_input_tokens += n_tokens
196197
if ret != 0:
197198
raise Exception(f"Call to llama_cpp.llama_decode returned {ret}.")
198199

200+
self.metrics.engine_output_tokens += 1
201+
199202
# get the logits
200203
logits = llama_cpp.llama_get_logits(self.model_obj.ctx)
201204
if llama_cpp.__version__ < "0.2.58":

guidance/models/transformers/_transformers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,10 @@ def _tokenizer(self, model, **kwargs):
122122

123123
return tokenizer
124124

125+
def __call__(self, byte_string):
126+
tokenisation = self._orig_tokenizer(byte_string)
127+
return tokenisation["input_ids"]
128+
125129

126130
class TransformersEngine(Engine):
127131
def __init__(self, model, tokenizer, compute_log_probs, **kwargs):
@@ -265,6 +269,8 @@ def get_logits(self, token_ids, forced_bytes, current_temp):
265269
self._cached_logits = (
266270
model_out.logits[0, -1, : len(self.tokenizer.tokens)].cpu().numpy()
267271
)
272+
self.metrics.engine_input_tokens += len(new_token_ids)
273+
self.metrics.engine_output_tokens += 1
268274

269275
return self._cached_logits
270276

tests/library/test_gen.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
from guidance import gen, models
5+
from guidance import gen, models, select
66

77

88
def test_basic():
@@ -73,6 +73,56 @@ def test_stop_quote(selected_model):
7373
assert not lm["title"].endswith('"')
7474

7575

76+
def test_metrics_smoke(selected_model: models.Model):
77+
lm = selected_model
78+
lm.engine.reset_metrics()
79+
80+
lm += "abcd"
81+
print(f"{lm.engine.metrics=}")
82+
lm += gen("first", max_tokens=1)
83+
print(f"{lm.engine.metrics=}")
84+
# Can't be sure of exact count due to token healing
85+
assert (
86+
lm.engine.metrics.engine_output_tokens == 1
87+
or lm.engine.metrics.engine_output_tokens == 2
88+
)
89+
assert lm.engine.metrics.engine_input_tokens >= 1
90+
last_input_tokens = lm.engine.metrics.engine_input_tokens
91+
92+
lm += "fg"
93+
lm += gen("second", max_tokens=1)
94+
# Again, trouble with healing
95+
assert (
96+
lm.engine.metrics.engine_output_tokens >= 2
97+
or lm.engine.metrics.engine_output_tokens <= 4
98+
)
99+
assert lm.engine.metrics.engine_input_tokens > last_input_tokens
100+
101+
102+
def test_metrics_select(selected_model: models.Model):
103+
lm = selected_model
104+
lm.engine.reset_metrics()
105+
106+
lm += "I will "
107+
lm += select(
108+
[
109+
"ride a bicycle down the road",
110+
"row in a boat along the river",
111+
"go for a swim in the ocean",
112+
]
113+
)
114+
print(f"lm={str(lm)}")
115+
print(f"{lm.engine.metrics=}")
116+
assert lm.engine.metrics.engine_input_tokens > 1
117+
assert lm.engine.metrics.engine_output_tokens > 0
118+
# Guidance should be able to force the generation after only a couple of tokens
119+
# so even though the options are long, relatively few output tokens should be
120+
# needed
121+
assert (
122+
lm.engine.metrics.engine_input_tokens > lm.engine.metrics.engine_output_tokens
123+
)
124+
125+
76126
def test_unicode(selected_model):
77127
# black makes this test ugly -- easier to read with fmt: off
78128
# fmt: off
@@ -85,11 +135,18 @@ def test_unicode(selected_model):
85135
# fmt: on
86136

87137

88-
def test_unicode2(selected_model):
138+
def test_unicode2(selected_model: models.Model):
89139
lm = selected_model
140+
lm.engine.reset_metrics()
90141
prompt = "Janet’s ducks lay 16 eggs per day"
91142
lm += prompt + gen(max_tokens=10)
92-
assert True
143+
assert lm.engine.metrics.engine_input_tokens > 1
144+
# Due to token healing, we can't be sure of the
145+
# precise output count
146+
assert (
147+
lm.engine.metrics.engine_output_tokens == 10
148+
or lm.engine.metrics.engine_output_tokens == 11
149+
)
93150

94151

95152
def test_gsm8k():

0 commit comments

Comments
 (0)