Skip to content

Commit 683c2a9

Browse files
committed
feat: Refactor LLMSummarizer class
Use model options to allow the user access model parameters from the app
1 parent 2e1206b commit 683c2a9

File tree

2 files changed

+41
-30
lines changed

2 files changed

+41
-30
lines changed

huginn_hears/main.py

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -159,54 +159,65 @@ def extractive_summarize(self, document: str) -> str:
159159
return final_summary
160160

161161

162-
class MistralSummarizer:
162+
class LLMSummarizer:
163163
"""
164-
A class for summarizing documents using the Mistral model.
164+
A class for summarizing documents using an LLM.
165165
166166
Args:
167-
model_path (str): The path to the Mistral model.
167+
repo_id (str): The path to a model in the Hugging Face model hub.
168+
filename (str): The filename of the model.
168169
text_splitter (TextSplitter, optional): The text splitter to use for splitting documents into chunks. Defaults to RecursiveCharacterTextSplitter.
169-
prompt_template (str, optional): The prompt template to use for generating prompts. Defaults to None.
170-
refine_template (str, optional): The refine template to use for refining summaries. Defaults to None.
170+
prompt_template (str): The prompt template to use for generating prompts.
171+
refine_template (str): The refine template to use for refining summaries.
172+
model_options (dict, optional): The options to use for the LLM model. Defaults to {
173+
'n_ctx': 4096,
174+
'max_tokens': 512,
175+
'n_batch': 16,
176+
'n_threads': 6,
177+
'temperature': 0.2,
178+
'top_p': 0.9,
179+
'repeat_penalty': 1.18,
180+
'verbose': True,
181+
'chat_format': "chatml",
182+
}
171183
"""
172184

173-
def __init__(self, repo_id: str, filename: str, text_splitter=RecursiveCharacterTextSplitter, prompt_template: str = None, refine_template: str = None):
185+
def __init__(self, repo_id: str, filename: str,
186+
prompt_template: str, refine_template: str,
187+
model_options: dict = {
188+
'n_ctx': 4096,
189+
'max_tokens': 512,
190+
'n_batch': 16,
191+
'n_threads': 6,
192+
'temperature': 0.2,
193+
'top_p': 0.9,
194+
'repeat_penalty': 1.18,
195+
'verbose': True,
196+
'chat_format': "chatml",
197+
},
198+
text_splitter=RecursiveCharacterTextSplitter):
199+
174200
self.repo_id = repo_id
175201
self.filename = filename
176202
self.layers = -1 if torch.cuda.is_available() else None
177203
self.model = None
178204
self.text_splitter = text_splitter(chunk_size=2048)
179205
self.prompt_template = PromptTemplate.from_template(prompt_template)
180206
self.refine_template = PromptTemplate.from_template(refine_template)
207+
self.model_options = model_options
181208

182209
@contextmanager
183-
def load_model(self, n_ctx=4096, max_tokens=512, n_batch=512, n_threads=6, temperature=0.2):
210+
def load_model(self):
184211
"""
185-
Context manager for loading and unloading the Mistral model.
186-
187-
Args:
188-
n_ctx (int, optional): The context size for the model. Defaults to 4096.
189-
max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 1024.
190-
n_batch (int, optional): The batch size for model inference. Defaults to 512.
191-
n_threads (int, optional): The number of threads to use for model inference. Defaults to 4.
192-
temperature (float, optional): The temperature for sampling from the model. Defaults to 0.2.
212+
Context manager for loading and unloading the Huggungface model.
193213
194214
Yields:
195215
LlamaCpp: The loaded Mistral model.
196216
"""
197217
self.model = CustomLlamaCpp(
198218
repo_id=self.repo_id,
199219
filename=self.filename,
200-
n_gpu_layers=self.layers,
201-
n_ctx=n_ctx,
202-
max_tokens=max_tokens,
203-
n_batch=n_batch,
204-
n_threads=n_threads,
205-
temperature=temperature,
206-
top_p=0.9,
207-
repeat_penalty=1.18, # Trying to avoid repeating the same words
208-
verbose=True,
209-
chat_format="chatml",
220+
**self.model_options,
210221
)
211222
try:
212223
yield self.model
@@ -265,11 +276,11 @@ class AudioSummarizationPipeline:
265276
TypeError: If summarizer is not an instance of MistralSummarizer.
266277
"""
267278

268-
def __init__(self, audio_path, transcriber: WhisperTranscriber, summarizer: MistralSummarizer, extractor: ExtractiveSummarizer):
279+
def __init__(self, audio_path, transcriber: WhisperTranscriber, summarizer: LLMSummarizer, extractor: ExtractiveSummarizer):
269280
if not isinstance(transcriber, WhisperTranscriber):
270281
raise TypeError(
271282
f'transcriber must be an instance of WhisperTranscriber, got {type(transcriber)} instead')
272-
if not isinstance(summarizer, MistralSummarizer):
283+
if not isinstance(summarizer, LLMSummarizer):
273284
raise TypeError(
274285
f'summarizer must be an instance of MistralSummarizer, got {type(summarizer)} instead')
275286
if not isinstance(extractor, ExtractiveSummarizer):
@@ -333,7 +344,7 @@ def run(self, extractive_summary=False):
333344
SVÆRT VIKTIG: Ikke nevn deg selv, kun skriv sammendraget. Ingen intro, ingen annen tekst [/INST]
334345
"""
335346
transcriber = WhisperTranscriber()
336-
summarizer = MistralSummarizer(repo_id="TheBloke/dolphin-2.6-mistral-7B-dpo-laser-GGUF", filename='*Q4_K_M.gguf',
347+
summarizer = LLMSummarizer(repo_id="TheBloke/dolphin-2.6-mistral-7B-dpo-laser-GGUF", filename='*Q4_K_M.gguf',
337348
prompt_template=prompt_template, refine_template=refine_template)
338349
extractor = ExtractiveSummarizer()
339350
audio_path = '/home/magsam/workspace/huginn-hears/test_files/king.mp3'

streamlit_app/app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
chatml_no_prompt_template, chatml_no_refine_template,
44
chatml_en_prompt_template, chatml_en_refine_template,
55
)
6-
from huginn_hears.main import WhisperTranscriber, MistralSummarizer, ExtractiveSummarizer
6+
from huginn_hears.main import WhisperTranscriber, LLMSummarizer, ExtractiveSummarizer
77
import gc
88
import torch
99
import tempfile
@@ -90,7 +90,7 @@ def main():
9090
# Initialize the transcriber and summarizer
9191
transcriber = WhisperTranscriber()
9292
extractive_summarizer = ExtractiveSummarizer()
93-
summarizer = MistralSummarizer(repo_id=mistral_model_path, filename=mistral_filename,
93+
summarizer = LLMSummarizer(repo_id=mistral_model_path, filename=mistral_filename,
9494
prompt_template=selected_prompt_template,
9595
refine_template=selected_refine_template)
9696

0 commit comments

Comments
 (0)