1
+ from agents .utils import llm
2
+ from langchain .schema import Document
3
+ from typing import List
4
+ from typing_extensions import TypedDict
5
+ from langgraph .graph import StateGraph , START , END
6
+ from langgraph .types import interrupt
7
+ from pydantic import BaseModel , Field
8
+ from typing_extensions import Annotated
9
+ import operator
10
+ from langchain_core .messages import AnyMessage , get_buffer_string , SystemMessage , HumanMessage
11
+
12
+
13
+ retriever = ... # TODO: Add retriever
14
+ class GraphState (TypedDict ):
15
+ question : str
16
+ messages : Annotated [List [AnyMessage ], operator .add ] # We now track a list of messages
17
+ generation : str
18
+ documents : List [Document ]
19
+ attempted_generations : int
20
+
21
+ class InputState (TypedDict ):
22
+ question : str
23
+
24
+ class OutputState (TypedDict ):
25
+ messages : Annotated [List [AnyMessage ], operator .add ] # We output messages now in our OutputState
26
+ documents : List [Document ]
27
+
28
+ from langchain_core .messages import HumanMessage
29
+
30
+ def retrieve_documents (state : GraphState ):
31
+ """
32
+ Args:
33
+ state (dict): The current graph state
34
+ Returns:
35
+ state (dict): New key added to state, documents, that contains retrieved documents
36
+ """
37
+ print ("---RETRIEVE DOCUMENTS---" )
38
+ question = state ["question" ]
39
+ documents = retriever .invoke (question )
40
+ return {"documents" : documents }
41
+
42
+ RAG_PROMPT_WITH_CHAT_HISTORY = """You are an assistant for question-answering tasks.
43
+ Use the following pieces of retrieved context to answer the latest question in the conversation.
44
+ If you don't know the answer, just say that you don't know.
45
+ The pre-existing conversation may provide important context to the question.
46
+ Use three sentences maximum and keep the answer concise.
47
+
48
+ Existing Conversation:
49
+ {conversation}
50
+
51
+ Latest Question:
52
+ {question}
53
+
54
+ Additional Context from Documents:
55
+ {context}
56
+
57
+ Answer:"""
58
+
59
+ def generate_response (state : GraphState ):
60
+ # We interrupt the graph, and ask the user for some additional context
61
+ additional_context = interrupt ("Do you have anything else to add that you think is relevant?" )
62
+ print ("---GENERATE RESPONSE---" )
63
+ question = state ["question" ]
64
+ documents = state ["documents" ]
65
+ # For simplicity, we'll just append the additional context to the conversation history
66
+ conversation = get_buffer_string (state ["messages" ]) + additional_context
67
+ attempted_generations = state .get ("attempted_generations" , 0 )
68
+ formatted_docs = "\n \n " .join (doc .page_content for doc in documents )
69
+
70
+ rag_prompt_formatted = RAG_PROMPT_WITH_CHAT_HISTORY .format (context = formatted_docs , conversation = conversation , question = question )
71
+ generation = llm .invoke ([HumanMessage (content = rag_prompt_formatted )])
72
+ return {
73
+ "generation" : generation ,
74
+ "attempted_generations" : attempted_generations + 1
75
+ }
76
+
77
+ class GradeDocuments (BaseModel ):
78
+ is_relevant : bool = Field (
79
+ description = "The document is relevant to the question, true or false"
80
+ )
81
+
82
+ grade_documents_llm = llm .with_structured_output (GradeDocuments )
83
+ grade_documents_system_prompt = """You are a grader assessing relevance of a retrieved document to a conversation between a user and an AI assistant, and user's latest question. \n
84
+ If the document contains keyword(s) or semantic meaning related to the user question, definitely grade it as relevant. \n
85
+ It does not need to be a stringent test. The goal is to filter out erroneous retrievals that are not relevant at all. \n
86
+ Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
87
+ grade_documents_prompt = "Here is the retrieved document: \n \n {document} \n \n Here is the conversation so far: \n \n {conversation} \n \n Here is the user question: \n \n {question}"
88
+ def grade_documents (state ):
89
+ print ("---CHECK DOCUMENT RELEVANCE TO QUESTION---" )
90
+ question = state ["question" ]
91
+ documents = state ["documents" ]
92
+ conversation = get_buffer_string (state ["messages" ])
93
+
94
+ filtered_docs = []
95
+ for d in documents :
96
+ grade_documents_prompt_formatted = grade_documents_prompt .format (document = d .page_content , question = question , conversation = conversation )
97
+ score = grade_documents_llm .invoke (
98
+ [SystemMessage (content = grade_documents_system_prompt )] + [HumanMessage (content = grade_documents_prompt_formatted )]
99
+ )
100
+ grade = score .is_relevant
101
+ if grade :
102
+ print ("---GRADE: DOCUMENT RELEVANT---" )
103
+ filtered_docs .append (d )
104
+ else :
105
+ print ("---GRADE: DOCUMENT NOT RELEVANT---" )
106
+ continue
107
+ return {"documents" : filtered_docs }
108
+
109
+ def decide_to_generate (state ):
110
+ """
111
+ Args:
112
+ state (dict): The current graph state
113
+ Returns:
114
+ str: Binary decision for next node to call
115
+ """
116
+ print ("---ASSESS GRADED DOCUMENTS---" )
117
+ filtered_documents = state ["documents" ]
118
+
119
+ if not filtered_documents :
120
+ print (
121
+ "---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, END---"
122
+ )
123
+ return "none relevant"
124
+ else :
125
+ # We have relevant documents, so generate answer
126
+ print ("---DECISION: GENERATE---" )
127
+ return "some relevant"
128
+
129
+ class GradeHallucinations (BaseModel ):
130
+ """Binary score for hallucination present in generation answer."""
131
+ grounded_in_facts : bool = Field (
132
+ description = "Answer is grounded in the facts, true or false"
133
+ )
134
+
135
+ grade_hallucinations_llm = llm .with_structured_output (GradeHallucinations )
136
+ grade_hallucinations_system_prompt = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n
137
+ Give a binary score true or false. True means that the answer is grounded in / supported by the set of facts."""
138
+ grade_hallucinations_prompt = "Set of facts: \n \n {documents} \n \n LLM generation: {generation}"
139
+
140
+ ATTEMPTED_GENERATION_MAX = 3
141
+
142
+ def grade_hallucinations (state ):
143
+ print ("---CHECK HALLUCINATIONS---" )
144
+ documents = state ["documents" ]
145
+ generation = state ["generation" ]
146
+ attempted_generations = state ["attempted_generations" ]
147
+
148
+ formatted_docs = "\n \n " .join (doc .page_content for doc in documents )
149
+
150
+ grade_hallucinations_prompt_formatted = grade_hallucinations_prompt .format (
151
+ documents = formatted_docs ,
152
+ generation = generation
153
+ )
154
+
155
+ score = grade_hallucinations_llm .invoke (
156
+ [SystemMessage (content = grade_hallucinations_system_prompt )] + [HumanMessage (content = grade_hallucinations_prompt_formatted )]
157
+ )
158
+ grade = score .grounded_in_facts
159
+
160
+ # Check hallucination
161
+ if grade :
162
+ print ("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---" )
163
+ return "supported"
164
+ elif attempted_generations >= ATTEMPTED_GENERATION_MAX : # New condition!
165
+ print ("---DECISION: TOO MANY ATTEMPTS, GIVE UP---" )
166
+ raise RuntimeError ("Too many attempted generations with hallucinations, giving up." )
167
+ # return "give up" # Note: We could also do this to silently fail
168
+ else :
169
+ print ("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---" )
170
+ return "not supported"
171
+
172
+ def configure_memory (state ):
173
+ question = state ["question" ]
174
+ generation = state ["generation" ]
175
+ return {
176
+ "messages" : [HumanMessage (content = question ), generation ], # Add generation to our messages_list
177
+ "attempted_generations" : 0 , # Reset this value to 0
178
+ "documents" : [] # Reset documents to empty
179
+ }
180
+
181
+ graph_builder = StateGraph (GraphState , input = InputState , output = OutputState )
182
+ graph_builder .add_node ("retrieve_documents" , retrieve_documents )
183
+ graph_builder .add_node ("generate_response" , generate_response )
184
+ graph_builder .add_node ("grade_documents" , grade_documents )
185
+ graph_builder .add_node ("configure_memory" , configure_memory ) # New node for configuring memory
186
+
187
+ graph_builder .add_edge (START , "retrieve_documents" )
188
+ graph_builder .add_edge ("retrieve_documents" , "grade_documents" )
189
+ graph_builder .add_conditional_edges (
190
+ "grade_documents" ,
191
+ decide_to_generate ,
192
+ {
193
+ "some relevant" : "generate_response" ,
194
+ "none relevant" : END
195
+ })
196
+ graph_builder .add_conditional_edges (
197
+ "generate_response" ,
198
+ grade_hallucinations ,
199
+ {
200
+ "supported" : "configure_memory" ,
201
+ "not supported" : "generate_response"
202
+ })
203
+ graph_builder .add_edge ("configure_memory" , END )
204
+
205
+ graph = graph_builder .compile ()
0 commit comments