Skip to content

Commit 0afe462

Browse files
authored
Merge pull request #568 from hodlen/main
Adding streaming support under guidance v0.1
2 parents e706971 + 258a99e commit 0afe462

File tree

5 files changed

+188
-62
lines changed

5 files changed

+188
-62
lines changed

guidance/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.1.10"
1+
__version__ = "0.1.11"
22

33
import functools
44
import os

guidance/_grammar.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,7 @@ class Join(StatelessFunction):
418418
__slots__ = ("nullable", "values", "name", "hidden", "commit_point", "capture_name", "max_tokens")
419419

420420
def __init__(self, values, name=None, max_tokens=100000000) -> None:
421+
values = [string(v) if isinstance(v, (str, bytes)) else v for v in values] # wrap raw strings
421422
self.nullable = all(v.nullable for v in values)
422423
self.values = [v for v in values if not isinstance(v, Null)]
423424
self.name = name if name is not None else StatelessFunction._new_name()

guidance/_parser.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from sys import stderr
12
import numpy as np
23
from ordered_set import OrderedSet
34
from ._grammar import Join, Select, Terminal, Null, Byte, ByteRange
@@ -378,14 +379,92 @@ def _reversed_state_sets(self):
378379

379380
def parse_tree(self):
380381
reversed_state_sets = self._reversed_state_sets()
382+
root_item = None
381383

382384
# find the matching root state
383385
for item in reversed_state_sets[0]:
384386
if item.node == self.grammar and item.start == len(self.bytes) and item.pos == len(item.values): # note that ".start" mean end because items are reversed
385387
root_item = item
388+
if root_item is None:
389+
return None
386390
self._compute_parse_tree(0, root_item, reversed_state_sets)
387391
return root_item
388-
392+
393+
def get_captures(self):
394+
root_node = self.parse_tree()
395+
data = {}
396+
log_prob_data = {}
397+
if root_node is not None:
398+
# parse complete, so we can get the captures
399+
self._record_captures_from_root(root_node, data, log_prob_data)
400+
return data, log_prob_data
401+
# compute on partially parsed tree
402+
self._record_captures_partial(data, log_prob_data)
403+
return data, log_prob_data
404+
405+
def _record_captures_partial(self, data, log_prob_data):
406+
byte_data = self.bytes
407+
408+
for item in self.state_sets[self.state_set_pos]:
409+
cname = item.node.capture_name
410+
if cname is None:
411+
continue
412+
captured_value = byte_data[item.start:self.earliest_hidden_start()]
413+
if captured_value.endswith(b'<'):
414+
print("WARNING: Captured value ends with '<' which is a special character in the parser!", file=stderr)
415+
data[cname] = captured_value
416+
log_prob_data[cname] = item.log_prob
417+
418+
def _record_captures_from_root(self, initial_item, data, log_prob_data):
419+
byte_data = self.bytes
420+
stack = [(initial_item, 0)]
421+
used_names = set() # track which capture names have been used so self-recursive children don't overwrite their parents
422+
423+
while stack:
424+
item, byte_pos = stack.pop()
425+
# terminal nodes
426+
if isinstance(item, Terminal):
427+
428+
# if we are at a capture group node then we save the matched terminal byte
429+
if item.capture_name is not None:
430+
data[item.capture_name] = item.byte
431+
log_prob_data[item.capture_name] = 0
432+
433+
# internal nodes
434+
else:
435+
start_byte_pos = byte_pos
436+
437+
# recurse for all our non-null children
438+
for child in item.children:
439+
if child is not None:
440+
stack.append((child, byte_pos))
441+
# _record_captures(child, data, log_prob_data, byte_data, byte_pos)
442+
if isinstance(child, Terminal):
443+
byte_pos += len(child)
444+
else:
445+
byte_pos = child.start # note that "start" means "end" since this is a reversed state set
446+
447+
# if we are at a capture group node then we save the matched bytes range
448+
# note that we record this after calling our children so that we save the outermost version of self-recursive calls
449+
cname = item.node.capture_name
450+
if cname is not None and cname not in used_names and not item.node.hidden:
451+
452+
# see if we are doing a list append
453+
if cname.startswith("__LIST_APPEND:"):
454+
cname = cname[14:] # trim off the list append tag
455+
if cname not in data or not isinstance(data[cname], list):
456+
data[cname] = []
457+
log_prob_data[cname] = []
458+
data[cname].append(byte_data[start_byte_pos:item.start])
459+
log_prob_data[cname].append(item.log_prob)
460+
461+
# or just a regular assignment
462+
else:
463+
data[cname] = byte_data[start_byte_pos:item.start] # note that "start" means "end" since this is a reversed state set
464+
log_prob_data[cname] = item.log_prob
465+
466+
used_names.add(cname)
467+
389468
def _compute_parse_tree(self, initial_pos, initial_item, reversed_state_sets):
390469
stack = [(initial_pos, initial_item)]
391470

guidance/models/_model.py

Lines changed: 87 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,16 @@
1515
import numpy as np
1616
import logging
1717
import base64
18+
import queue
19+
import threading
1820

1921
logger = logging.getLogger(__name__)
2022
try:
2123
from .. import cpp
2224
except ImportError:
2325
logger.warn("Failed to load guidance.cpp, falling back to Python mirror implementations...")
2426
from .. import _cpp as cpp
25-
from .._utils import softmax
27+
from .._utils import softmax, CaptureEvents
2628
from .._parser import EarleyCommitParser
2729
from .._grammar import StatelessFunction, string, _call_pool, _tag_pattern, Null, replace_model_variables, unreplace_model_variables, select, Terminal
2830

@@ -83,6 +85,7 @@ def __init__(self, tokens, bos_token_id=None, eos_token_id=None, echo=True, comp
8385
self._event_queue = None # TODO: these are for streaming results in code, but that needs implemented
8486
self._event_parent = None
8587
self._last_display = 0 # used to track the last display call to enable throttling
88+
self._last_event_stream = 0 # used to track the last event streaming call to enable throttling
8689

8790
# build a prefix tree of the tokens
8891
self._token_trie = cpp.ByteTrie(tokens, np.arange(len(tokens)))
@@ -139,6 +142,9 @@ def _send_to_event_queue(self, value):
139142
self._event_queue.put(value)
140143
if self._event_parent is not None:
141144
self._event_parent._send_to_event_queue(value)
145+
146+
def stream(self):
147+
return ModelStream(self)
142148

143149
def copy(self):
144150
'''Create a shallow copy of the model object.'''
@@ -151,10 +157,12 @@ def copy(self):
151157
new_lm._variables_log_probs = self._variables_log_probs.copy()
152158
new_lm.opened_blocks = self.opened_blocks.copy()
153159

154-
# create a new clean event queue # TODO: can we delete this now?
155-
new_lm._event_queue = None
160+
# create a new clean event queue
161+
new_lm._event_queue = None # we start with no event queue because nobody is listening to us yet
156162
if self._event_queue is not None:
157-
new_lm._event_parent = self
163+
new_lm._event_parent = self # the current lm has an event que we make it our parent
164+
elif self._event_parent is not None:
165+
new_lm._event_parent = self._event_parent # otherwise if the current event que has an event parent then that is also our parent
158166

159167
return new_lm
160168

@@ -177,8 +185,15 @@ def _inplace_append(self, value, force_silent=False):
177185
if not force_silent:
178186
self._update_display()
179187

180-
# TODO: is this needed? This was for programmatic streaming...
181-
self._send_to_event_queue(self)
188+
# this is for programmatic streaming among other things
189+
if Model._throttle_refresh > 0:
190+
curr_time = time.time()
191+
if curr_time - self._last_event_stream >= self.max_display_rate:
192+
self._last_event_stream = curr_time
193+
self._send_to_event_queue(self)
194+
else:
195+
self._send_to_event_queue(self)
196+
182197

183198
def _update_display(self, throttle=True):
184199
if self.echo:
@@ -305,7 +320,9 @@ def __add__(self, value):
305320
else:
306321
out = value(lm)
307322
if out is None:
308-
raise Exception(f"A guidance function did not return a model object! Did you forget to return the new lm at the end of your function?")
323+
raise Exception(f"A guidance function returned `None`, not a model object! Did you forget to return the new lm at the end of your function?")
324+
if not isinstance(out, Model):
325+
raise Exception(f"A guidance function did not return a model object! Did you try to add a function to a model without calling the function? For example `model + guidance_function()` is correct, while `model + guidance_function` will cause this error.")
309326

310327
# this flushes the display
311328
out._inplace_append("")
@@ -940,9 +957,10 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e
940957
# self._cache_state["new_token_ids"].append(sampled_token_ind)
941958

942959
# capture the named groups from the parse tree
943-
parse_tree = parser.parse_tree()
944-
_record_captures(parse_tree, captured_data, captured_log_prob_data, parser.bytes)
945-
960+
new_captured_data, new_captured_log_prob_data = parser.get_captures()
961+
captured_data.update(new_captured_data)
962+
captured_log_prob_data.update(new_captured_log_prob_data)
963+
946964
# we have no valid log prob data if we didn't compute it
947965
yield new_bytes[hidden_count:], is_generated, new_bytes_prob, captured_data, captured_log_prob_data, token_count - last_token_count
948966
last_token_count = token_count
@@ -953,7 +971,11 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e
953971
# yeild the snippet of text created by the next token
954972
out = new_bytes[hidden_count:]
955973
if len(out) > 0:
956-
yield out, is_generated, new_bytes_prob, {}, {}, token_count - last_token_count # note that we don't capture groups until a complete parse right now...
974+
# capture the named groups from the (partial) parse tree, # TODO: disabled for now until we handle list_append correctly
975+
# new_captured_data, new_captured_log_prob_data = parser.get_captures()
976+
# captured_data.update(new_captured_data)
977+
# captured_log_prob_data.update(new_captured_log_prob_data)
978+
yield out, is_generated, new_bytes_prob, captured_data, captured_log_prob_data, token_count - last_token_count # note that we don't capture groups until a complete parse right now...
957979
last_token_count = token_count
958980
hidden_count = 0
959981
token_count += 1 # note we only update this for tokens that emit non-hidden content
@@ -968,6 +990,60 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e
968990
else:
969991
token_byte_positions.append(token_byte_positions[-1] + len(sampled_token))
970992

993+
class ModelStream:
994+
def __init__(self, model, grammar=None, timeout=5):
995+
'''Create a model stream object that delays execution until it is iterated over.'''
996+
if model.echo:
997+
model = model.copy()
998+
model.echo = False # turn off display echoing
999+
self.model = model
1000+
self.grammar = grammar
1001+
self.timeout = timeout
1002+
1003+
def __add__(self, grammar):
1004+
'''Extend this delayed chain of execution with another grammar append.'''
1005+
return ModelStream(self.model, grammar)
1006+
1007+
def _inner_run(self, model):
1008+
'''This runs the model stream without iterating, and is only using internally by __iter__.'''
1009+
if isinstance(self.grammar, ModelStream):
1010+
model = self.grammar._inner_run(model)
1011+
elif self.grammar is None:
1012+
model = self.model + ""
1013+
else:
1014+
model = self.model + self.grammar
1015+
1016+
def __iter__(self):
1017+
'''Starts a thread to execute the model and grammar, yielding events as they occur.'''
1018+
1019+
# Create a thread-safe queue to hold events
1020+
with CaptureEvents(self.model) as events:
1021+
1022+
# Define the target function for the thread
1023+
def target():
1024+
self._inner_run(self.model)
1025+
events.put(None) # mark that we are done
1026+
1027+
# Start the thread
1028+
thread = threading.Thread(target=target)
1029+
thread.start()
1030+
1031+
# Yield events from the queue as they become available
1032+
while True:
1033+
try:
1034+
# Wait for an event with a timeout to allow for thread termination
1035+
event = events.get(timeout=self.timeout)
1036+
if event is None:
1037+
break
1038+
yield event
1039+
except queue.Empty:
1040+
# Check if the thread is still alive
1041+
if not thread.is_alive():
1042+
break
1043+
1044+
# Ensure the thread has completed
1045+
thread.join()
1046+
9711047
class Chat(Model):
9721048
'''The base class for all chat-tuned models.'''
9731049

@@ -1033,55 +1109,6 @@ def throttle_refresh():
10331109
class ConstraintException(Exception):
10341110
pass
10351111

1036-
def _record_captures(initial_item, data, log_prob_data, byte_data):
1037-
stack = [(initial_item, 0)]
1038-
used_names = set() # track which capture names have been used so self-recursive children don't overwrite their parents
1039-
1040-
while stack:
1041-
item, byte_pos = stack.pop()
1042-
# terminal nodes
1043-
if isinstance(item, Terminal):
1044-
1045-
# if we are at a capture group node then we save the matched terminal byte
1046-
if item.capture_name is not None:
1047-
data[item.capture_name] = item.byte
1048-
log_prob_data[item.capture_name] = 0
1049-
1050-
# internal nodes
1051-
else:
1052-
start_byte_pos = byte_pos
1053-
1054-
# recurse for all our non-null children
1055-
for child in item.children:
1056-
if child is not None:
1057-
stack.append((child, byte_pos))
1058-
# _record_captures(child, data, log_prob_data, byte_data, byte_pos)
1059-
if isinstance(child, Terminal):
1060-
byte_pos += len(child)
1061-
else:
1062-
byte_pos = child.start # note that "start" means "end" since this is a reversed state set
1063-
1064-
# if we are at a capture group node then we save the matched bytes range
1065-
# note that we record this after calling our children so that we save the outermost version of self-recursive calls
1066-
cname = item.node.capture_name
1067-
if cname is not None and cname not in used_names and not item.node.hidden:
1068-
1069-
# see if we are doing a list append
1070-
if cname.startswith("__LIST_APPEND:"):
1071-
cname = cname[14:] # trim off the list append tag
1072-
if cname not in data or not isinstance(data[cname], list):
1073-
data[cname] = []
1074-
log_prob_data[cname] = []
1075-
data[cname].append(byte_data[start_byte_pos:item.start])
1076-
log_prob_data[cname].append(item.log_prob)
1077-
1078-
# or just a regular assignment
1079-
else:
1080-
data[cname] = byte_data[start_byte_pos:item.start] # note that "start" means "end" since this is a reversed state set
1081-
log_prob_data[cname] = item.log_prob
1082-
1083-
used_names.add(cname)
1084-
10851112
# def _compute_probs(trie, probs, found):
10861113
# '''Computes the log probabilities for each internal trie node.'''
10871114
# if trie.value is not None:

notebooks/tutorials/intro_to_guidance.ipynb

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,25 @@
581581
"gpt35 + experts(query='What is the meaning of life?')"
582582
]
583583
},
584+
{
585+
"cell_type": "markdown",
586+
"metadata": {},
587+
"source": [
588+
"## Streaming\n",
589+
"\n",
590+
"Often you want to get the results of a generation as it is happening so you update an interface. You can do this programmatically using the `.stream()` method of model objects. This creates a `ModelStream` that you can use to accumulate updates. These updates don't get executed until you interate over then `ModelStream` object. When you iterate over the object you get lots of partially completed model objects as the guidance program is executed."
591+
]
592+
},
593+
{
594+
"cell_type": "code",
595+
"execution_count": null,
596+
"metadata": {},
597+
"outputs": [],
598+
"source": [
599+
"for part in llama2.stream() + qa_bot(query):\n",
600+
" part # do something with the partially executed lm"
601+
]
602+
},
584603
{
585604
"attachments": {},
586605
"cell_type": "markdown",

0 commit comments

Comments
 (0)