Skip to content

Commit 016003f

Browse files
Kristian NylundKristian Nylund
authored andcommitted
added state store
1 parent c416cc7 commit 016003f

File tree

4 files changed

+40
-21
lines changed

4 files changed

+40
-21
lines changed

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from autogen_text_2_sql.custom_agents.parallel_query_solving_agent import (
1313
ParallelQuerySolvingAgent,
1414
)
15+
from state_store import StateStore
1516
from autogen_agentchat.messages import TextMessage
1617
import json
1718
import os
@@ -31,9 +32,13 @@
3132

3233

3334
class AutoGenText2Sql:
34-
def __init__(self, **kwargs):
35+
def __init__(self, state_store : StateStore, **kwargs):
3536
self.target_engine = os.environ["Text2Sql__DatabaseEngine"].upper()
3637

38+
if not state_store:
39+
raise ValueError("State store must be provided")
40+
self.state_store = state_store
41+
3742
if "use_case" not in kwargs:
3843
logging.warning(
3944
"No use case provided. It is advised to provide a use case to help the LLM reason."
@@ -250,43 +255,31 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:
250255

251256
async def process_user_message(
252257
self,
258+
thread_id: str,
253259
message_payload: UserMessagePayload,
254-
chat_history: list[InteractionPayload] = None,
255260
) -> AsyncGenerator[InteractionPayload, None]:
256261
"""Process the complete message through the unified system.
257262
258263
Args:
259264
----
265+
thread_id (str): The ID of the thread the message belongs to.
260266
task (str): The user message to process.
261-
chat_history (list[str], optional): The chat history. Defaults to None. The last message is the most recent message.
262267
injected_parameters (dict, optional): Parameters to pass to agents. Defaults to None.
263268
264269
Returns:
265270
-------
266271
dict: The response from the system.
267272
"""
268273
logging.info("Processing message: %s", message_payload.body.user_message)
269-
logging.info("Chat history: %s", chat_history)
270274

271275
agent_input = {
272276
"message": message_payload.body.user_message,
273277
"injected_parameters": message_payload.body.injected_parameters,
274278
}
275279

276-
latest_state = None
277-
if chat_history is not None:
278-
# Update input
279-
for chat in reversed(chat_history):
280-
if chat.root.payload_type in [
281-
PayloadType.ANSWER_WITH_SOURCES,
282-
PayloadType.DISAMBIGUATION_REQUESTS,
283-
]:
284-
latest_state = chat.body.assistant_state
285-
break
286-
287-
# TODO: Trim the chat history to the last message from the user
288-
if latest_state is not None:
289-
await self.agentic_flow.load_state(latest_state)
280+
state = self.state_store.get_state(thread_id)
281+
if state is not None:
282+
await self.agentic_flow.load_state(state)
290283

291284
async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)):
292285
logging.debug("Message: %s", message)
@@ -340,7 +333,7 @@ async def process_user_message(
340333
):
341334
# Get the state
342335
assistant_state = await self.agentic_flow.save_state()
343-
payload.body.assistant_state = assistant_state
336+
self.state_store.save_state(thread_id, assistant_state)
344337

345338
logging.debug("Final Payload: %s", payload)
346339

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from abc import ABC, abstractmethod
2+
3+
class StateStore(ABC):
4+
@abstractmethod
5+
def get_state(self, thread_id):
6+
pass
7+
8+
@abstractmethod
9+
def save_state(self, thread_id, state):
10+
pass
11+
12+
13+
class InMemoryStateStore(StateStore):
14+
def __init__(self):
15+
# Replace with a caching library or something to have some sort of expiry for entries so this doesn't grow forever
16+
self.cache = {}
17+
18+
def get_state(self, thread_id):
19+
return self.cache.get(thread_id)
20+
21+
def save_state(self, thread_id, state):
22+
self.cache[thread_id] = state
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from state_store import InMemoryStateStore
2+
3+
x=InMemoryStateStore()
4+
print(x.get_state("1"))
5+
x.save_state("1", {'x':2})
6+
print(x.get_state("1"))

text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ class DismabiguationRequest(InteractionPayloadBase):
5757
decomposed_user_messages: list[list[str]] = Field(
5858
default_factory=list, alias="decomposedUserMessages"
5959
)
60-
assistant_state: dict | None = Field(default=None, alias="assistantState")
6160

6261
payload_type: Literal[PayloadType.DISAMBIGUATION_REQUESTS] = Field(
6362
PayloadType.DISAMBIGUATION_REQUESTS, alias="payloadType"
@@ -86,7 +85,6 @@ class Source(InteractionPayloadBase):
8685
default_factory=list, alias="decomposedUserMessages"
8786
)
8887
sources: list[Source] = Field(default_factory=list)
89-
assistant_state: dict | None = Field(default=None, alias="assistantState")
9088

9189
payload_type: Literal[PayloadType.ANSWER_WITH_SOURCES] = Field(
9290
PayloadType.ANSWER_WITH_SOURCES, alias="payloadType"

0 commit comments

Comments
 (0)