Skip to content

Commit 45b186d

Browse files
committed
Track and format chat history
Given the multiplicity of formats, formatting the prompt for chat workflows with open models can be a real hassle and is error-prone. In this PR we introduce a `Chat` class that allows users to track the conversation and easily print the corresponding prompt.
1 parent 3d2689a commit 45b186d

File tree

4 files changed

+151
-0
lines changed

4 files changed

+151
-0
lines changed

docs/reference/chat.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Chat history
2+
3+
## Filter message
4+
5+
In some situation you may want to filter the messages before building the prompt, for instance to use RAG. In this case you can subclass `Chat` and override the `filter` method:
6+
7+
8+
```python
9+
from prompts import Chat
10+
11+
class RAGChat(Chat):
12+
13+
def filter(self):
14+
filtered_message = []
15+
for message in filtered_message:
16+
if message.role == "user" and "Hi" in message.content:
17+
filtered_message.append(message)
18+
19+
return filtered_messages
20+
```

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,4 @@ nav:
7676
- Prompt template: reference/template.md
7777
- Dispatch: reference/dispatch.md
7878
- Special tokens: reference/special_tokens.md
79+
- Chat History: reference/chat.md

prompts/chat.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from dataclasses import dataclass
2+
from enum import Enum
3+
from typing import List, Optional
4+
5+
from prompts.templates import template
6+
7+
8+
class Role(Enum):
9+
system = "system"
10+
user = "user"
11+
assistant = "assistant"
12+
13+
14+
@dataclass
15+
class Message:
16+
role: Role
17+
content: str
18+
19+
20+
class Chat:
21+
history: List[Message]
22+
23+
def __init__(self, model_name: str, system_msg: Optional[str] = None):
24+
from transformers import AutoTokenizer
25+
26+
# This is annoying, we need to handle those ourselves.
27+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
28+
self.history = []
29+
if system_msg is not None:
30+
self.history.append(Message(Role.system, system_msg))
31+
32+
def __str__(self):
33+
"""Render the prompt that corresponds to the chat history in the format
34+
that `model_name` expects.
35+
36+
In order to be compatible with any library we choose to append the
37+
token that corresponds to the beginning of the assistant's response
38+
when the last message is from a `user`.
39+
40+
How is not adding this token useful anyway?
41+
42+
This needs to be properly documented.
43+
44+
I think correctness, i.e. alternation between user and assistant, should
45+
be checked after filtering the history.
46+
47+
"""
48+
history = self.filter()
49+
if not self._is_history_valid(history):
50+
raise ValueError("History not valid")
51+
52+
prompt = chat_template[self.model_name](history)
53+
54+
# translate this to format expected by huggingface
55+
# use tokenizer.apply_chat_template(chat, tokenizer=False)
56+
57+
return prompt
58+
59+
def filter(self):
60+
"""Filter the messages before building the prompt.
61+
62+
The `Chat` class should be subclassed by users who want to filter
63+
messages before building the prompt, and override this method. This
64+
can for instance use a RAG step.
65+
66+
(Document)
67+
68+
"""
69+
return self.history
70+
71+
def __getitem__(self, index: int):
72+
return self.history[index]
73+
74+
def __getattribute__(self, role: str):
75+
"""Returns all messages for the role `role`"""
76+
return [message for message in self.history if message.role == role]
77+
78+
def user(self, msg: str):
79+
"""Add a new user message."""
80+
self.history.append(Message(Role.user, msg))
81+
82+
def assistant(self, msg: str):
83+
"""Add a new assistant message."""
84+
85+
self.history.append(Message(Role.assistant, msg))
86+
87+
88+
@template
89+
def chat_template(messages):
90+
"""
91+
{% for message in messages %}
92+
{%- if loop.index == 0 %}
93+
{%- if message.role == 'system' %}
94+
{{- message.content + bos }}\n
95+
{%- else %}
96+
{{- bos + user.begin + message.content + user.end }}
97+
{%- endif %}
98+
{%- else %}
99+
{%- if message.role == 'user' %}
100+
\n{{- user.begin + message.content + user.end }}
101+
{%- else %}
102+
\n{{- assistant.begin + message.content + assistant.end }}
103+
{%- endif %}
104+
{%- endif %}
105+
{% endfor %}
106+
{%- if messages[-1].role == 'user'}
107+
\n{{ assistant.begin }}
108+
{% endif %}"""

tests/test_chat.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import pytest
2+
3+
from prompts.chat import Chat
4+
5+
6+
def test_simple():
7+
chat = Chat("gpt2", "system message")
8+
chat.user("new user message")
9+
chat.assistant("new assistant message")
10+
print(chat)
11+
12+
assert chat["assistant"][0].content == "new assistant message"
13+
assert chat["user"][0].content == "new user message"
14+
assert chat[1].content == "new user message"
15+
16+
17+
def test_error():
18+
with pytest.raises(ValueError):
19+
chat = Chat("gpt2", "system message")
20+
chat.user("new user message")
21+
chat.user("new user message")
22+
print(chat)

0 commit comments

Comments
 (0)