Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 77 additions & 61 deletions chatarena/backends/hf_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
from typing import List

from tenacity import retry, stop_after_attempt, wait_random_exponential
import torch

from ..message import SYSTEM_NAME as SYSTEM
from ..message import SYSTEM_NAME
from ..message import Message
from .base import IntelligenceBackend, register_backend

END_OF_MESSAGE = "<EOS>" # End of message token specified by us not OpenAI
BASE_PROMPT = f"The messages always end with the token {END_OF_MESSAGE}."
DEFAULT_MAX_TOKENS = 256


@contextmanager
def suppress_stdout_stderr():
Expand All @@ -22,9 +27,8 @@ def suppress_stdout_stderr():
try:
import transformers
from transformers import pipeline
from transformers.pipelines.conversational import (
Conversation,
ConversationalPipeline,
from transformers.pipelines.text_generation import (
TextGenerationPipeline,
)
except ImportError:
is_transformers_available = False
Expand All @@ -39,25 +43,40 @@ class TransformersConversational(IntelligenceBackend):
stateful = False
type_name = "transformers:conversational"

def __init__(self, model: str, device: int = -1, **kwargs):
super().__init__(model=model, device=device, **kwargs)
def __init__(
self,
model: str,
max_tokens: int = DEFAULT_MAX_TOKENS,
merge_other_agents_as_one_user: bool = True,
**kwargs,
):
super().__init__(model=model, **kwargs)
self.model = model
self.device = device
self.max_tokens = max_tokens

assert is_transformers_available, "Transformers package is not installed"
self.chatbot = pipeline(
task="conversational", model=self.model, device=self.device
task="text-generation",
model=self.model,
device_map="auto",
model_kwargs={"torch_dtype": torch.bfloat16},
)
self.terminators = [
self.chatbot.tokenizer.eos_token_id,
self.chatbot.tokenizer.convert_tokens_to_ids("<|eot_id|>"),
]
self.merge_other_agent_as_user = merge_other_agents_as_one_user

@retry(stop=stop_after_attempt(6), wait=wait_random_exponential(min=1, max=60))
def _get_response(self, conversation):
conversation = self.chatbot(conversation)
response = conversation.generated_responses[-1]
return response

@staticmethod
def _msg_template(agent_name, content):
return f"[{agent_name}]: {content}"
conversation = self.chatbot(
conversation,
max_new_tokens=self.max_tokens,
eos_token_id=self.terminators,
pad_token_id=self.chatbot.tokenizer.eos_token_id,
)
response = conversation[0]["generated_text"][-1]["content"]
return response

def query(
self,
Expand All @@ -69,60 +88,57 @@ def query(
*args,
**kwargs,
) -> str:
user_inputs, generated_responses = [], []
all_messages = (
[(SYSTEM, global_prompt), (SYSTEM, role_desc)]
if global_prompt
else [(SYSTEM, role_desc)]
)
if global_prompt: # Prepend the global prompt if it exists
system_prompt = f"You are a helpful assistant.\n{global_prompt.strip()}\n{BASE_PROMPT}\n\nYour name is {agent_name}.\n\nYour role:{role_desc}"
else:
system_prompt = f"You are a helpful assistant. Your name is {agent_name}.\n\nYour role:{role_desc}\n\n{BASE_PROMPT}"

all_messages = [(SYSTEM_NAME, system_prompt)]
for msg in history_messages:
all_messages.append((msg.agent_name, msg.content))
if request_msg:
all_messages.append((SYSTEM, request_msg.content))
if msg.agent_name == SYSTEM_NAME:
all_messages.append((SYSTEM_NAME, msg.content))
else: # non-system messages are suffixed with the end of message token
all_messages.append((msg.agent_name, f"{msg.content}{END_OF_MESSAGE}"))

prev_is_user = False # Whether the previous message is from the user
for i, message in enumerate(all_messages):
if request_msg:
all_messages.append((SYSTEM_NAME, request_msg.content))
else: # The default request message that reminds the agent its role and instruct it to speak
all_messages.append(
(SYSTEM_NAME, f"Now you speak, {agent_name}.{END_OF_MESSAGE}")
)

messages = []
for i, msg in enumerate(all_messages):
if i == 0:
assert (
message[0] == SYSTEM
msg[0] == SYSTEM_NAME
) # The first message should be from the system

if message[0] != agent_name:
if not prev_is_user:
user_inputs.append(self._msg_template(message[0], message[1]))
else:
user_inputs[-1] += "\n" + self._msg_template(message[0], message[1])
prev_is_user = True
messages.append({"role": "system", "content": msg[1]})
else:
if prev_is_user:
generated_responses.append(message[1])
if msg[0] == agent_name:
messages.append({"role": "assistant", "content": msg[1]})
else:
generated_responses[-1] += "\n" + message[1]
prev_is_user = False

assert len(user_inputs) == len(generated_responses) + 1
past_user_inputs = user_inputs[:-1]
new_user_input = user_inputs[-1]

# Recreate a conversation object from the history messages
conversation = Conversation(
text=new_user_input,
past_user_inputs=past_user_inputs,
generated_responses=generated_responses,
)
if messages[-1]["role"] == "user": # last message is from user
if self.merge_other_agent_as_user:
messages[-1][
"content"
] = f"{messages[-1]['content']}\n\n[{msg[0]}]: {msg[1]}"
else:
messages.append(
{"role": "user", "content": f"[{msg[0]}]: {msg[1]}"}
)
elif (
messages[-1]["role"] == "assistant"
): # consecutive assistant messages
# Merge the assistant messages
messages[-1]["content"] = f"{messages[-1]['content']}\n{msg[1]}"
elif messages[-1]["role"] == "system":
messages.append(
{"role": "user", "content": f"[{msg[0]}]: {msg[1]}"}
)
else:
raise ValueError(f"Invalid role: {messages[-1]['role']}")

# Get the response
response = self._get_response(conversation)
response = self._get_response(messages)
return response


# conversation = Conversation("Going to the movies tonight - any suggestions?")
#
# # Steps usually performed by the model when generating a response:
# # 1. Mark the user input as processed (moved to the history)
# conversation.mark_processed()
# # 2. Append a mode response
# conversation.append_response("The Big lebowski.")
#
# conversation.add_user_input("Is it good?")
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ langchain = ["langchain>=0.0.340"]
gradio = ["gradio==3.34.0", "pydantic==1.10.13"]
pettingzoo = ["pettingzoo>=1.24.0", "chess==1.9.4", "rlcard==1.0.5", "pygame==2.3.0", "gymnasium>=0.28.1"]
umshini = ["pettingzoo>=1.24.1", "chess==1.9.4", "rlcard==1.0.5", "pygame==2.3.0", "gymnasium>=0.28.1", "langchain>=0.0.340", "colorama>=0.4.6"]
all_backends = ["anthropic>=0.2.8", "cohere>=4.3.1", "transformers>=4.27.4", "bardapi==0.1.11", "langchain>=0.0.135"]
all_backends = ["anthropic>=0.2.8", "cohere>=4.3.1", "transformers>=4.27.4", "bardapi==0.1.11", "langchain>=0.0.135", "accelerate>=0.33.0"]
all_envs = ["pettingzoo>=1.24.0", "chess==1.9.4", "rlcard==1.0.5", "pygame==2.3.0", "langchain>=0.0.135"]
database = ["supabase==2.0.3"]
testing = ["deptry>=0.12.0", "pytest>=7.4.3", "pytest-cov>=4.1.0", "pytest-xdist>=3.4.0"]
all = ["anthropic==0.2.8", "cohere==4.3.1", "transformers>=4.27.4", "gradio==3.34.0", "pydantic==1.10.13", "pettingzoo>=1.24.0", "chess==1.9.4", "rlcard==1.0.5", "pygame==2.3.0", "gymnasium>=0.28.1",
"colorama>=0.4.6", "supabase==2.0.3", "bardapi==0.1.11", "langchain>=0.0.340", "deptry>=0.12.0", "pytest>=7.4.3", "pytest-cov>=4.1.0", "pytest-xdist>=3.4.0"]
"colorama>=0.4.6", "supabase==2.0.3", "bardapi==0.1.11", "langchain>=0.0.340", "deptry>=0.12.0", "pytest>=7.4.3", "pytest-cov>=4.1.0", "pytest-xdist>=3.4.0", "accelerate>=0.33.0"]

[tool.deptry.per_rule_ignores]
DEP002 = [ "pytest", "pytest-cov", "deptry", "pytest-xdist", "chess", "rlcard", "pygame", "pydantic" ]
Expand Down
5 changes: 2 additions & 3 deletions tests/unit/test_hf_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@


class TestHFTransformers(TestCase):
@unittest.skip("TODO: fix failing test")
def test_transformers_conv_1(self):
backend = TransformersConversational(
model="facebook/blenderbot-400M-distill", device=-1
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
)

history_messages = [
Expand All @@ -45,7 +44,7 @@ def test_transformers_conv_1(self):

def test_transformers_conv_2(self):
backend = TransformersConversational(
model="facebook/blenderbot-400M-distill", device=-1
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
)

history_messages = [
Expand Down