Skip to content

Commit 863435f

Browse files
Revert to main versions: llm.py, main.py, request_models.py
1 parent 6195940 commit 863435f

File tree

3 files changed

+75
-163
lines changed

3 files changed

+75
-163
lines changed

genAi/llm.py

Lines changed: 53 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -12,112 +12,76 @@
1212

1313

1414
class StudyLLM:
15-
# Class-level attributes for lazy initialization
16-
_chat_llm = None
17-
_generation_llm = None
18-
19-
@classmethod
20-
def _get_chat_llm(cls):
21-
"""Lazy initialization of chat LLM"""
22-
if cls._chat_llm is None:
23-
cls._chat_llm = ChatOpenAI(
24-
model="llama3.3:latest",
25-
temperature=0.5,
26-
api_key=os.getenv("OPEN_WEBUI_API_KEY_CHAT"),
27-
base_url="https://gpu.aet.cit.tum.de/api/",
28-
)
29-
return cls._chat_llm
30-
31-
@classmethod
32-
def _get_generation_llm(cls):
33-
"""Lazy initialization of generation LLM"""
34-
if cls._generation_llm is None:
35-
cls._generation_llm = ChatOpenAI(
36-
model="llama3.3:latest",
37-
temperature=0.5,
38-
api_key=os.getenv("OPEN_WEBUI_API_KEY_GEN"),
39-
base_url="https://gpu.aet.cit.tum.de/api/",
40-
)
41-
return cls._generation_llm
42-
43-
@property
44-
def chat_llm(self):
45-
"""Get the chat LLM instance"""
46-
return self._get_chat_llm()
47-
48-
@chat_llm.setter
49-
def chat_llm(self, value):
50-
"""Set the chat LLM instance (for testing)"""
51-
StudyLLM._chat_llm = value
52-
53-
@chat_llm.deleter
54-
def chat_llm(self):
55-
"""Reset the chat LLM instance (for testing)"""
56-
StudyLLM._chat_llm = None
57-
58-
@property
59-
def generation_llm(self):
60-
"""Get the generation LLM instance"""
61-
return self._get_generation_llm()
62-
63-
@generation_llm.setter
64-
def generation_llm(self, value):
65-
"""Set the generation LLM instance (for testing)"""
66-
StudyLLM._generation_llm = value
67-
68-
@generation_llm.deleter
69-
def generation_llm(self):
70-
"""Reset the generation LLM instance (for testing)"""
71-
StudyLLM._generation_llm = None
72-
15+
# for chat
16+
chat_llm = ChatOpenAI(
17+
model="llama3.3:latest",
18+
temperature=0.5,
19+
api_key=os.getenv("OPEN_WEBUI_API_KEY_CHAT"),
20+
base_url="https://gpu.aet.cit.tum.de/api/"
21+
)
22+
23+
# For summaries, quizzes, flashcards
24+
generation_llm = ChatOpenAI(
25+
model="llama3.3:latest",
26+
temperature=0.5,
27+
api_key=os.getenv("OPEN_WEBUI_API_KEY_GEN"),
28+
base_url="https://gpu.aet.cit.tum.de/api/"
29+
)
30+
7331
def __init__(self, doc_path: str):
74-
base_system_template = (
75-
"You are an expert on the information in the context given below.\n"
76-
"Use the context as your primary knowledge source. If you can't fulfill your task given the context, just say that.\n"
77-
"context: {context}\n"
78-
"Your task is {task}"
79-
)
80-
self.base_prompt_template = ChatPromptTemplate.from_messages(
81-
[("system", base_system_template), ("human", "{input}")]
82-
)
83-
try:
32+
base_system_template = ("You are an expert on the information in the context given below.\n"
33+
"Use the context as your primary knowledge source. If you can't fulfill your task given the context, just say that.\n"
34+
"context: {context}\n"
35+
"Your task is {task}"
36+
)
37+
self.base_prompt_template = ChatPromptTemplate.from_messages([
38+
('system', base_system_template),
39+
('human', '{input}')
40+
])
41+
try:
8442
self.rag_helper = RAGHelper(doc_path)
8543
except Exception as e:
8644
raise ValueError(f"Error initializing RAGHelper: {e}")
8745

46+
8847
async def prompt(self, prompt: str) -> str:
8948
"""
9049
Call the LLM with a given prompt.
91-
50+
9251
Args:
9352
prompt (str): The input prompt for the LLM.
94-
53+
9554
Returns:
9655
str: The response from the LLM.
9756
"""
98-
task = (
57+
task = (
9958
"To answer questions based on your context."
10059
"If you're asked a question that does not relate to your context, do not answer it - instead, answer by saying you're only familiar with <the topic in your context>.\n"
101-
)
102-
60+
)
61+
10362
context = self.rag_helper.retrieve(prompt, top_k=5)
10463
chain = self.base_prompt_template | self.chat_llm
105-
response = await chain.ainvoke(
106-
{"context": context, "task": task, "input": prompt}
107-
)
108-
64+
response = await chain.ainvoke({
65+
'context': context,
66+
'task':task,
67+
'input':prompt
68+
})
69+
10970
return response.content
11071

11172
async def summarize(self):
11273
"""
11374
Summarize the given document using the LLM.
114-
75+
11576
Returns:
11677
str: The summary of the document.
11778
"""
118-
79+
11980
map_prompt = PromptTemplate.from_template(
120-
(f"Write a medium length summary of the following:\n\n" "{text}")
81+
(
82+
f"Write a medium length summary of the following:\n\n"
83+
"{text}"
84+
)
12185
)
12286

12387
combine_prompt = PromptTemplate.from_template(
@@ -135,42 +99,37 @@ async def summarize(self):
13599
self.generation_llm,
136100
chain_type="map_reduce",
137101
map_prompt=map_prompt,
138-
combine_prompt=combine_prompt,
139-
)
140-
141-
result = await chain.ainvoke(
142-
{"input_documents": self.rag_helper.summary_chunks}
102+
combine_prompt=combine_prompt
143103
)
144104

105+
result = await chain.ainvoke({"input_documents": self.rag_helper.summary_chunks})
106+
145107
return result["output_text"]
146-
108+
147109
async def generate_flashcards(self):
148110
"""
149111
Generate flashcards from the document using the LLM.
150-
112+
151113
Returns:
152114
list: A list of flashcard objects.
153115
"""
154116
flashcard_chain = FlashcardChain(self.generation_llm)
155117
cards = await flashcard_chain.invoke(self.rag_helper.summary_chunks)
156118
return cards
157-
119+
158120
async def generate_quiz(self):
159121
"""
160122
Generate a quiz from the document using the LLM.
161-
123+
162124
Returns:
163125
list: A quiz object.
164126
"""
165127
quiz_chain = QuizChain(self.generation_llm)
166128
quiz = await quiz_chain.invoke(self.rag_helper.summary_chunks)
167129
return quiz
168-
130+
169131
def cleanup(self):
170132
"""
171133
Cleanup resources used by the LLM.
172134
"""
173-
try:
174-
self.rag_helper.cleanup()
175-
except Exception as e:
176-
print(f"Error during RAGHelper cleanup: {e}")
135+
self.rag_helper.cleanup()

genAi/main.py

Lines changed: 17 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,7 @@
33
from fastapi import FastAPI
44
from fastapi.responses import JSONResponse
55
from helpers import save_document
6-
from request_models import (
7-
CreateSessionRequest,
8-
PromptRequest,
9-
SummaryRequest,
10-
QuizRequest,
11-
FlashcardRequest,
12-
ProcessRequest,
13-
)
6+
from request_models import CreateSessionRequest, PromptRequest, SummaryRequest, QuizRequest, FlashcardRequest
147
from llm import StudyLLM
158
from prometheus_fastapi_instrumentator import Instrumentator
169

@@ -21,15 +14,13 @@
2114

2215
llm_instances: dict[str, StudyLLM] = {}
2316

24-
2517
@asynccontextmanager
2618
async def lifespan(_):
2719
yield
2820
# Shutdown: cleanup
2921
for llm in llm_instances.values():
3022
llm.cleanup()
3123

32-
3324
app = FastAPI(
3425
title="tutor",
3526
openapi_tags=[
@@ -47,20 +38,19 @@ async def lifespan(_):
4738
},
4839
{"name": "Ingestion", "description": "Endpoints to start ingestion processes."},
4940
],
50-
lifespan=lifespan,
41+
lifespan=lifespan
5142
)
5243

5344
Instrumentator(
54-
excluded_handlers=["/metrics"],
45+
excluded_handlers=['/metrics'],
5546
should_group_status_codes=False,
56-
should_instrument_requests_inprogress=True,
57-
).instrument(app).expose(app)
47+
should_instrument_requests_inprogress=True
48+
).instrument(app).expose(app)
5849

5950

6051
# llm_instances["dummy"] = StudyLLM("./documents/example/W07_Microservices_and_Scalable_Architectures.pdf") # TODO: remove
6152
# llm_instances["dummy2"] = StudyLLM("./documents/example/dummy_knowledge.txt") # TODO: remove
6253

63-
6454
# Auxiliary Endpoints
6555
@app.get("/health")
6656
async def health_check():
@@ -81,10 +71,8 @@ async def load_session(data: CreateSessionRequest):
8171
if data.session_id in llm_instances:
8272
logger.info(f"Session {data.session_id} already exists")
8373
return {"message": "Session already loaded."}
84-
85-
logger.info(
86-
f"Creating new session {data.session_id} for document {data.document_name}"
87-
)
74+
75+
logger.info(f"Creating new session {data.session_id} for document {data.document_name}")
8876
doc_name = f"{data.session_id}_{data.document_name}"
8977
path = save_document(doc_name, data.document_base64)
9078
llm_instances[data.session_id] = StudyLLM(path)
@@ -105,10 +93,8 @@ async def receive_prompt(data: PromptRequest):
10593
if data.session_id not in llm_instances:
10694
error_msg = f"Session {data.session_id} not found. Please ensure the document was processed successfully."
10795
logger.error(error_msg)
108-
return JSONResponse(
109-
status_code=404, content={"response": f"ERROR: {error_msg}"}
110-
)
111-
96+
return JSONResponse(status_code=404, content={"response": f"ERROR: {error_msg}"})
97+
11298
logger.info(f"Processing chat request for session {data.session_id}")
11399
response = await llm_instances[data.session_id].prompt(data.message)
114100
return {"response": response}
@@ -117,7 +103,6 @@ async def receive_prompt(data: PromptRequest):
117103
logger.error(error_msg)
118104
return {"response": f"ERROR: {error_msg}"}
119105

120-
121106
@app.post("/summary")
122107
async def generate_summary(data: SummaryRequest):
123108
"""
@@ -128,7 +113,7 @@ async def generate_summary(data: SummaryRequest):
128113
error_msg = f"Session {data.session_id} not found. Please ensure the document was processed successfully."
129114
logger.error(error_msg)
130115
return {"response": f"ERROR: {error_msg}"}
131-
116+
132117
logger.info(f"Generating summary for session {data.session_id}")
133118
response = await llm_instances[data.session_id].summarize()
134119
return {"response": response}
@@ -137,7 +122,6 @@ async def generate_summary(data: SummaryRequest):
137122
logger.error(error_msg)
138123
return {"response": f"ERROR: {error_msg}"}
139124

140-
141125
@app.post("/flashcard")
142126
async def generate_flashcards(data: FlashcardRequest):
143127
"""
@@ -148,19 +132,16 @@ async def generate_flashcards(data: FlashcardRequest):
148132
error_msg = f"Session {data.session_id} not found. Please ensure the document was processed successfully."
149133
logger.error(error_msg)
150134
return {"response": {"flashcards": [], "error": error_msg}}
151-
135+
152136
logger.info(f"Generating flashcards for session {data.session_id}")
153137
response = await llm_instances[data.session_id].generate_flashcards()
154138
logger.info(f"Flashcards generated successfully for session {data.session_id}")
155139
return {"response": response}
156140
except Exception as e:
157-
error_msg = (
158-
f"Flashcard generation error for session {data.session_id}: {str(e)}"
159-
)
141+
error_msg = f"Flashcard generation error for session {data.session_id}: {str(e)}"
160142
logger.error(error_msg)
161143
return {"response": {"flashcards": [], "error": error_msg}}
162144

163-
164145
@app.post("/quiz")
165146
async def generate_quiz(data: QuizRequest):
166147
"""
@@ -171,7 +152,7 @@ async def generate_quiz(data: QuizRequest):
171152
error_msg = f"Session {data.session_id} not found. Please ensure the document was processed successfully."
172153
logger.error(error_msg)
173154
return {"response": {"questions": [], "error": error_msg}}
174-
155+
175156
logger.info(f"Generating quiz for session {data.session_id}")
176157
response = await llm_instances[data.session_id].generate_quiz()
177158
logger.info(f"Quiz generated successfully for session {data.session_id}")
@@ -181,9 +162,8 @@ async def generate_quiz(data: QuizRequest):
181162
logger.error(error_msg)
182163
return {"response": {"questions": [], "error": error_msg}}
183164

184-
185165
@app.post("/process")
186-
async def process_document(data: ProcessRequest):
166+
async def process_document(data: SummaryRequest):
187167
"""Compatibility endpoint for Kotlin genai-service (/process).
188168
It creates a session (if not present) and immediately returns QUEUED.
189169
(Actual processing e.g. summary generation can be triggered asynchronously.)"""
@@ -199,13 +179,13 @@ async def process_document(data: ProcessRequest):
199179
"requestId": session_id,
200180
"status": "QUEUED",
201181
"message": "Document queued for processing",
202-
"estimatedTime": None,
182+
"estimatedTime": None
203183
}
204184
except Exception as e:
205185
logger.error(f"/process error: {str(e)}")
206186
return {
207187
"requestId": None,
208188
"status": "FAILED",
209189
"message": f"Failed to process document: {str(e)}",
210-
"estimatedTime": None,
211-
}
190+
"estimatedTime": None
191+
}

0 commit comments

Comments
 (0)