Skip to content

Commit 9425899

Browse files
committed
Add ModelStream class and the model.stream() method
1 parent 36dc95d commit 9425899

File tree

2 files changed

+79
-9
lines changed

2 files changed

+79
-9
lines changed

guidance/_grammar.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,7 @@ class Join(StatelessFunction):
418418
__slots__ = ("nullable", "values", "name", "hidden", "commit_point", "capture_name", "max_tokens")
419419

420420
def __init__(self, values, name=None, max_tokens=100000000) -> None:
421+
values = [string(v) if isinstance(v, (str, bytes)) else v for v in values] # wrap raw strings
421422
self.nullable = all(v.nullable for v in values)
422423
self.values = [v for v in values if not isinstance(v, Null)]
423424
self.name = name if name is not None else StatelessFunction._new_name()

guidance/models/_model.py

Lines changed: 78 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,16 @@
1515
import numpy as np
1616
import logging
1717
import base64
18+
import queue
19+
import threading
1820

1921
logger = logging.getLogger(__name__)
2022
try:
2123
from .. import cpp
2224
except ImportError:
2325
logger.warn("Failed to load guidance.cpp, falling back to Python mirror implementations...")
2426
from .. import _cpp as cpp
25-
from .._utils import softmax
27+
from .._utils import softmax, CaptureEvents
2628
from .._parser import EarleyCommitParser
2729
from .._grammar import StatelessFunction, string, _call_pool, _tag_pattern, Null, replace_model_variables, unreplace_model_variables, select, Terminal
2830

@@ -83,6 +85,7 @@ def __init__(self, tokens, bos_token_id=None, eos_token_id=None, echo=True, comp
8385
self._event_queue = None # TODO: these are for streaming results in code, but that needs implemented
8486
self._event_parent = None
8587
self._last_display = 0 # used to track the last display call to enable throttling
88+
self._last_event_stream = 0 # used to track the last event streaming call to enable throttling
8689

8790
# build a prefix tree of the tokens
8891
self._token_trie = cpp.ByteTrie(tokens, np.arange(len(tokens)))
@@ -139,6 +142,9 @@ def _send_to_event_queue(self, value):
139142
self._event_queue.put(value)
140143
if self._event_parent is not None:
141144
self._event_parent._send_to_event_queue(value)
145+
146+
def stream(self):
147+
return ModelStream(self)
142148

143149
def copy(self):
144150
'''Create a shallow copy of the model object.'''
@@ -151,10 +157,12 @@ def copy(self):
151157
new_lm._variables_log_probs = self._variables_log_probs.copy()
152158
new_lm.opened_blocks = self.opened_blocks.copy()
153159

154-
# create a new clean event queue # TODO: can we delete this now?
155-
new_lm._event_queue = None
160+
# create a new clean event queue
161+
new_lm._event_queue = None # we start with no event queue because nobody is listening to us yet
156162
if self._event_queue is not None:
157-
new_lm._event_parent = self
163+
new_lm._event_parent = self # the current lm has an event que we make it our parent
164+
elif self._event_parent is not None:
165+
new_lm._event_parent = self._event_parent # otherwise if the current event que has an event parent then that is also our parent
158166

159167
return new_lm
160168

@@ -177,8 +185,15 @@ def _inplace_append(self, value, force_silent=False):
177185
if not force_silent:
178186
self._update_display()
179187

180-
# TODO: is this needed? This was for programmatic streaming...
181-
self._send_to_event_queue(self)
188+
# this is for programmatic streaming among other things
189+
if Model._throttle_refresh > 0:
190+
curr_time = time.time()
191+
if curr_time - self._last_event_stream >= self.max_display_rate:
192+
self._last_event_stream = curr_time
193+
self._send_to_event_queue(self)
194+
else:
195+
self._send_to_event_queue(self)
196+
182197

183198
def _update_display(self, throttle=True):
184199
if self.echo:
@@ -299,15 +314,15 @@ def __add__(self, value):
299314

300315
# run stateless functions (grammar nodes)
301316
elif isinstance(value, StatelessFunction):
302-
value._event_parent = lm
303317
out = lm._run_stateless(value)
304318

305319
# run stateful functions
306320
else:
307-
value._event_parent = lm
308321
out = value(lm)
309322
if out is None:
310-
raise Exception(f"A guidance function did not return a model object! Did you forget to return the new lm at the end of your function?")
323+
raise Exception(f"A guidance function returned `None`, not a model object! Did you forget to return the new lm at the end of your function?")
324+
if not isinstance(out, Model):
325+
raise Exception(f"A guidance function did not return a model object! Did you try to add a function to a model without calling the function? For example `model + guidance_function()` is correct, while `model + guidance_function` will cause this error.")
311326

312327
# this flushes the display
313328
out._inplace_append("")
@@ -970,6 +985,60 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e
970985
else:
971986
token_byte_positions.append(token_byte_positions[-1] + len(sampled_token))
972987

988+
class ModelStream:
989+
def __init__(self, model, grammar=None, timeout=5):
990+
'''Create a model stream object that delays execution until it is iterated over.'''
991+
if model.echo:
992+
model = model.copy()
993+
model.echo = False # turn off display echoing
994+
self.model = model
995+
self.grammar = grammar
996+
self.timeout = timeout
997+
998+
def __add__(self, grammar):
999+
'''Extend this delayed chain of execution with another grammar append.'''
1000+
return ModelStream(self.model, grammar)
1001+
1002+
def _inner_run(self, model):
1003+
'''This runs the model stream without iterating, and is only using internally by __iter__.'''
1004+
if isinstance(self.grammar, ModelStream):
1005+
model = self.grammar._inner_run(model)
1006+
elif self.grammar is None:
1007+
model = self.model + ""
1008+
else:
1009+
model = self.model + self.grammar
1010+
1011+
def __iter__(self):
1012+
'''Starts a thread to execute the model and grammar, yielding events as they occur.'''
1013+
1014+
# Create a thread-safe queue to hold events
1015+
with CaptureEvents(self.model) as events:
1016+
1017+
# Define the target function for the thread
1018+
def target():
1019+
self._inner_run(self.model)
1020+
events.put(None) # mark that we are done
1021+
1022+
# Start the thread
1023+
thread = threading.Thread(target=target)
1024+
thread.start()
1025+
1026+
# Yield events from the queue as they become available
1027+
while True:
1028+
try:
1029+
# Wait for an event with a timeout to allow for thread termination
1030+
event = events.get(timeout=self.timeout)
1031+
if event is None:
1032+
break
1033+
yield event
1034+
except queue.Empty:
1035+
# Check if the thread is still alive
1036+
if not thread.is_alive():
1037+
break
1038+
1039+
# Ensure the thread has completed
1040+
thread.join()
1041+
9731042
class Chat(Model):
9741043
'''The base class for all chat-tuned models.'''
9751044

0 commit comments

Comments
 (0)