Skip to content

Commit 865e948

Browse files
committed
add base chat model and webui service
1 parent b62ccea commit 865e948

File tree

7 files changed

+136
-9
lines changed

7 files changed

+136
-9
lines changed

genai/app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from dotenv import load_dotenv
22
from waitress import serve
33
from flask import Flask
4-
from controller.generate_controller import generate_bp
4+
from genai.controller.generate_controller import generate_bp
55

6-
from config import Config
6+
from genai.config import Config
77

88
app = Flask(__name__)
99
app.register_blueprint(generate_bp)

genai/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99
"Config",
1010
[
1111
"api_key_openai",
12-
"waitress"
12+
"waitress",
13+
"api_openwebui"
1314
],
1415
)
1516

1617
Config = ConfigT(
1718
api_key_openai=environ.get("API_SECRET_OPENAI_MINE"),
1819
waitress=environ.get("USE_WAITRESS", "false").lower() == "true",
20+
api_openwebui=environ.get("API_OPENWEBUI")
1921
)

genai/controller/generate_controller.py

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,23 @@
33
import logging
44
from werkzeug.utils import secure_filename
55

6-
from rag.ingestion_pipeline import IngestionPipeline
7-
from vector_database.qdrant_vdb import QdrantVDB
6+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
7+
from langchain_core.messages import HumanMessage
8+
9+
from genai.rag.ingestion_pipeline import IngestionPipeline
10+
from genai.vector_database.qdrant_vdb import QdrantVDB
11+
from genai.rag.llm.chat_model import ChatModel
12+
813

914
# Set Logging
1015
logging.getLogger().setLevel(logging.INFO)
1116

17+
# Set ChatModel
18+
llm = ChatModel(model_name="llama3.3:latest")
19+
20+
# Set Vector Database
21+
qdrant = QdrantVDB()
22+
1223
generate_bp = Blueprint('generate', __name__)
1324

1425

@@ -31,8 +42,6 @@ def upload_file():
3142

3243
try:
3344
collection_name = "recipes"
34-
# Initialize vector database
35-
qdrant = QdrantVDB()
3645
# Check if the file already in the collection
3746
if (qdrant.client.collection_exists(collection_name)
3847
and qdrant.collection_contains_file(
@@ -69,6 +78,47 @@ def upload_file():
6978
os.remove(file_path)
7079

7180

72-
@generate_bp.route('/api/generate', methods=['POST'])
81+
@generate_bp.route('/genai/generate', methods=['POST'])
7382
def generate():
74-
return jsonify({'output': 'Hello World!'})
83+
data = request.get_json()
84+
85+
if not data or "query" not in data or "conversation_id" not in data:
86+
return jsonify({"error": "Missing 'query' or 'conversation_id'"}), 400
87+
88+
query = data["query"]
89+
conversation_id = data["conversation_id"] # will be used
90+
91+
try:
92+
collection_name = "recipes"
93+
94+
if qdrant.client.collection_exists(collection_name):
95+
# Get vector store
96+
vector_store = qdrant.create_and_get_vector_storage(
97+
collection_name
98+
)
99+
100+
# Retrieve 5 similar documents
101+
retriever = vector_store.as_retriever(search_kwargs={"k": 5})
102+
retrieved_docs = retriever.invoke(query)
103+
docs_content = "\n\n".join(doc.page_content for doc in retrieved_docs)
104+
105+
# Prepare prompt
106+
prompt_template = ChatPromptTemplate([
107+
("system", "You are a helpful assistant for recipe generation based on the given ingredients and the following context:\n\n{context}"),
108+
MessagesPlaceholder("msgs")
109+
])
110+
111+
prompt = prompt_template.invoke({
112+
"context": docs_content,
113+
"msgs": HumanMessage(content=query)
114+
})
115+
116+
response = llm.invoke(prompt)
117+
return jsonify({
118+
"response": response.content,
119+
}), 200
120+
121+
except Exception as e:
122+
return jsonify({"error": str(e)}), 500
123+
124+

genai/rag/llm/__init__.py

Whitespace-only changes.

genai/rag/llm/chat_model.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from typing import List
2+
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
3+
from langchain_core.language_models.chat_models import BaseChatModel
4+
from langchain_core.outputs import ChatResult, ChatGeneration
5+
from pydantic import Field
6+
7+
from genai.service.openwebui_service import generate_response
8+
9+
10+
class ChatModel(BaseChatModel):
11+
model_name: str = Field(default="llama3.3:latest")
12+
13+
def _generate(self, messages: List[BaseMessage], stop=None, **kwargs) -> ChatResult:
14+
prompt = "\n".join([msg.content for msg in messages if isinstance(msg, HumanMessage)])
15+
response_text = generate_response(self.model_name, prompt)
16+
17+
return ChatResult(
18+
generations=[ChatGeneration(message=AIMessage(content=response_text))]
19+
)
20+
21+
@property
22+
def _llm_type(self) -> str:
23+
return "recipai-custom-model"
24+
25+
# For Testing purposes
26+
# if __name__ == "__main__":
27+
# llm = ChatModel(model_name="llama3.3:latest")
28+
29+
# message = HumanMessage(content="What is langchain, explain very briefly?")
30+
31+
# response = llm.invoke([message])
32+
33+
# print("LLM response:\n", response.content)

genai/service/openwebui_service.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import requests
2+
3+
from genai.config import Config
4+
5+
BASE_URL = "https://gpu.aet.cit.tum.de/"
6+
7+
def generate_response(model_name: str, prompt: str):
8+
"""Making a POST request to the respective endpoint for
9+
response generation by an LLM"""
10+
url = f"{BASE_URL}/api/chat/completions"
11+
12+
headers = {
13+
"Authorization": f"Bearer {Config.api_openwebui}",
14+
"Content-Type": "application/json"
15+
}
16+
17+
payload = {
18+
"model": model_name,
19+
"messages": [
20+
{
21+
"role": "user",
22+
"content": prompt
23+
}
24+
]
25+
}
26+
27+
try:
28+
response = requests.post(
29+
url,
30+
json=payload,
31+
headers=headers,
32+
timeout=120
33+
)
34+
response.raise_for_status()
35+
return response.json()["choices"][0]["message"]["content"]
36+
37+
except requests.exceptions.HTTPError as e:
38+
raise RuntimeError(f"HTTP error from LLM server: {e} (status {response.status_code})") from e
39+
except requests.exceptions.Timeout as e:
40+
raise RuntimeError(f"Request to LLM timed out: {e}") from e
41+
except requests.exceptions.RequestException as e:
42+
raise RuntimeError(f"Request to LLM failed: {e}") from e

genai/service/rag_service.py

Whitespace-only changes.

0 commit comments

Comments
 (0)