diff --git a/deploy/docker/docker-compose.yml b/deploy/docker/docker-compose.yml index a7c26cdc..820119cd 100755 --- a/deploy/docker/docker-compose.yml +++ b/deploy/docker/docker-compose.yml @@ -166,6 +166,7 @@ services: - MONGO_DB_USER=admin - MONGO_DB_PASSWORD=crapisecretpassword - MONGO_DB_NAME=crapi + - DEFAULT_MODEL=gpt-4o-mini # - CHATBOT_OPENAI_API_KEY= depends_on: mongodb: diff --git a/deploy/helm/templates/chatbot/config.yaml b/deploy/helm/templates/chatbot/config.yaml index 31a8ceb4..9afbb753 100644 --- a/deploy/helm/templates/chatbot/config.yaml +++ b/deploy/helm/templates/chatbot/config.yaml @@ -21,3 +21,4 @@ data: MONGO_DB_PASSWORD: {{ .Values.mongodb.config.mongoPassword }} MONGO_DB_NAME: {{ .Values.mongodb.config.mongoDbName }} CHATBOT_OPENAI_API_KEY: {{ .Values.openAIApiKey }} + DEFAULT_MODEL: {{ .Values.chatbot.config.defaultModel | quote }} diff --git a/deploy/helm/values.yaml b/deploy/helm/values.yaml index 84cce3b3..9122ff70 100644 --- a/deploy/helm/values.yaml +++ b/deploy/helm/values.yaml @@ -151,6 +151,7 @@ chatbot: postgresDbDriver: postgres mongoDbDriver: mongodb secretKey: crapi + defaultModel: gpt-4o-mini deploymentLabels: app: crapi-chatbot podLabels: diff --git a/services/chatbot/src/chatbot/chat_api.py b/services/chatbot/src/chatbot/chat_api.py index b47f99c2..1c8b0211 100644 --- a/services/chatbot/src/chatbot/chat_api.py +++ b/services/chatbot/src/chatbot/chat_api.py @@ -1,12 +1,15 @@ import logging from quart import Blueprint, jsonify, request, session from uuid import uuid4 +from .config import Config from .chat_service import delete_chat_history, get_chat_history, process_user_message from .session_service import ( delete_api_key, get_api_key, + get_model_name, get_or_create_session_id, store_api_key, + store_model_name, ) chat_bp = Blueprint("chat", __name__, url_prefix="/genai") @@ -34,11 +37,22 @@ async def init(): await store_api_key(session_id, openai_api_key) return jsonify({"message": "Initialized"}), 200 +@chat_bp.route("/model", methods=["POST"]) +async def model(): + session_id = await get_or_create_session_id() + data = await request.get_json() + model_name = Config.DEFAULT_MODEL_NAME + if data and "model_name" in data and data["model_name"]: + model_name = data["model_name"] + logger.debug("Setting model %s for session %s", model_name, session_id) + await store_model_name(session_id, model_name) + return jsonify({"model_used": model_name}), 200 @chat_bp.route("/ask", methods=["POST"]) async def chat(): session_id = await get_or_create_session_id() openai_api_key = await get_api_key(session_id) + model_name = await get_model_name(session_id) if not openai_api_key: return jsonify({"message": "Missing OpenAI API key. Please authenticate."}), 400 data = await request.get_json() @@ -46,7 +60,7 @@ async def chat(): id = data.get("id", uuid4().int & (1 << 63) - 1) if not message: return jsonify({"message": "Message is required", "id": id}), 400 - reply, response_id = await process_user_message(session_id, message, openai_api_key) + reply, response_id = await process_user_message(session_id, message, openai_api_key, model_name) return jsonify({"id": response_id, "message": reply}), 200 diff --git a/services/chatbot/src/chatbot/chat_service.py b/services/chatbot/src/chatbot/chat_service.py index 092ebfa0..7f38907e 100644 --- a/services/chatbot/src/chatbot/chat_service.py +++ b/services/chatbot/src/chatbot/chat_service.py @@ -22,13 +22,13 @@ async def delete_chat_history(session_id): await db.chat_sessions.delete_one({"session_id": session_id}) -async def process_user_message(session_id, user_message, api_key): +async def process_user_message(session_id, user_message, api_key, model_name): history = await get_chat_history(session_id) # generate a unique numeric id for the message that is random but unique source_message_id = uuid4().int & (1 << 63) - 1 history.append({"id": source_message_id, "role": "user", "content": user_message}) # Run LangGraph agent - response = await execute_langgraph_agent(api_key, history, session_id) + response = await execute_langgraph_agent(api_key, model_name, history, session_id) print("Response", response) reply: Messages = response.get("messages", [{}])[-1] print("Reply", reply.content) diff --git a/services/chatbot/src/chatbot/config.py b/services/chatbot/src/chatbot/config.py index 489f5ebc..2994d906 100644 --- a/services/chatbot/src/chatbot/config.py +++ b/services/chatbot/src/chatbot/config.py @@ -10,3 +10,4 @@ class Config: SECRET_KEY = os.getenv("SECRET_KEY", "super-secret") MONGO_URI = MONGO_CONNECTION_URI + DEFAULT_MODEL_NAME = os.getenv("DEFAULT_MODEL", "gpt-4o-mini") diff --git a/services/chatbot/src/chatbot/langgraph_agent.py b/services/chatbot/src/chatbot/langgraph_agent.py index 678ba9a8..ed164a4f 100644 --- a/services/chatbot/src/chatbot/langgraph_agent.py +++ b/services/chatbot/src/chatbot/langgraph_agent.py @@ -20,8 +20,6 @@ from .extensions import postgresdb from .mcp_client import mcp_client -model_name = "gpt-4o-mini" - async def get_retriever_tool(api_key): embeddings = OpenAIEmbeddings(api_key=api_key) @@ -48,7 +46,7 @@ async def get_retriever_tool(api_key): return retriever_tool -async def build_langgraph_agent(api_key): +async def build_langgraph_agent(api_key, model_name): system_prompt = textwrap.dedent( """ You are crAPI Assistant — an expert agent that helps users explore and test the Completely Ridiculous API (crAPI), a vulnerable-by-design application for learning and evaluating modern API security issues. @@ -86,7 +84,7 @@ async def build_langgraph_agent(api_key): Use the tools only if you don't know the answer. """ ) - llm = ChatOpenAI(api_key=api_key, model="gpt-4o-mini") + llm = ChatOpenAI(api_key=api_key, model=model_name) toolkit = SQLDatabaseToolkit(db=postgresdb, llm=llm) mcp_tools = await mcp_client.get_tools() db_tools = toolkit.get_tools() @@ -97,8 +95,8 @@ async def build_langgraph_agent(api_key): return agent_node -async def execute_langgraph_agent(api_key, messages, session_id=None): - agent = await build_langgraph_agent(api_key) +async def execute_langgraph_agent(api_key, model_name, messages, session_id=None): + agent = await build_langgraph_agent(api_key, model_name) print("messages", messages) print("Session ID", session_id) response = await agent.ainvoke({"messages": messages}) diff --git a/services/chatbot/src/chatbot/session_service.py b/services/chatbot/src/chatbot/session_service.py index d5dfd827..54902cb8 100644 --- a/services/chatbot/src/chatbot/session_service.py +++ b/services/chatbot/src/chatbot/session_service.py @@ -1,6 +1,6 @@ import os import uuid - +from .config import Config from quart import after_this_request, request from .extensions import db @@ -44,3 +44,16 @@ async def delete_api_key(session_id): await db.sessions.update_one( {"session_id": session_id}, {"$unset": {"openai_api_key": ""}} ) + +async def store_model_name(session_id, model_name): + await db.sessions.update_one( + {"session_id": session_id}, {"$set": {"model_name": model_name}}, upsert=True + ) + +async def get_model_name(session_id): + doc = await db.sessions.find_one({"session_id": session_id}) + if not doc: + return Config.DEFAULT_MODEL_NAME + if "model_name" not in doc: + return Config.DEFAULT_MODEL_NAME + return doc["model_name"] diff --git a/services/web/src/components/bot/ActionProvider.tsx b/services/web/src/components/bot/ActionProvider.tsx index 85243a9a..0990a15b 100644 --- a/services/web/src/components/bot/ActionProvider.tsx +++ b/services/web/src/components/bot/ActionProvider.tsx @@ -141,7 +141,7 @@ class ActionProvider { } console.log(res); const successmessage = this.createChatBotMessage( - "Chatbot initialized successfully.", + "Chatbot initialized successfully. By default, GPT-4o-mini model is being used. To change chatbot's model, please type model and press enter.", Math.floor(Math.random() * 65536), { loading: true, @@ -154,6 +154,95 @@ class ActionProvider { }); }; + handleModelSelection = (initRequired: boolean): void => { + console.log("Initialization required:", initRequired); + if (initRequired) { + const message = this.createChatBotMessage( + "Chatbot not initialized. To initialize the chatbot, please type init and press enter.", + Math.floor(Math.random() * 65536), + { + loading: true, + terminateLoading: true, + role: "assistant", + }, + ); + this.addMessageToState(message); + } else { + this.addModelSelectionToState(); + const message = this.createChatBotMessage( + `Type one of these available options and press enter:\n\n` + + `1. \`gpt-4o\` : GPT-4 Omni (fastest, multimodal, best for general use)\n\n` + + `2. \`gpt-4o-mini\` : Lighter version of GPT-4o (efficient for most tasks)\n\n` + + `3. \`gpt-4-turbo\` : GPT-4 Turbo (older but solid performance)\n\n` + + `4. \`gpt-3.5-turbo\` : GPT-3.5 Turbo (cheaper, good for lightweight tasks)\n\n` + + `5. \`gpt-3.5-turbo-16k\` : Like above but with 16k context window\n\n` + + `By default, GPT-4o-mini will be used if any invalid option is entered.`, + Math.floor(Math.random() * 65536), + { + loading: true, + terminateLoading: true, + role: "assistant", + }, + ); + this.addMessageToState(message); + } + }; + + handleModelConfirmation = (model_name: string | null, accessToken: string): void => { + const validModels: Record = { + "1": "gpt-4o", + "2": "gpt-4o-mini", + "3": "gpt-4-turbo", + "4": "gpt-3.5-turbo", + "5": "gpt-3.5-turbo-16k", + "gpt-4o": "gpt-4o", + "gpt-4o-mini": "gpt-4o-mini", + "gpt-4-turbo": "gpt-4-turbo", + "gpt-3.5-turbo": "gpt-3.5-turbo", + "gpt-3.5-turbo-16k": "gpt-3.5-turbo-16k" + }; + const selectedModel = model_name?.trim(); + const modelToUse = selectedModel && validModels[selectedModel] ? validModels[selectedModel] : null; + + const modelUrl = APIService.CHATBOT_SERVICE + "genai/model"; + superagent + .post(modelUrl) + .send({ model_name: modelToUse }) + .set("Accept", "application/json") + .set("Content-Type", "application/json") + .set("Authorization", `Bearer ${accessToken}`) + .end((err, res) => { + if (err) { + console.log(err); + const errormessage = this.createChatBotMessage( + "Failed to set model. Please try again.", + Math.floor(Math.random() * 65536), + { + loading: true, + terminateLoading: true, + role: "assistant", + }, + ); + this.addMessageToState(errormessage); + return; + } + + console.log(res); + const currentModel = res.body?.model_used || modelToUse; + const successmessage = this.createChatBotMessage( + `Model has been successfully set to ${currentModel}. You can now start chatting.`, + Math.floor(Math.random() * 65536), + { + loading: true, + terminateLoading: true, + role: "assistant", + }, + ); + this.addMessageToState(successmessage); + this.addModelConfirmationToState(); + }); + }; + handleChat = (message: string, accessToken: string): void => { const chatUrl = APIService.CHATBOT_SERVICE + "genai/ask"; console.log("Chat message:", message); @@ -223,7 +312,7 @@ class ActionProvider { this.addMessageToState(message); } else { const message = this.createChatBotMessage( - "Chat with the bot and exploit it.", + "Chat with the bot and exploit it. To change chatbot's model, please type model and press enter.", Math.floor(Math.random() * 65536), { loading: true, @@ -303,6 +392,20 @@ class ActionProvider { })); }; + addModelSelectionToState = (): void => { + this.setState((state) => ({ + ...state, + modelSelection: true, + })); + }; + + addModelConfirmationToState = (): void => { + this.setState((state) => ({ + ...state, + modelSelection: false, + })); + }; + clearMessages = (): void => { this.setState((state) => ({ ...state, diff --git a/services/web/src/components/bot/MessageParser.tsx b/services/web/src/components/bot/MessageParser.tsx index f7214eec..9e1a56d5 100644 --- a/services/web/src/components/bot/MessageParser.tsx +++ b/services/web/src/components/bot/MessageParser.tsx @@ -18,6 +18,7 @@ import request from "superagent"; interface State { initializationRequired?: boolean; initializing?: boolean; + modelSelection?: boolean; accessToken: string; chatHistory: ChatMessage[]; } @@ -40,6 +41,8 @@ interface ActionProvider { chatHistory: ChatMessage[], ) => void; handleNotInitialized: () => void; + handleModelSelection: (initRequired: boolean) => void; + handleModelConfirmation: (message: string, accessToken: string) => void; handleChat: (message: string, accessToken: string) => void; } @@ -107,6 +110,12 @@ class MessageParser { return this.actionProvider.handleInitialize( this.state.initializationRequired, ); + } else if (message_l === "model" || message_l === "models") { + const [initRequired, chatHistory] = await this.initializationRequired(); + this.state.initializationRequired = initRequired; + this.state.chatHistory = chatHistory; + console.log("State help:", this.state); + return this.actionProvider.handleModelSelection(this.state.initializationRequired); } else if ( message_l === "clear" || message_l === "reset" || @@ -121,6 +130,8 @@ class MessageParser { ); } else if (this.state.initializationRequired) { return this.actionProvider.handleNotInitialized(); + } else if (this.state.modelSelection) { + return this.actionProvider.handleModelConfirmation(message, this.state.accessToken); } return this.actionProvider.handleChat(message, this.state.accessToken);