Skip to content

Commit 4da011f

Browse files
committed
Expose grammar match() method
1 parent 9306961 commit 4da011f

File tree

3 files changed

+53
-7
lines changed

3 files changed

+53
-7
lines changed

guidance/_grammar.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import types
66
import re
77
from . import _serialization_pb2
8+
from . import _parser
89

910
tag_start = "{{G|"
1011
tag_end = "|G}}"
@@ -83,6 +84,27 @@ def __radd__(model):
8384
return self(model)
8485
return RawFunction(__radd__, [], {})
8586

87+
class Match:
88+
def __init__(self, captures, log_probs, partial):
89+
self.captures = captures
90+
self.log_probs = log_probs
91+
self.partial = partial
92+
93+
def __getitem__(self, key):
94+
return self.captures[key]
95+
96+
def __len__(self):
97+
return len(self.captures)
98+
99+
def __bool__(self):
100+
return True
101+
102+
def __str__(self):
103+
return str(self.captures)
104+
105+
def __repr__(self):
106+
return "<guidance.Match object; captures="+str(self.captures)+"; partial="+str(self.partial)+">"
107+
86108
class GrammarFunction(Function):
87109
num_used_names = 0
88110

@@ -123,6 +145,22 @@ def __radd__(self, value):
123145
def __getitem__(self, value):
124146
raise StatefulException("GrammarFunctions can't access state!")
125147

148+
def match(self, byte_string, allow_partial=False):
149+
if isinstance(byte_string, str):
150+
byte_string = byte_string.encode()
151+
parser = _parser.EarleyCommitParser(self)
152+
153+
for i in range(len(byte_string)):
154+
try:
155+
parser.consume_byte(byte_string[i:i+1])
156+
except:
157+
return None
158+
159+
if not allow_partial and not parser.matched():
160+
return None
161+
else:
162+
return Match(*parser.get_captures(), partial=not parser.matched())
163+
126164
@staticmethod
127165
def _new_name():
128166
num_used = GrammarFunction.num_used_names

guidance/models/_mock.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from ._remote import RemoteEngine
55

66
class MockEngine(Engine):
7-
def __init__(self, tokenizer, byte_patterns, compute_log_probs):
7+
def __init__(self, tokenizer, byte_patterns, compute_log_probs, force):
88
super().__init__(tokenizer, compute_log_probs=compute_log_probs)
99

1010
self._valid_mask = np.zeros(len(tokenizer.tokens))
@@ -14,6 +14,7 @@ def __init__(self, tokenizer, byte_patterns, compute_log_probs):
1414
self._valid_mask[i] = 1.0
1515
except:
1616
pass
17+
self.force = force
1718

1819
# allow a single byte pattern to be passed
1920
if isinstance(byte_patterns, (bytes, str)):
@@ -23,8 +24,10 @@ def __init__(self, tokenizer, byte_patterns, compute_log_probs):
2324
for i,pattern in enumerate(byte_patterns):
2425
if isinstance(pattern, str):
2526
byte_patterns[i] = pattern.encode("utf8")
26-
27+
2728
self.byte_patterns = byte_patterns
29+
30+
# seed the random number generator
2831
self._rand_generator = np.random.default_rng(seed=42)
2932

3033
def get_logits(self, token_ids, forced_bytes, current_temp):
@@ -34,8 +37,13 @@ def get_logits(self, token_ids, forced_bytes, current_temp):
3437
# build the byte strings
3538
byte_string = b"".join(self.tokenizer.tokens[i] for i in token_ids)
3639

37-
# we randomly generate valid unicode bytes
38-
logits = self._rand_generator.standard_normal(len(self.tokenizer.tokens)) * self._valid_mask
40+
# if we are forcing the bytes patterns then don't allow other tokens
41+
if self.force:
42+
logits = np.ones(len(self.tokenizer.tokens)) * -np.inf
43+
44+
# otherwise we randomly generate valid unicode bytes
45+
else:
46+
logits = self._rand_generator.standard_normal(len(self.tokenizer.tokens)) * self._valid_mask
3947

4048
# if we have a pattern that matches then force the next token
4149
bias = 100.0
@@ -55,7 +63,7 @@ def _get_next_tokens(self, byte_string):
5563
yield i
5664

5765
class Mock(Model):
58-
def __init__(self, byte_patterns=[], echo=True, compute_log_probs=False, **kwargs):
66+
def __init__(self, byte_patterns=[], echo=True, compute_log_probs=False, force=False, **kwargs):
5967
'''Build a new Mock model object that represents a model in a given state.'''
6068

6169
if isinstance(byte_patterns, str) and byte_patterns.startswith("http"):
@@ -67,7 +75,7 @@ def __init__(self, byte_patterns=[], echo=True, compute_log_probs=False, **kwarg
6775
0,
6876
0
6977
)
70-
engine = MockEngine(tokenizer, byte_patterns, compute_log_probs)
78+
engine = MockEngine(tokenizer, byte_patterns, compute_log_probs, force)
7179

7280

7381
super().__init__(

guidance/models/_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ def __call__(self, parser, grammar, ensure_bos_token=True):
311311
# if we walked all the way to a forced token then we advance without computing the logits
312312
# we are forced if there are no more options and we are either in the middle of the grammar or at a trie leaf
313313
is_forced = next_byte_mask_sum <= 1 and (len(trie) == 0 if parser.matched() else trie != self._token_trie)
314+
token_pos = 0
314315
if is_forced:
315316
sampled_token_ind = trie.value
316317
sampled_token = self.tokenizer.tokens[sampled_token_ind]
@@ -319,7 +320,6 @@ def __call__(self, parser, grammar, ensure_bos_token=True):
319320

320321
# we are at the end of the grammar
321322
elif next_byte_mask_sum == 0:
322-
token_pos = 0
323323

324324
# mark the token we "sampled" if we have comsumed some bytes
325325
if trie != self._token_trie:

0 commit comments

Comments
 (0)