diff --git a/chatarena/backends/hf_transformers.py b/chatarena/backends/hf_transformers.py index a307b4d9..f50573be 100644 --- a/chatarena/backends/hf_transformers.py +++ b/chatarena/backends/hf_transformers.py @@ -3,11 +3,16 @@ from typing import List from tenacity import retry, stop_after_attempt, wait_random_exponential +import torch -from ..message import SYSTEM_NAME as SYSTEM +from ..message import SYSTEM_NAME from ..message import Message from .base import IntelligenceBackend, register_backend +END_OF_MESSAGE = "" # End of message token specified by us not OpenAI +BASE_PROMPT = f"The messages always end with the token {END_OF_MESSAGE}." +DEFAULT_MAX_TOKENS = 256 + @contextmanager def suppress_stdout_stderr(): @@ -22,9 +27,8 @@ def suppress_stdout_stderr(): try: import transformers from transformers import pipeline - from transformers.pipelines.conversational import ( - Conversation, - ConversationalPipeline, + from transformers.pipelines.text_generation import ( + TextGenerationPipeline, ) except ImportError: is_transformers_available = False @@ -39,25 +43,40 @@ class TransformersConversational(IntelligenceBackend): stateful = False type_name = "transformers:conversational" - def __init__(self, model: str, device: int = -1, **kwargs): - super().__init__(model=model, device=device, **kwargs) + def __init__( + self, + model: str, + max_tokens: int = DEFAULT_MAX_TOKENS, + merge_other_agents_as_one_user: bool = True, + **kwargs, + ): + super().__init__(model=model, **kwargs) self.model = model - self.device = device + self.max_tokens = max_tokens assert is_transformers_available, "Transformers package is not installed" self.chatbot = pipeline( - task="conversational", model=self.model, device=self.device + task="text-generation", + model=self.model, + device_map="auto", + model_kwargs={"torch_dtype": torch.bfloat16}, ) + self.terminators = [ + self.chatbot.tokenizer.eos_token_id, + self.chatbot.tokenizer.convert_tokens_to_ids("<|eot_id|>"), + ] + self.merge_other_agent_as_user = merge_other_agents_as_one_user - @retry(stop=stop_after_attempt(6), wait=wait_random_exponential(min=1, max=60)) def _get_response(self, conversation): - conversation = self.chatbot(conversation) - response = conversation.generated_responses[-1] - return response - @staticmethod - def _msg_template(agent_name, content): - return f"[{agent_name}]: {content}" + conversation = self.chatbot( + conversation, + max_new_tokens=self.max_tokens, + eos_token_id=self.terminators, + pad_token_id=self.chatbot.tokenizer.eos_token_id, + ) + response = conversation[0]["generated_text"][-1]["content"] + return response def query( self, @@ -69,60 +88,57 @@ def query( *args, **kwargs, ) -> str: - user_inputs, generated_responses = [], [] - all_messages = ( - [(SYSTEM, global_prompt), (SYSTEM, role_desc)] - if global_prompt - else [(SYSTEM, role_desc)] - ) + if global_prompt: # Prepend the global prompt if it exists + system_prompt = f"You are a helpful assistant.\n{global_prompt.strip()}\n{BASE_PROMPT}\n\nYour name is {agent_name}.\n\nYour role:{role_desc}" + else: + system_prompt = f"You are a helpful assistant. Your name is {agent_name}.\n\nYour role:{role_desc}\n\n{BASE_PROMPT}" + all_messages = [(SYSTEM_NAME, system_prompt)] for msg in history_messages: - all_messages.append((msg.agent_name, msg.content)) - if request_msg: - all_messages.append((SYSTEM, request_msg.content)) + if msg.agent_name == SYSTEM_NAME: + all_messages.append((SYSTEM_NAME, msg.content)) + else: # non-system messages are suffixed with the end of message token + all_messages.append((msg.agent_name, f"{msg.content}{END_OF_MESSAGE}")) - prev_is_user = False # Whether the previous message is from the user - for i, message in enumerate(all_messages): + if request_msg: + all_messages.append((SYSTEM_NAME, request_msg.content)) + else: # The default request message that reminds the agent its role and instruct it to speak + all_messages.append( + (SYSTEM_NAME, f"Now you speak, {agent_name}.{END_OF_MESSAGE}") + ) + + messages = [] + for i, msg in enumerate(all_messages): if i == 0: assert ( - message[0] == SYSTEM + msg[0] == SYSTEM_NAME ) # The first message should be from the system - - if message[0] != agent_name: - if not prev_is_user: - user_inputs.append(self._msg_template(message[0], message[1])) - else: - user_inputs[-1] += "\n" + self._msg_template(message[0], message[1]) - prev_is_user = True + messages.append({"role": "system", "content": msg[1]}) else: - if prev_is_user: - generated_responses.append(message[1]) + if msg[0] == agent_name: + messages.append({"role": "assistant", "content": msg[1]}) else: - generated_responses[-1] += "\n" + message[1] - prev_is_user = False - - assert len(user_inputs) == len(generated_responses) + 1 - past_user_inputs = user_inputs[:-1] - new_user_input = user_inputs[-1] - - # Recreate a conversation object from the history messages - conversation = Conversation( - text=new_user_input, - past_user_inputs=past_user_inputs, - generated_responses=generated_responses, - ) + if messages[-1]["role"] == "user": # last message is from user + if self.merge_other_agent_as_user: + messages[-1][ + "content" + ] = f"{messages[-1]['content']}\n\n[{msg[0]}]: {msg[1]}" + else: + messages.append( + {"role": "user", "content": f"[{msg[0]}]: {msg[1]}"} + ) + elif ( + messages[-1]["role"] == "assistant" + ): # consecutive assistant messages + # Merge the assistant messages + messages[-1]["content"] = f"{messages[-1]['content']}\n{msg[1]}" + elif messages[-1]["role"] == "system": + messages.append( + {"role": "user", "content": f"[{msg[0]}]: {msg[1]}"} + ) + else: + raise ValueError(f"Invalid role: {messages[-1]['role']}") # Get the response - response = self._get_response(conversation) + response = self._get_response(messages) return response - - -# conversation = Conversation("Going to the movies tonight - any suggestions?") -# -# # Steps usually performed by the model when generating a response: -# # 1. Mark the user input as processed (moved to the history) -# conversation.mark_processed() -# # 2. Append a mode response -# conversation.append_response("The Big lebowski.") -# -# conversation.add_user_input("Is it good?") diff --git a/pyproject.toml b/pyproject.toml index 3a816763..01cb3108 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,12 +43,12 @@ langchain = ["langchain>=0.0.340"] gradio = ["gradio==3.34.0", "pydantic==1.10.13"] pettingzoo = ["pettingzoo>=1.24.0", "chess==1.9.4", "rlcard==1.0.5", "pygame==2.3.0", "gymnasium>=0.28.1"] umshini = ["pettingzoo>=1.24.1", "chess==1.9.4", "rlcard==1.0.5", "pygame==2.3.0", "gymnasium>=0.28.1", "langchain>=0.0.340", "colorama>=0.4.6"] -all_backends = ["anthropic>=0.2.8", "cohere>=4.3.1", "transformers>=4.27.4", "bardapi==0.1.11", "langchain>=0.0.135"] +all_backends = ["anthropic>=0.2.8", "cohere>=4.3.1", "transformers>=4.27.4", "bardapi==0.1.11", "langchain>=0.0.135", "accelerate>=0.33.0"] all_envs = ["pettingzoo>=1.24.0", "chess==1.9.4", "rlcard==1.0.5", "pygame==2.3.0", "langchain>=0.0.135"] database = ["supabase==2.0.3"] testing = ["deptry>=0.12.0", "pytest>=7.4.3", "pytest-cov>=4.1.0", "pytest-xdist>=3.4.0"] all = ["anthropic==0.2.8", "cohere==4.3.1", "transformers>=4.27.4", "gradio==3.34.0", "pydantic==1.10.13", "pettingzoo>=1.24.0", "chess==1.9.4", "rlcard==1.0.5", "pygame==2.3.0", "gymnasium>=0.28.1", - "colorama>=0.4.6", "supabase==2.0.3", "bardapi==0.1.11", "langchain>=0.0.340", "deptry>=0.12.0", "pytest>=7.4.3", "pytest-cov>=4.1.0", "pytest-xdist>=3.4.0"] + "colorama>=0.4.6", "supabase==2.0.3", "bardapi==0.1.11", "langchain>=0.0.340", "deptry>=0.12.0", "pytest>=7.4.3", "pytest-cov>=4.1.0", "pytest-xdist>=3.4.0", "accelerate>=0.33.0"] [tool.deptry.per_rule_ignores] DEP002 = [ "pytest", "pytest-cov", "deptry", "pytest-xdist", "chess", "rlcard", "pygame", "pydantic" ] diff --git a/tests/unit/test_hf_transformers.py b/tests/unit/test_hf_transformers.py index e2e21b27..3123401b 100644 --- a/tests/unit/test_hf_transformers.py +++ b/tests/unit/test_hf_transformers.py @@ -20,10 +20,9 @@ class TestHFTransformers(TestCase): - @unittest.skip("TODO: fix failing test") def test_transformers_conv_1(self): backend = TransformersConversational( - model="facebook/blenderbot-400M-distill", device=-1 + model="meta-llama/Meta-Llama-3.1-8B-Instruct" ) history_messages = [ @@ -45,7 +44,7 @@ def test_transformers_conv_1(self): def test_transformers_conv_2(self): backend = TransformersConversational( - model="facebook/blenderbot-400M-distill", device=-1 + model="meta-llama/Meta-Llama-3.1-8B-Instruct" ) history_messages = [