Skip to content

Commit 1b6b8bb

Browse files
PeriniMsis0k0
authored andcommitted
feat: add langgraph studio agents
1 parent 831d751 commit 1b6b8bb

File tree

6 files changed

+556
-0
lines changed

6 files changed

+556
-0
lines changed

agents/corrective_rag.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
from agents.utils import llm
2+
from langchain.schema import Document
3+
from typing import List
4+
from typing_extensions import TypedDict
5+
from langgraph.graph import StateGraph, START, END
6+
from pydantic import BaseModel, Field
7+
8+
9+
retriever = ... # TODO: Add retriever
10+
11+
class GraphState(TypedDict):
12+
question: str
13+
generation: str
14+
documents: List[Document]
15+
attempted_generations: int
16+
17+
class InputState(TypedDict):
18+
question: str
19+
20+
class OutputState(TypedDict):
21+
generation: str
22+
documents: List[Document]
23+
24+
from langchain_core.messages import HumanMessage
25+
26+
def retrieve_documents(state: GraphState):
27+
"""
28+
Args:
29+
state (dict): The current graph state
30+
Returns:
31+
state (dict): New key added to state, documents, that contains retrieved documents
32+
"""
33+
print("---RETRIEVE DOCUMENTS---")
34+
question = state["question"]
35+
documents = retriever.invoke(question)
36+
return {"documents": documents}
37+
38+
RAG_PROMPT = """You are an assistant for question-answering tasks.
39+
Use the following pieces of retrieved context to answer the question.
40+
If you don't know the answer, just say that you don't know.
41+
Use three sentences maximum and keep the answer concise.
42+
43+
Question: {question}
44+
Context: {context}
45+
Answer:"""
46+
47+
def generate_response(state: GraphState):
48+
print("---GENERATE RESPONSE---")
49+
question = state["question"]
50+
documents = state["documents"]
51+
attempted_generations = state.get("attempted_generations", 0) # By default we set attempted_generations to 0 if it doesn't exist yet
52+
formatted_docs = "\n\n".join(doc.page_content for doc in documents)
53+
54+
# Invoke our LLM with our RAG prompt
55+
rag_prompt_formatted = RAG_PROMPT.format(context=formatted_docs, question=question)
56+
generation = llm.invoke([HumanMessage(content=rag_prompt_formatted)])
57+
return {
58+
"generation": generation,
59+
"attempted_generations": attempted_generations + 1 # In our state update, we increment attempted_generations
60+
}
61+
62+
class GradeDocuments(BaseModel):
63+
is_relevant: bool = Field(
64+
description="The document is relevant to the question, true or false"
65+
)
66+
67+
grade_documents_llm = llm.with_structured_output(GradeDocuments)
68+
grade_documents_system_prompt = """You are a grader assessing relevance of a retrieved document to a user question. \n
69+
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
70+
It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
71+
Give a binary score true or false to indicate whether the document is relevant to the question."""
72+
grade_documents_prompt = "Here is the retrieved document: \n\n {document} \n\n Here is the user question: \n\n {question}"
73+
74+
from langchain_core.messages import SystemMessage
75+
76+
def grade_documents(state):
77+
"""
78+
Args:
79+
state (dict): The current graph state
80+
Returns:
81+
state (dict): Updates documents key with only filtered relevant documents
82+
"""
83+
print("---GRADE DOCUMENTS---")
84+
question = state["question"]
85+
documents = state["documents"]
86+
# Score each doc
87+
filtered_docs = []
88+
for d in documents:
89+
grade_documents_prompt_formatted = grade_documents_prompt.format(document=d.page_content, question=question)
90+
score = grade_documents_llm.invoke(
91+
[SystemMessage(content=grade_documents_system_prompt)] + [HumanMessage(content=grade_documents_prompt_formatted)]
92+
)
93+
grade = score.is_relevant
94+
if grade:
95+
print("---GRADE: DOCUMENT RELEVANT---")
96+
filtered_docs.append(d)
97+
else:
98+
print("---GRADE: DOCUMENT NOT RELEVANT---")
99+
continue
100+
return {"documents": filtered_docs}
101+
102+
def decide_to_generate(state):
103+
"""
104+
Args:
105+
state (dict): The current graph state
106+
Returns:
107+
str: Binary decision for next node to call
108+
"""
109+
print("---ASSESS GRADED DOCUMENTS---")
110+
filtered_documents = state["documents"]
111+
112+
if not filtered_documents:
113+
print(
114+
"---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, END---"
115+
)
116+
return "none relevant"
117+
else:
118+
# We have relevant documents, so generate answer
119+
print("---DECISION: GENERATE---")
120+
return "some relevant"
121+
122+
class GradeHallucinations(BaseModel):
123+
"""Binary score for hallucination present in generation answer."""
124+
grounded_in_facts: bool = Field(
125+
description="Answer is grounded in the facts, true or false"
126+
)
127+
128+
grade_hallucinations_llm = llm.with_structured_output(GradeHallucinations)
129+
grade_hallucinations_system_prompt = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n
130+
Give a binary score true or false. True means that the answer is grounded in / supported by the set of facts."""
131+
grade_hallucinations_prompt = "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"
132+
133+
ATTEMPTED_GENERATION_MAX = 3
134+
135+
def grade_hallucinations(state):
136+
print("---CHECK HALLUCINATIONS---")
137+
documents = state["documents"]
138+
generation = state["generation"]
139+
attempted_generations = state["attempted_generations"]
140+
141+
formatted_docs = "\n\n".join(doc.page_content for doc in documents)
142+
143+
grade_hallucinations_prompt_formatted = grade_hallucinations_prompt.format(
144+
documents=formatted_docs,
145+
generation=generation
146+
)
147+
148+
score = grade_hallucinations_llm.invoke(
149+
[SystemMessage(content=grade_hallucinations_system_prompt)] + [HumanMessage(content=grade_hallucinations_prompt_formatted)]
150+
)
151+
grade = score.grounded_in_facts
152+
153+
# Check hallucination
154+
if grade:
155+
print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
156+
return "supported"
157+
elif attempted_generations >= ATTEMPTED_GENERATION_MAX: # New condition!
158+
print("---DECISION: TOO MANY ATTEMPTS, GIVE UP---")
159+
raise RuntimeError("Too many attempted generations with hallucinations, giving up.")
160+
# return "give up" # Note: We could also do this to silently fail
161+
else:
162+
print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
163+
return "not supported"
164+
165+
graph_builder = StateGraph(GraphState, input=InputState, output=OutputState)
166+
graph_builder.add_node("retrieve_documents", retrieve_documents)
167+
graph_builder.add_node("generate_response", generate_response)
168+
graph_builder.add_node("grade_documents", grade_documents)
169+
graph_builder.add_edge(START, "retrieve_documents")
170+
graph_builder.add_edge("retrieve_documents", "grade_documents")
171+
graph_builder.add_conditional_edges(
172+
"grade_documents",
173+
decide_to_generate,
174+
{
175+
"some relevant": "generate_response",
176+
"none relevant": END
177+
})
178+
graph_builder.add_conditional_edges(
179+
"generate_response",
180+
grade_hallucinations,
181+
{
182+
"supported": END,
183+
"not supported": "generate_response"
184+
})
185+
186+
graph = graph_builder.compile()

agents/memory_hil_rag.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
from agents.utils import llm
2+
from langchain.schema import Document
3+
from typing import List
4+
from typing_extensions import TypedDict
5+
from langgraph.graph import StateGraph, START, END
6+
from langgraph.types import interrupt
7+
from pydantic import BaseModel, Field
8+
from typing_extensions import Annotated
9+
import operator
10+
from langchain_core.messages import AnyMessage, get_buffer_string, SystemMessage, HumanMessage
11+
12+
13+
retriever = ... # TODO: Add retriever
14+
class GraphState(TypedDict):
15+
question: str
16+
messages: Annotated[List[AnyMessage], operator.add] # We now track a list of messages
17+
generation: str
18+
documents: List[Document]
19+
attempted_generations: int
20+
21+
class InputState(TypedDict):
22+
question: str
23+
24+
class OutputState(TypedDict):
25+
messages: Annotated[List[AnyMessage], operator.add] # We output messages now in our OutputState
26+
documents: List[Document]
27+
28+
from langchain_core.messages import HumanMessage
29+
30+
def retrieve_documents(state: GraphState):
31+
"""
32+
Args:
33+
state (dict): The current graph state
34+
Returns:
35+
state (dict): New key added to state, documents, that contains retrieved documents
36+
"""
37+
print("---RETRIEVE DOCUMENTS---")
38+
question = state["question"]
39+
documents = retriever.invoke(question)
40+
return {"documents": documents}
41+
42+
RAG_PROMPT_WITH_CHAT_HISTORY = """You are an assistant for question-answering tasks.
43+
Use the following pieces of retrieved context to answer the latest question in the conversation.
44+
If you don't know the answer, just say that you don't know.
45+
The pre-existing conversation may provide important context to the question.
46+
Use three sentences maximum and keep the answer concise.
47+
48+
Existing Conversation:
49+
{conversation}
50+
51+
Latest Question:
52+
{question}
53+
54+
Additional Context from Documents:
55+
{context}
56+
57+
Answer:"""
58+
59+
def generate_response(state: GraphState):
60+
# We interrupt the graph, and ask the user for some additional context
61+
additional_context = interrupt("Do you have anything else to add that you think is relevant?")
62+
print("---GENERATE RESPONSE---")
63+
question = state["question"]
64+
documents = state["documents"]
65+
# For simplicity, we'll just append the additional context to the conversation history
66+
conversation = get_buffer_string(state["messages"]) + additional_context
67+
attempted_generations = state.get("attempted_generations", 0)
68+
formatted_docs = "\n\n".join(doc.page_content for doc in documents)
69+
70+
rag_prompt_formatted = RAG_PROMPT_WITH_CHAT_HISTORY.format(context=formatted_docs, conversation=conversation, question=question)
71+
generation = llm.invoke([HumanMessage(content=rag_prompt_formatted)])
72+
return {
73+
"generation": generation,
74+
"attempted_generations": attempted_generations + 1
75+
}
76+
77+
class GradeDocuments(BaseModel):
78+
is_relevant: bool = Field(
79+
description="The document is relevant to the question, true or false"
80+
)
81+
82+
grade_documents_llm = llm.with_structured_output(GradeDocuments)
83+
grade_documents_system_prompt = """You are a grader assessing relevance of a retrieved document to a conversation between a user and an AI assistant, and user's latest question. \n
84+
If the document contains keyword(s) or semantic meaning related to the user question, definitely grade it as relevant. \n
85+
It does not need to be a stringent test. The goal is to filter out erroneous retrievals that are not relevant at all. \n
86+
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
87+
grade_documents_prompt = "Here is the retrieved document: \n\n {document} \n\n Here is the conversation so far: \n\n {conversation} \n\n Here is the user question: \n\n {question}"
88+
def grade_documents(state):
89+
print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
90+
question = state["question"]
91+
documents = state["documents"]
92+
conversation = get_buffer_string(state["messages"])
93+
94+
filtered_docs = []
95+
for d in documents:
96+
grade_documents_prompt_formatted = grade_documents_prompt.format(document=d.page_content, question=question, conversation=conversation)
97+
score = grade_documents_llm.invoke(
98+
[SystemMessage(content=grade_documents_system_prompt)] + [HumanMessage(content=grade_documents_prompt_formatted)]
99+
)
100+
grade = score.is_relevant
101+
if grade:
102+
print("---GRADE: DOCUMENT RELEVANT---")
103+
filtered_docs.append(d)
104+
else:
105+
print("---GRADE: DOCUMENT NOT RELEVANT---")
106+
continue
107+
return {"documents": filtered_docs}
108+
109+
def decide_to_generate(state):
110+
"""
111+
Args:
112+
state (dict): The current graph state
113+
Returns:
114+
str: Binary decision for next node to call
115+
"""
116+
print("---ASSESS GRADED DOCUMENTS---")
117+
filtered_documents = state["documents"]
118+
119+
if not filtered_documents:
120+
print(
121+
"---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, END---"
122+
)
123+
return "none relevant"
124+
else:
125+
# We have relevant documents, so generate answer
126+
print("---DECISION: GENERATE---")
127+
return "some relevant"
128+
129+
class GradeHallucinations(BaseModel):
130+
"""Binary score for hallucination present in generation answer."""
131+
grounded_in_facts: bool = Field(
132+
description="Answer is grounded in the facts, true or false"
133+
)
134+
135+
grade_hallucinations_llm = llm.with_structured_output(GradeHallucinations)
136+
grade_hallucinations_system_prompt = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n
137+
Give a binary score true or false. True means that the answer is grounded in / supported by the set of facts."""
138+
grade_hallucinations_prompt = "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"
139+
140+
ATTEMPTED_GENERATION_MAX = 3
141+
142+
def grade_hallucinations(state):
143+
print("---CHECK HALLUCINATIONS---")
144+
documents = state["documents"]
145+
generation = state["generation"]
146+
attempted_generations = state["attempted_generations"]
147+
148+
formatted_docs = "\n\n".join(doc.page_content for doc in documents)
149+
150+
grade_hallucinations_prompt_formatted = grade_hallucinations_prompt.format(
151+
documents=formatted_docs,
152+
generation=generation
153+
)
154+
155+
score = grade_hallucinations_llm.invoke(
156+
[SystemMessage(content=grade_hallucinations_system_prompt)] + [HumanMessage(content=grade_hallucinations_prompt_formatted)]
157+
)
158+
grade = score.grounded_in_facts
159+
160+
# Check hallucination
161+
if grade:
162+
print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
163+
return "supported"
164+
elif attempted_generations >= ATTEMPTED_GENERATION_MAX: # New condition!
165+
print("---DECISION: TOO MANY ATTEMPTS, GIVE UP---")
166+
raise RuntimeError("Too many attempted generations with hallucinations, giving up.")
167+
# return "give up" # Note: We could also do this to silently fail
168+
else:
169+
print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
170+
return "not supported"
171+
172+
def configure_memory(state):
173+
question = state["question"]
174+
generation = state["generation"]
175+
return {
176+
"messages": [HumanMessage(content=question), generation], # Add generation to our messages_list
177+
"attempted_generations": 0, # Reset this value to 0
178+
"documents": [] # Reset documents to empty
179+
}
180+
181+
graph_builder = StateGraph(GraphState, input=InputState, output=OutputState)
182+
graph_builder.add_node("retrieve_documents", retrieve_documents)
183+
graph_builder.add_node("generate_response", generate_response)
184+
graph_builder.add_node("grade_documents", grade_documents)
185+
graph_builder.add_node("configure_memory", configure_memory) # New node for configuring memory
186+
187+
graph_builder.add_edge(START, "retrieve_documents")
188+
graph_builder.add_edge("retrieve_documents", "grade_documents")
189+
graph_builder.add_conditional_edges(
190+
"grade_documents",
191+
decide_to_generate,
192+
{
193+
"some relevant": "generate_response",
194+
"none relevant": END
195+
})
196+
graph_builder.add_conditional_edges(
197+
"generate_response",
198+
grade_hallucinations,
199+
{
200+
"supported": "configure_memory",
201+
"not supported": "generate_response"
202+
})
203+
graph_builder.add_edge("configure_memory", END)
204+
205+
graph = graph_builder.compile()

0 commit comments

Comments
 (0)