Skip to content

Commit f6d7f5d

Browse files
authored
Merge pull request #578 from cpcdoy/feature/debug_llm_role_blocks
Prompt role blocks debug output + new prompt implementation example
2 parents f0c2157 + 47b3339 commit f6d7f5d

File tree

7 files changed

+470
-32
lines changed

7 files changed

+470
-32
lines changed

guidance/library/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313

1414
# context blocks
1515
from ._block import block
16-
from ._role import role, system, assistant, user, function, instruction
16+
from ._role import role, system, assistant, user, function, instruction, indent_roles
1717
from ._format import monospace
1818
from ._silent import silent
19+
from ._set_var import set_var
20+
from ._set_attribute import set_attribute
1921
# from ..models._model import context_free
2022

2123
# stateless library functions

guidance/library/_role.py

Lines changed: 65 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,95 @@
11
import guidance
22
from ._block import block
3+
from ._set_attribute import set_attribute
4+
5+
nodisp_start = "<||_#NODISP_||>"
6+
nodisp_end = "<||_/NODISP_||>"
7+
span_start = "<||_html:<span style='background-color: rgba(255, 180, 0, 0.3); border-radius: 3px;'>_||>"
8+
span_end = "<||_html:</span>_||>"
9+
310

411
@guidance
512
def role_opener(lm, role_name, **kwargs):
13+
indent = getattr(lm, "indent_roles", True)
614
if not hasattr(lm, "get_role_start"):
7-
raise Exception(f"You need to use a chat model in order the use role blocks like `with {role_name}():`! Perhaps you meant to use the {type(lm).__name__}Chat class?")
8-
lm += f"<||_html:<div style='display: flex; border-bottom: 1px solid rgba(127, 127, 127, 0.2); align-items: center;'><div style='flex: 0 0 80px; opacity: 0.5;'>{role_name.lower()}</div><div style='flex-grow: 1; padding: 5px; padding-top: 10px; padding-bottom: 10px; margin-top: 0px; white-space: pre-wrap; margin-bottom: 0px;'>_||>"
9-
lm += "<||_#NODISP_||>" + lm.get_role_start(role_name, **kwargs) + "<||_/NODISP_||>"
15+
raise Exception(
16+
f"You need to use a chat model in order the use role blocks like `with {role_name}():`! Perhaps you meant to use the {type(lm).__name__}Chat class?"
17+
)
18+
19+
# Block start container (centers elements)
20+
if indent:
21+
lm += f"<||_html:<div style='display: flex; border-bottom: 1px solid rgba(127, 127, 127, 0.2); justify-content: center; align-items: center;'><div style='flex: 0 0 80px; opacity: 0.5;'>{role_name.lower()}</div><div style='flex-grow: 1; padding: 5px; padding-top: 10px; padding-bottom: 10px; margin-top: 0px; white-space: pre-wrap; margin-bottom: 0px;'>_||>"
22+
23+
# Start of either debug or HTML no disp block
24+
if indent:
25+
lm += nodisp_start
26+
else:
27+
lm += span_start
28+
29+
lm += lm.get_role_start(role_name, **kwargs)
30+
31+
# End of either debug or HTML no disp block
32+
if indent:
33+
lm += nodisp_end
34+
else:
35+
lm += span_end
36+
1037
return lm
1138

39+
1240
@guidance
1341
def role_closer(lm, role_name, **kwargs):
14-
lm += "<||_html:</div></div>_||>" + "<||_#NODISP_||>" + lm.get_role_end(role_name) + "<||_/NODISP_||>"
42+
indent = getattr(lm, "indent_roles", True)
43+
# Start of either debug or HTML no disp block
44+
if indent:
45+
lm += nodisp_start
46+
else:
47+
lm += span_start
48+
49+
lm += lm.get_role_end(role_name)
50+
51+
# End of either debug or HTML no disp block
52+
if indent:
53+
lm += nodisp_end
54+
else:
55+
lm += span_end
56+
57+
# End of top container
58+
if indent:
59+
lm += "<||_html:</div></div>_||>"
60+
1561
return lm
1662

63+
1764
def role(role_name, text=None, **kwargs):
1865
if text is None:
19-
return block(opener=role_opener(role_name, **kwargs), closer=role_closer(role_name, **kwargs))
66+
return block(
67+
opener=role_opener(role_name, **kwargs),
68+
closer=role_closer(role_name, **kwargs),
69+
)
2070
else:
2171
assert False
22-
#return self.append(open_text + text + close_text)
72+
# return self.append(open_text + text + close_text)
73+
2374

2475
def system(text=None, **kwargs):
2576
return role("system", text, **kwargs)
2677

78+
2779
def user(text=None, **kwargs):
2880
return role("user", text, **kwargs)
2981

82+
3083
def assistant(text=None, **kwargs):
3184
return role("assistant", text, **kwargs)
3285

86+
3387
def function(text=None, **kwargs):
3488
return role("function", text, **kwargs)
3589

90+
3691
def instruction(text=None, **kwargs):
37-
return role("instruction", text, **kwargs)
92+
return role("instruction", text, **kwargs)
93+
94+
def indent_roles(indent=True):
95+
return set_attribute("indent_roles", indent)

guidance/library/_set_attribute.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import guidance
2+
from ._block import block
3+
4+
@guidance
5+
def set_attr_opener(lm, name, value):
6+
if hasattr(lm, name):
7+
lm = lm.setattr("__save" + name, getattr(lm, name))
8+
return lm.setattr(name, value)
9+
10+
@guidance
11+
def set_attr_closer(lm, name):
12+
if hasattr(lm, "__save" + name):
13+
return lm.setattr(name, lm["__save" + name]).delattr("__save" + name)
14+
else:
15+
return lm.delattr(name)
16+
17+
def set_attribute(name, value=True):
18+
return block(
19+
opener=set_attr_opener(name, value),
20+
closer=set_attr_closer(name),
21+
)

guidance/library/_set_var.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import guidance
2+
from ._block import block
3+
4+
@guidance
5+
def set_opener(lm, name, value):
6+
if name in lm:
7+
lm = lm.set("__save" + name, lm[name])
8+
return lm.set(name, value)
9+
10+
@guidance
11+
def set_closer(lm, name):
12+
if "__save" + name in lm:
13+
return lm.set(name, lm["__save" + name]).remove("__save" + name)
14+
else:
15+
return lm.remove(name)
16+
17+
def set_var(name, value=True):
18+
return block(
19+
opener=set_opener(name, value),
20+
closer=set_closer(name),
21+
)

guidance/models/_model.py

Lines changed: 60 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -252,31 +252,44 @@ def __add__(self, value):
252252
# inside this context we are free to drop display calls that come too close together
253253
with throttle_refresh():
254254

255-
# close any newly closed contexts
255+
# find what new blocks need to be applied
256+
new_blocks = []
257+
for context in Model.open_blocks:
258+
if context not in lm.opened_blocks:
259+
new_blocks.append(context)
260+
261+
# mark this so we don't re-add when computing the opener or closer (even though we don't know the close text yet)
262+
lm.opened_blocks[context] = (0, "")
263+
264+
# find what old blocks need to be removed
265+
old_blocks = []
256266
for context in list(reversed(lm.opened_blocks)):
257267
if context not in Model.open_blocks and context in lm.opened_blocks:
258-
pos, close_text = lm.opened_blocks[context] # save so we can delete it before adding it
259-
if context.name is not None:
260-
lm._variables[context.name] = format_pattern.sub("", lm._state[pos:])
268+
old_blocks.append((lm.opened_blocks[context], context))
269+
270+
# delete this so we don't re-close when computing the opener or closer
261271
del lm.opened_blocks[context]
262-
lm._inplace_append(close_text)
272+
273+
# close any newly closed contexts
274+
for (pos, close_text), context in old_blocks:
275+
if context.name is not None:
276+
lm._variables[context.name] = format_pattern.sub("", lm._state[pos:])
277+
lm += context.closer
263278

264279
# apply any newly opened contexts (new from this object's perspective)
265-
for context in Model.open_blocks:
266-
if context not in lm.opened_blocks:
267-
lm.opened_blocks[context] = (0, "") # mark this so we don't readd when computing the opener (even though we don't know the close text yet)
268-
lm += context.opener
269-
with grammar_only():
270-
tmp = lm + context.closer
271-
close_text = tmp._state[len(lm._state):] # get the new state added by calling the closer
272-
lm.opened_blocks[context] = (len(lm._state), close_text)
273-
274-
# clear out names that we override
275-
if context.name is not None:
276-
if context.name in lm._variables:
277-
del lm._variables[context.name]
278-
if context.name in lm._variables_log_probs:
279-
del lm._variables_log_probs[context.name]
280+
for context in new_blocks:
281+
lm += context.opener
282+
with grammar_only():
283+
tmp = lm + context.closer
284+
close_text = tmp._state[len(lm._state):] # get the new state added by calling the closer
285+
lm.opened_blocks[context] = (len(lm._state), close_text)
286+
287+
# clear out names that we override
288+
if context.name is not None:
289+
if context.name in lm._variables:
290+
del lm._variables[context.name]
291+
if context.name in lm._variables_log_probs:
292+
del lm._variables_log_probs[context.name]
280293

281294
# wrap raw string values
282295
if isinstance(value, str):
@@ -367,6 +380,32 @@ def get(self, key, default=None):
367380
The value to return if the variable is not current set.
368381
'''
369382
return self._variables.get(key, default)
383+
384+
def setattr(self, key, value):
385+
'''Return a new model with the given model attribute set.
386+
387+
Parameters
388+
----------
389+
key : str
390+
The name of the attribute to be set.
391+
value : any
392+
The value to set the attribute to.
393+
'''
394+
copy = self.copy()
395+
setattr(copy, key, value)
396+
return copy
397+
398+
def delattr(self, key):
399+
'''Return a new model with the given attribute deleted.
400+
401+
Parameters
402+
----------
403+
key : str
404+
The attribute name to remove.
405+
'''
406+
copy = self.copy()
407+
delattr(copy, key)
408+
return copy
370409

371410
def set(self, key, value):
372411
'''Return a new model with the given variable value set.
@@ -957,9 +996,7 @@ def __call__(self, grammar, max_tokens=1000000, n=1, top_p=1, temperature=0.0, e
957996
# self._cache_state["new_token_ids"].append(sampled_token_ind)
958997

959998
# capture the named groups from the parse tree
960-
new_captured_data, new_captured_log_prob_data = parser.get_captures()
961-
captured_data.update(new_captured_data)
962-
captured_log_prob_data.update(new_captured_log_prob_data)
999+
parser.get_captures(captured_data, captured_log_prob_data)
9631000

9641001
# we have no valid log prob data if we didn't compute it
9651002
yield new_bytes[hidden_count:], is_generated, new_bytes_prob, captured_data, captured_log_prob_data, token_count - last_token_count
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
from ._llama import Llama, LlamaChat
1+
from ._llama import Llama, LlamaChat
2+
from ._transformers import Transformers, TransformersChat

0 commit comments

Comments
 (0)