Skip to content

Commit 822f8e1

Browse files
committed
I don't think I need these bits
1 parent f8de7c8 commit 822f8e1

File tree

2 files changed

+7
-22
lines changed

2 files changed

+7
-22
lines changed

guidance/models/_model.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,6 @@ def __init__(self, tokenizer, compute_log_probs=False):
204204
)
205205
self._token_trie.match = True
206206
self._token_trie.match_version = 0
207-
# Any time get_logits is called, it should update this
208-
# This does add to the list of "Thread Unsafety"
209-
self.metrics = GuidanceEngineMetrics()
210207

211208
def start(self, parser, grammar, ensure_bos_token=True):
212209
"""Start processing parser state executed through the grammar.
@@ -687,7 +684,6 @@ def next(self, logits):
687684
self._sampled_token = self.tokenizer.tokens[self._sampled_token_ind]
688685
self._new_bytes_prob = 1.0
689686
self._was_forced = True
690-
self.metrics.forced_tokens += 1
691687

692688
# we are at the end of the grammar
693689
elif next_byte_mask_sum == 0:
@@ -758,6 +754,8 @@ def __call__(self, parser, grammar, ensure_bos_token=True):
758754
response_new_token_count,
759755
) = response_state
760756

757+
print(f"{response_is_generated=} {response_new_token_count=} {response_new_bytes=}")
758+
761759
yield EngineCallResponse(
762760
new_bytes=response_new_bytes,
763761
is_generated=response_is_generated,
@@ -1382,9 +1380,6 @@ def _run_stateless(self, stateless_function, temperature=0.0, top_p=1.0, n=1):
13821380
# we will return a new extended version of ourselves, which we track as `lm`
13831381
lm = self
13841382

1385-
# Prepare our metrics update. This is part of our Thread Unsafety programme
1386-
metrics_before = lm.engine.metrics.model_copy(deep=True)
1387-
13881383
# single generation
13891384
if n == 1:
13901385
generated_value = ""
@@ -1398,6 +1393,11 @@ def _run_stateless(self, stateless_function, temperature=0.0, top_p=1.0, n=1):
13981393
# if not self.engine.compute_log_probs:
13991394
# chunk.new_bytes_prob = 1.0
14001395

1396+
if chunk.is_generated:
1397+
self.engine_metrics.generated_tokens += chunk.new_token_count
1398+
else:
1399+
self.engine_metrics.forced_tokens += chunk.new_token_count
1400+
14011401
# convert the bytes to a string (delaying if we don't yet have a valid unicode string)
14021402
lm.token_count += chunk.new_token_count
14031403
chunk.new_bytes = delayed_bytes + chunk.new_bytes
@@ -1466,17 +1466,6 @@ def _run_stateless(self, stateless_function, temperature=0.0, top_p=1.0, n=1):
14661466

14671467
unreplace_model_variables(replacements)
14681468

1469-
# Now update our metrics while maintaining Thread Unsafety
1470-
lm.engine_metrics.prompt_tokens += (
1471-
self.engine.metrics.prompt_tokens - metrics_before.prompt_tokens
1472-
)
1473-
lm.engine_metrics.generated_tokens += (
1474-
self.engine.metrics.generated_tokens - metrics_before.generated_tokens
1475-
)
1476-
lm.engine_metrics.forced_tokens += (
1477-
self.engine.metrics.forced_tokens - metrics_before.forced_tokens
1478-
)
1479-
14801469
logger.debug("finish Model._run_stateless")
14811470

14821471
return lm

guidance/models/transformers/_transformers.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -269,10 +269,6 @@ def get_logits(self, token_ids, forced_bytes, current_temp):
269269
model_out.logits[0, -1, : len(self.tokenizer.tokens)].cpu().numpy()
270270
)
271271

272-
# Update metrics
273-
self.metrics.prompt_tokens += len(new_token_ids)
274-
self.metrics.generated_tokens += 1
275-
276272
return self._cached_logits
277273

278274

0 commit comments

Comments
 (0)