diff --git a/neural_network/chatbot/README.md b/neural_network/chatbot/README.md new file mode 100644 index 000000000000..acddd4c8f671 --- /dev/null +++ b/neural_network/chatbot/README.md @@ -0,0 +1,83 @@ + + +# Chatbot with LLM Integration and Database Storage + +This chatbot application integrates LLM (Large Language Model) API services, **Together** and **Groq**(you can use any one of them), to generate AI-driven responses. It stores conversation history in a MySQL database and manages chat sessions with triggers that update the status of conversations automatically. + +## Features +- Supports LLM response generation using **Together** and **Groq** APIs. +- Stores chat sessions and message exchanges in MySQL database tables. +- Automatically updates chat session status using database triggers. +- Manages conversation history with user-assistant interaction. + +## Requirements + +Before running the application, ensure the following dependencies are installed: + +- Python 3.13+ +- MySQL Server +- The following Python libraries: + ```bash + pip3 install -r requirements.txt + ``` + +## Setup Instructions + +### Step 1: Set Up Environment Variables + +Create a `.env` file in the root directory of your project and add the following entries for your database credentials and API keys: + +``` +# Together API key +TOGETHER_API_KEY="YOUR_API_KEY" + +# Groq API key +GROQ_API_KEY = "YOUR_API_KEY" + +# MySQL connectionDB (if you're running locally) +DB_USER = "" +DB_PASSWORD = "" +DB_HOST = "127.0.0.1" +DB_NAME = "ChatDB" +PORT = "3306" + +# API service to you(or use "Together") +API_SERVICE = "Groq" +``` + +### Step 2: Create MySQL Tables and Trigger + +The `create_tables()` function in the script automatically creates the necessary tables and a trigger for updating chat session statuses. To ensure the database is set up correctly, the function is called at the beginning of the script. + +Ensure that your MySQL server is running and accessible before running the code. + +### Step 3: Run the Application + +To start the chatbot: + +1. Ensure your MySQL server is running. +2. Open a terminal and run the Python script: + +```bash +python3 chat_db.py +``` + +The chatbot will initialize, and you can interact with it by typing your inputs. Type `/stop` to end the conversation. + +### Step 4: Test and Validate Code + +This project uses doctests to ensure that the functions work as expected. To run the doctests: + +```bash +python3 -m doctest -v chatbot.py +``` + +Make sure to add doctests to all your functions where applicable, to validate both valid and erroneous inputs. + +### Key Functions + +- **create_tables()**: Sets up the MySQL tables (`Chat_history` and `Chat_data`) and the `update_is_stream` trigger. +- **insert_chat_history()**: Inserts a new chat session into the `Chat_history` table. +- **insert_chat_data()**: Inserts user-assistant message pairs into the `Chat_data` table. +- **generate_llm_response()**: Generates a response from the selected LLM API service, either **Together** or **Groq**. + diff --git a/neural_network/chatbot/chat_db.py b/neural_network/chatbot/chat_db.py new file mode 100644 index 000000000000..c1758ed0ac95 --- /dev/null +++ b/neural_network/chatbot/chat_db.py @@ -0,0 +1,312 @@ +""" +credits : https://medium.com/google-developer-experts/beyond-live-sessions-building-persistent-memory-chatbots-with-langchain-gemini-pro-and-firebase-19d6f84e21d3 +""" + +import os +import datetime +import mysql.connector +from dotenv import load_dotenv +from together import Together +from groq import Groq +import unittest +from unittest.mock import patch +from io import StringIO + +load_dotenv() + +# Database configuration +db_config = { + "user": os.environ.get("DB_USER"), + "password": os.environ.get("DB_PASSWORD"), + "host": os.environ.get("DB_HOST"), + "database": os.environ.get("DB_NAME"), +} + + +class LLMService: + def __init__(self, api_service: str): + self.api_service = api_service + if self.api_service == "Together": + self.client = Together(api_key=os.environ.get("TOGETHER_API_KEY")) + else: + self.client = Groq(api_key=os.environ.get("GROQ_API_KEY")) + + def generate_response(self, conversation_history: list[dict]) -> str: + """ + Generate a response from the LLM based on the conversation history. + + Example: + >>> llm_service = LLMService(api_service="Groq") + >>> response = llm_service.generate_response([{"role": "user", "content": "Hello"}]) + >>> isinstance(response, str) + True + """ + if self.api_service == "Together": + response = self.client.chat.completions.create( + model="meta-llama/Llama-3.2-3B-Instruct-Turbo", + messages=conversation_history, + max_tokens=512, + temperature=0.3, + top_p=0.7, + top_k=50, + repetition_penalty=1, + stop=["<|eot_id|>", "<|eom_id|>"], + stream=False, + ) + else: + response = self.client.chat.completions.create( + model="llama3-8b-8192", + messages=conversation_history, + max_tokens=1024, + temperature=0.3, + top_p=0.7, + stop=["<|eot_id|>", "<|eom_id|>"], + stream=False, + ) + + return response.choices[0].message.content + + +class ChatDB: + @staticmethod + def create_tables() -> None: + """ + Create the ChatDB.Chat_history and ChatDB.Chat_data tables + if they do not exist. Also, create a trigger to update is_stream + in Chat_data when Chat_history.is_stream is updated. + + Example: + >>> ChatDB.create_tables() + Tables and trigger created successfully + """ + try: + conn = mysql.connector.connect(**db_config) + cursor = conn.cursor() + + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS ChatDB.Chat_history ( + chat_id INT AUTO_INCREMENT PRIMARY KEY, + start_time DATETIME, + is_stream INT + ) + """ + ) + + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS ChatDB.Chat_data ( + id INT AUTO_INCREMENT PRIMARY KEY, + chat_id INT, + user TEXT, + assistant TEXT, + FOREIGN KEY (chat_id) REFERENCES ChatDB.Chat_history(chat_id) + ) + """ + ) + + cursor.execute("DROP TRIGGER IF EXISTS update_is_stream;") + + cursor.execute( + """ + CREATE TRIGGER update_is_stream + AFTER UPDATE ON ChatDB.Chat_history + FOR EACH ROW + BEGIN + UPDATE ChatDB.Chat_data + SET is_stream = NEW.is_stream + WHERE chat_id = NEW.chat_id; + END; + """ + ) + + conn.commit() + print("Tables and trigger created successfully") + except mysql.connector.Error as err: + print(f"Error: {err}") + finally: + cursor.close() + conn.close() + + @staticmethod + def insert_chat_history(start_time: datetime.datetime, is_stream: int) -> int: + """ + Insert a new row into the ChatDB.Chat_history table and return the inserted chat_id. + + Example: + >>> from datetime import datetime + >>> chat_id = ChatDB.insert_chat_history(datetime(2024, 1, 1, 12, 0, 0), 1) + >>> isinstance(chat_id, int) + True + """ + try: + conn = mysql.connector.connect(**db_config) + cursor = conn.cursor() + cursor.execute( + """ + INSERT INTO ChatDB.Chat_history (start_time, is_stream) + VALUES (%s, %s) + """, + (start_time, is_stream), + ) + conn.commit() + cursor.execute("SELECT LAST_INSERT_ID()") + chat_id = cursor.fetchone()[0] + print("Chat history inserted successfully.") + return chat_id + except mysql.connector.Error as err: + print(f"Error: {err}") + return None + finally: + cursor.close() + conn.close() + + @staticmethod + def get_latest_chat_id() -> int: + """ + Retrieve the latest chat_id from the ChatDB.Chat_history table. + :return: The latest chat_id or None if no chat_id exists. + + Example: + >>> chat_id = ChatDB.get_latest_chat_id() + >>> isinstance(chat_id, int) + True + """ + try: + conn = mysql.connector.connect(**db_config) + cursor = conn.cursor() + cursor.execute( + """ + SELECT chat_id FROM ChatDB.Chat_history + ORDER BY chat_id DESC LIMIT 1 + """ + ) + chat_id = cursor.fetchone()[0] + return chat_id if chat_id else None + except mysql.connector.Error as err: + print(f"Error: {err}") + return None + finally: + cursor.close() + conn.close() + + @staticmethod + def insert_chat_data( + chat_id: int, user_message: str, assistant_message: str + ) -> None: + """ + Insert a new row into the ChatDB.Chat_data table. + + Example: + >>> ChatDB.insert_chat_data(1, 'Hello', 'Hi there!') + Chat data inserted successfully. + """ + try: + conn = mysql.connector.connect(**db_config) + cursor = conn.cursor() + cursor.execute( + """ + INSERT INTO ChatDB.Chat_data (chat_id, user, assistant) + VALUES (%s, %s, %s) + """, + (chat_id, user_message, assistant_message), + ) + conn.commit() + print("Chat data inserted successfully.") + except mysql.connector.Error as err: + print(f"Error: {err}") + finally: + cursor.close() + conn.close() + + +class Chatbot: + def __init__(self, api_service: str): + self.llm_service = LLMService(api_service) + self.conversation_history = [] + self.chat_id_pk = None + self.start_time = datetime.datetime.now(datetime.timezone.utc) + + def chat_session(self) -> None: + """ + Start a chatbot session, allowing the user to interact with the LLM. + Saves conversation history in the database and ends the session on "/stop" command. + + Example: + >>> chatbot = Chatbot(api_service="Groq") + >>> chatbot.chat_session() # This will be mocked in the tests + Welcome to the chatbot! Type '/stop' to end the conversation. + """ + print("Welcome to the chatbot! Type '/stop' to end the conversation.") + + while True: + user_input = input("\nYou: ").strip() + self.conversation_history.append({"role": "user", "content": user_input}) + + if self.chat_id_pk is None: + if user_input.lower() == "/stop": + break + bot_response = self.llm_service.generate_response( + self.conversation_history + ) + self.conversation_history.append( + {"role": "assistant", "content": bot_response} + ) + + is_stream = 1 # New conversation + self.chat_id_pk = ChatDB.insert_chat_history( + self.start_time, is_stream + ) # Return the chat_id + if self.chat_id_pk: + ChatDB.insert_chat_data(self.chat_id_pk, user_input, bot_response) + else: + if user_input.lower() == "/stop": + is_stream = 2 # End of conversation + current_time = datetime.datetime.now(datetime.timezone.utc) + ChatDB.insert_chat_history(current_time, is_stream) + break + + bot_response = self.llm_service.generate_response( + self.conversation_history + ) + self.conversation_history.append( + {"role": "assistant", "content": bot_response} + ) + + is_stream = 0 # Continuation of conversation + current_time = datetime.datetime.now(datetime.timezone.utc) + ChatDB.insert_chat_history(current_time, is_stream) + ChatDB.insert_chat_data(self.chat_id_pk, user_input, bot_response) + + if len(self.conversation_history) > 1000: + self.conversation_history = self.conversation_history[-3:] + + +# Test cases for Chatbot +class TestChatbot(unittest.TestCase): + @patch("builtins.input", side_effect=["Hello", "/stop"]) + @patch("sys.stdout", new_callable=StringIO) + def test_chat_session(self, mock_stdout, mock_input): + """ + Test the chat_session method for expected welcome message. + """ + chatbot = Chatbot(api_service="Groq") + chatbot.chat_session() + + # Check for the welcome message in the output + output = mock_stdout.getvalue().strip().splitlines() + self.assertIn( + "Welcome to the chatbot! Type '/stop' to end the conversation.", output + ) + self.assertTrue( + any("Chat history inserted successfully." in line for line in output) + ) + self.assertTrue( + any("Chat data inserted successfully." in line for line in output) + ) + + +if __name__ == "__main__": + # + ChatDB.create_tables() + unittest.main() diff --git a/neural_network/chatbot/requirements.txt b/neural_network/chatbot/requirements.txt new file mode 100644 index 000000000000..0f1204243a5d --- /dev/null +++ b/neural_network/chatbot/requirements.txt @@ -0,0 +1,57 @@ +aiohappyeyeballs==2.4.2 +aiohttp==3.10.8 +aiosignal==1.3.1 +annotated-types==0.7.0 +anyio==4.6.0 +asgiref==3.8.1 +attrs==24.2.0 +black==24.10.0 +certifi==2024.8.30 +cfgv==3.4.0 +charset-normalizer==3.3.2 +click==8.1.7 +distlib==0.3.9 +distro==1.9.0 +Django==5.1.1 +djangorestframework==3.15.2 +eval_type_backport==0.2.0 +filelock==3.16.1 +frozenlist==1.4.1 +groq==0.11.0 +h11==0.14.0 +httpcore==1.0.5 +httpx==0.27.2 +identify==2.6.1 +idna==3.10 +markdown-it-py==3.0.0 +mdurl==0.1.2 +multidict==6.1.0 +mypy-extensions==1.0.0 +mysql-connector-python==9.0.0 +nodeenv==1.9.1 +numpy==2.1.1 +packaging==24.1 +pathspec==0.12.1 +pillow==10.4.0 +platformdirs==4.3.6 +pre_commit==4.0.1 +pyarrow==17.0.0 +pydantic==2.9.2 +pydantic_core==2.23.4 +Pygments==2.18.0 +python-dotenv==1.0.1 +PyYAML==6.0.2 +requests==2.32.3 +rich==13.8.1 +ruff==0.7.0 +shellingham==1.5.4 +sniffio==1.3.1 +sqlparse==0.5.1 +tabulate==0.9.0 +together==1.3.0 +tqdm==4.66.5 +typer==0.12.5 +typing_extensions==4.12.2 +urllib3==2.2.3 +virtualenv==20.27.0 +yarl==1.13.1