-
-
Notifications
You must be signed in to change notification settings - Fork 48.7k
added unittest #12201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added unittest #12201
Changes from 19 commits
d6609cb
28b1f02
998eed4
7019bf4
f3d43e8
2dad12b
4ecdca1
f8510d7
fb102e6
6e7a428
922a230
a1d4cd9
508249e
4af7a67
7c49052
276528d
789e975
3e4430d
322434d
5d91b30
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
|
||
|
||
# Chatbot with LLM Integration and Database Storage | ||
|
||
This chatbot application integrates LLM (Large Language Model) API services, **Together** and **Groq**(you can use any one of them), to generate AI-driven responses. It stores conversation history in a MySQL database and manages chat sessions with triggers that update the status of conversations automatically. | ||
|
||
## Features | ||
- Supports LLM response generation using **Together** and **Groq** APIs. | ||
- Stores chat sessions and message exchanges in MySQL database tables. | ||
- Automatically updates chat session status using database triggers. | ||
- Manages conversation history with user-assistant interaction. | ||
|
||
## Requirements | ||
|
||
Before running the application, ensure the following dependencies are installed: | ||
|
||
- Python 3.13+ | ||
- MySQL Server | ||
- The following Python libraries: | ||
```bash | ||
pip3 install -r requirements.txt | ||
``` | ||
|
||
## Setup Instructions | ||
|
||
### Step 1: Set Up Environment Variables | ||
|
||
Create a `.env` file in the root directory of your project and add the following entries for your database credentials and API keys: | ||
|
||
``` | ||
# Together API key | ||
TOGETHER_API_KEY="YOUR_API_KEY" | ||
|
||
# Groq API key | ||
GROQ_API_KEY = "YOUR_API_KEY" | ||
|
||
# MySQL connectionDB (if you're running locally) | ||
DB_USER = "<DB_USER_NAME>" | ||
DB_PASSWORD = "<DB_USER_NAME>" | ||
DB_HOST = "127.0.0.1" | ||
DB_NAME = "ChatDB" | ||
PORT = "3306" | ||
|
||
# API service to you(or use "Together") | ||
API_SERVICE = "Groq" | ||
``` | ||
|
||
### Step 2: Create MySQL Tables and Trigger | ||
|
||
The `create_tables()` function in the script automatically creates the necessary tables and a trigger for updating chat session statuses. To ensure the database is set up correctly, the function is called at the beginning of the script. | ||
|
||
Ensure that your MySQL server is running and accessible before running the code. | ||
|
||
### Step 3: Run the Application | ||
|
||
To start the chatbot: | ||
|
||
1. Ensure your MySQL server is running. | ||
2. Open a terminal and run the Python script: | ||
|
||
```bash | ||
python3 chat_db.py | ||
``` | ||
|
||
The chatbot will initialize, and you can interact with it by typing your inputs. Type `/stop` to end the conversation. | ||
|
||
### Step 4: Test and Validate Code | ||
|
||
This project uses doctests to ensure that the functions work as expected. To run the doctests: | ||
|
||
```bash | ||
python3 -m doctest -v chatbot.py | ||
``` | ||
|
||
Make sure to add doctests to all your functions where applicable, to validate both valid and erroneous inputs. | ||
|
||
### Key Functions | ||
|
||
- **create_tables()**: Sets up the MySQL tables (`Chat_history` and `Chat_data`) and the `update_is_stream` trigger. | ||
- **insert_chat_history()**: Inserts a new chat session into the `Chat_history` table. | ||
- **insert_chat_data()**: Inserts user-assistant message pairs into the `Chat_data` table. | ||
- **generate_llm_response()**: Generates a response from the selected LLM API service, either **Together** or **Groq**. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,292 @@ | ||
""" | ||
credits : https://medium.com/google-developer-experts/beyond-live-sessions-building-persistent-memory-chatbots-with-langchain-gemini-pro-and-firebase-19d6f84e21d3 | ||
""" | ||
|
||
import os | ||
import datetime | ||
import mysql.connector | ||
from dotenv import load_dotenv | ||
from together import Together | ||
from groq import Groq | ||
import unittest | ||
from unittest.mock import patch | ||
from io import StringIO | ||
|
||
load_dotenv() | ||
|
||
# Database configuration | ||
db_config = { | ||
"user": os.environ.get("DB_USER"), | ||
"password": os.environ.get("DB_PASSWORD"), | ||
"host": os.environ.get("DB_HOST"), | ||
"database": os.environ.get("DB_NAME"), | ||
} | ||
|
||
class LLMService: | ||
def __init__(self, api_service: str): | ||
self.api_service = api_service | ||
if self.api_service == "Together": | ||
self.client = Together(api_key=os.environ.get("TOGETHER_API_KEY")) | ||
else: | ||
self.client = Groq(api_key=os.environ.get("GROQ_API_KEY")) | ||
|
||
def generate_response(self, conversation_history: list[dict]) -> str: | ||
""" | ||
Generate a response from the LLM based on the conversation history. | ||
|
||
Example: | ||
>>> llm_service = LLMService(api_service="Groq") | ||
>>> response = llm_service.generate_response([{"role": "user", "content": "Hello"}]) | ||
>>> isinstance(response, str) | ||
True | ||
""" | ||
if self.api_service == "Together": | ||
response = self.client.chat.completions.create( | ||
model="meta-llama/Llama-3.2-3B-Instruct-Turbo", | ||
messages=conversation_history, | ||
max_tokens=512, | ||
temperature=0.3, | ||
top_p=0.7, | ||
top_k=50, | ||
repetition_penalty=1, | ||
stop=["<|eot_id|>", "<|eom_id|>"], | ||
stream=False, | ||
) | ||
else: | ||
response = self.client.chat.completions.create( | ||
model="llama3-8b-8192", | ||
messages=conversation_history, | ||
max_tokens=1024, | ||
temperature=0.3, | ||
top_p=0.7, | ||
stop=["<|eot_id|>", "<|eom_id|>"], | ||
stream=False, | ||
) | ||
|
||
return response.choices[0].message.content | ||
|
||
|
||
class ChatDB: | ||
@staticmethod | ||
def create_tables() -> None: | ||
""" | ||
Create the ChatDB.Chat_history and ChatDB.Chat_data tables | ||
if they do not exist. Also, create a trigger to update is_stream | ||
in Chat_data when Chat_history.is_stream is updated. | ||
|
||
Example: | ||
>>> ChatDB.create_tables() | ||
Tables and trigger created successfully | ||
""" | ||
try: | ||
conn = mysql.connector.connect(**db_config) | ||
cursor = conn.cursor() | ||
|
||
cursor.execute( | ||
""" | ||
CREATE TABLE IF NOT EXISTS ChatDB.Chat_history ( | ||
chat_id INT AUTO_INCREMENT PRIMARY KEY, | ||
start_time DATETIME, | ||
is_stream INT | ||
) | ||
""" | ||
) | ||
|
||
cursor.execute( | ||
""" | ||
CREATE TABLE IF NOT EXISTS ChatDB.Chat_data ( | ||
id INT AUTO_INCREMENT PRIMARY KEY, | ||
chat_id INT, | ||
user TEXT, | ||
assistant TEXT, | ||
FOREIGN KEY (chat_id) REFERENCES ChatDB.Chat_history(chat_id) | ||
) | ||
""" | ||
) | ||
|
||
cursor.execute("DROP TRIGGER IF EXISTS update_is_stream;") | ||
|
||
cursor.execute( | ||
""" | ||
CREATE TRIGGER update_is_stream | ||
AFTER UPDATE ON ChatDB.Chat_history | ||
FOR EACH ROW | ||
BEGIN | ||
UPDATE ChatDB.Chat_data | ||
SET is_stream = NEW.is_stream | ||
WHERE chat_id = NEW.chat_id; | ||
END; | ||
""" | ||
) | ||
|
||
conn.commit() | ||
print("Tables and trigger created successfully") | ||
except mysql.connector.Error as err: | ||
print(f"Error: {err}") | ||
finally: | ||
cursor.close() | ||
conn.close() | ||
|
||
@staticmethod | ||
def insert_chat_history(start_time: datetime.datetime, is_stream: int) -> int: | ||
""" | ||
Insert a new row into the ChatDB.Chat_history table and return the inserted chat_id. | ||
|
||
Example: | ||
>>> from datetime import datetime | ||
>>> chat_id = ChatDB.insert_chat_history(datetime(2024, 1, 1, 12, 0, 0), 1) | ||
>>> isinstance(chat_id, int) | ||
True | ||
""" | ||
try: | ||
conn = mysql.connector.connect(**db_config) | ||
cursor = conn.cursor() | ||
cursor.execute( | ||
""" | ||
INSERT INTO ChatDB.Chat_history (start_time, is_stream) | ||
VALUES (%s, %s) | ||
""", | ||
(start_time, is_stream), | ||
) | ||
conn.commit() | ||
cursor.execute("SELECT LAST_INSERT_ID()") | ||
chat_id = cursor.fetchone()[0] | ||
print("Chat history inserted successfully.") | ||
return chat_id | ||
except mysql.connector.Error as err: | ||
print(f"Error: {err}") | ||
return None | ||
finally: | ||
cursor.close() | ||
conn.close() | ||
|
||
@staticmethod | ||
def get_latest_chat_id() -> int: | ||
""" | ||
Retrieve the latest chat_id from the ChatDB.Chat_history table. | ||
:return: The latest chat_id or None if no chat_id exists. | ||
|
||
Example: | ||
>>> chat_id = ChatDB.get_latest_chat_id() | ||
>>> isinstance(chat_id, int) | ||
True | ||
""" | ||
try: | ||
conn = mysql.connector.connect(**db_config) | ||
cursor = conn.cursor() | ||
cursor.execute( | ||
""" | ||
SELECT chat_id FROM ChatDB.Chat_history | ||
ORDER BY chat_id DESC LIMIT 1 | ||
""" | ||
) | ||
chat_id = cursor.fetchone()[0] | ||
return chat_id if chat_id else None | ||
except mysql.connector.Error as err: | ||
print(f"Error: {err}") | ||
return None | ||
finally: | ||
cursor.close() | ||
conn.close() | ||
|
||
@staticmethod | ||
def insert_chat_data(chat_id: int, user_message: str, assistant_message: str) -> None: | ||
""" | ||
Insert a new row into the ChatDB.Chat_data table. | ||
|
||
Example: | ||
>>> ChatDB.insert_chat_data(1, 'Hello', 'Hi there!') | ||
Chat data inserted successfully. | ||
""" | ||
try: | ||
conn = mysql.connector.connect(**db_config) | ||
cursor = conn.cursor() | ||
cursor.execute( | ||
""" | ||
INSERT INTO ChatDB.Chat_data (chat_id, user, assistant) | ||
VALUES (%s, %s, %s) | ||
""", | ||
(chat_id, user_message, assistant_message), | ||
) | ||
conn.commit() | ||
print("Chat data inserted successfully.") | ||
except mysql.connector.Error as err: | ||
print(f"Error: {err}") | ||
finally: | ||
cursor.close() | ||
conn.close() | ||
|
||
|
||
class Chatbot: | ||
def __init__(self, api_service: str): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please provide return type hint for the function: |
||
self.llm_service = LLMService(api_service) | ||
self.conversation_history = [] | ||
self.chat_id_pk = None | ||
self.start_time = datetime.datetime.now(datetime.timezone.utc) | ||
|
||
def chat_session(self) -> None: | ||
""" | ||
Start a chatbot session, allowing the user to interact with the LLM. | ||
Saves conversation history in the database and ends the session on "/stop" command. | ||
|
||
Example: | ||
>>> chatbot = Chatbot(api_service="Groq") | ||
>>> chatbot.chat_session() # This will be mocked in the tests | ||
Welcome to the chatbot! Type '/stop' to end the conversation. | ||
""" | ||
print("Welcome to the chatbot! Type '/stop' to end the conversation.") | ||
|
||
while True: | ||
user_input = input("\nYou: ").strip() | ||
self.conversation_history.append({"role": "user", "content": user_input}) | ||
|
||
if self.chat_id_pk is None: | ||
if user_input.lower() == "/stop": | ||
break | ||
bot_response = self.llm_service.generate_response(self.conversation_history) | ||
self.conversation_history.append({"role": "assistant", "content": bot_response}) | ||
|
||
is_stream = 1 # New conversation | ||
self.chat_id_pk = ChatDB.insert_chat_history(self.start_time, is_stream) # Return the chat_id | ||
if self.chat_id_pk: | ||
ChatDB.insert_chat_data(self.chat_id_pk, user_input, bot_response) | ||
else: | ||
if user_input.lower() == "/stop": | ||
is_stream = 2 # End of conversation | ||
current_time = datetime.datetime.now(datetime.timezone.utc) | ||
ChatDB.insert_chat_history(current_time, is_stream) | ||
break | ||
|
||
bot_response = self.llm_service.generate_response(self.conversation_history) | ||
self.conversation_history.append({"role": "assistant", "content": bot_response}) | ||
|
||
is_stream = 0 # Continuation of conversation | ||
current_time = datetime.datetime.now(datetime.timezone.utc) | ||
ChatDB.insert_chat_history(current_time, is_stream) | ||
ChatDB.insert_chat_data(self.chat_id_pk, user_input, bot_response) | ||
|
||
if len(self.conversation_history) > 1000: | ||
self.conversation_history = self.conversation_history[-3:] | ||
|
||
# Test cases for Chatbot | ||
class TestChatbot(unittest.TestCase): | ||
|
||
@patch('builtins.input', side_effect=["Hello", "/stop"]) | ||
@patch('sys.stdout', new_callable=StringIO) | ||
def test_chat_session(self, mock_stdout, mock_input): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please provide return type hint for the function: Please provide type hint for the parameter: Please provide type hint for the parameter: |
||
""" | ||
Test the chat_session method for expected welcome message. | ||
""" | ||
chatbot = Chatbot(api_service="Groq") | ||
chatbot.chat_session() | ||
|
||
# Check for the welcome message in the output | ||
output = mock_stdout.getvalue().strip().splitlines() | ||
self.assertIn("Welcome to the chatbot! Type '/stop' to end the conversation.", output) | ||
self.assertTrue(any("Chat history inserted successfully." in line for line in output)) | ||
self.assertTrue(any("Chat data inserted successfully." in line for line in output)) | ||
|
||
if __name__ == "__main__": | ||
# | ||
ChatDB.create_tables() | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please provide return type hint for the function:
__init__
. If the function does not return a value, please provide the type hint as:def function() -> None: