@@ -252,31 +252,44 @@ def __add__(self, value):
252252 # inside this context we are free to drop display calls that come too close together
253253 with throttle_refresh ():
254254
255- # close any newly closed contexts
255+ # find what new blocks need to be applied
256+ new_blocks = []
257+ for context in Model .open_blocks :
258+ if context not in lm .opened_blocks :
259+ new_blocks .append (context )
260+
261+ # mark this so we don't re-add when computing the opener or closer (even though we don't know the close text yet)
262+ lm .opened_blocks [context ] = (0 , "" )
263+
264+ # find what old blocks need to be removed
265+ old_blocks = []
256266 for context in list (reversed (lm .opened_blocks )):
257267 if context not in Model .open_blocks and context in lm .opened_blocks :
258- pos , close_text = lm .opened_blocks [context ] # save so we can delete it before adding it
259- if context . name is not None :
260- lm . _variables [ context . name ] = format_pattern . sub ( "" , lm . _state [ pos :])
268+ old_blocks . append (( lm .opened_blocks [context ], context ))
269+
270+ # delete this so we don't re-close when computing the opener or closer
261271 del lm .opened_blocks [context ]
262- lm ._inplace_append (close_text )
272+
273+ # close any newly closed contexts
274+ for (pos , close_text ), context in old_blocks :
275+ if context .name is not None :
276+ lm ._variables [context .name ] = format_pattern .sub ("" , lm ._state [pos :])
277+ lm += context .closer
263278
264279 # apply any newly opened contexts (new from this object's perspective)
265- for context in Model .open_blocks :
266- if context not in lm .opened_blocks :
267- lm .opened_blocks [context ] = (0 , "" ) # mark this so we don't readd when computing the opener (even though we don't know the close text yet)
268- lm += context .opener
269- with grammar_only ():
270- tmp = lm + context .closer
271- close_text = tmp ._state [len (lm ._state ):] # get the new state added by calling the closer
272- lm .opened_blocks [context ] = (len (lm ._state ), close_text )
273-
274- # clear out names that we override
275- if context .name is not None :
276- if context .name in lm ._variables :
277- del lm ._variables [context .name ]
278- if context .name in lm ._variables_log_probs :
279- del lm ._variables_log_probs [context .name ]
280+ for context in new_blocks :
281+ lm += context .opener
282+ with grammar_only ():
283+ tmp = lm + context .closer
284+ close_text = tmp ._state [len (lm ._state ):] # get the new state added by calling the closer
285+ lm .opened_blocks [context ] = (len (lm ._state ), close_text )
286+
287+ # clear out names that we override
288+ if context .name is not None :
289+ if context .name in lm ._variables :
290+ del lm ._variables [context .name ]
291+ if context .name in lm ._variables_log_probs :
292+ del lm ._variables_log_probs [context .name ]
280293
281294 # wrap raw string values
282295 if isinstance (value , str ):
@@ -367,6 +380,32 @@ def get(self, key, default=None):
367380 The value to return if the variable is not current set.
368381 '''
369382 return self ._variables .get (key , default )
383+
384+ def setattr (self , key , value ):
385+ '''Return a new model with the given model attribute set.
386+
387+ Parameters
388+ ----------
389+ key : str
390+ The name of the attribute to be set.
391+ value : any
392+ The value to set the attribute to.
393+ '''
394+ copy = self .copy ()
395+ setattr (copy , key , value )
396+ return copy
397+
398+ def delattr (self , key ):
399+ '''Return a new model with the given attribute deleted.
400+
401+ Parameters
402+ ----------
403+ key : str
404+ The attribute name to remove.
405+ '''
406+ copy = self .copy ()
407+ delattr (copy , key )
408+ return copy
370409
371410 def set (self , key , value ):
372411 '''Return a new model with the given variable value set.
@@ -957,9 +996,7 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e
957996 # self._cache_state["new_token_ids"].append(sampled_token_ind)
958997
959998 # capture the named groups from the parse tree
960- new_captured_data , new_captured_log_prob_data = parser .get_captures ()
961- captured_data .update (new_captured_data )
962- captured_log_prob_data .update (new_captured_log_prob_data )
999+ parser .get_captures (captured_data , captured_log_prob_data )
9631000
9641001 # we have no valid log prob data if we didn't compute it
9651002 yield new_bytes [hidden_count :], is_generated , new_bytes_prob , captured_data , captured_log_prob_data , token_count - last_token_count
0 commit comments