Skip to content

Commit aaee20b

Browse files
committed
create an agent base and agent factory
1 parent af839a9 commit aaee20b

File tree

2 files changed

+159
-0
lines changed

2 files changed

+159
-0
lines changed
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
"""Base classes for SQL migration agents."""
2+
3+
import logging
4+
from abc import ABC, abstractmethod
5+
from typing import Any, Generic, TypeVar
6+
7+
from azure.ai.projects.models import (
8+
ResponseFormatJsonSchema,
9+
ResponseFormatJsonSchemaType,
10+
)
11+
from common.config.config import app_config
12+
from common.models.api import AgentType
13+
from semantic_kernel.agents.azure_ai.azure_ai_agent import AzureAIAgent
14+
from semantic_kernel.functions import KernelArguments
15+
from sql_agents.agent_config import AgentModelDeployment, AgentsConfigDialect
16+
from sql_agents.helpers.utils import get_prompt
17+
18+
# Type variable for response models
19+
T = TypeVar('T')
20+
21+
logger = logging.getLogger(__name__)
22+
logger.setLevel(logging.DEBUG)
23+
24+
25+
class BaseSQLAgent(Generic[T], ABC):
26+
"""Base class for all SQL migration agents."""
27+
28+
def __init__(
29+
self,
30+
agent_type: AgentType,
31+
config: AgentsConfigDialect,
32+
deployment_name: AgentModelDeployment,
33+
34+
):
35+
"""Initialize the base SQL agent."""
36+
self.agent_type = agent_type
37+
self.config = config
38+
self.deployment_name = deployment_name
39+
self.agent: AzureAIAgent = None
40+
41+
42+
@property
43+
@abstractmethod
44+
def response_schema(self) -> type:
45+
"""Get the response schema for this agent."""
46+
pass
47+
48+
@property
49+
@abstractmethod
50+
def num_candidates(self) -> int:
51+
"""Get the number of candidates for this agent."""
52+
pass
53+
54+
async def setup(self) -> AzureAIAgent:
55+
"""Setup the agent with Azure AI."""
56+
_deployment_name = self.deployment_name.value
57+
_name = self.agent_type.value
58+
59+
try:
60+
template_content = get_prompt(_name)
61+
except FileNotFoundError as exc:
62+
logger.error("Prompt file for %s not found.", _name)
63+
raise ValueError(f"Prompt file for {_name} not found.") from exc
64+
65+
kernel_args = KernelArguments(
66+
target=self.config.sql_dialect_out,
67+
numCandidates=str(self.num_candidates),
68+
source=self.config.sql_dialect_in,
69+
)
70+
71+
# Define an agent on the Azure AI agent service
72+
agent_definition = await app_config.ai_project_client.agents.create_agent(
73+
model=_deployment_name,
74+
name=_name,
75+
instructions=template_content,
76+
temperature=self.temperature,
77+
response_format=ResponseFormatJsonSchemaType(
78+
json_schema=ResponseFormatJsonSchema(
79+
name=self.response_schema.__name__,
80+
description=f"respond with {self.response_schema.__name__.lower()}",
81+
schema=self.response_schema.model_json_schema(),
82+
)
83+
),
84+
)
85+
86+
# Create a Semantic Kernel agent based on the agent definition
87+
self.agent = AzureAIAgent(
88+
client=app_config.ai_project_client,
89+
definition=agent_definition,
90+
arguments=kernel_args,
91+
)
92+
93+
return self.agent
94+
95+
async def get_agent(self) -> AzureAIAgent:
96+
"""Get the agent, setting it up if needed."""
97+
if self.agent is None:
98+
await self.setup()
99+
return self.agent
100+
101+
async def execute(self, inputs: Any) -> T:
102+
"""Execute the agent with the given inputs."""
103+
agent = await self.get_agent()
104+
response = await agent.invoke(inputs)
105+
return response # Type will be inferred from T
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""Factory for creating SQL migration agents."""
2+
3+
import logging
4+
from typing import Type, TypeVar
5+
6+
from common.models.api import AgentType
7+
from semantic_kernel.agents.azure_ai.azure_ai_agent import AzureAIAgent
8+
from sql_agents.agent_base import BaseSQLAgent
9+
from sql_agents.agent_config import AgentModelDeployment, AgentsConfigDialect
10+
from sql_agents.specific_agents import MigratorAgent, PickerAgent
11+
12+
logger = logging.getLogger(__name__)
13+
logger.setLevel(logging.DEBUG)
14+
15+
# Type variable for agent response types
16+
T = TypeVar('T')
17+
18+
19+
class SQLAgentFactory:
20+
"""Factory for creating SQL migration agents."""
21+
22+
_agent_classes = {
23+
AgentType.PICKER: PickerAgent,
24+
AgentType.MIGRATOR: MigratorAgent,
25+
}
26+
27+
@classmethod
28+
async def create_agent(
29+
cls,
30+
agent_type: AgentType,
31+
config: AgentsConfigDialect,
32+
deployment_name: AgentModelDeployment,
33+
) -> AzureAIAgent:
34+
"""Create and setup an agent of the specified type."""
35+
agent_class = cls._agent_classes.get(agent_type)
36+
if not agent_class:
37+
raise ValueError(f"Unknown agent type: {agent_type}")
38+
39+
agent = agent_class(agent_type, config, deployment_name)
40+
return await agent.setup()
41+
42+
@classmethod
43+
def get_agent_class(cls, agent_type: AgentType) -> Type[BaseSQLAgent]:
44+
"""Get the agent class for the specified type."""
45+
agent_class = cls._agent_classes.get(agent_type)
46+
if not agent_class:
47+
raise ValueError(f"Unknown agent type: {agent_type}")
48+
return agent_class
49+
50+
@classmethod
51+
def register_agent_class(cls, agent_type: AgentType, agent_class: Type[BaseSQLAgent]) -> None:
52+
"""Register a new agent class with the factory."""
53+
cls._agent_classes[agent_type] = agent_class
54+
logger.info("Registered agent class %s for type %s", agent_class.__name__, agent_type.value)

0 commit comments

Comments
 (0)