Skip to content

Commit 1e8b9dc

Browse files
committed
feat: added model downloads chapter and new available GGUF and quantized bitsandbytes models
1 parent 6746b12 commit 1e8b9dc

File tree

1 file changed

+239
-7
lines changed

1 file changed

+239
-7
lines changed

agentic_rag/gradio_app.py

Lines changed: 239 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import tempfile
66
from dotenv import load_dotenv
77
import yaml
8+
import torch
9+
import time
810

911
from pdf_processor import PDFProcessor
1012
from web_processor import WebProcessor
@@ -89,19 +91,49 @@ def chat(message: str, history: List[List[str]], agent_type: str, use_cot: bool,
8991
# Skip analysis for General Knowledge or when using standard chat interface (not CoT)
9092
skip_analysis = collection == "General Knowledge" or not use_cot
9193

94+
# Parse agent type to determine model and quantization
95+
quantization = None
96+
model_name = None
97+
98+
if "4-bit" in agent_type:
99+
quantization = "4bit"
100+
model_type = "Local (Mistral)"
101+
elif "8-bit" in agent_type:
102+
quantization = "8bit"
103+
model_type = "Local (Mistral)"
104+
elif "GGUF" in agent_type:
105+
model_type = "GGUF"
106+
# Extract model name from agent_type
107+
if "Phi-4-mini" in agent_type:
108+
model_name = "unsloth/Phi-4-mini-instruct-GGUF"
109+
elif "Qwen_QwQ-32B" in agent_type:
110+
model_name = "bartowski/Qwen_QwQ-32B-GGUF"
111+
elif "TinyR1-32B" in agent_type:
112+
model_name = "bartowski/qihoo360_TinyR1-32B-Preview-GGUF"
113+
else:
114+
model_type = agent_type
115+
92116
# Select appropriate agent and reinitialize with correct settings
93-
if agent_type == "Local (Mistral)":
117+
if "Local" in model_type or model_type == "GGUF":
94118
if not hf_token:
95119
response_text = "Local agent not available. Please check your HuggingFace token configuration."
96120
print(f"Error: {response_text}")
97121
return history + [[message, response_text]]
98-
agent = LocalRAGAgent(vector_store, use_cot=use_cot, collection=collection, skip_analysis=skip_analysis)
122+
123+
# Use specified model_name for GGUF models, otherwise use default
124+
if model_type == "GGUF" and model_name:
125+
agent = LocalRAGAgent(vector_store, model_name=model_name, use_cot=use_cot,
126+
collection=collection, skip_analysis=skip_analysis)
127+
else:
128+
agent = LocalRAGAgent(vector_store, use_cot=use_cot, collection=collection,
129+
skip_analysis=skip_analysis, quantization=quantization)
99130
else:
100131
if not openai_key:
101132
response_text = "OpenAI agent not available. Please check your OpenAI API key configuration."
102133
print(f"Error: {response_text}")
103134
return history + [[message, response_text]]
104-
agent = RAGAgent(vector_store, openai_api_key=openai_key, use_cot=use_cot, collection=collection, skip_analysis=skip_analysis)
135+
agent = RAGAgent(vector_store, openai_api_key=openai_key, use_cot=use_cot,
136+
collection=collection, skip_analysis=skip_analysis)
105137

106138
# Process query and get response
107139
print("Processing query...")
@@ -183,6 +215,83 @@ def create_interface():
183215
> **Note on Performance**: When using the Local (Mistral) model, initial loading can take 1-5 minutes, and each query may take 30-60 seconds to process depending on your hardware. OpenAI queries are typically much faster.
184216
""")
185217

218+
# Create model choices list for reuse
219+
model_choices = []
220+
if hf_token:
221+
model_choices.extend([
222+
"Local (Mistral)",
223+
"Local (Mistral) - 4-bit Quantized",
224+
"Local (Mistral) - 8-bit Quantized",
225+
"GGUF - Phi-4-mini-instruct",
226+
"GGUF - Qwen_QwQ-32B",
227+
"GGUF - TinyR1-32B-Preview"
228+
])
229+
if openai_key:
230+
model_choices.append("OpenAI")
231+
232+
# Model Management Tab (First Tab)
233+
with gr.Tab("Model Management"):
234+
gr.Markdown("""
235+
## Model Management
236+
237+
Download models in advance to prepare them for use in the chat interface.
238+
This can help avoid delays when first using a model and ensure all models are properly downloaded.
239+
240+
> **Note**: Some models may require accepting terms on the Hugging Face website before downloading.
241+
> If you encounter a 401 error, please follow the link provided to accept the model terms.
242+
""")
243+
244+
with gr.Row():
245+
with gr.Column():
246+
model_dropdown = gr.Dropdown(
247+
choices=model_choices,
248+
value=model_choices[0] if model_choices else None,
249+
label="Select Model to Download",
250+
interactive=True
251+
)
252+
download_button = gr.Button("Download Selected Model")
253+
model_status = gr.Textbox(
254+
label="Download Status",
255+
placeholder="Select a model and click Download to begin...",
256+
interactive=False
257+
)
258+
259+
with gr.Column():
260+
gr.Markdown("""
261+
### Model Information
262+
263+
**Local (Mistral)**: The default Mistral-7B-Instruct-v0.2 model.
264+
- Size: ~14GB
265+
- VRAM Required: ~8GB
266+
- Good balance of quality and speed
267+
268+
**Local (Mistral) - 4-bit Quantized**: 4-bit quantized version of Mistral-7B.
269+
- Size: ~4GB
270+
- VRAM Required: ~4GB
271+
- Faster inference with minimal quality loss
272+
273+
**Local (Mistral) - 8-bit Quantized**: 8-bit quantized version of Mistral-7B.
274+
- Size: ~7GB
275+
- VRAM Required: ~6GB
276+
- Balance between quality and memory usage
277+
278+
**GGUF - Phi-4-mini-instruct**: Microsoft's Phi-4-mini model in GGUF format.
279+
- Size: ~2-4GB
280+
- VRAM Required: Scales based on available VRAM
281+
- Efficient small model with good performance
282+
283+
**GGUF - Qwen_QwQ-32B**: Qwen 32B model in GGUF format.
284+
- Size: ~20GB
285+
- VRAM Required: Scales based on available VRAM
286+
- High-quality large model
287+
288+
**GGUF - TinyR1-32B-Preview**: Qihoo 360's TinyR1 32B model in GGUF format.
289+
- Size: ~20GB
290+
- VRAM Required: Scales based on available VRAM
291+
- High-quality large model
292+
""")
293+
294+
# Document Processing Tab
186295
with gr.Tab("Document Processing"):
187296
with gr.Row():
188297
with gr.Column():
@@ -203,9 +312,10 @@ def create_interface():
203312
with gr.Tab("Standard Chat Interface"):
204313
with gr.Row():
205314
with gr.Column(scale=1):
315+
# Create model choices with quantization options
206316
standard_agent_dropdown = gr.Dropdown(
207-
choices=["Local (Mistral)", "OpenAI"] if openai_key else ["Local (Mistral)"],
208-
value="Local (Mistral)",
317+
choices=model_choices,
318+
value=model_choices[0] if model_choices else None,
209319
label="Select Agent"
210320
)
211321
with gr.Column(scale=1):
@@ -230,9 +340,10 @@ def create_interface():
230340
with gr.Tab("Chain of Thought Chat Interface"):
231341
with gr.Row():
232342
with gr.Column(scale=1):
343+
# Create model choices with quantization options
233344
cot_agent_dropdown = gr.Dropdown(
234-
choices=["Local (Mistral)", "OpenAI"] if openai_key else ["Local (Mistral)"],
235-
value="Local (Mistral)",
345+
choices=model_choices,
346+
value=model_choices[0] if model_choices else None,
236347
label="Select Agent"
237348
)
238349
with gr.Column(scale=1):
@@ -260,6 +371,9 @@ def create_interface():
260371
url_button.click(process_url, inputs=[url_input], outputs=[url_output])
261372
repo_button.click(process_repo, inputs=[repo_input], outputs=[repo_output])
262373

374+
# Model download event handler
375+
download_button.click(download_model, inputs=[model_dropdown], outputs=[model_status])
376+
263377
# Standard chat handlers
264378
standard_msg.submit(
265379
chat,
@@ -359,5 +473,123 @@ def main():
359473
inbrowser=True
360474
)
361475

476+
def download_model(model_type: str) -> str:
477+
"""Download a model and return status message"""
478+
try:
479+
print(f"Downloading model: {model_type}")
480+
481+
# Parse model type to determine model and quantization
482+
quantization = None
483+
model_name = "mistralai/Mistral-7B-Instruct-v0.2" # Default model
484+
485+
if "4-bit" in model_type:
486+
quantization = "4bit"
487+
elif "8-bit" in model_type:
488+
quantization = "8bit"
489+
elif "GGUF" in model_type:
490+
# Extract model name from model_type
491+
if "Phi-4-mini" in model_type:
492+
model_name = "unsloth/Phi-4-mini-instruct-GGUF"
493+
elif "Qwen_QwQ-32B" in model_type:
494+
model_name = "bartowski/Qwen_QwQ-32B-GGUF"
495+
elif "TinyR1-32B" in model_type:
496+
model_name = "bartowski/qihoo360_TinyR1-32B-Preview-GGUF"
497+
498+
# Check if HuggingFace token is available
499+
if not hf_token:
500+
return "❌ Error: HuggingFace token not found in config.yaml. Please add your token first."
501+
502+
# Start download timer
503+
start_time = time.time()
504+
505+
# For GGUF models, use the GGUFModelHandler to download
506+
if "GGUF" in model_type:
507+
try:
508+
from local_rag_agent import GGUFModelHandler
509+
from huggingface_hub import list_repo_files
510+
511+
# Extract repo_id
512+
parts = model_name.split('/')
513+
repo_id = '/'.join(parts[:2])
514+
515+
# Check if model is gated
516+
try:
517+
files = list_repo_files(repo_id, token=hf_token)
518+
gguf_files = [f for f in files if f.endswith('.gguf')]
519+
520+
if not gguf_files:
521+
return f"❌ Error: No GGUF files found in repo: {repo_id}"
522+
523+
# Download the model
524+
handler = GGUFModelHandler(model_name)
525+
526+
# Calculate download time
527+
download_time = time.time() - start_time
528+
return f"✅ Successfully downloaded {model_type} in {download_time:.1f} seconds."
529+
530+
except Exception as e:
531+
if "401" in str(e):
532+
return f"❌ Error: This model is gated. Please accept the terms on the Hugging Face website: https://huggingface.co/{repo_id}"
533+
else:
534+
return f"❌ Error downloading model: {str(e)}"
535+
536+
except ImportError:
537+
return "❌ Error: llama-cpp-python not installed. Please install with: pip install llama-cpp-python"
538+
539+
# For Transformers models, use the Transformers library
540+
else:
541+
try:
542+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
543+
544+
# Download tokenizer first (smaller download to check access)
545+
try:
546+
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
547+
except Exception as e:
548+
if "401" in str(e):
549+
return f"❌ Error: This model is gated. Please accept the terms on the Hugging Face website: https://huggingface.co/{model_name}"
550+
else:
551+
return f"❌ Error downloading tokenizer: {str(e)}"
552+
553+
# Set up model loading parameters
554+
model_kwargs = {
555+
"token": hf_token,
556+
"device_map": None, # Don't load on GPU for download only
557+
}
558+
559+
# Apply quantization if specified
560+
if quantization == '4bit':
561+
try:
562+
quantization_config = BitsAndBytesConfig(
563+
load_in_4bit=True,
564+
bnb_4bit_compute_dtype=torch.float16,
565+
bnb_4bit_use_double_quant=True,
566+
bnb_4bit_quant_type="nf4"
567+
)
568+
model_kwargs["quantization_config"] = quantization_config
569+
except ImportError:
570+
return "❌ Error: bitsandbytes not installed. Please install with: pip install bitsandbytes>=0.41.0"
571+
elif quantization == '8bit':
572+
try:
573+
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
574+
model_kwargs["quantization_config"] = quantization_config
575+
except ImportError:
576+
return "❌ Error: bitsandbytes not installed. Please install with: pip install bitsandbytes>=0.41.0"
577+
578+
# Download model (but don't load it fully to save memory)
579+
AutoModelForCausalLM.from_pretrained(
580+
model_name,
581+
**model_kwargs
582+
)
583+
584+
# Calculate download time
585+
download_time = time.time() - start_time
586+
return f"✅ Successfully downloaded {model_type} in {download_time:.1f} seconds."
587+
588+
except Exception as e:
589+
return f"❌ Error downloading model: {str(e)}"
590+
591+
except Exception as e:
592+
return f"❌ Error: {str(e)}"
593+
362594
if __name__ == "__main__":
363595
main()

0 commit comments

Comments
 (0)