11"""Manages the shared state and workflow for Willa chatbots."""
2- from typing import Any , Optional , Annotated , NotRequired
2+ from typing import Optional , Annotated , NotRequired
33from typing_extensions import TypedDict
44
5+ from langchain_core .documents import Document
56from langchain_core .language_models import BaseChatModel
67from langchain_core .messages import ChatMessage , HumanMessage , AIMessage
78from langchain_core .vectorstores .base import VectorStore
@@ -19,10 +20,10 @@ class WillaChatbotState(TypedDict):
1920 messages : Annotated [list [AnyMessage ], add_messages ]
2021 filtered_messages : NotRequired [list [AnyMessage ]]
2122 summarized_messages : NotRequired [list [AnyMessage ]]
22- docs_context : NotRequired [str ]
23+ messages_for_generation : NotRequired [list [ AnyMessage ] ]
2324 search_query : NotRequired [str ]
2425 tind_metadata : NotRequired [str ]
25- context : NotRequired [dict [str , Any ]]
26+ documents : NotRequired [list [ dict [str , str ] ]]
2627
2728
2829class GraphManager : # pylint: disable=too-few-public-methods
@@ -51,13 +52,15 @@ def _create_workflow(self) -> CompiledStateGraph:
5152 workflow .add_node ("summarize" , summarization_node )
5253 workflow .add_node ("prepare_search" , self ._prepare_search_query )
5354 workflow .add_node ("retrieve_context" , self ._retrieve_context )
55+ workflow .add_node ("prepare_for_generation" , self ._prepare_for_generation )
5456 workflow .add_node ("generate_response" , self ._generate_response )
5557
5658 # Define edges
5759 workflow .add_edge ("filter_messages" , "summarize" )
5860 workflow .add_edge ("summarize" , "prepare_search" )
5961 workflow .add_edge ("prepare_search" , "retrieve_context" )
60- workflow .add_edge ("retrieve_context" , "generate_response" )
62+ workflow .add_edge ("retrieve_context" , "prepare_for_generation" )
63+ workflow .add_edge ("prepare_for_generation" , "generate_response" )
6164
6265 workflow .set_entry_point ("filter_messages" )
6366 workflow .set_finish_point ("generate_response" )
@@ -68,7 +71,10 @@ def _filter_messages(self, state: WillaChatbotState) -> dict[str, list[AnyMessag
6871 """Filter out TIND messages from the conversation history."""
6972 messages = state ["messages" ]
7073
71- filtered = [msg for msg in messages if 'tind' not in msg .response_metadata ]
74+ filtered : list [AnyMessage ] = [
75+ msg for msg in messages
76+ if "tind" not in getattr (msg , "response_metadata" , {}) and msg .type != "system"
77+ ]
7278 return {"filtered_messages" : filtered }
7379
7480 def _prepare_search_query (self , state : WillaChatbotState ) -> dict [str , str ]:
@@ -79,60 +85,83 @@ def _prepare_search_query(self, state: WillaChatbotState) -> dict[str, str]:
7985
8086 # summarization may include a system message as well as any human or ai messages
8187 search_query = '\n ' .join (str (msg .content ) for msg in messages if hasattr (msg , 'content' ))
88+
89+ # if summarization fails or some other issue, truncate to the last 2048 characters
90+ if len (search_query ) > 2048 :
91+ search_query = search_query [- 2048 :]
92+
8293 return {"search_query" : search_query }
8394
84- def _retrieve_context (self , state : WillaChatbotState ) -> dict [str , str ]:
95+ def _format_retrieved_documents (self , matching_docs : list [Document ]) -> list [dict [str , str ]]:
96+ """Format documents from vector store into a list of dictionaries."""
97+ formatted_documents : list [dict [str , str ]] = []
98+ for i , doc in enumerate (matching_docs , 1 ):
99+ tind_metadata = doc .metadata .get ('tind_metadata' , {})
100+ tind_id = tind_metadata .get ('tind_id' , ['' ])[0 ]
101+ formatted_documents .append ({
102+ "id" : f"{ i } _{ tind_id } " ,
103+ "page_content" : doc .page_content ,
104+ "title" : tind_metadata .get ('title' , ['' ])[0 ],
105+ "project" : tind_metadata .get ('isPartOf' , ['' ])[0 ],
106+ "tind_link" : format_tind_context .get_tind_url (tind_id )
107+ })
108+ return formatted_documents
109+
110+ def _retrieve_context (self , state : WillaChatbotState ) -> dict [str , str | list [dict [str , str ]]]:
85111 """Retrieve relevant context from vector store."""
86112 search_query = state .get ("search_query" , "" )
87113 vector_store = self ._vector_store
88114
89115 if not search_query or not vector_store :
90- return {"docs_context " : "" , "tind_metadata " : "" }
116+ return {"tind_metadata " : "" , "documents " : [] }
91117
92118 # Search for relevant documents
93119 retriever = vector_store .as_retriever (search_kwargs = {"k" : int (CONFIG ['K_VALUE' ])})
94120 matching_docs = retriever .invoke (search_query )
121+ formatted_documents = self ._format_retrieved_documents (matching_docs )
95122
96- # Format context and metadata
97- docs_context = '\n \n ' .join (doc .page_content for doc in matching_docs )
123+ # Format tind metadata
98124 tind_metadata = format_tind_context .get_tind_context (matching_docs )
99125
100- return {"docs_context " : docs_context , "tind_metadata " : tind_metadata }
126+ return {"tind_metadata " : tind_metadata , "documents " : formatted_documents }
101127
102- # This should be refactored probably. Very bulky
103- def _generate_response (self , state : WillaChatbotState ) -> dict [str , list [AnyMessage ]]:
104- """Generate response using the model."""
128+ def _prepare_for_generation (self , state : WillaChatbotState ) -> dict [str , list [AnyMessage ]]:
129+ """Prepare the current and past messages for response generation."""
105130 messages = state ["messages" ]
106131 summarized_conversation = state .get ("summarized_messages" , messages )
107- docs_context = state .get ("docs_context" , "" )
108- tind_metadata = state .get ("tind_metadata" , "" )
109- model = self ._model
110-
111- if not model :
112- return {"messages" : [AIMessage (content = "Model not available." )]}
113-
114- # Get the latest human message
115- latest_message = next (
116- (msg for msg in reversed (messages ) if isinstance (msg , HumanMessage )),
117- None
118- )
119132
120- if not latest_message :
133+ if not any ( isinstance ( msg , HumanMessage ) for msg in messages ) :
121134 return {"messages" : [AIMessage (content = "I'm sorry, I didn't receive a question." )]}
122135
123136 prompt = get_langfuse_prompt ()
124- system_messages = prompt .invoke ({'context' : docs_context ,
125- 'question' : latest_message . content })
137+ system_messages = prompt .invoke ({})
138+
126139 if hasattr (system_messages , "messages" ):
127140 all_messages = summarized_conversation + system_messages .messages
128141 else :
129142 all_messages = summarized_conversation + [system_messages ]
130143
144+ return {"messages_for_generation" : all_messages }
145+
146+ def _generate_response (self , state : WillaChatbotState ) -> dict [str , list [AnyMessage ]]:
147+ """Generate response using the model."""
148+ tind_metadata = state .get ("tind_metadata" , "" )
149+ model = self ._model
150+ documents = state .get ("documents" , [])
151+ messages = state .get ("messages_for_generation" ) or state .get ("messages" , [])
152+
153+ if not model :
154+ return {"messages" : [AIMessage (content = "Model not available." )]}
155+
131156 # Get response from model
132- response = model .invoke (all_messages )
157+ response = model .invoke (
158+ messages ,
159+ additional_model_request_fields = {"documents" : documents }
160+ )
133161
134162 # Create clean response content
135163 response_content = str (response .content ) if hasattr (response , 'content' ) else str (response )
164+
136165 response_messages : list [AnyMessage ] = [AIMessage (content = response_content ),
137166 ChatMessage (content = tind_metadata , role = 'TIND' ,
138167 response_metadata = {'tind' : True })]
0 commit comments