|
| 1 | +"""Graph Agent demo using the RAG component (ingest -> retrieve -> answer). |
| 2 | +
|
| 3 | +Offline by default (hash embeddings + in-memory store). |
| 4 | +""" |
| 5 | + |
| 6 | +import asyncio |
| 7 | +from typing import Dict, Any, List |
| 8 | + |
| 9 | +from spoon_ai.graph import StateGraph |
| 10 | +from spoon_ai.agents.graph_agent import GraphAgent |
| 11 | +from spoon_ai.chat import ChatBot |
| 12 | + |
| 13 | +from spoon_ai.rag import ( |
| 14 | + get_default_config, |
| 15 | + get_embedding_client, |
| 16 | + get_vector_store, |
| 17 | + RagIndex, |
| 18 | + RagRetriever, |
| 19 | + RagQA, |
| 20 | +) |
| 21 | + |
| 22 | + |
| 23 | +def build_pipeline(): |
| 24 | + cfg = get_default_config() |
| 25 | + store = get_vector_store(cfg.backend) |
| 26 | + embed = get_embedding_client( |
| 27 | + cfg.embeddings_provider, |
| 28 | + openai_api_key=cfg.openai_api_key, |
| 29 | + openai_model=cfg.openai_embeddings_model, |
| 30 | + anyroute_api_key=cfg.anyroute_api_key, |
| 31 | + anyroute_base_url=cfg.anyroute_base_url, |
| 32 | + anyroute_model=cfg.anyroute_model, |
| 33 | + ) |
| 34 | + |
| 35 | + index = RagIndex(config=cfg, store=store, embeddings=embed) |
| 36 | + retriever = RagRetriever(config=cfg, store=store, embeddings=embed) |
| 37 | + # Offline-friendly: if RAG_FAKE_QA=1, avoid initializing ChatBot |
| 38 | + import os |
| 39 | + llm = None if os.getenv("RAG_FAKE_QA") == "1" else ChatBot() |
| 40 | + qa = RagQA(config=cfg, llm=llm) |
| 41 | + |
| 42 | + def ingest_node(state: Dict[str, Any]) -> Dict[str, Any]: |
| 43 | + inputs = state.get("inputs", []) |
| 44 | + n = index.ingest(inputs) |
| 45 | + return {"ingested": n} |
| 46 | + |
| 47 | + def retrieve_node(state: Dict[str, Any]) -> Dict[str, Any]: |
| 48 | + q = state.get("question", "") |
| 49 | + chunks = retriever.retrieve(q) |
| 50 | + return {"chunks": chunks, "context": retriever.build_context(chunks)} |
| 51 | + |
| 52 | + async def answer_node(state: Dict[str, Any]) -> Dict[str, Any]: |
| 53 | + q = state.get("question", "") |
| 54 | + chunks = state.get("chunks", []) |
| 55 | + res = await qa.answer(q, chunks) |
| 56 | + return {"answer": res.answer, "citations": res.citations} |
| 57 | + |
| 58 | + graph = StateGraph(dict) |
| 59 | + graph.add_node("RAGIngestNode", ingest_node) |
| 60 | + graph.add_node("RAGRetrieveNode", retrieve_node) |
| 61 | + graph.add_node("RAGAnswerNode", answer_node) |
| 62 | + graph.add_edge("RAGIngestNode", "RAGRetrieveNode") |
| 63 | + graph.add_edge("RAGRetrieveNode", "RAGAnswerNode") |
| 64 | + graph.set_entry_point("RAGIngestNode") |
| 65 | + return graph.compile() |
| 66 | + |
| 67 | + |
| 68 | +async def main(): |
| 69 | + compiled = build_pipeline() |
| 70 | + agent = GraphAgent(name="rag-graph", graph=compiled.graph) |
| 71 | + initial_state = { |
| 72 | + "inputs": ["./doc"], |
| 73 | + "question": "How do I install the SDK?", |
| 74 | + } |
| 75 | + result = await compiled.ainvoke(initial_state) |
| 76 | + print("\n== RAG Graph Agent Demo ==") |
| 77 | + print("Answer:", result.get("answer")) |
| 78 | + print("Citations:", result.get("citations")) |
| 79 | + |
| 80 | + |
| 81 | +if __name__ == "__main__": |
| 82 | + asyncio.run(main()) |
0 commit comments