1212
1313
1414class StudyLLM :
15- # Class-level attributes for lazy initialization
16- _chat_llm = None
17- _generation_llm = None
18-
19- @classmethod
20- def _get_chat_llm (cls ):
21- """Lazy initialization of chat LLM"""
22- if cls ._chat_llm is None :
23- cls ._chat_llm = ChatOpenAI (
24- model = "llama3.3:latest" ,
25- temperature = 0.5 ,
26- api_key = os .getenv ("OPEN_WEBUI_API_KEY_CHAT" ),
27- base_url = "https://gpu.aet.cit.tum.de/api/" ,
28- )
29- return cls ._chat_llm
30-
31- @classmethod
32- def _get_generation_llm (cls ):
33- """Lazy initialization of generation LLM"""
34- if cls ._generation_llm is None :
35- cls ._generation_llm = ChatOpenAI (
36- model = "llama3.3:latest" ,
37- temperature = 0.5 ,
38- api_key = os .getenv ("OPEN_WEBUI_API_KEY_GEN" ),
39- base_url = "https://gpu.aet.cit.tum.de/api/" ,
40- )
41- return cls ._generation_llm
42-
43- @property
44- def chat_llm (self ):
45- """Get the chat LLM instance"""
46- return self ._get_chat_llm ()
47-
48- @chat_llm .setter
49- def chat_llm (self , value ):
50- """Set the chat LLM instance (for testing)"""
51- StudyLLM ._chat_llm = value
52-
53- @chat_llm .deleter
54- def chat_llm (self ):
55- """Reset the chat LLM instance (for testing)"""
56- StudyLLM ._chat_llm = None
57-
58- @property
59- def generation_llm (self ):
60- """Get the generation LLM instance"""
61- return self ._get_generation_llm ()
62-
63- @generation_llm .setter
64- def generation_llm (self , value ):
65- """Set the generation LLM instance (for testing)"""
66- StudyLLM ._generation_llm = value
67-
68- @generation_llm .deleter
69- def generation_llm (self ):
70- """Reset the generation LLM instance (for testing)"""
71- StudyLLM ._generation_llm = None
72-
15+ # for chat
16+ chat_llm = ChatOpenAI (
17+ model = "llama3.3:latest" ,
18+ temperature = 0.5 ,
19+ api_key = os .getenv ("OPEN_WEBUI_API_KEY_CHAT" ),
20+ base_url = "https://gpu.aet.cit.tum.de/api/"
21+ )
22+
23+ # For summaries, quizzes, flashcards
24+ generation_llm = ChatOpenAI (
25+ model = "llama3.3:latest" ,
26+ temperature = 0.5 ,
27+ api_key = os .getenv ("OPEN_WEBUI_API_KEY_GEN" ),
28+ base_url = "https://gpu.aet.cit.tum.de/api/"
29+ )
30+
7331 def __init__ (self , doc_path : str ):
74- base_system_template = (
75- "You are an expert on the information in the context given below .\n "
76- "Use the context as your primary knowledge source. If you can't fulfill your task given the context, just say that. \n "
77- "context: {context} \n "
78- "Your task is {task}"
79- )
80- self . base_prompt_template = ChatPromptTemplate . from_messages (
81- [( "system" , base_system_template ), ( " human" , " {input}" )]
82- )
83- try :
32+ base_system_template = ("You are an expert on the information in the context given below. \n "
33+ "Use the context as your primary knowledge source. If you can't fulfill your task given the context, just say that .\n "
34+ "context: {context} \n "
35+ "Your task is {task} "
36+ )
37+ self . base_prompt_template = ChatPromptTemplate . from_messages ([
38+ ( 'system' , base_system_template ),
39+ ( ' human' , ' {input}' )
40+ ] )
41+ try :
8442 self .rag_helper = RAGHelper (doc_path )
8543 except Exception as e :
8644 raise ValueError (f"Error initializing RAGHelper: { e } " )
8745
46+
8847 async def prompt (self , prompt : str ) -> str :
8948 """
9049 Call the LLM with a given prompt.
91-
50+
9251 Args:
9352 prompt (str): The input prompt for the LLM.
94-
53+
9554 Returns:
9655 str: The response from the LLM.
9756 """
98- task = (
57+ task = (
9958 "To answer questions based on your context."
10059 "If you're asked a question that does not relate to your context, do not answer it - instead, answer by saying you're only familiar with <the topic in your context>.\n "
101- )
102-
60+ )
61+
10362 context = self .rag_helper .retrieve (prompt , top_k = 5 )
10463 chain = self .base_prompt_template | self .chat_llm
105- response = await chain .ainvoke (
106- {"context" : context , "task" : task , "input" : prompt }
107- )
108-
64+ response = await chain .ainvoke ({
65+ 'context' : context ,
66+ 'task' :task ,
67+ 'input' :prompt
68+ })
69+
10970 return response .content
11071
11172 async def summarize (self ):
11273 """
11374 Summarize the given document using the LLM.
114-
75+
11576 Returns:
11677 str: The summary of the document.
11778 """
118-
79+
11980 map_prompt = PromptTemplate .from_template (
120- (f"Write a medium length summary of the following:\n \n " "{text}" )
81+ (
82+ f"Write a medium length summary of the following:\n \n "
83+ "{text}"
84+ )
12185 )
12286
12387 combine_prompt = PromptTemplate .from_template (
@@ -135,42 +99,37 @@ async def summarize(self):
13599 self .generation_llm ,
136100 chain_type = "map_reduce" ,
137101 map_prompt = map_prompt ,
138- combine_prompt = combine_prompt ,
139- )
140-
141- result = await chain .ainvoke (
142- {"input_documents" : self .rag_helper .summary_chunks }
102+ combine_prompt = combine_prompt
143103 )
144104
105+ result = await chain .ainvoke ({"input_documents" : self .rag_helper .summary_chunks })
106+
145107 return result ["output_text" ]
146-
108+
147109 async def generate_flashcards (self ):
148110 """
149111 Generate flashcards from the document using the LLM.
150-
112+
151113 Returns:
152114 list: A list of flashcard objects.
153115 """
154116 flashcard_chain = FlashcardChain (self .generation_llm )
155117 cards = await flashcard_chain .invoke (self .rag_helper .summary_chunks )
156118 return cards
157-
119+
158120 async def generate_quiz (self ):
159121 """
160122 Generate a quiz from the document using the LLM.
161-
123+
162124 Returns:
163125 list: A quiz object.
164126 """
165127 quiz_chain = QuizChain (self .generation_llm )
166128 quiz = await quiz_chain .invoke (self .rag_helper .summary_chunks )
167129 return quiz
168-
130+
169131 def cleanup (self ):
170132 """
171133 Cleanup resources used by the LLM.
172134 """
173- try :
174- self .rag_helper .cleanup ()
175- except Exception as e :
176- print (f"Error during RAGHelper cleanup: { e } " )
135+ self .rag_helper .cleanup ()
0 commit comments