55import tempfile
66from dotenv import load_dotenv
77import yaml
8+ import torch
9+ import time
810
911from pdf_processor import PDFProcessor
1012from web_processor import WebProcessor
@@ -89,19 +91,49 @@ def chat(message: str, history: List[List[str]], agent_type: str, use_cot: bool,
8991 # Skip analysis for General Knowledge or when using standard chat interface (not CoT)
9092 skip_analysis = collection == "General Knowledge" or not use_cot
9193
94+ # Parse agent type to determine model and quantization
95+ quantization = None
96+ model_name = None
97+
98+ if "4-bit" in agent_type :
99+ quantization = "4bit"
100+ model_type = "Local (Mistral)"
101+ elif "8-bit" in agent_type :
102+ quantization = "8bit"
103+ model_type = "Local (Mistral)"
104+ elif "GGUF" in agent_type :
105+ model_type = "GGUF"
106+ # Extract model name from agent_type
107+ if "Phi-4-mini" in agent_type :
108+ model_name = "unsloth/Phi-4-mini-instruct-GGUF"
109+ elif "Qwen_QwQ-32B" in agent_type :
110+ model_name = "bartowski/Qwen_QwQ-32B-GGUF"
111+ elif "TinyR1-32B" in agent_type :
112+ model_name = "bartowski/qihoo360_TinyR1-32B-Preview-GGUF"
113+ else :
114+ model_type = agent_type
115+
92116 # Select appropriate agent and reinitialize with correct settings
93- if agent_type == "Local (Mistral) " :
117+ if "Local" in model_type or model_type == "GGUF " :
94118 if not hf_token :
95119 response_text = "Local agent not available. Please check your HuggingFace token configuration."
96120 print (f"Error: { response_text } " )
97121 return history + [[message , response_text ]]
98- agent = LocalRAGAgent (vector_store , use_cot = use_cot , collection = collection , skip_analysis = skip_analysis )
122+
123+ # Use specified model_name for GGUF models, otherwise use default
124+ if model_type == "GGUF" and model_name :
125+ agent = LocalRAGAgent (vector_store , model_name = model_name , use_cot = use_cot ,
126+ collection = collection , skip_analysis = skip_analysis )
127+ else :
128+ agent = LocalRAGAgent (vector_store , use_cot = use_cot , collection = collection ,
129+ skip_analysis = skip_analysis , quantization = quantization )
99130 else :
100131 if not openai_key :
101132 response_text = "OpenAI agent not available. Please check your OpenAI API key configuration."
102133 print (f"Error: { response_text } " )
103134 return history + [[message , response_text ]]
104- agent = RAGAgent (vector_store , openai_api_key = openai_key , use_cot = use_cot , collection = collection , skip_analysis = skip_analysis )
135+ agent = RAGAgent (vector_store , openai_api_key = openai_key , use_cot = use_cot ,
136+ collection = collection , skip_analysis = skip_analysis )
105137
106138 # Process query and get response
107139 print ("Processing query..." )
@@ -183,6 +215,83 @@ def create_interface():
183215 > **Note on Performance**: When using the Local (Mistral) model, initial loading can take 1-5 minutes, and each query may take 30-60 seconds to process depending on your hardware. OpenAI queries are typically much faster.
184216 """ )
185217
218+ # Create model choices list for reuse
219+ model_choices = []
220+ if hf_token :
221+ model_choices .extend ([
222+ "Local (Mistral)" ,
223+ "Local (Mistral) - 4-bit Quantized" ,
224+ "Local (Mistral) - 8-bit Quantized" ,
225+ "GGUF - Phi-4-mini-instruct" ,
226+ "GGUF - Qwen_QwQ-32B" ,
227+ "GGUF - TinyR1-32B-Preview"
228+ ])
229+ if openai_key :
230+ model_choices .append ("OpenAI" )
231+
232+ # Model Management Tab (First Tab)
233+ with gr .Tab ("Model Management" ):
234+ gr .Markdown ("""
235+ ## Model Management
236+
237+ Download models in advance to prepare them for use in the chat interface.
238+ This can help avoid delays when first using a model and ensure all models are properly downloaded.
239+
240+ > **Note**: Some models may require accepting terms on the Hugging Face website before downloading.
241+ > If you encounter a 401 error, please follow the link provided to accept the model terms.
242+ """ )
243+
244+ with gr .Row ():
245+ with gr .Column ():
246+ model_dropdown = gr .Dropdown (
247+ choices = model_choices ,
248+ value = model_choices [0 ] if model_choices else None ,
249+ label = "Select Model to Download" ,
250+ interactive = True
251+ )
252+ download_button = gr .Button ("Download Selected Model" )
253+ model_status = gr .Textbox (
254+ label = "Download Status" ,
255+ placeholder = "Select a model and click Download to begin..." ,
256+ interactive = False
257+ )
258+
259+ with gr .Column ():
260+ gr .Markdown ("""
261+ ### Model Information
262+
263+ **Local (Mistral)**: The default Mistral-7B-Instruct-v0.2 model.
264+ - Size: ~14GB
265+ - VRAM Required: ~8GB
266+ - Good balance of quality and speed
267+
268+ **Local (Mistral) - 4-bit Quantized**: 4-bit quantized version of Mistral-7B.
269+ - Size: ~4GB
270+ - VRAM Required: ~4GB
271+ - Faster inference with minimal quality loss
272+
273+ **Local (Mistral) - 8-bit Quantized**: 8-bit quantized version of Mistral-7B.
274+ - Size: ~7GB
275+ - VRAM Required: ~6GB
276+ - Balance between quality and memory usage
277+
278+ **GGUF - Phi-4-mini-instruct**: Microsoft's Phi-4-mini model in GGUF format.
279+ - Size: ~2-4GB
280+ - VRAM Required: Scales based on available VRAM
281+ - Efficient small model with good performance
282+
283+ **GGUF - Qwen_QwQ-32B**: Qwen 32B model in GGUF format.
284+ - Size: ~20GB
285+ - VRAM Required: Scales based on available VRAM
286+ - High-quality large model
287+
288+ **GGUF - TinyR1-32B-Preview**: Qihoo 360's TinyR1 32B model in GGUF format.
289+ - Size: ~20GB
290+ - VRAM Required: Scales based on available VRAM
291+ - High-quality large model
292+ """ )
293+
294+ # Document Processing Tab
186295 with gr .Tab ("Document Processing" ):
187296 with gr .Row ():
188297 with gr .Column ():
@@ -203,9 +312,10 @@ def create_interface():
203312 with gr .Tab ("Standard Chat Interface" ):
204313 with gr .Row ():
205314 with gr .Column (scale = 1 ):
315+ # Create model choices with quantization options
206316 standard_agent_dropdown = gr .Dropdown (
207- choices = [ "Local (Mistral)" , "OpenAI" ] if openai_key else [ "Local (Mistral)" ] ,
208- value = "Local (Mistral)" ,
317+ choices = model_choices ,
318+ value = model_choices [ 0 ] if model_choices else None ,
209319 label = "Select Agent"
210320 )
211321 with gr .Column (scale = 1 ):
@@ -230,9 +340,10 @@ def create_interface():
230340 with gr .Tab ("Chain of Thought Chat Interface" ):
231341 with gr .Row ():
232342 with gr .Column (scale = 1 ):
343+ # Create model choices with quantization options
233344 cot_agent_dropdown = gr .Dropdown (
234- choices = [ "Local (Mistral)" , "OpenAI" ] if openai_key else [ "Local (Mistral)" ] ,
235- value = "Local (Mistral)" ,
345+ choices = model_choices ,
346+ value = model_choices [ 0 ] if model_choices else None ,
236347 label = "Select Agent"
237348 )
238349 with gr .Column (scale = 1 ):
@@ -260,6 +371,9 @@ def create_interface():
260371 url_button .click (process_url , inputs = [url_input ], outputs = [url_output ])
261372 repo_button .click (process_repo , inputs = [repo_input ], outputs = [repo_output ])
262373
374+ # Model download event handler
375+ download_button .click (download_model , inputs = [model_dropdown ], outputs = [model_status ])
376+
263377 # Standard chat handlers
264378 standard_msg .submit (
265379 chat ,
@@ -359,5 +473,123 @@ def main():
359473 inbrowser = True
360474 )
361475
476+ def download_model (model_type : str ) -> str :
477+ """Download a model and return status message"""
478+ try :
479+ print (f"Downloading model: { model_type } " )
480+
481+ # Parse model type to determine model and quantization
482+ quantization = None
483+ model_name = "mistralai/Mistral-7B-Instruct-v0.2" # Default model
484+
485+ if "4-bit" in model_type :
486+ quantization = "4bit"
487+ elif "8-bit" in model_type :
488+ quantization = "8bit"
489+ elif "GGUF" in model_type :
490+ # Extract model name from model_type
491+ if "Phi-4-mini" in model_type :
492+ model_name = "unsloth/Phi-4-mini-instruct-GGUF"
493+ elif "Qwen_QwQ-32B" in model_type :
494+ model_name = "bartowski/Qwen_QwQ-32B-GGUF"
495+ elif "TinyR1-32B" in model_type :
496+ model_name = "bartowski/qihoo360_TinyR1-32B-Preview-GGUF"
497+
498+ # Check if HuggingFace token is available
499+ if not hf_token :
500+ return "❌ Error: HuggingFace token not found in config.yaml. Please add your token first."
501+
502+ # Start download timer
503+ start_time = time .time ()
504+
505+ # For GGUF models, use the GGUFModelHandler to download
506+ if "GGUF" in model_type :
507+ try :
508+ from local_rag_agent import GGUFModelHandler
509+ from huggingface_hub import list_repo_files
510+
511+ # Extract repo_id
512+ parts = model_name .split ('/' )
513+ repo_id = '/' .join (parts [:2 ])
514+
515+ # Check if model is gated
516+ try :
517+ files = list_repo_files (repo_id , token = hf_token )
518+ gguf_files = [f for f in files if f .endswith ('.gguf' )]
519+
520+ if not gguf_files :
521+ return f"❌ Error: No GGUF files found in repo: { repo_id } "
522+
523+ # Download the model
524+ handler = GGUFModelHandler (model_name )
525+
526+ # Calculate download time
527+ download_time = time .time () - start_time
528+ return f"✅ Successfully downloaded { model_type } in { download_time :.1f} seconds."
529+
530+ except Exception as e :
531+ if "401" in str (e ):
532+ return f"❌ Error: This model is gated. Please accept the terms on the Hugging Face website: https://huggingface.co/{ repo_id } "
533+ else :
534+ return f"❌ Error downloading model: { str (e )} "
535+
536+ except ImportError :
537+ return "❌ Error: llama-cpp-python not installed. Please install with: pip install llama-cpp-python"
538+
539+ # For Transformers models, use the Transformers library
540+ else :
541+ try :
542+ from transformers import AutoTokenizer , AutoModelForCausalLM , BitsAndBytesConfig
543+
544+ # Download tokenizer first (smaller download to check access)
545+ try :
546+ tokenizer = AutoTokenizer .from_pretrained (model_name , token = hf_token )
547+ except Exception as e :
548+ if "401" in str (e ):
549+ return f"❌ Error: This model is gated. Please accept the terms on the Hugging Face website: https://huggingface.co/{ model_name } "
550+ else :
551+ return f"❌ Error downloading tokenizer: { str (e )} "
552+
553+ # Set up model loading parameters
554+ model_kwargs = {
555+ "token" : hf_token ,
556+ "device_map" : None , # Don't load on GPU for download only
557+ }
558+
559+ # Apply quantization if specified
560+ if quantization == '4bit' :
561+ try :
562+ quantization_config = BitsAndBytesConfig (
563+ load_in_4bit = True ,
564+ bnb_4bit_compute_dtype = torch .float16 ,
565+ bnb_4bit_use_double_quant = True ,
566+ bnb_4bit_quant_type = "nf4"
567+ )
568+ model_kwargs ["quantization_config" ] = quantization_config
569+ except ImportError :
570+ return "❌ Error: bitsandbytes not installed. Please install with: pip install bitsandbytes>=0.41.0"
571+ elif quantization == '8bit' :
572+ try :
573+ quantization_config = BitsAndBytesConfig (load_in_8bit = True )
574+ model_kwargs ["quantization_config" ] = quantization_config
575+ except ImportError :
576+ return "❌ Error: bitsandbytes not installed. Please install with: pip install bitsandbytes>=0.41.0"
577+
578+ # Download model (but don't load it fully to save memory)
579+ AutoModelForCausalLM .from_pretrained (
580+ model_name ,
581+ ** model_kwargs
582+ )
583+
584+ # Calculate download time
585+ download_time = time .time () - start_time
586+ return f"✅ Successfully downloaded { model_type } in { download_time :.1f} seconds."
587+
588+ except Exception as e :
589+ return f"❌ Error downloading model: { str (e )} "
590+
591+ except Exception as e :
592+ return f"❌ Error: { str (e )} "
593+
362594if __name__ == "__main__" :
363595 main ()
0 commit comments