Skip to content

Commit 03a8bae

Browse files
committed
Fix test_save_stop
1 parent 9425899 commit 03a8bae

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

guidance/models/_model.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -957,7 +957,10 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e
957957
# self._cache_state["new_token_ids"].append(sampled_token_ind)
958958

959959
# capture the named groups from the parse tree
960-
captured_data, captured_log_prob_data = parser.get_captures()
960+
new_captured_data, new_captured_log_prob_data = parser.get_captures()
961+
captured_data.update(new_captured_data)
962+
captured_log_prob_data.update(new_captured_log_prob_data)
963+
961964
# we have no valid log prob data if we didn't compute it
962965
yield new_bytes[hidden_count:], is_generated, new_bytes_prob, captured_data, captured_log_prob_data, token_count - last_token_count
963966
last_token_count = token_count
@@ -969,7 +972,9 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e
969972
out = new_bytes[hidden_count:]
970973
if len(out) > 0:
971974
# capture the named groups from the (partial) parse tree,
972-
captured_data, captured_log_prob_data = parser.get_captures()
975+
new_captured_data, new_captured_log_prob_data = parser.get_captures()
976+
captured_data.update(new_captured_data)
977+
captured_log_prob_data.update(new_captured_log_prob_data)
973978
yield out, is_generated, new_bytes_prob, captured_data, captured_log_prob_data, token_count - last_token_count # note that we don't capture groups until a complete parse right now...
974979
last_token_count = token_count
975980
hidden_count = 0

0 commit comments

Comments
 (0)