Skip to content

Commit 090439d

Browse files
committed
feat: enhance RAG module and docs (#198)
1 parent ca96159 commit 090439d

File tree

11 files changed

+663
-71
lines changed

11 files changed

+663
-71
lines changed

doc/rag/backends.md

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ Vector Stores
66

77
The RAG registry supports the following backends via `RAG_BACKEND`:
88

9-
- `faiss` (default): 本地与离线友好。提供真实 FAISS 适配器(如未安装 `faiss` 将自动回退到内存余弦检索)。
10-
- `pinecone`: 云端向量库(需要 `PINECONE_API_KEY`,默认索引名 `spoon-rag`,可用 `RAG_PINECONE_INDEX` 指定)。已提供真实适配器。
11-
- `qdrant`: 本地/云端(需要 `pip install qdrant-client` 且服务默认 `http://localhost:6333`)。已提供真实适配器。
12-
- 支持嵌入式模式:`QDRANT_URL=:memory:` 或设置 `QDRANT_PATH=:memory:` 可在本地无服务运行。
13-
- `chroma`: 本地(需要 `pip install chromadb`,默认持久目录 `${RAG_DIR:-.rag_store}/chroma`)。已提供真实适配器。
9+
- `faiss` (default): Local and offline friendly. Provides a real FAISS adapter (falls back to in-memory cosine retrieval if `faiss` is not installed).
10+
- `pinecone`: Cloud vector database (requires `PINECONE_API_KEY`, default index name `spoon-rag`, can be specified via `RAG_PINECONE_INDEX`). Real adapter provided.
11+
- `qdrant`: Local/Cloud (requires `pip install qdrant-client` and service at `http://localhost:6333` by default). Real adapter provided.
12+
- Supports embedded mode: `QDRANT_URL=:memory:` or set `QDRANT_PATH=:memory:` to run locally without a service.
13+
- `chroma`: Local (requires `pip install chromadb`, default persistence directory `${RAG_DIR:-.rag_store}/chroma`). Real adapter provided.
1414

15-
说明:为了保证离线测试稳定,`faiss` 仍默认以内存向量库运行。其余后端需要按需安装/配置,适配器层保持统一 API
15+
Note: To ensure stable offline testing, `faiss` still defaults to in-memory vector store behavior. Other backends require installation/configuration as needed, while the adapter layer maintains a unified API.
1616

1717
Embeddings
1818
----------
@@ -21,29 +21,29 @@ Embeddings
2121
- OpenAI: set `OPENAI_API_KEY` (uses `text-embedding-3-small` by default).
2222
- Hash (fallback): deterministic offline embedding for tests and demos (no env needed).
2323

24-
Backend Smoke 测试
25-
------------------
24+
Backend Smoke Tests
25+
-------------------
2626

27-
- FAISS
27+
- FAISS:
2828
```bash
29-
# 如未安装 faiss,会自动回退到内存实现
29+
# Automatically falls back to in-memory implementation if faiss is not installed
3030
RAG_BACKEND=faiss RAG_FAKE_QA=1 python examples/smoke/rag_faiss_smoke.py
3131
```
3232

33-
- Pinecone
33+
- Pinecone:
3434
```bash
3535
export PINECONE_API_KEY=...
3636
RAG_BACKEND=pinecone RAG_FAKE_QA=1 python examples/smoke/rag_pinecone_smoke.py
3737
```
3838

39-
- Qdrant(需本地/远程服务与 `qdrant-client`
39+
- Qdrant (requires local/remote service and `qdrant-client`)
4040
```bash
4141
pip install qdrant-client
42-
export QDRANT_URL=http://localhost:6333 # 如需
42+
export QDRANT_URL=http://localhost:6333 # if needed
4343
RAG_BACKEND=qdrant RAG_FAKE_QA=1 python examples/smoke/rag_qdrant_smoke.py
4444
```
4545

46-
- Chroma(需 `chromadb`
46+
- Chroma (requires `chromadb`)
4747
```bash
4848
pip install chromadb
4949
RAG_BACKEND=chroma RAG_FAKE_QA=1 python examples/smoke/rag_chroma_smoke.py

examples/rag_graph_agent_demo.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""Graph Agent demo using the RAG component (ingest -> retrieve -> answer).
2+
3+
Offline by default (hash embeddings + in-memory store).
4+
"""
5+
6+
import asyncio
7+
from typing import Dict, Any, List
8+
9+
from spoon_ai.graph import StateGraph
10+
from spoon_ai.agents.graph_agent import GraphAgent
11+
from spoon_ai.chat import ChatBot
12+
13+
from spoon_ai.rag import (
14+
get_default_config,
15+
get_embedding_client,
16+
get_vector_store,
17+
RagIndex,
18+
RagRetriever,
19+
RagQA,
20+
)
21+
22+
23+
def build_pipeline():
24+
cfg = get_default_config()
25+
store = get_vector_store(cfg.backend)
26+
embed = get_embedding_client(
27+
cfg.embeddings_provider,
28+
openai_api_key=cfg.openai_api_key,
29+
openai_model=cfg.openai_embeddings_model,
30+
anyroute_api_key=cfg.anyroute_api_key,
31+
anyroute_base_url=cfg.anyroute_base_url,
32+
anyroute_model=cfg.anyroute_model,
33+
)
34+
35+
index = RagIndex(config=cfg, store=store, embeddings=embed)
36+
retriever = RagRetriever(config=cfg, store=store, embeddings=embed)
37+
# Offline-friendly: if RAG_FAKE_QA=1, avoid initializing ChatBot
38+
import os
39+
llm = None if os.getenv("RAG_FAKE_QA") == "1" else ChatBot()
40+
qa = RagQA(config=cfg, llm=llm)
41+
42+
def ingest_node(state: Dict[str, Any]) -> Dict[str, Any]:
43+
inputs = state.get("inputs", [])
44+
n = index.ingest(inputs)
45+
return {"ingested": n}
46+
47+
def retrieve_node(state: Dict[str, Any]) -> Dict[str, Any]:
48+
q = state.get("question", "")
49+
chunks = retriever.retrieve(q)
50+
return {"chunks": chunks, "context": retriever.build_context(chunks)}
51+
52+
async def answer_node(state: Dict[str, Any]) -> Dict[str, Any]:
53+
q = state.get("question", "")
54+
chunks = state.get("chunks", [])
55+
res = await qa.answer(q, chunks)
56+
return {"answer": res.answer, "citations": res.citations}
57+
58+
graph = StateGraph(dict)
59+
graph.add_node("RAGIngestNode", ingest_node)
60+
graph.add_node("RAGRetrieveNode", retrieve_node)
61+
graph.add_node("RAGAnswerNode", answer_node)
62+
graph.add_edge("RAGIngestNode", "RAGRetrieveNode")
63+
graph.add_edge("RAGRetrieveNode", "RAGAnswerNode")
64+
graph.set_entry_point("RAGIngestNode")
65+
return graph.compile()
66+
67+
68+
async def main():
69+
compiled = build_pipeline()
70+
agent = GraphAgent(name="rag-graph", graph=compiled.graph)
71+
initial_state = {
72+
"inputs": ["./doc"],
73+
"question": "How do I install the SDK?",
74+
}
75+
result = await compiled.ainvoke(initial_state)
76+
print("\n== RAG Graph Agent Demo ==")
77+
print("Answer:", result.get("answer"))
78+
print("Citations:", result.get("citations"))
79+
80+
81+
if __name__ == "__main__":
82+
asyncio.run(main())

examples/rag_react_agent_demo.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""ReAct Agent demo using the RAG component (ingest -> search -> QA).
2+
3+
Runs offline by default using a deterministic hash embedding and in-memory store.
4+
If OPENAI_API_KEY or ANYROUTE_* are present, will use those embeddings.
5+
6+
Env vars:
7+
- RAG_BACKEND=faiss|pinecone|qdrant|chroma (default: faiss)
8+
- RAG_COLLECTION=<name> (default: default)
9+
- TOP_K, CHUNK_SIZE, CHUNK_OVERLAP optional
10+
"""
11+
12+
import asyncio
13+
import os
14+
from spoon_ai.chat import ChatBot
15+
from spoon_ai.agents.toolcall import ToolCallAgent
16+
from spoon_ai.tools import ToolManager
17+
from spoon_ai.tools.rag_tools import RAGIngestTool, RAGSearchTool, RAGQATool
18+
19+
20+
async def main() -> None:
21+
tools = ToolManager([
22+
RAGIngestTool(),
23+
RAGSearchTool(),
24+
RAGQATool(),
25+
])
26+
27+
agent = ToolCallAgent(
28+
name="rag-react",
29+
llm=ChatBot(),
30+
available_tools=tools,
31+
)
32+
33+
print("\n== RAG ReAct Agent Demo ==\n")
34+
35+
# 1) Ingest a local directory or url
36+
docs_dir = os.getenv("RAG_DOCS", "./doc")
37+
ingest_request = f"Use rag_ingest to index docs in {docs_dir}"
38+
print("User:", ingest_request)
39+
out = await agent.run(ingest_request)
40+
print("Assistant:", out)
41+
42+
# 2) Search
43+
search_request = "Use rag_search to find info about installation"
44+
print("\nUser:", search_request)
45+
out = await agent.run(search_request)
46+
print("Assistant:", out)
47+
48+
# 3) QA
49+
qa_request = "Use rag_qa to answer: How do I install the SDK?"
50+
print("\nUser:", qa_request)
51+
out = await agent.run(qa_request)
52+
print("Assistant:", out)
53+
54+
55+
if __name__ == "__main__":
56+
asyncio.run(main())
57+

examples/smoke/rag_react_offline_smoke.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import asyncio
77
import os
88
try:
9-
from spoon_ai.tools.rag_tools import RAGIngestTool, RAGSearchTool, RAGQATool
9+
from spoon_ai.tools.rag_tools import , RAGSearchTool, RAGQATool
1010
except ModuleNotFoundError:
1111
import sys, pathlib
1212
sys.path.insert(0, str(pathlib.Path(__file__).resolve().parents[2]))

examples/smoke/rag_tools_smoke.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""Smoke test for RAG Tools, verifying dependency injection and backward compatibility.
2+
3+
Run:
4+
python3 examples/smoke/rag_tools_smoke.py
5+
"""
6+
7+
import asyncio
8+
import os
9+
import sys
10+
import pathlib
11+
12+
# Add project root to path
13+
sys.path.insert(0, str(pathlib.Path(__file__).resolve().parents[2]))
14+
15+
# Mock missing dependencies to isolate tool logic testing
16+
from unittest.mock import MagicMock
17+
import sys
18+
19+
# Mock openai and related modules that might be missing in this env
20+
sys.modules["openai"] = MagicMock()
21+
sys.modules["spoon_ai.min_pypi"] = MagicMock() # potential internal dep
22+
23+
# We need to mock ChatBot because rag_tools imports it
24+
mock_chat = MagicMock()
25+
sys.modules["spoon_ai.chat"] = mock_chat
26+
27+
# Mock RAG backends
28+
sys.modules["pinecone"] = MagicMock()
29+
sys.modules["chromadb"] = MagicMock()
30+
sys.modules["qdrant_client"] = MagicMock()
31+
sys.modules["faiss"] = MagicMock()
32+
33+
try:
34+
from spoon_ai.tools.rag_tools import RAGQATool, RAGSearchTool
35+
# Mock ChatBot inside the module if needed, though sys.modules should cover it
36+
mock_chat.ChatBot = MagicMock
37+
except ImportError as e:
38+
print(f"Error importing modules: {e}")
39+
# Print sys.path to help debug
40+
print(sys.path)
41+
sys.exit(1)
42+
43+
class MockLLM:
44+
async def ask(self, messages, **kwargs):
45+
return "This is a mock answer with citation [1]."
46+
47+
async def main():
48+
print("=== Testing RAG Tools Compatibility ===")
49+
50+
# 1. Test Backward Compatibility (No args)
51+
try:
52+
tool_v1 = RAGQATool()
53+
print("[PASS] V1 Initialization (No args) passed.")
54+
assert tool_v1._llm is None
55+
except Exception as e:
56+
print(f"[FAIL] V1 Initialization failed: {e}")
57+
58+
# 2. Test Backward Compatibility (Pydantic fields)
59+
try:
60+
tool_v2 = RAGQATool(name="custom_rag", description="Custom desc")
61+
print("[PASS] V1 Initialization (Pydantic fields) passed.")
62+
assert tool_v2.name == "custom_rag"
63+
assert tool_v2._llm is None
64+
except Exception as e:
65+
print(f"[FAIL] V1 Initialization (Pydantic fields) failed: {e}")
66+
67+
# 3. Test Dependency Injection (New feature)
68+
try:
69+
mock_llm = MockLLM()
70+
tool_v3 = RAGQATool(llm=mock_llm)
71+
print("[PASS] V2 Initialization (Dependency Injection) passed.")
72+
assert tool_v3._llm is mock_llm
73+
except Exception as e:
74+
print(f"[FAIL] V2 Initialization (Dependency Injection) failed: {e}")
75+
76+
# 4. Test Dependency Injection Mixed with Pydantic
77+
try:
78+
mock_llm = MockLLM()
79+
tool_v4 = RAGQATool(llm=mock_llm, name="injected_tool")
80+
print("[PASS] V2 Initialization (Mixed) passed.")
81+
assert tool_v4.name == "injected_tool"
82+
assert tool_v4._llm is mock_llm
83+
except Exception as e:
84+
print(f"[FAIL] V2 Initialization (Mixed) failed: {e}")
85+
86+
print("\n=== Testing Parameter Defaults ===")
87+
# 5. Check schema defaults
88+
search_tool = RAGSearchTool()
89+
props = search_tool.parameters["properties"]
90+
91+
if "default: 5" in props["top_k"].get("description", ""):
92+
print("[PASS] RAGSearchTool top_k description contains default.")
93+
else:
94+
print(f"[FAIL] RAGSearchTool top_k description missing default: {props['top_k']}")
95+
96+
if "default: 'default'" in props["collection"].get("description", ""):
97+
print("[PASS] RAGSearchTool collection description contains default.")
98+
else:
99+
print(f"[FAIL] RAGSearchTool collection description missing default: {props['collection']}")
100+
101+
print("\nAll Smoke Tests Completed.")
102+
103+
if __name__ == "__main__":
104+
asyncio.run(main())

0 commit comments

Comments
 (0)