Skip to content

Commit d5978ab

Browse files
committed
fix: adjust RAG tool behavior and docs (#198)
1 parent 090439d commit d5978ab

File tree

7 files changed

+230
-111
lines changed

7 files changed

+230
-111
lines changed

spoon_ai/rag/config.py

Lines changed: 49 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -30,50 +30,7 @@ class RagConfig:
3030
rag_dir: str = ".rag_store"
3131

3232

33-
_PLACEHOLDER_PATTERNS = [
34-
r"^sk-your-.*-key-here$",
35-
r"^sk-your-openai-api-key-here$",
36-
r"^your-.*-api-key-here$",
37-
r"^your_api_key$",
38-
r"^api_key_here$",
39-
r"^<.*>$",
40-
r"^\[.*\]$",
41-
r"^\{.*\}$",
42-
]
43-
44-
# Mapping of known OpenAI-compatible providers to their defaults
45-
# This allows using project-standard keys (e.g. DEEPSEEK_API_KEY) with RAG automatically.
46-
_COMPATIBLE_PROVIDERS: Dict[str, Dict[str, str]] = {
47-
"deepseek": {
48-
"env_key": "DEEPSEEK_API_KEY",
49-
"base_url": "https://api.deepseek.com/v1",
50-
"default_model": "", # Let server decide or user override
51-
},
52-
"openrouter": {
53-
"env_key": "OPENROUTER_API_KEY",
54-
"base_url": "https://openrouter.ai/api/v1",
55-
"default_model": "",
56-
},
57-
# Note: Gemini and Anthropic are not strictly OpenAI-compatible for embeddings (paths differ),
58-
# so we do not auto-map them to AnyRoute to avoid runtime errors unless explicitly configured.
59-
}
60-
61-
62-
def _is_placeholder(value: Optional[str]) -> bool:
63-
if not value or not isinstance(value, str):
64-
return True
65-
v = value.strip().lower()
66-
if not v:
67-
return True
68-
for p in _PLACEHOLDER_PATTERNS:
69-
if re.match(p, v):
70-
return True
71-
# Common keywords that indicate examples
72-
for k in ("placeholder", "example", "sample", "demo", "insert", "replace", "change-me"):
73-
if k in v:
74-
return True
75-
return False
76-
33+
from spoon_ai.llm.config import ConfigurationManager
7734

7835
def get_default_config() -> RagConfig:
7936
backend = os.getenv("RAG_BACKEND", "faiss").lower()
@@ -83,50 +40,66 @@ def get_default_config() -> RagConfig:
8340
chunk_size = int(os.getenv("CHUNK_SIZE", "800"))
8441
chunk_overlap = int(os.getenv("CHUNK_OVERLAP", "120"))
8542

86-
# Embeddings provider selection
87-
embeddings_provider = None
88-
89-
# 1. AnyRoute (Explicit RAG config) - Highest Priority
90-
anyroute_api_key = os.getenv("ANYROUTE_API_KEY")
91-
anyroute_base = os.getenv("ANYROUTE_BASE_URL")
92-
anyroute_model = os.getenv("ANYROUTE_MODEL")
43+
# Use LLM ConfigurationManager for standardized provider detection
44+
config_manager = ConfigurationManager()
9345

94-
# 2. OpenAI (Native support)
95-
openai_key = os.getenv("OPENAI_API_KEY")
96-
97-
# Logic to determine provider
98-
if (anyroute_api_key and anyroute_base) and not (_is_placeholder(anyroute_api_key) or _is_placeholder(anyroute_base)):
46+
# 1. Determine active provider
47+
# Try ANYROUTE_API_KEY explicitly first (legacy RAG priority)
48+
anyroute_key = os.getenv("ANYROUTE_API_KEY")
49+
# Use static method from ConfigurationManager
50+
if anyroute_key and not ConfigurationManager._is_placeholder_value(anyroute_key):
9951
embeddings_provider = "anyroute"
100-
elif openai_key and not _is_placeholder(openai_key):
101-
embeddings_provider = "openai"
52+
anyroute_base = os.getenv("ANYROUTE_BASE_URL", "https://api.openai.com/v1") # Default generic
53+
anyroute_model = os.getenv("ANYROUTE_MODEL")
54+
openai_key = None
10255
else:
103-
# 3. Try Auto-mapping compatible providers (DeepSeek, OpenRouter, etc.)
104-
for name, defaults in _COMPATIBLE_PROVIDERS.items():
105-
key_val = os.getenv(defaults["env_key"])
106-
if key_val and not _is_placeholder(key_val):
107-
embeddings_provider = "anyroute"
108-
anyroute_api_key = key_val
109-
# Use provider default base URL if explicit ANYROUTE_BASE_URL is missing
110-
anyroute_base = anyroute_base or defaults["base_url"]
111-
# Use provider default model if explicit ANYROUTE_MODEL is missing
112-
if not anyroute_model and defaults["default_model"]:
113-
anyroute_model = defaults["default_model"]
114-
break
56+
# Fallback to LLM module's intelligent selection
57+
# This picks defaults based on available API keys (OpenAI > Anthropic > OpenRouter...)
58+
# Note: Anthropic/Gemini are not directly supported for embeddings here unless mapped
59+
provider = config_manager.get_default_provider()
11560

116-
# 4. Fallback
117-
if not embeddings_provider:
118-
embeddings_provider = "hash" # deterministic offline fallback
61+
# Load full config for the selected provider
62+
try:
63+
llm_config = config_manager.load_provider_config(provider)
64+
except Exception:
65+
llm_config = None
66+
67+
embeddings_provider = "hash" # Default fallback
68+
anyroute_key = None
69+
anyroute_base = None
70+
anyroute_model = None
71+
openai_key = None
11972

73+
if llm_config:
74+
if provider == "openai":
75+
embeddings_provider = "openai"
76+
openai_key = llm_config.api_key
77+
elif provider in ("deepseek", "openrouter", "anyroute"):
78+
# Map compatible OpenAI-like providers to AnyRoute client
79+
embeddings_provider = "anyroute"
80+
anyroute_key = llm_config.api_key
81+
anyroute_base = llm_config.base_url
82+
83+
# Check for explicit override or intelligent default
84+
env_model = os.getenv("ANYROUTE_MODEL")
85+
if env_model:
86+
anyroute_model = env_model
87+
elif provider == "openrouter" and "embedding" not in llm_config.model.lower():
88+
# OpenRouter: Default to openai/text-embedding-3-small if main model is not an embedding model
89+
anyroute_model = "openai/text-embedding-3-small"
90+
else:
91+
anyroute_model = llm_config.model
92+
12093
return RagConfig(
12194
backend=backend,
12295
collection=collection,
12396
top_k=top_k,
12497
chunk_size=chunk_size,
12598
chunk_overlap=chunk_overlap,
12699
embeddings_provider=embeddings_provider,
127-
anyroute_api_key=None if _is_placeholder(anyroute_api_key) else anyroute_api_key,
128-
anyroute_base_url=None if _is_placeholder(anyroute_base) else anyroute_base,
100+
anyroute_api_key=anyroute_key,
101+
anyroute_base_url=anyroute_base,
129102
anyroute_model=anyroute_model,
130-
openai_api_key=None if _is_placeholder(openai_key) else openai_key,
103+
openai_api_key=openai_key,
131104
rag_dir=rag_dir,
132105
)

spoon_ai/rag/vectorstores/chroma_store.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,40 @@ def _get_collection(self, name: str):
3131

3232
def add(self, *, collection: str, ids: List[str], embeddings: List[List[float]], metadatas: List[Dict]) -> None:
3333
col = self._get_collection(collection)
34-
col.add(ids=ids, embeddings=embeddings, metadatas=metadatas)
34+
try:
35+
col.add(ids=ids, embeddings=embeddings, metadatas=metadatas)
36+
except Exception as e:
37+
msg = str(e).lower()
38+
if "dimension" in msg or "dimensionality" in msg:
39+
raise ValueError(
40+
f"Chroma embedding dimension mismatch in collection '{collection}'. "
41+
"You may be using a different embedding model than the one used to create this collection. "
42+
f"Consider deleting the collection via `store.delete_collection('{collection}')` "
43+
"or using a new collection name."
44+
) from e
45+
raise e
3546

3647
def query(self, *, collection: str, query_embeddings: List[List[float]], top_k: int = 5, filter: Optional[Dict] = None) -> List[List[Tuple[str, float, Dict]]]:
3748
col = self._get_collection(collection)
38-
# Chroma >=1.3 disallows requesting "ids" in include; request metadatas+distances only.
39-
res = col.query(query_embeddings=query_embeddings, n_results=top_k, include=["metadatas", "distances"])
49+
try:
50+
# Chroma >=1.3 disallows requesting "ids" in include; request metadatas+distances only.
51+
# Pass filter as 'where' clause for metadata filtering
52+
res = col.query(
53+
query_embeddings=query_embeddings,
54+
n_results=top_k,
55+
include=["metadatas", "distances"],
56+
where=filter # Pass explicit filter dict
57+
)
58+
except Exception as e:
59+
msg = str(e).lower()
60+
if "dimension" in msg or "dimensionality" in msg:
61+
raise ValueError(
62+
f"Chroma query dimension mismatch in collection '{collection}'. "
63+
"The query embedding dimension does not match the collection's index. "
64+
"Please ensure you are using the same embedding model as when the data was ingested."
65+
) from e
66+
raise e
67+
4068
out: List[List[Tuple[str, float, Dict]]] = []
4169
q = len(query_embeddings)
4270
for i in range(q):

spoon_ai/rag/vectorstores/faiss_store.py

Lines changed: 83 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import os
34
from typing import Dict, List, Optional, Tuple
45

56
from .base import VectorStore
@@ -8,11 +9,73 @@
89
class FaissVectorStore(VectorStore):
910
"""FAISS-backed local vector store (cosine via inner product + L2 norm)."""
1011

11-
def __init__(self) -> None:
12+
def __init__(self, *, persist_dir: Optional[str] = None) -> None:
13+
import os
14+
self.persist_dir = persist_dir or os.getenv("RAG_FAISS_DIR", os.path.join(os.getenv("RAG_DIR", ".rag_store"), "faiss"))
1215
self._collections: Dict[str, Dict] = {}
16+
self._load()
17+
18+
def _get_index_path(self, collection: str) -> str:
19+
return os.path.join(self.persist_dir, f"{collection}.index")
20+
21+
def _get_meta_path(self, collection: str) -> str:
22+
return os.path.join(self.persist_dir, f"{collection}.pkl")
23+
24+
def _load(self):
25+
import os
26+
import pickle
27+
import faiss # type: ignore
28+
29+
if not os.path.exists(self.persist_dir):
30+
return
31+
32+
for fname in os.listdir(self.persist_dir):
33+
if fname.endswith(".index"):
34+
collection = fname[:-6]
35+
index_path = os.path.join(self.persist_dir, fname)
36+
meta_path = self._get_meta_path(collection)
37+
38+
if not os.path.exists(meta_path):
39+
continue
40+
41+
try:
42+
index = faiss.read_index(index_path)
43+
with open(meta_path, "rb") as f:
44+
meta_data = pickle.load(f)
45+
46+
self._collections[collection] = {
47+
"index": index,
48+
"ids": meta_data["ids"],
49+
"metas": meta_data["metas"],
50+
"dim": meta_data["dim"],
51+
}
52+
except Exception as e:
53+
print(f"Error loading FAISS collection '{collection}': {e}")
54+
# Ignore corrupted files
55+
pass
56+
57+
def _save(self, collection: str):
58+
import os
59+
import pickle
60+
import faiss # type: ignore
61+
62+
os.makedirs(self.persist_dir, exist_ok=True)
63+
col = self._collections.get(collection)
64+
if not col:
65+
return
66+
67+
index_path = self._get_index_path(collection)
68+
meta_path = self._get_meta_path(collection)
69+
70+
faiss.write_index(col["index"], index_path)
71+
with open(meta_path, "wb") as f:
72+
pickle.dump({
73+
"ids": col["ids"],
74+
"metas": col["metas"],
75+
"dim": col["dim"]
76+
}, f)
1377

1478
def _get_or_create(self, collection: str, dim: Optional[int] = None):
15-
import numpy as np # noqa: F401
1679
import faiss # type: ignore
1780

1881
col = self._collections.get(collection)
@@ -53,11 +116,19 @@ def add(self, *, collection: str, ids: List[str], embeddings: List[List[float]],
53116
col["ids"].extend(ids)
54117
for id_, md in zip(ids, metadatas):
55118
col["metas"][id_] = md
119+
120+
# Persist changes
121+
self._save(collection)
56122

57123
def query(self, *, collection: str, query_embeddings: List[List[float]], top_k: int = 5, filter: Optional[Dict] = None) -> List[List[Tuple[str, float, Dict]]]:
58124
import numpy as np
59125

60-
col = self._get_or_create(collection)
126+
# Ensure loaded or created if not in memory (but _load handles init)
127+
col = self._collections.get(collection)
128+
if not col:
129+
# If not in memory and not loaded, it doesn't exist
130+
return [[] for _ in query_embeddings]
131+
61132
if len(col["ids"]) == 0:
62133
return [[] for _ in query_embeddings]
63134

@@ -85,5 +156,14 @@ def query(self, *, collection: str, query_embeddings: List[List[float]], top_k:
85156
return results
86157

87158
def delete_collection(self, collection: str) -> None:
159+
import os
88160
self._collections.pop(collection, None)
161+
# Also remove from disk
162+
try:
163+
if os.path.exists(self._get_index_path(collection)):
164+
os.remove(self._get_index_path(collection))
165+
if os.path.exists(self._get_meta_path(collection)):
166+
os.remove(self._get_meta_path(collection))
167+
except Exception:
168+
pass
89169

spoon_ai/rag/vectorstores/pinecone_store.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def _extract_index_names(obj) -> set:
133133
def add(self, *, collection: str, ids: List[str], embeddings: List[List[float]], metadatas: List[Dict]) -> None:
134134
index = self._ensure_index(dim=len(embeddings[0]) if embeddings else None)
135135
vectors = [
136-
{"id": id_, "values": vec, "metadata": md}
136+
{"id": id_, "values": [float(x) for x in vec], "metadata": md}
137137
for id_, vec, md in zip(ids, embeddings, metadatas)
138138
]
139139
index.upsert(vectors=vectors, namespace=collection)
@@ -142,7 +142,8 @@ def query(self, *, collection: str, query_embeddings: List[List[float]], top_k:
142142
index = self._ensure_index()
143143
results: List[List[Tuple[str, float, Dict]]] = []
144144
for q in query_embeddings:
145-
res = index.query(namespace=collection, vector=q, top_k=top_k, include_metadata=True)
145+
# Pass filter dict directly (Pinecone uses Mongo-style filters)
146+
res = index.query(namespace=collection, vector=q, top_k=top_k, include_metadata=True, filter=filter)
146147
matches = res.get("matches", []) if isinstance(res, dict) else getattr(res, "matches", [])
147148
out: List[Tuple[str, float, Dict]] = []
148149
for m in matches:

spoon_ai/rag/vectorstores/qdrant_store.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,25 @@ def add(self, *, collection: str, ids: List[str], embeddings: List[List[float]],
6161

6262
def query(self, *, collection: str, query_embeddings: List[List[float]], top_k: int = 5, filter: Optional[Dict] = None) -> List[List[Tuple[str, float, Dict]]]:
6363
client = self._client_or_raise()
64+
65+
# Build Qdrant filter (dict structure to avoid imports)
66+
q_filter = None
67+
if filter:
68+
musts = []
69+
for k, v in filter.items():
70+
musts.append({"key": k, "match": {"value": v}})
71+
q_filter = {"must": musts}
72+
6473
results: List[List[Tuple[str, float, Dict]]] = []
6574
for q in query_embeddings:
6675
# qdrant-client >=1.x uses query_points for vector search; ensure payload returned
67-
res = client.query_points(collection_name=collection, query=q, limit=top_k, with_payload=True)
76+
res = client.query_points(
77+
collection_name=collection,
78+
query=q,
79+
limit=top_k,
80+
with_payload=True,
81+
query_filter=q_filter
82+
)
6883
# Normalize response
6984
try:
7085
points = res.points # type: ignore[attr-defined]

0 commit comments

Comments
 (0)