Skip to content

Commit 791a2e1

Browse files
authored
investigating separating out documents from the rest of the message h… (#95)
* investigating separating out documents from the rest of the message history and instructions. * preserving cohere response citations - this gets cohere specific response field that includes citations for the response text * add prepare generation node - temporarily add raw citations to response. * improving citation output prep * strip citations * fix return type for retrieve docs * pylint fixes * truncate search query if needed ... suggestions from @awilfox * pylint fix. aaarrrrgghghghghg * possible summarization fix * fixed summarization issues * typing * clean up refactor and addressing some @awilfox and copilot suggestions * mypy type error
1 parent 497cda4 commit 791a2e1

File tree

1 file changed

+58
-29
lines changed

1 file changed

+58
-29
lines changed

willa/chatbot/graph_manager.py

Lines changed: 58 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""Manages the shared state and workflow for Willa chatbots."""
2-
from typing import Any, Optional, Annotated, NotRequired
2+
from typing import Optional, Annotated, NotRequired
33
from typing_extensions import TypedDict
44

5+
from langchain_core.documents import Document
56
from langchain_core.language_models import BaseChatModel
67
from langchain_core.messages import ChatMessage, HumanMessage, AIMessage
78
from langchain_core.vectorstores.base import VectorStore
@@ -19,10 +20,10 @@ class WillaChatbotState(TypedDict):
1920
messages: Annotated[list[AnyMessage], add_messages]
2021
filtered_messages: NotRequired[list[AnyMessage]]
2122
summarized_messages: NotRequired[list[AnyMessage]]
22-
docs_context: NotRequired[str]
23+
messages_for_generation: NotRequired[list[AnyMessage]]
2324
search_query: NotRequired[str]
2425
tind_metadata: NotRequired[str]
25-
context: NotRequired[dict[str, Any]]
26+
documents: NotRequired[list[dict[str, str]]]
2627

2728

2829
class GraphManager: # pylint: disable=too-few-public-methods
@@ -51,13 +52,15 @@ def _create_workflow(self) -> CompiledStateGraph:
5152
workflow.add_node("summarize", summarization_node)
5253
workflow.add_node("prepare_search", self._prepare_search_query)
5354
workflow.add_node("retrieve_context", self._retrieve_context)
55+
workflow.add_node("prepare_for_generation", self._prepare_for_generation)
5456
workflow.add_node("generate_response", self._generate_response)
5557

5658
# Define edges
5759
workflow.add_edge("filter_messages", "summarize")
5860
workflow.add_edge("summarize", "prepare_search")
5961
workflow.add_edge("prepare_search", "retrieve_context")
60-
workflow.add_edge("retrieve_context", "generate_response")
62+
workflow.add_edge("retrieve_context", "prepare_for_generation")
63+
workflow.add_edge("prepare_for_generation", "generate_response")
6164

6265
workflow.set_entry_point("filter_messages")
6366
workflow.set_finish_point("generate_response")
@@ -68,7 +71,10 @@ def _filter_messages(self, state: WillaChatbotState) -> dict[str, list[AnyMessag
6871
"""Filter out TIND messages from the conversation history."""
6972
messages = state["messages"]
7073

71-
filtered = [msg for msg in messages if 'tind' not in msg.response_metadata]
74+
filtered: list[AnyMessage] = [
75+
msg for msg in messages
76+
if "tind" not in getattr(msg, "response_metadata", {}) and msg.type != "system"
77+
]
7278
return {"filtered_messages": filtered}
7379

7480
def _prepare_search_query(self, state: WillaChatbotState) -> dict[str, str]:
@@ -79,60 +85,83 @@ def _prepare_search_query(self, state: WillaChatbotState) -> dict[str, str]:
7985

8086
# summarization may include a system message as well as any human or ai messages
8187
search_query = '\n'.join(str(msg.content) for msg in messages if hasattr(msg, 'content'))
88+
89+
# if summarization fails or some other issue, truncate to the last 2048 characters
90+
if len(search_query) > 2048:
91+
search_query = search_query[-2048:]
92+
8293
return {"search_query": search_query}
8394

84-
def _retrieve_context(self, state: WillaChatbotState) -> dict[str, str]:
95+
def _format_retrieved_documents(self, matching_docs: list[Document]) -> list[dict[str, str]]:
96+
"""Format documents from vector store into a list of dictionaries."""
97+
formatted_documents: list[dict[str, str]] = []
98+
for i, doc in enumerate(matching_docs, 1):
99+
tind_metadata = doc.metadata.get('tind_metadata', {})
100+
tind_id = tind_metadata.get('tind_id', [''])[0]
101+
formatted_documents.append({
102+
"id": f"{i}_{tind_id}",
103+
"page_content": doc.page_content,
104+
"title": tind_metadata.get('title', [''])[0],
105+
"project": tind_metadata.get('isPartOf', [''])[0],
106+
"tind_link": format_tind_context.get_tind_url(tind_id)
107+
})
108+
return formatted_documents
109+
110+
def _retrieve_context(self, state: WillaChatbotState) -> dict[str, str | list[dict[str, str]]]:
85111
"""Retrieve relevant context from vector store."""
86112
search_query = state.get("search_query", "")
87113
vector_store = self._vector_store
88114

89115
if not search_query or not vector_store:
90-
return {"docs_context": "", "tind_metadata": ""}
116+
return {"tind_metadata": "", "documents": []}
91117

92118
# Search for relevant documents
93119
retriever = vector_store.as_retriever(search_kwargs={"k": int(CONFIG['K_VALUE'])})
94120
matching_docs = retriever.invoke(search_query)
121+
formatted_documents = self._format_retrieved_documents(matching_docs)
95122

96-
# Format context and metadata
97-
docs_context = '\n\n'.join(doc.page_content for doc in matching_docs)
123+
# Format tind metadata
98124
tind_metadata = format_tind_context.get_tind_context(matching_docs)
99125

100-
return {"docs_context": docs_context, "tind_metadata": tind_metadata}
126+
return {"tind_metadata": tind_metadata, "documents": formatted_documents}
101127

102-
# This should be refactored probably. Very bulky
103-
def _generate_response(self, state: WillaChatbotState) -> dict[str, list[AnyMessage]]:
104-
"""Generate response using the model."""
128+
def _prepare_for_generation(self, state: WillaChatbotState) -> dict[str, list[AnyMessage]]:
129+
"""Prepare the current and past messages for response generation."""
105130
messages = state["messages"]
106131
summarized_conversation = state.get("summarized_messages", messages)
107-
docs_context = state.get("docs_context", "")
108-
tind_metadata = state.get("tind_metadata", "")
109-
model = self._model
110-
111-
if not model:
112-
return {"messages": [AIMessage(content="Model not available.")]}
113-
114-
# Get the latest human message
115-
latest_message = next(
116-
(msg for msg in reversed(messages) if isinstance(msg, HumanMessage)),
117-
None
118-
)
119132

120-
if not latest_message:
133+
if not any(isinstance(msg, HumanMessage) for msg in messages):
121134
return {"messages": [AIMessage(content="I'm sorry, I didn't receive a question.")]}
122135

123136
prompt = get_langfuse_prompt()
124-
system_messages = prompt.invoke({'context': docs_context,
125-
'question': latest_message.content})
137+
system_messages = prompt.invoke({})
138+
126139
if hasattr(system_messages, "messages"):
127140
all_messages = summarized_conversation + system_messages.messages
128141
else:
129142
all_messages = summarized_conversation + [system_messages]
130143

144+
return {"messages_for_generation": all_messages}
145+
146+
def _generate_response(self, state: WillaChatbotState) -> dict[str, list[AnyMessage]]:
147+
"""Generate response using the model."""
148+
tind_metadata = state.get("tind_metadata", "")
149+
model = self._model
150+
documents = state.get("documents", [])
151+
messages = state.get("messages_for_generation") or state.get("messages", [])
152+
153+
if not model:
154+
return {"messages": [AIMessage(content="Model not available.")]}
155+
131156
# Get response from model
132-
response = model.invoke(all_messages)
157+
response = model.invoke(
158+
messages,
159+
additional_model_request_fields={"documents": documents}
160+
)
133161

134162
# Create clean response content
135163
response_content = str(response.content) if hasattr(response, 'content') else str(response)
164+
136165
response_messages: list[AnyMessage] = [AIMessage(content=response_content),
137166
ChatMessage(content=tind_metadata, role='TIND',
138167
response_metadata={'tind': True})]

0 commit comments

Comments
 (0)