Skip to content

Commit 6d6886b

Browse files
Fix Bad Chat History Implementation (#120)
* Fix bad chat history instantiation * Update namign * Update chat history * Update header
1 parent 087797b commit 6d6886b

File tree

4 files changed

+17
-33
lines changed

4 files changed

+17
-33
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33
from autogen_text_2_sql.autogen_text_2_sql import AutoGenText2Sql
4-
from text_2_sql_core.payloads.agent_response import AgentRequestBody
4+
from text_2_sql_core.payloads.agent_request_response_pair import AgentRequestBody
55

66
__all__ = ["AutoGenText2Sql", "AgentRequestBody"]

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,13 @@
1212
from autogen_text_2_sql.custom_agents.parallel_query_solving_agent import (
1313
ParallelQuerySolvingAgent,
1414
)
15-
from autogen_agentchat.agents import UserProxyAgent
1615
from autogen_agentchat.messages import TextMessage
1716
import json
1817
import os
1918
from datetime import datetime
2019

21-
from text_2_sql_core.payloads.agent_response import (
22-
AgentResponse,
20+
from text_2_sql_core.payloads.agent_request_response_pair import (
21+
AgentRequestResponsePair,
2322
AgentRequestBody,
2423
AnswerWithSources,
2524
Source,
@@ -30,26 +29,10 @@
3029
ProcessingUpdateBody,
3130
ProcessingUpdate,
3231
)
33-
from autogen_agentchat.base import Response, TaskResult
32+
from autogen_agentchat.base import TaskResult
3433
from typing import AsyncGenerator
3534

3635

37-
class EmptyResponseUserProxyAgent(UserProxyAgent):
38-
"""UserProxyAgent that automatically responds with empty messages."""
39-
40-
def __init__(self, name):
41-
super().__init__(name=name)
42-
self._has_responded = False
43-
44-
async def on_messages_stream(self, messages, sender=None, config=None):
45-
"""Auto-respond with empty message and return Response object."""
46-
message = TextMessage(content="", source=self.name)
47-
if not self._has_responded:
48-
self._has_responded = True
49-
yield message
50-
yield Response(chat_message=message)
51-
52-
5336
class AutoGenText2Sql:
5437
def __init__(self, engine_specific_rules: str, **kwargs: dict):
5538
self.target_engine = os.environ["Text2Sql__DatabaseEngine"].upper()
@@ -71,11 +54,7 @@ def get_all_agents(self):
7154

7255
self.answer_agent = LLMAgentCreator.create("answer_agent")
7356

74-
# Auto-responding UserProxyAgent
75-
self.user_proxy = EmptyResponseUserProxyAgent(name="user_proxy")
76-
7757
agents = [
78-
self.user_proxy,
7958
self.query_rewrite_agent,
8059
self.parallel_query_solving_agent,
8160
self.answer_agent,
@@ -182,7 +161,7 @@ async def process_question(
182161
self,
183162
request: AgentRequestBody,
184163
chat_history: list[ChatHistoryItem] = None,
185-
) -> AsyncGenerator[AgentResponse | ProcessingUpdate, None]:
164+
) -> AsyncGenerator[AgentRequestResponsePair | ProcessingUpdate, None]:
186165
"""Process the complete question through the unified system.
187166
188167
Args:
@@ -208,7 +187,10 @@ async def process_question(
208187
# Update input
209188
for idx, chat in enumerate(chat_history):
210189
# For now only consider the user query
211-
agent_input[f"chat_{idx}"] = chat.request.question
190+
chat_history_key = f"chat_{idx}"
191+
agent_input[
192+
chat_history_key
193+
] = chat.request_response_pair.request.question
212194

213195
async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)):
214196
logging.debug("Message: %s", message)
@@ -250,7 +232,7 @@ async def process_question(
250232
logging.error("Unexpected TaskResult: %s", message)
251233
raise ValueError("Unexpected TaskResult")
252234

253-
payload = AgentResponse(request=request, response=response)
235+
payload = AgentRequestResponsePair(request=request, response=response)
254236

255237
if payload is not None:
256238
logging.debug("Payload: %s", payload)

text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/agent_response.py renamed to text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/agent_request_response_pair.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from datetime import datetime, timezone
88

99

10-
class AgentResponseHeader(BaseModel):
10+
class AgentRequestResponseHeader(BaseModel):
1111
prompt_tokens: int
1212
completion_tokens: int
1313
timestamp: datetime = Field(
@@ -83,7 +83,7 @@ def add_defaults_to_injected_parameters(cls, values):
8383
return values
8484

8585

86-
class AgentResponse(BaseModel):
87-
header: AgentResponseHeader | None = Field(default=None)
86+
class AgentRequestResponsePair(BaseModel):
87+
header: AgentRequestResponseHeader | None = Field(default=None)
8888
request: AgentRequestBody
8989
response: AgentResponseBody

text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/chat_history.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33
from pydantic import BaseModel, Field
4-
from text_2_sql_core.payloads.agent_response import AgentResponse
4+
from text_2_sql_core.payloads.agent_request_response_pair import (
5+
AgentRequestResponsePair,
6+
)
57
from datetime import datetime, timezone
68

79

@@ -13,4 +15,4 @@ class ChatHistoryItem(BaseModel):
1315
description="Timestamp in UTC",
1416
default_factory=lambda: datetime.now(timezone.utc),
1517
)
16-
agent_response: AgentResponse
18+
request_response_pair: AgentRequestResponsePair

0 commit comments

Comments
 (0)