|
12 | 12 | from autogen_text_2_sql.custom_agents.parallel_query_solving_agent import ( |
13 | 13 | ParallelQuerySolvingAgent, |
14 | 14 | ) |
| 15 | +from state_store import StateStore |
15 | 16 | from autogen_agentchat.messages import TextMessage |
16 | 17 | import json |
17 | 18 | import os |
|
31 | 32 |
|
32 | 33 |
|
33 | 34 | class AutoGenText2Sql: |
34 | | - def __init__(self, **kwargs): |
| 35 | + def __init__(self, state_store : StateStore, **kwargs): |
35 | 36 | self.target_engine = os.environ["Text2Sql__DatabaseEngine"].upper() |
36 | 37 |
|
| 38 | + if not state_store: |
| 39 | + raise ValueError("State store must be provided") |
| 40 | + self.state_store = state_store |
| 41 | + |
37 | 42 | if "use_case" not in kwargs: |
38 | 43 | logging.warning( |
39 | 44 | "No use case provided. It is advised to provide a use case to help the LLM reason." |
@@ -250,43 +255,31 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload: |
250 | 255 |
|
251 | 256 | async def process_user_message( |
252 | 257 | self, |
| 258 | + thread_id: str, |
253 | 259 | message_payload: UserMessagePayload, |
254 | | - chat_history: list[InteractionPayload] = None, |
255 | 260 | ) -> AsyncGenerator[InteractionPayload, None]: |
256 | 261 | """Process the complete message through the unified system. |
257 | 262 |
|
258 | 263 | Args: |
259 | 264 | ---- |
| 265 | + thread_id (str): The ID of the thread the message belongs to. |
260 | 266 | task (str): The user message to process. |
261 | | - chat_history (list[str], optional): The chat history. Defaults to None. The last message is the most recent message. |
262 | 267 | injected_parameters (dict, optional): Parameters to pass to agents. Defaults to None. |
263 | 268 |
|
264 | 269 | Returns: |
265 | 270 | ------- |
266 | 271 | dict: The response from the system. |
267 | 272 | """ |
268 | 273 | logging.info("Processing message: %s", message_payload.body.user_message) |
269 | | - logging.info("Chat history: %s", chat_history) |
270 | 274 |
|
271 | 275 | agent_input = { |
272 | 276 | "message": message_payload.body.user_message, |
273 | 277 | "injected_parameters": message_payload.body.injected_parameters, |
274 | 278 | } |
275 | 279 |
|
276 | | - latest_state = None |
277 | | - if chat_history is not None: |
278 | | - # Update input |
279 | | - for chat in reversed(chat_history): |
280 | | - if chat.root.payload_type in [ |
281 | | - PayloadType.ANSWER_WITH_SOURCES, |
282 | | - PayloadType.DISAMBIGUATION_REQUESTS, |
283 | | - ]: |
284 | | - latest_state = chat.body.assistant_state |
285 | | - break |
286 | | - |
287 | | - # TODO: Trim the chat history to the last message from the user |
288 | | - if latest_state is not None: |
289 | | - await self.agentic_flow.load_state(latest_state) |
| 280 | + state = self.state_store.get_state(thread_id) |
| 281 | + if state is not None: |
| 282 | + await self.agentic_flow.load_state(state) |
290 | 283 |
|
291 | 284 | async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)): |
292 | 285 | logging.debug("Message: %s", message) |
@@ -340,7 +333,7 @@ async def process_user_message( |
340 | 333 | ): |
341 | 334 | # Get the state |
342 | 335 | assistant_state = await self.agentic_flow.save_state() |
343 | | - payload.body.assistant_state = assistant_state |
| 336 | + self.state_store.save_state(thread_id, assistant_state) |
344 | 337 |
|
345 | 338 | logging.debug("Final Payload: %s", payload) |
346 | 339 |
|
|
0 commit comments