1818
1919
2020class StudyLLM :
21- llm = ChatOpenAI (
21+ # for chat
22+ chat_llm = ChatOpenAI (
2223 model = "llama3.3:latest" ,
2324 temperature = 0.5 ,
24- api_key = os .getenv ("OPEN_WEBUI_API_KEY" ),
25+ api_key = os .getenv ("OPEN_WEBUI_API_KEY_CHAT" ),
26+ base_url = "https://gpu.aet.cit.tum.de/api/"
27+ )
28+
29+ # For summaries, quizzes, flashcards
30+ generation_llm = ChatOpenAI (
31+ model = "llama3.3:latest" ,
32+ temperature = 0.5 ,
33+ api_key = os .getenv ("OPEN_WEBUI_API_KEY_GEN" ),
2534 base_url = "https://gpu.aet.cit.tum.de/api/"
2635 )
2736
@@ -39,26 +48,9 @@ def __init__(self, doc_path: str):
3948 self .rag_helper = RAGHelper (doc_path )
4049 except Exception as e :
4150 raise ValueError (f"Error initializing RAGHelper: { e } " )
42-
43- def _chain (self , output_model : BaseModel = None ):
44- """
45- Construct a chain for the LLM with given configurations.
46-
47- Args:
48- OutputModel (BaseModel, optional): A Pydantic model for structured output.
49- ...
50- Returns:
51- RnnableSequence: The chain for the LLM.
52- """
53- llm = self .llm
54-
55- if output_model :
56- llm = llm .with_structured_output (output_model )
57-
58- return self .base_prompt_template | llm
5951
6052
61- def prompt (self , prompt : str ) -> str :
53+ async def prompt (self , prompt : str ) -> str :
6254 """
6355 Call the LLM with a given prompt.
6456
@@ -74,13 +66,16 @@ def prompt(self, prompt: str) -> str:
7466 )
7567
7668 context = self .rag_helper .retrieve (prompt , top_k = 5 )
77- return self ._chain ().invoke ({
69+ chain = self .base_prompt_template | self .chat_llm
70+ response = await chain .ainvoke ({
7871 'context' : context ,
7972 'task' :task ,
8073 'input' :prompt
81- }).content
74+ })
75+
76+ return response .content
8277
83- def summarize (self ):
78+ async def summarize (self ):
8479 """
8580 Summarize the given document using the LLM.
8681
@@ -107,13 +102,13 @@ def summarize(self):
107102 )
108103
109104 chain = load_summarize_chain (
110- self .llm ,
105+ self .generation_llm ,
111106 chain_type = "map_reduce" ,
112107 map_prompt = map_prompt ,
113108 combine_prompt = combine_prompt
114109 )
115110
116- result = chain .invoke ({"input_documents" : self .rag_helper .summary_chunks })
111+ result = await chain .ainvoke ({"input_documents" : self .rag_helper .summary_chunks })
117112
118113 return result ["output_text" ]
119114
@@ -124,7 +119,7 @@ async def generate_flashcards(self):
124119 Returns:
125120 list: A list of flashcard objects.
126121 """
127- flashcard_chain = FlashcardChain (self .llm )
122+ flashcard_chain = FlashcardChain (self .generation_llm )
128123 cards = await flashcard_chain .invoke (self .rag_helper .summary_chunks )
129124 return cards
130125
@@ -135,7 +130,7 @@ async def generate_quiz(self):
135130 Returns:
136131 list: A quiz object.
137132 """
138- quiz_chain = QuizChain (self .llm )
133+ quiz_chain = QuizChain (self .generation_llm )
139134 quiz = await quiz_chain .invoke (self .rag_helper .summary_chunks )
140135 return quiz
141136
0 commit comments