@@ -159,54 +159,65 @@ def extractive_summarize(self, document: str) -> str:
159159 return final_summary
160160
161161
162- class MistralSummarizer :
162+ class LLMSummarizer :
163163 """
164- A class for summarizing documents using the Mistral model .
164+ A class for summarizing documents using an LLM .
165165
166166 Args:
167- model_path (str): The path to the Mistral model.
167+ repo_id (str): The path to a model in the Hugging Face model hub.
168+ filename (str): The filename of the model.
168169 text_splitter (TextSplitter, optional): The text splitter to use for splitting documents into chunks. Defaults to RecursiveCharacterTextSplitter.
169- prompt_template (str, optional): The prompt template to use for generating prompts. Defaults to None.
170- refine_template (str, optional): The refine template to use for refining summaries. Defaults to None.
170+ prompt_template (str): The prompt template to use for generating prompts.
171+ refine_template (str): The refine template to use for refining summaries.
172+ model_options (dict, optional): The options to use for the LLM model. Defaults to {
173+ 'n_ctx': 4096,
174+ 'max_tokens': 512,
175+ 'n_batch': 16,
176+ 'n_threads': 6,
177+ 'temperature': 0.2,
178+ 'top_p': 0.9,
179+ 'repeat_penalty': 1.18,
180+ 'verbose': True,
181+ 'chat_format': "chatml",
182+ }
171183 """
172184
173- def __init__ (self , repo_id : str , filename : str , text_splitter = RecursiveCharacterTextSplitter , prompt_template : str = None , refine_template : str = None ):
185+ def __init__ (self , repo_id : str , filename : str ,
186+ prompt_template : str , refine_template : str ,
187+ model_options : dict = {
188+ 'n_ctx' : 4096 ,
189+ 'max_tokens' : 512 ,
190+ 'n_batch' : 16 ,
191+ 'n_threads' : 6 ,
192+ 'temperature' : 0.2 ,
193+ 'top_p' : 0.9 ,
194+ 'repeat_penalty' : 1.18 ,
195+ 'verbose' : True ,
196+ 'chat_format' : "chatml" ,
197+ },
198+ text_splitter = RecursiveCharacterTextSplitter ):
199+
174200 self .repo_id = repo_id
175201 self .filename = filename
176202 self .layers = - 1 if torch .cuda .is_available () else None
177203 self .model = None
178204 self .text_splitter = text_splitter (chunk_size = 2048 )
179205 self .prompt_template = PromptTemplate .from_template (prompt_template )
180206 self .refine_template = PromptTemplate .from_template (refine_template )
207+ self .model_options = model_options
181208
182209 @contextmanager
183- def load_model (self , n_ctx = 4096 , max_tokens = 512 , n_batch = 512 , n_threads = 6 , temperature = 0.2 ):
210+ def load_model (self ):
184211 """
185- Context manager for loading and unloading the Mistral model.
186-
187- Args:
188- n_ctx (int, optional): The context size for the model. Defaults to 4096.
189- max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 1024.
190- n_batch (int, optional): The batch size for model inference. Defaults to 512.
191- n_threads (int, optional): The number of threads to use for model inference. Defaults to 4.
192- temperature (float, optional): The temperature for sampling from the model. Defaults to 0.2.
212+ Context manager for loading and unloading the Huggungface model.
193213
194214 Yields:
195215 LlamaCpp: The loaded Mistral model.
196216 """
197217 self .model = CustomLlamaCpp (
198218 repo_id = self .repo_id ,
199219 filename = self .filename ,
200- n_gpu_layers = self .layers ,
201- n_ctx = n_ctx ,
202- max_tokens = max_tokens ,
203- n_batch = n_batch ,
204- n_threads = n_threads ,
205- temperature = temperature ,
206- top_p = 0.9 ,
207- repeat_penalty = 1.18 , # Trying to avoid repeating the same words
208- verbose = True ,
209- chat_format = "chatml" ,
220+ ** self .model_options ,
210221 )
211222 try :
212223 yield self .model
@@ -265,11 +276,11 @@ class AudioSummarizationPipeline:
265276 TypeError: If summarizer is not an instance of MistralSummarizer.
266277 """
267278
268- def __init__ (self , audio_path , transcriber : WhisperTranscriber , summarizer : MistralSummarizer , extractor : ExtractiveSummarizer ):
279+ def __init__ (self , audio_path , transcriber : WhisperTranscriber , summarizer : LLMSummarizer , extractor : ExtractiveSummarizer ):
269280 if not isinstance (transcriber , WhisperTranscriber ):
270281 raise TypeError (
271282 f'transcriber must be an instance of WhisperTranscriber, got { type (transcriber )} instead' )
272- if not isinstance (summarizer , MistralSummarizer ):
283+ if not isinstance (summarizer , LLMSummarizer ):
273284 raise TypeError (
274285 f'summarizer must be an instance of MistralSummarizer, got { type (summarizer )} instead' )
275286 if not isinstance (extractor , ExtractiveSummarizer ):
@@ -333,7 +344,7 @@ def run(self, extractive_summary=False):
333344 SVÆRT VIKTIG: Ikke nevn deg selv, kun skriv sammendraget. Ingen intro, ingen annen tekst [/INST]
334345 """
335346 transcriber = WhisperTranscriber ()
336- summarizer = MistralSummarizer (repo_id = "TheBloke/dolphin-2.6-mistral-7B-dpo-laser-GGUF" , filename = '*Q4_K_M.gguf' ,
347+ summarizer = LLMSummarizer (repo_id = "TheBloke/dolphin-2.6-mistral-7B-dpo-laser-GGUF" , filename = '*Q4_K_M.gguf' ,
337348 prompt_template = prompt_template , refine_template = refine_template )
338349 extractor = ExtractiveSummarizer ()
339350 audio_path = '/home/magsam/workspace/huginn-hears/test_files/king.mp3'
0 commit comments