1515import numpy as np
1616import logging
1717import base64
18+ import queue
19+ import threading
1820
1921logger = logging .getLogger (__name__ )
2022try :
2123 from .. import cpp
2224except ImportError :
2325 logger .warn ("Failed to load guidance.cpp, falling back to Python mirror implementations..." )
2426 from .. import _cpp as cpp
25- from .._utils import softmax
27+ from .._utils import softmax , CaptureEvents
2628from .._parser import EarleyCommitParser
2729from .._grammar import StatelessFunction , string , _call_pool , _tag_pattern , Null , replace_model_variables , unreplace_model_variables , select , Terminal
2830
@@ -83,6 +85,7 @@ def __init__(self, tokens, bos_token_id=None, eos_token_id=None, echo=True, comp
8385 self ._event_queue = None # TODO: these are for streaming results in code, but that needs implemented
8486 self ._event_parent = None
8587 self ._last_display = 0 # used to track the last display call to enable throttling
88+ self ._last_event_stream = 0 # used to track the last event streaming call to enable throttling
8689
8790 # build a prefix tree of the tokens
8891 self ._token_trie = cpp .ByteTrie (tokens , np .arange (len (tokens )))
@@ -139,6 +142,9 @@ def _send_to_event_queue(self, value):
139142 self ._event_queue .put (value )
140143 if self ._event_parent is not None :
141144 self ._event_parent ._send_to_event_queue (value )
145+
146+ def stream (self ):
147+ return ModelStream (self )
142148
143149 def copy (self ):
144150 '''Create a shallow copy of the model object.'''
@@ -151,10 +157,12 @@ def copy(self):
151157 new_lm ._variables_log_probs = self ._variables_log_probs .copy ()
152158 new_lm .opened_blocks = self .opened_blocks .copy ()
153159
154- # create a new clean event queue # TODO: can we delete this now?
155- new_lm ._event_queue = None
160+ # create a new clean event queue
161+ new_lm ._event_queue = None # we start with no event queue because nobody is listening to us yet
156162 if self ._event_queue is not None :
157- new_lm ._event_parent = self
163+ new_lm ._event_parent = self # the current lm has an event que we make it our parent
164+ elif self ._event_parent is not None :
165+ new_lm ._event_parent = self ._event_parent # otherwise if the current event que has an event parent then that is also our parent
158166
159167 return new_lm
160168
@@ -177,8 +185,15 @@ def _inplace_append(self, value, force_silent=False):
177185 if not force_silent :
178186 self ._update_display ()
179187
180- # TODO: is this needed? This was for programmatic streaming...
181- self ._send_to_event_queue (self )
188+ # this is for programmatic streaming among other things
189+ if Model ._throttle_refresh > 0 :
190+ curr_time = time .time ()
191+ if curr_time - self ._last_event_stream >= self .max_display_rate :
192+ self ._last_event_stream = curr_time
193+ self ._send_to_event_queue (self )
194+ else :
195+ self ._send_to_event_queue (self )
196+
182197
183198 def _update_display (self , throttle = True ):
184199 if self .echo :
@@ -305,7 +320,9 @@ def __add__(self, value):
305320 else :
306321 out = value (lm )
307322 if out is None :
308- raise Exception (f"A guidance function did not return a model object! Did you forget to return the new lm at the end of your function?" )
323+ raise Exception (f"A guidance function returned `None`, not a model object! Did you forget to return the new lm at the end of your function?" )
324+ if not isinstance (out , Model ):
325+ raise Exception (f"A guidance function did not return a model object! Did you try to add a function to a model without calling the function? For example `model + guidance_function()` is correct, while `model + guidance_function` will cause this error." )
309326
310327 # this flushes the display
311328 out ._inplace_append ("" )
@@ -940,9 +957,10 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e
940957 # self._cache_state["new_token_ids"].append(sampled_token_ind)
941958
942959 # capture the named groups from the parse tree
943- parse_tree = parser .parse_tree ()
944- _record_captures (parse_tree , captured_data , captured_log_prob_data , parser .bytes )
945-
960+ new_captured_data , new_captured_log_prob_data = parser .get_captures ()
961+ captured_data .update (new_captured_data )
962+ captured_log_prob_data .update (new_captured_log_prob_data )
963+
946964 # we have no valid log prob data if we didn't compute it
947965 yield new_bytes [hidden_count :], is_generated , new_bytes_prob , captured_data , captured_log_prob_data , token_count - last_token_count
948966 last_token_count = token_count
@@ -953,7 +971,11 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e
953971 # yeild the snippet of text created by the next token
954972 out = new_bytes [hidden_count :]
955973 if len (out ) > 0 :
956- yield out , is_generated , new_bytes_prob , {}, {}, token_count - last_token_count # note that we don't capture groups until a complete parse right now...
974+ # capture the named groups from the (partial) parse tree, # TODO: disabled for now until we handle list_append correctly
975+ # new_captured_data, new_captured_log_prob_data = parser.get_captures()
976+ # captured_data.update(new_captured_data)
977+ # captured_log_prob_data.update(new_captured_log_prob_data)
978+ yield out , is_generated , new_bytes_prob , captured_data , captured_log_prob_data , token_count - last_token_count # note that we don't capture groups until a complete parse right now...
957979 last_token_count = token_count
958980 hidden_count = 0
959981 token_count += 1 # note we only update this for tokens that emit non-hidden content
@@ -968,6 +990,60 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e
968990 else :
969991 token_byte_positions .append (token_byte_positions [- 1 ] + len (sampled_token ))
970992
993+ class ModelStream :
994+ def __init__ (self , model , grammar = None , timeout = 5 ):
995+ '''Create a model stream object that delays execution until it is iterated over.'''
996+ if model .echo :
997+ model = model .copy ()
998+ model .echo = False # turn off display echoing
999+ self .model = model
1000+ self .grammar = grammar
1001+ self .timeout = timeout
1002+
1003+ def __add__ (self , grammar ):
1004+ '''Extend this delayed chain of execution with another grammar append.'''
1005+ return ModelStream (self .model , grammar )
1006+
1007+ def _inner_run (self , model ):
1008+ '''This runs the model stream without iterating, and is only using internally by __iter__.'''
1009+ if isinstance (self .grammar , ModelStream ):
1010+ model = self .grammar ._inner_run (model )
1011+ elif self .grammar is None :
1012+ model = self .model + ""
1013+ else :
1014+ model = self .model + self .grammar
1015+
1016+ def __iter__ (self ):
1017+ '''Starts a thread to execute the model and grammar, yielding events as they occur.'''
1018+
1019+ # Create a thread-safe queue to hold events
1020+ with CaptureEvents (self .model ) as events :
1021+
1022+ # Define the target function for the thread
1023+ def target ():
1024+ self ._inner_run (self .model )
1025+ events .put (None ) # mark that we are done
1026+
1027+ # Start the thread
1028+ thread = threading .Thread (target = target )
1029+ thread .start ()
1030+
1031+ # Yield events from the queue as they become available
1032+ while True :
1033+ try :
1034+ # Wait for an event with a timeout to allow for thread termination
1035+ event = events .get (timeout = self .timeout )
1036+ if event is None :
1037+ break
1038+ yield event
1039+ except queue .Empty :
1040+ # Check if the thread is still alive
1041+ if not thread .is_alive ():
1042+ break
1043+
1044+ # Ensure the thread has completed
1045+ thread .join ()
1046+
9711047class Chat (Model ):
9721048 '''The base class for all chat-tuned models.'''
9731049
@@ -1033,55 +1109,6 @@ def throttle_refresh():
10331109class ConstraintException (Exception ):
10341110 pass
10351111
1036- def _record_captures (initial_item , data , log_prob_data , byte_data ):
1037- stack = [(initial_item , 0 )]
1038- used_names = set () # track which capture names have been used so self-recursive children don't overwrite their parents
1039-
1040- while stack :
1041- item , byte_pos = stack .pop ()
1042- # terminal nodes
1043- if isinstance (item , Terminal ):
1044-
1045- # if we are at a capture group node then we save the matched terminal byte
1046- if item .capture_name is not None :
1047- data [item .capture_name ] = item .byte
1048- log_prob_data [item .capture_name ] = 0
1049-
1050- # internal nodes
1051- else :
1052- start_byte_pos = byte_pos
1053-
1054- # recurse for all our non-null children
1055- for child in item .children :
1056- if child is not None :
1057- stack .append ((child , byte_pos ))
1058- # _record_captures(child, data, log_prob_data, byte_data, byte_pos)
1059- if isinstance (child , Terminal ):
1060- byte_pos += len (child )
1061- else :
1062- byte_pos = child .start # note that "start" means "end" since this is a reversed state set
1063-
1064- # if we are at a capture group node then we save the matched bytes range
1065- # note that we record this after calling our children so that we save the outermost version of self-recursive calls
1066- cname = item .node .capture_name
1067- if cname is not None and cname not in used_names and not item .node .hidden :
1068-
1069- # see if we are doing a list append
1070- if cname .startswith ("__LIST_APPEND:" ):
1071- cname = cname [14 :] # trim off the list append tag
1072- if cname not in data or not isinstance (data [cname ], list ):
1073- data [cname ] = []
1074- log_prob_data [cname ] = []
1075- data [cname ].append (byte_data [start_byte_pos :item .start ])
1076- log_prob_data [cname ].append (item .log_prob )
1077-
1078- # or just a regular assignment
1079- else :
1080- data [cname ] = byte_data [start_byte_pos :item .start ] # note that "start" means "end" since this is a reversed state set
1081- log_prob_data [cname ] = item .log_prob
1082-
1083- used_names .add (cname )
1084-
10851112# def _compute_probs(trie, probs, found):
10861113# '''Computes the log probabilities for each internal trie node.'''
10871114# if trie.value is not None:
0 commit comments