|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT License. |
| 3 | +from typing import AsyncGenerator, List, Sequence |
| 4 | + |
| 5 | +from autogen_agentchat.agents import BaseChatAgent |
| 6 | +from autogen_agentchat.base import Response |
| 7 | +from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage |
| 8 | +from autogen_core.base import CancellationToken |
| 9 | +from text_2_sql.autogen.utils.sql import get_entity_schemas |
| 10 | +from keybert import KeyBERT |
| 11 | +import logging |
| 12 | + |
| 13 | + |
| 14 | +class SqlSchemaExtractionAgent(BaseChatAgent): |
| 15 | + def __init__(self): |
| 16 | + super().__init__( |
| 17 | + "sql_query_cache_agent", |
| 18 | + "An agent that fetches the queries from the cache based on the user question.", |
| 19 | + ) |
| 20 | + |
| 21 | + @property |
| 22 | + def produced_message_types(self) -> List[type[ChatMessage]]: |
| 23 | + return [TextMessage] |
| 24 | + |
| 25 | + async def on_messages( |
| 26 | + self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken |
| 27 | + ) -> Response: |
| 28 | + # Calls the on_messages_stream. |
| 29 | + response: Response | None = None |
| 30 | + async for message in self.on_messages_stream(messages, cancellation_token): |
| 31 | + if isinstance(message, Response): |
| 32 | + response = message |
| 33 | + assert response is not None |
| 34 | + return response |
| 35 | + |
| 36 | + async def on_messages_stream( |
| 37 | + self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken |
| 38 | + ) -> AsyncGenerator[AgentMessage | Response, None]: |
| 39 | + user_question = messages[0].content |
| 40 | + |
| 41 | + # Fetch the queries from the cache based on the user question. |
| 42 | + logging.info("Fetching queries from cache based on the user question...") |
| 43 | + |
| 44 | + kw_model = KeyBERT() |
| 45 | + |
| 46 | + top_keywords = kw_model.extract_keywords( |
| 47 | + user_question, keyphrase_ngram_range=(1, 3), top_n=5 |
| 48 | + ) |
| 49 | + |
| 50 | + # Extract just the key phrases (ignoring the score) |
| 51 | + key_phrases = [keyword[0] for keyword in top_keywords] |
| 52 | + |
| 53 | + # Join them into a string list |
| 54 | + key_phrases_str = ", ".join(key_phrases) |
| 55 | + |
| 56 | + entity_schemas = await get_entity_schemas(key_phrases_str) |
| 57 | + |
| 58 | + logging.info(entity_schemas) |
| 59 | + |
| 60 | + yield Response( |
| 61 | + chat_message=TextMessage(content=entity_schemas, source=self.name) |
| 62 | + ) |
| 63 | + |
| 64 | + async def on_reset(self, cancellation_token: CancellationToken) -> None: |
| 65 | + pass |
0 commit comments