@@ -57,7 +57,10 @@ def __init__(self, model_name: str):
57
57
Args:
58
58
model_name: Name of the Ollama model to use
59
59
"""
60
- # Use the model name directly without any transformation
60
+ # Ensure model name has :latest suffix
61
+ if not model_name .endswith (":latest" ):
62
+ model_name = f"{ model_name } :latest"
63
+
61
64
self .model_name = model_name
62
65
self ._check_ollama_running ()
63
66
@@ -74,13 +77,11 @@ def _check_ollama_running(self):
74
77
75
78
# Check if the requested model is available
76
79
if self .model_name not in available_models :
77
- # Try with :latest suffix
78
- if f"{ self .model_name } :latest" in available_models :
79
- self .model_name = f"{ self .model_name } :latest"
80
- print (f"Using model with :latest suffix: { self .model_name } " )
81
- else :
82
- print (f"Model '{ self .model_name } ' not found in Ollama. Available models: { ', ' .join (available_models )} " )
83
- print (f"You can pull it with: ollama pull { self .model_name } " )
80
+ print (f"Model '{ self .model_name } ' not found in Ollama. Available models: { ', ' .join (available_models )} " )
81
+ print (f"You can pull it with: ollama pull { self .model_name } " )
82
+ raise ValueError (f"Model '{ self .model_name } ' not found in Ollama" )
83
+ else :
84
+ print (f"Using Ollama model: { self .model_name } " )
84
85
except Exception as e :
85
86
raise ConnectionError (f"Failed to connect to Ollama. Please make sure Ollama is running. Error: { str (e )} " )
86
87
@@ -92,6 +93,9 @@ def __call__(self, prompt, max_new_tokens=512, temperature=0.1, top_p=0.95, **kw
92
93
try :
93
94
import ollama
94
95
96
+ print (f"\n Generating response with Ollama model: { self .model_name } " )
97
+ print (f"Prompt: { prompt [:100 ]} ..." ) # Print first 100 chars of prompt
98
+
95
99
# Generate text
96
100
response = ollama .generate (
97
101
model = self .model_name ,
@@ -103,6 +107,8 @@ def __call__(self, prompt, max_new_tokens=512, temperature=0.1, top_p=0.95, **kw
103
107
}
104
108
)
105
109
110
+ print (f"Response generated successfully with { self .model_name } " )
111
+
106
112
# Format result to match transformers pipeline output
107
113
formatted_result = [{
108
114
"generated_text" : response ["response" ]
@@ -114,7 +120,7 @@ def __call__(self, prompt, max_new_tokens=512, temperature=0.1, top_p=0.95, **kw
114
120
raise Exception (f"Failed to generate text with Ollama: { str (e )} " )
115
121
116
122
class LocalRAGAgent :
117
- def __init__ (self , vector_store : VectorStore = None , model_name : str = "mistralai/Mistral-7B-Instruct-v0.2" ,
123
+ def __init__ (self , vector_store : VectorStore = None , model_name : str = None ,
118
124
use_cot : bool = False , collection : str = None , skip_analysis : bool = False ,
119
125
quantization : str = None , use_oracle_db : bool = True ):
120
126
"""Initialize local RAG agent with vector store and local LLM
@@ -165,7 +171,7 @@ def __init__(self, vector_store: VectorStore = None, model_name: str = "mistrala
165
171
# skip_analysis parameter kept for backward compatibility but no longer used
166
172
167
173
# Check if this is an Ollama model
168
- self .is_ollama = model_name .startswith ("ollama:" ) or "Ollama - " in model_name
174
+ self .is_ollama = model_name and ( model_name .startswith ("ollama:" ) or "Ollama - " in model_name )
169
175
170
176
if self .is_ollama :
171
177
# Extract the actual model name from the prefix
@@ -178,6 +184,10 @@ def __init__(self, vector_store: VectorStore = None, model_name: str = "mistrala
178
184
else :
179
185
ollama_model_name = model_name
180
186
187
+ # Add :latest suffix if not present
188
+ if not ollama_model_name .endswith (":latest" ):
189
+ ollama_model_name = f"{ ollama_model_name } :latest"
190
+
181
191
# Load Ollama model
182
192
print ("\n Loading Ollama model..." )
183
193
print (f"Model: { ollama_model_name } " )
@@ -188,87 +198,46 @@ def __init__(self, vector_store: VectorStore = None, model_name: str = "mistrala
188
198
189
199
# Create pipeline-like interface
190
200
self .pipeline = self .ollama_handler
191
-
201
+ print ( f"Using Ollama model: { ollama_model_name } " )
192
202
else :
193
- # Load HuggingFace token from config
194
- try :
195
- with open ('config.yaml' , 'r' ) as f :
196
- config = yaml .safe_load (f )
197
- token = config .get ('HUGGING_FACE_HUB_TOKEN' )
198
- if not token :
199
- raise ValueError ("HUGGING_FACE_HUB_TOKEN not found in config.yaml" )
200
- except Exception as e :
201
- raise Exception (f"Failed to load HuggingFace token from config.yaml: { str (e )} " )
202
-
203
- # Load model and tokenizer
204
- print ("\n Loading model and tokenizer..." )
205
- print (f"Model: { model_name } " )
206
- if quantization :
207
- print (f"Quantization: { quantization } " )
208
- print ("Note: Initial loading and inference can take 1-5 minutes depending on your hardware." )
209
- print ("Subsequent queries will be faster but may still take 30-60 seconds per response." )
210
-
211
- # Check if CUDA is available and set appropriate dtype
212
- if torch .cuda .is_available ():
213
- print ("CUDA is available. Using GPU acceleration." )
214
- dtype = torch .float16
203
+ # Only initialize Mistral if no model is specified
204
+ if not model_name :
205
+ print ("\n Loading default model and tokenizer..." )
206
+ print ("Model: mistralai/Mistral-7B-Instruct-v0.2" )
207
+ self .model_name = "mistralai/Mistral-7B-Instruct-v0.2"
208
+ self .tokenizer = AutoTokenizer .from_pretrained (self .model_name )
209
+ self .model = AutoModelForCausalLM .from_pretrained (
210
+ self .model_name ,
211
+ device_map = "auto" ,
212
+ torch_dtype = torch .float16 ,
213
+ load_in_8bit = quantization == "8bit" ,
214
+ load_in_4bit = quantization == "4bit"
215
+ )
216
+ self .pipeline = pipeline (
217
+ "text-generation" ,
218
+ model = self .model ,
219
+ tokenizer = self .tokenizer ,
220
+ device_map = "auto"
221
+ )
222
+ print (f"Using default model: { self .model_name } " )
215
223
else :
216
- print ("CUDA is not available. Using CPU only (this will be slow)." )
217
- dtype = torch .float32
218
-
219
- # Set up model loading parameters
220
- model_kwargs = {
221
- "torch_dtype" : dtype ,
222
- "device_map" : "auto" ,
223
- "token" : token ,
224
- "low_cpu_mem_usage" : True ,
225
- "offload_folder" : "offload"
226
- }
227
-
228
- # Apply quantization if specified
229
- if quantization == '4bit' :
230
- try :
231
- from transformers import BitsAndBytesConfig
232
- quantization_config = BitsAndBytesConfig (
233
- load_in_4bit = True ,
234
- bnb_4bit_compute_dtype = torch .float16 ,
235
- bnb_4bit_use_double_quant = True ,
236
- bnb_4bit_quant_type = "nf4"
237
- )
238
- model_kwargs ["quantization_config" ] = quantization_config
239
- print ("Using 4-bit quantization with bitsandbytes" )
240
- except ImportError :
241
- print ("Warning: bitsandbytes not installed. Falling back to standard loading." )
242
- print ("To use 4-bit quantization, install bitsandbytes: pip install bitsandbytes" )
243
- elif quantization == '8bit' :
244
- try :
245
- from transformers import BitsAndBytesConfig
246
- quantization_config = BitsAndBytesConfig (load_in_8bit = True )
247
- model_kwargs ["quantization_config" ] = quantization_config
248
- print ("Using 8-bit quantization with bitsandbytes" )
249
- except ImportError :
250
- print ("Warning: bitsandbytes not installed. Falling back to standard loading." )
251
- print ("To use 8-bit quantization, install bitsandbytes: pip install bitsandbytes" )
252
-
253
- # Load model with appropriate settings
254
- self .model = AutoModelForCausalLM .from_pretrained (
255
- model_name ,
256
- ** model_kwargs
257
- )
258
- self .tokenizer = AutoTokenizer .from_pretrained (model_name , token = token )
259
-
260
- # Create text generation pipeline with optimized settings
261
- self .pipeline = pipeline (
262
- "text-generation" ,
263
- model = self .model ,
264
- tokenizer = self .tokenizer ,
265
- max_new_tokens = 512 ,
266
- do_sample = True ,
267
- temperature = 0.1 ,
268
- top_p = 0.95 ,
269
- device_map = "auto"
270
- )
271
- print ("✓ Model loaded successfully" )
224
+ print (f"\n Using specified model: { model_name } " )
225
+ self .model_name = model_name
226
+ self .tokenizer = AutoTokenizer .from_pretrained (self .model_name )
227
+ self .model = AutoModelForCausalLM .from_pretrained (
228
+ self .model_name ,
229
+ device_map = "auto" ,
230
+ torch_dtype = torch .float16 ,
231
+ load_in_8bit = quantization == "8bit" ,
232
+ load_in_4bit = quantization == "4bit"
233
+ )
234
+ self .pipeline = pipeline (
235
+ "text-generation" ,
236
+ model = self .model ,
237
+ tokenizer = self .tokenizer ,
238
+ device_map = "auto"
239
+ )
240
+ print (f"Using specified model: { self .model_name } " )
272
241
273
242
# Create LLM wrapper
274
243
self .llm = LocalLLM (self .pipeline )
0 commit comments