Skip to content

Commit 6e18eb3

Browse files
committed
fix: bugfix
1 parent dc35345 commit 6e18eb3

File tree

1 file changed

+58
-89
lines changed

1 file changed

+58
-89
lines changed

agentic_rag/local_rag_agent.py

Lines changed: 58 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,10 @@ def __init__(self, model_name: str):
5757
Args:
5858
model_name: Name of the Ollama model to use
5959
"""
60-
# Use the model name directly without any transformation
60+
# Ensure model name has :latest suffix
61+
if not model_name.endswith(":latest"):
62+
model_name = f"{model_name}:latest"
63+
6164
self.model_name = model_name
6265
self._check_ollama_running()
6366

@@ -74,13 +77,11 @@ def _check_ollama_running(self):
7477

7578
# Check if the requested model is available
7679
if self.model_name not in available_models:
77-
# Try with :latest suffix
78-
if f"{self.model_name}:latest" in available_models:
79-
self.model_name = f"{self.model_name}:latest"
80-
print(f"Using model with :latest suffix: {self.model_name}")
81-
else:
82-
print(f"Model '{self.model_name}' not found in Ollama. Available models: {', '.join(available_models)}")
83-
print(f"You can pull it with: ollama pull {self.model_name}")
80+
print(f"Model '{self.model_name}' not found in Ollama. Available models: {', '.join(available_models)}")
81+
print(f"You can pull it with: ollama pull {self.model_name}")
82+
raise ValueError(f"Model '{self.model_name}' not found in Ollama")
83+
else:
84+
print(f"Using Ollama model: {self.model_name}")
8485
except Exception as e:
8586
raise ConnectionError(f"Failed to connect to Ollama. Please make sure Ollama is running. Error: {str(e)}")
8687

@@ -92,6 +93,9 @@ def __call__(self, prompt, max_new_tokens=512, temperature=0.1, top_p=0.95, **kw
9293
try:
9394
import ollama
9495

96+
print(f"\nGenerating response with Ollama model: {self.model_name}")
97+
print(f"Prompt: {prompt[:100]}...") # Print first 100 chars of prompt
98+
9599
# Generate text
96100
response = ollama.generate(
97101
model=self.model_name,
@@ -103,6 +107,8 @@ def __call__(self, prompt, max_new_tokens=512, temperature=0.1, top_p=0.95, **kw
103107
}
104108
)
105109

110+
print(f"Response generated successfully with {self.model_name}")
111+
106112
# Format result to match transformers pipeline output
107113
formatted_result = [{
108114
"generated_text": response["response"]
@@ -114,7 +120,7 @@ def __call__(self, prompt, max_new_tokens=512, temperature=0.1, top_p=0.95, **kw
114120
raise Exception(f"Failed to generate text with Ollama: {str(e)}")
115121

116122
class LocalRAGAgent:
117-
def __init__(self, vector_store: VectorStore = None, model_name: str = "mistralai/Mistral-7B-Instruct-v0.2",
123+
def __init__(self, vector_store: VectorStore = None, model_name: str = None,
118124
use_cot: bool = False, collection: str = None, skip_analysis: bool = False,
119125
quantization: str = None, use_oracle_db: bool = True):
120126
"""Initialize local RAG agent with vector store and local LLM
@@ -165,7 +171,7 @@ def __init__(self, vector_store: VectorStore = None, model_name: str = "mistrala
165171
# skip_analysis parameter kept for backward compatibility but no longer used
166172

167173
# Check if this is an Ollama model
168-
self.is_ollama = model_name.startswith("ollama:") or "Ollama - " in model_name
174+
self.is_ollama = model_name and (model_name.startswith("ollama:") or "Ollama - " in model_name)
169175

170176
if self.is_ollama:
171177
# Extract the actual model name from the prefix
@@ -178,6 +184,10 @@ def __init__(self, vector_store: VectorStore = None, model_name: str = "mistrala
178184
else:
179185
ollama_model_name = model_name
180186

187+
# Add :latest suffix if not present
188+
if not ollama_model_name.endswith(":latest"):
189+
ollama_model_name = f"{ollama_model_name}:latest"
190+
181191
# Load Ollama model
182192
print("\nLoading Ollama model...")
183193
print(f"Model: {ollama_model_name}")
@@ -188,87 +198,46 @@ def __init__(self, vector_store: VectorStore = None, model_name: str = "mistrala
188198

189199
# Create pipeline-like interface
190200
self.pipeline = self.ollama_handler
191-
201+
print(f"Using Ollama model: {ollama_model_name}")
192202
else:
193-
# Load HuggingFace token from config
194-
try:
195-
with open('config.yaml', 'r') as f:
196-
config = yaml.safe_load(f)
197-
token = config.get('HUGGING_FACE_HUB_TOKEN')
198-
if not token:
199-
raise ValueError("HUGGING_FACE_HUB_TOKEN not found in config.yaml")
200-
except Exception as e:
201-
raise Exception(f"Failed to load HuggingFace token from config.yaml: {str(e)}")
202-
203-
# Load model and tokenizer
204-
print("\nLoading model and tokenizer...")
205-
print(f"Model: {model_name}")
206-
if quantization:
207-
print(f"Quantization: {quantization}")
208-
print("Note: Initial loading and inference can take 1-5 minutes depending on your hardware.")
209-
print("Subsequent queries will be faster but may still take 30-60 seconds per response.")
210-
211-
# Check if CUDA is available and set appropriate dtype
212-
if torch.cuda.is_available():
213-
print("CUDA is available. Using GPU acceleration.")
214-
dtype = torch.float16
203+
# Only initialize Mistral if no model is specified
204+
if not model_name:
205+
print("\nLoading default model and tokenizer...")
206+
print("Model: mistralai/Mistral-7B-Instruct-v0.2")
207+
self.model_name = "mistralai/Mistral-7B-Instruct-v0.2"
208+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
209+
self.model = AutoModelForCausalLM.from_pretrained(
210+
self.model_name,
211+
device_map="auto",
212+
torch_dtype=torch.float16,
213+
load_in_8bit=quantization == "8bit",
214+
load_in_4bit=quantization == "4bit"
215+
)
216+
self.pipeline = pipeline(
217+
"text-generation",
218+
model=self.model,
219+
tokenizer=self.tokenizer,
220+
device_map="auto"
221+
)
222+
print(f"Using default model: {self.model_name}")
215223
else:
216-
print("CUDA is not available. Using CPU only (this will be slow).")
217-
dtype = torch.float32
218-
219-
# Set up model loading parameters
220-
model_kwargs = {
221-
"torch_dtype": dtype,
222-
"device_map": "auto",
223-
"token": token,
224-
"low_cpu_mem_usage": True,
225-
"offload_folder": "offload"
226-
}
227-
228-
# Apply quantization if specified
229-
if quantization == '4bit':
230-
try:
231-
from transformers import BitsAndBytesConfig
232-
quantization_config = BitsAndBytesConfig(
233-
load_in_4bit=True,
234-
bnb_4bit_compute_dtype=torch.float16,
235-
bnb_4bit_use_double_quant=True,
236-
bnb_4bit_quant_type="nf4"
237-
)
238-
model_kwargs["quantization_config"] = quantization_config
239-
print("Using 4-bit quantization with bitsandbytes")
240-
except ImportError:
241-
print("Warning: bitsandbytes not installed. Falling back to standard loading.")
242-
print("To use 4-bit quantization, install bitsandbytes: pip install bitsandbytes")
243-
elif quantization == '8bit':
244-
try:
245-
from transformers import BitsAndBytesConfig
246-
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
247-
model_kwargs["quantization_config"] = quantization_config
248-
print("Using 8-bit quantization with bitsandbytes")
249-
except ImportError:
250-
print("Warning: bitsandbytes not installed. Falling back to standard loading.")
251-
print("To use 8-bit quantization, install bitsandbytes: pip install bitsandbytes")
252-
253-
# Load model with appropriate settings
254-
self.model = AutoModelForCausalLM.from_pretrained(
255-
model_name,
256-
**model_kwargs
257-
)
258-
self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
259-
260-
# Create text generation pipeline with optimized settings
261-
self.pipeline = pipeline(
262-
"text-generation",
263-
model=self.model,
264-
tokenizer=self.tokenizer,
265-
max_new_tokens=512,
266-
do_sample=True,
267-
temperature=0.1,
268-
top_p=0.95,
269-
device_map="auto"
270-
)
271-
print("✓ Model loaded successfully")
224+
print(f"\nUsing specified model: {model_name}")
225+
self.model_name = model_name
226+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
227+
self.model = AutoModelForCausalLM.from_pretrained(
228+
self.model_name,
229+
device_map="auto",
230+
torch_dtype=torch.float16,
231+
load_in_8bit=quantization == "8bit",
232+
load_in_4bit=quantization == "4bit"
233+
)
234+
self.pipeline = pipeline(
235+
"text-generation",
236+
model=self.model,
237+
tokenizer=self.tokenizer,
238+
device_map="auto"
239+
)
240+
print(f"Using specified model: {self.model_name}")
272241

273242
# Create LLM wrapper
274243
self.llm = LocalLLM(self.pipeline)

0 commit comments

Comments
 (0)