Skip to content

Commit ba48033

Browse files
committed
feat: enhanced model downloads
1 parent 08fb3ee commit ba48033

File tree

3 files changed

+90
-19
lines changed

3 files changed

+90
-19
lines changed

agentic_rag/gradio_app.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def chat(message: str, history: List[List[str]], agent_type: str, use_cot: bool,
103103
model_type = "Local (Mistral)"
104104
elif "Ollama" in agent_type:
105105
model_type = "Ollama"
106-
# Extract model name from agent_type
106+
# Extract model name from agent_type and use correct Ollama model names
107107
if "llama3" in agent_type.lower():
108108
model_name = "ollama:llama3"
109109
elif "phi-3" in agent_type.lower():
@@ -588,6 +588,18 @@ def download_model(model_type: str) -> str:
588588
print(f"Pulling Ollama model: {model_name}")
589589
start_time = time.time()
590590

591+
# Check if model already exists
592+
try:
593+
models = ollama.list().models
594+
available_models = [model.model for model in models]
595+
596+
# Check for model with or without :latest suffix
597+
if model_name in available_models or f"{model_name}:latest" in available_models:
598+
return f"✅ Model {model_name} is already available in Ollama."
599+
except Exception:
600+
# If we can't check, proceed with pull anyway
601+
pass
602+
591603
# Pull the model with progress tracking
592604
progress_text = ""
593605
for progress in ollama.pull(model_name, stream=True):

agentic_rag/local_rag_agent.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ def __init__(self, model_name: str):
5151
Args:
5252
model_name: Name of the Ollama model to use
5353
"""
54-
self.model_name = model_name
54+
# Remove the 'ollama:' prefix if present
55+
self.model_name = model_name.replace("ollama:", "") if model_name.startswith("ollama:") else model_name
5556
self._check_ollama_running()
5657

5758
def _check_ollama_running(self):
@@ -67,8 +68,13 @@ def _check_ollama_running(self):
6768

6869
# Check if the requested model is available
6970
if self.model_name not in available_models:
70-
print(f"Model '{self.model_name}' not found in Ollama. Available models: {', '.join(available_models)}")
71-
print(f"You can pull it with: ollama pull {self.model_name}")
71+
# Try with :latest suffix
72+
if f"{self.model_name}:latest" in available_models:
73+
self.model_name = f"{self.model_name}:latest"
74+
print(f"Using model with :latest suffix: {self.model_name}")
75+
else:
76+
print(f"Model '{self.model_name}' not found in Ollama. Available models: {', '.join(available_models)}")
77+
print(f"You can pull it with: ollama pull {self.model_name}")
7278
except Exception as e:
7379
raise ConnectionError(f"Failed to connect to Ollama. Please make sure Ollama is running. Error: {str(e)}")
7480

@@ -427,6 +433,33 @@ def _generate_response(self, query: str, context: List[Dict[str, Any]]) -> Dict[
427433
prompt = template.format(context=context_str, query=query)
428434
response = self._generate_text(prompt)
429435

436+
# Add sources to response if available
437+
if context:
438+
# Group sources by document
439+
sources = {}
440+
for item in context:
441+
source = item['metadata'].get('source', 'Unknown')
442+
if source not in sources:
443+
sources[source] = set()
444+
445+
# Add page number if available
446+
if 'page' in item['metadata']:
447+
sources[source].add(str(item['metadata']['page']))
448+
# Add file path if available for code
449+
if 'file_path' in item['metadata']:
450+
sources[source] = item['metadata']['file_path']
451+
452+
# Print concise source information
453+
print("\nSources detected:")
454+
for source, details in sources.items():
455+
if isinstance(details, set): # PDF with pages
456+
pages = ", ".join(sorted(details))
457+
print(f"Document: {source} (pages: {pages})")
458+
else: # Code with file path
459+
print(f"Code file: {source}")
460+
461+
response['sources'] = sources
462+
430463
return {
431464
"answer": response,
432465
"context": context

agentic_rag/rag_agent.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -190,25 +190,51 @@ def _process_query_standard(self, query: str) -> Dict[str, Any]:
190190
return response
191191

192192
def _generate_response(self, query: str, context: List[Dict[str, Any]]) -> Dict[str, Any]:
193-
"""Generate a response using the retrieved context"""
194-
context_str = "\n\n".join([f"Context {i+1}:\n{item['content']}"
195-
for i, item in enumerate(context)])
193+
"""Generate a response based on the query and context"""
194+
# Format context for the prompt
195+
formatted_context = "\n\n".join([f"Context {i+1}:\n{item['content']}"
196+
for i, item in enumerate(context)])
196197

197-
template = """Answer the following query using the provided context.
198-
Respond as if you are knowledgeable about the topic and incorporate the context naturally.
199-
Do not mention limitations in the context or that you couldn't find specific information.
200-
201-
Context:
202-
{context}
203-
204-
Query: {query}
205-
206-
Answer:"""
198+
# Create the prompt
199+
system_prompt = """You are an AI assistant answering questions based on the provided context.
200+
Answer the question based on the context provided. If the answer is not in the context, say "I don't have enough information to answer this question." Be concise and accurate."""
207201

208-
prompt = ChatPromptTemplate.from_template(template)
209-
messages = prompt.format_messages(context=context_str, query=query)
202+
# Create messages for the chat model
203+
messages = [
204+
{"role": "system", "content": system_prompt},
205+
{"role": "user", "content": f"Context:\n{formatted_context}\n\nQuestion: {query}"}
206+
]
207+
208+
# Generate response
210209
response = self.llm.invoke(messages)
211210

211+
# Add sources to response if available
212+
if context:
213+
# Group sources by document
214+
sources = {}
215+
for item in context:
216+
source = item['metadata'].get('source', 'Unknown')
217+
if source not in sources:
218+
sources[source] = set()
219+
220+
# Add page number if available
221+
if 'page' in item['metadata']:
222+
sources[source].add(str(item['metadata']['page']))
223+
# Add file path if available for code
224+
if 'file_path' in item['metadata']:
225+
sources[source] = item['metadata']['file_path']
226+
227+
# Print concise source information
228+
print("\nSources detected:")
229+
for source, details in sources.items():
230+
if isinstance(details, set): # PDF with pages
231+
pages = ", ".join(sorted(details))
232+
print(f"Document: {source} (pages: {pages})")
233+
else: # Code with file path
234+
print(f"Code file: {source}")
235+
236+
response['sources'] = sources
237+
212238
return {
213239
"answer": response.content,
214240
"context": context

0 commit comments

Comments
 (0)