Skip to content

Commit 2dd1b90

Browse files
committed
Split agent mods and setup
1 parent a0d54f3 commit 2dd1b90

File tree

16 files changed

+303
-521
lines changed

16 files changed

+303
-521
lines changed

src/backend/common/config/config.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import os
1616

1717
from azure.identity.aio import ClientSecretCredential, DefaultAzureCredential
18-
from semantic_kernel.agents.azure_ai.azure_ai_agent import AzureAIAgent
1918

2019

2120
class Config:
@@ -51,10 +50,6 @@ def __init__(self):
5150

5251
self.__azure_credentials = DefaultAzureCredential()
5352

54-
self.ai_project_client = AzureAIAgent.create_client(
55-
credential=self.get_azure_credentials()
56-
)
57-
5853
def get_azure_credentials(self):
5954
"""Retrieve Azure credentials, either from environment variables or managed identity."""
6055
if all([self.azure_tenant_id, self.azure_client_id, self.azure_client_secret]):

src/backend/sql_agents/__init__.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,30 @@
1-
"""This module initializes the agents and helpers for the"""
1+
# """This module initializes the agents and helpers for the"""
22

3-
from common.models.api import AgentType
4-
from sql_agents.fixer.agent import setup_fixer_agent
5-
from sql_agents.helpers.sk_utils import create_kernel_with_chat_completion
6-
from sql_agents.helpers.utils import get_prompt
7-
from sql_agents.migrator.agent import setup_migrator_agent
8-
from sql_agents.picker.agent import setup_picker_agent
9-
from sql_agents.semantic_verifier.agent import setup_semantic_verifier_agent
10-
from sql_agents.syntax_checker.agent import setup_syntax_checker_agent
3+
# from common.models.api import AgentType
4+
from sql_agents.fixer.agent import FixerAgent, setup_fixer_agent
5+
from sql_agents.migrator.agent import MigratorAgent, setup_migrator_agent
6+
from sql_agents.picker.agent import PickerAgent, setup_picker_agent
7+
from sql_agents.semantic_verifier.agent import (
8+
SemanticVerifierAgent,
9+
setup_semantic_verifier_agent,
10+
)
11+
from sql_agents.syntax_checker.agent import (
12+
SyntaxCheckerAgent,
13+
setup_syntax_checker_agent,
14+
)
1115

12-
# Import the configuration function
13-
from .agent_config import AgentsConfigDialect, create_config
16+
# from sql_agents.agent_config import AgentBaseConfig
17+
# from sql_agents.agent_factory import SQLAgentFactory
1418

1519
__all__ = [
16-
"create_kernel_with_chat_completion",
1720
"setup_migrator_agent",
21+
"MigratorAgent",
1822
"setup_fixer_agent",
23+
"FixerAgent",
1924
"setup_picker_agent",
25+
"PickerAgent",
2026
"setup_syntax_checker_agent",
27+
"SyntaxCheckerAgent",
2128
"setup_semantic_verifier_agent",
22-
"get_prompt",
23-
"create_config",
24-
"AgentType",
29+
"SemanticVerifierAgent",
2530
]

src/backend/sql_agents/agent_base.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,14 @@
88
ResponseFormatJsonSchema,
99
ResponseFormatJsonSchemaType,
1010
)
11-
from common.config.config import app_config
1211
from common.models.api import AgentType
1312
from semantic_kernel.agents.azure_ai.azure_ai_agent import AzureAIAgent
1413
from semantic_kernel.functions import KernelArguments
15-
from sql_agents.agent_config import AgentModelDeployment, AgentsConfigDialect
14+
from sql_agents.agent_config import AgentBaseConfig
1615
from sql_agents.helpers.utils import get_prompt
1716

1817
# Type variable for response models
19-
T = TypeVar('T')
18+
T = TypeVar("T")
2019

2120
logger = logging.getLogger(__name__)
2221
logger.setLevel(logging.DEBUG)
@@ -28,12 +27,12 @@ class BaseSQLAgent(Generic[T], ABC):
2827
def __init__(
2928
self,
3029
agent_type: AgentType,
31-
config: AgentsConfigDialect,
32-
deployment_name: AgentModelDeployment,
30+
config: AgentBaseConfig,
31+
deployment_name: None,
3332
temperature: float = 0.0,
3433
):
3534
"""Initialize the base SQL agent.
36-
35+
3736
Args:
3837
agent_type: The type of agent to create.
3938
config: The dialect configuration for the agent.
@@ -55,7 +54,7 @@ def response_schema(self) -> type:
5554
@property
5655
def num_candidates(self) -> Optional[int]:
5756
"""Get the number of candidates for this agent.
58-
57+
5958
Returns:
6059
The number of candidates, or None if not applicable.
6160
"""
@@ -64,32 +63,32 @@ def num_candidates(self) -> Optional[int]:
6463
@property
6564
def plugins(self) -> Optional[List[Union[str, Any]]]:
6665
"""Get the plugins for this agent.
67-
66+
6867
Returns:
6968
A list of plugins, or None if not applicable.
7069
"""
7170
return None
7271

7372
def get_kernel_arguments(self) -> KernelArguments:
7473
"""Get the kernel arguments for this agent.
75-
74+
7675
Returns:
7776
A KernelArguments object with the necessary arguments.
7877
"""
7978
args = {
80-
"target": self.config.sql_dialect_out,
81-
"source": self.config.sql_dialect_in,
79+
"target": self.config.sql_to,
80+
"source": self.config.sql_from,
8281
}
83-
82+
8483
if self.num_candidates is not None:
8584
args["numCandidates"] = str(self.num_candidates)
86-
85+
8786
return KernelArguments(**args)
8887

8988
async def setup(self) -> AzureAIAgent:
9089
"""Setup the agent with Azure AI."""
91-
_deployment_name = self.deployment_name.value
9290
_name = self.agent_type.value
91+
_deployment_name = self.config.model_type.get(self.agent_type)
9392

9493
try:
9594
template_content = get_prompt(_name)
@@ -100,7 +99,7 @@ async def setup(self) -> AzureAIAgent:
10099
kernel_args = self.get_kernel_arguments()
101100

102101
# Define an agent on the Azure AI agent service
103-
agent_definition = await app_config.ai_project_client.agents.create_agent(
102+
agent_definition = await self.config.ai_project_client.agents.create_agent(
104103
model=_deployment_name,
105104
name=_name,
106105
instructions=template_content,
@@ -116,15 +115,15 @@ async def setup(self) -> AzureAIAgent:
116115

117116
# Create a Semantic Kernel agent based on the agent definition
118117
agent_kwargs = {
119-
"client": app_config.ai_project_client,
118+
"client": self.config.ai_project_client,
120119
"definition": agent_definition,
121120
"arguments": kernel_args,
122121
}
123-
122+
124123
# Add plugins if specified
125124
if self.plugins:
126125
agent_kwargs["plugins"] = self.plugins
127-
126+
128127
self.agent = AzureAIAgent(**agent_kwargs)
129128

130129
return self.agent
@@ -139,4 +138,4 @@ async def execute(self, inputs: Any) -> T:
139138
"""Execute the agent with the given inputs."""
140139
agent = await self.get_agent()
141140
response = await agent.invoke(inputs)
142-
return response # Type will be inferred from T
141+
return response # Type will be inferred from T
Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,37 @@
1-
"""Configuration for the agents module."""
1+
"""Configuration class for the agents.
2+
This class loads configuration values from environment variables and provides
3+
properties to access them. It also stores an Azure AI client and SQL dialect
4+
configuration for the agents, that will be set per batch.
5+
Access to .env variables requires adding the `python-dotenv` package to, or
6+
configuration of the env python path through the IDE. For example, in VSCode, the
7+
settings.json file in the .vscode folder should include the following:
8+
{
9+
"python.envFile": "${workspaceFolder}/.env"
10+
}
11+
"""
212

3-
import json
413
import os
514
from enum import Enum
615

7-
from dotenv import load_dotenv
16+
from azure.ai.projects.aio import AIProjectClient
17+
from common.models.api import AgentType
818

9-
load_dotenv()
1019

11-
12-
class AgentModelDeployment(Enum):
20+
class AgentBaseConfig:
1321
"""Agent model deployment names."""
1422

15-
MIGRATOR_AGENT_MODEL_DEPLOY = os.getenv("MIGRATOR_AGENT_MODEL_DEPLOY")
16-
PICKER_AGENT_MODEL_DEPLOY = os.getenv("PICKER_AGENT_MODEL_DEPLOY")
17-
FIXER_AGENT_MODEL_DEPLOY = os.getenv("FIXER_AGENT_MODEL_DEPLOY")
18-
SEMANTIC_VERIFIER_AGENT_MODEL_DEPLOY = os.getenv(
19-
"SEMANTIC_VERIFIER_AGENT_MODEL_DEPLOY"
20-
)
21-
SYNTAX_CHECKER_AGENT_MODEL_DEPLOY = os.getenv("SYNTAX_CHECKER_AGENT_MODEL_DEPLOY")
22-
SELECTION_MODEL_DEPLOY = os.getenv("SELECTION_MODEL_DEPLOY")
23-
TERMINATION_MODEL_DEPLOY = os.getenv("TERMINATION_MODEL_DEPLOY")
24-
25-
26-
class AgentsConfigDialect:
27-
"""Configuration for the agents module."""
28-
29-
def __init__(self, sql_dialect_in, sql_dialect_out):
30-
self.sql_dialect_in = sql_dialect_in
31-
self.sql_dialect_out = sql_dialect_out
32-
33-
34-
def create_config(sql_dialect_in, sql_dialect_out):
35-
"""Create and return a new AgentConfig object."""
36-
return AgentsConfigDialect(sql_dialect_in, sql_dialect_out)
23+
def __init__(self, project_client: AIProjectClient, sql_from: str, sql_to: str):
24+
25+
self.ai_project_client: AIProjectClient = project_client
26+
self.sql_from = sql_from
27+
self.sql_to = sql_to
28+
29+
model_type = {
30+
AgentType.MIGRATOR: os.getenv("MIGRATOR_AGENT_MODEL_DEPLOY"),
31+
AgentType.PICKER: os.getenv("PICKER_AGENT_MODEL_DEPLOY"),
32+
AgentType.FIXER: os.getenv("FIXER_AGENT_MODEL_DEPLOY"),
33+
AgentType.SEMANTIC_VERIFIER: os.getenv("SEMANTIC_VERIFIER_AGENT_MODEL_DEPLOY"),
34+
AgentType.SYNTAX_CHECKER: os.getenv("SYNTAX_CHECKER_AGENT_MODEL_DEPLOY"),
35+
AgentType.SELECTION: os.getenv("SELECTION_MODEL_DEPLOY"),
36+
AgentType.TERMINATION: os.getenv("TERMINATION_MODEL_DEPLOY"),
37+
}
Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,25 @@
11
"""Factory for creating SQL migration agents."""
22

33
import logging
4-
from typing import Type, TypeVar, Optional, Dict, Any
4+
from typing import Any, Dict, Optional, Type, TypeVar
55

66
from common.models.api import AgentType
77
from semantic_kernel.agents.azure_ai.azure_ai_agent import AzureAIAgent
8+
from sql_agents import (
9+
FixerAgent,
10+
MigratorAgent,
11+
PickerAgent,
12+
SemanticVerifierAgent,
13+
SyntaxCheckerAgent,
14+
)
815
from sql_agents.agent_base import BaseSQLAgent
9-
from sql_agents.agent_config import AgentModelDeployment, AgentsConfigDialect
10-
from sql_agents.migrator.agent import MigratorAgent
11-
from sql_agents.picker.agent import PickerAgent
12-
from sql_agents.syntax_checker.agent import SyntaxCheckerAgent
13-
from sql_agents.fixer.agent import FixerAgent
14-
from sql_agents.semantic_verifier.agent import SemanticVerifierAgent
15-
from sql_agents.helpers.utils import get_prompt
16+
from sql_agents.agent_config import AgentBaseConfig
1617

1718
logger = logging.getLogger(__name__)
1819
logger.setLevel(logging.DEBUG)
1920

2021
# Type variable for agent response types
21-
T = TypeVar('T')
22+
T = TypeVar("T")
2223

2324

2425
class SQLAgentFactory:
@@ -36,20 +37,19 @@ class SQLAgentFactory:
3637
async def create_agent(
3738
cls,
3839
agent_type: AgentType,
39-
config: AgentsConfigDialect,
40-
deployment_name: AgentModelDeployment,
40+
config: AgentBaseConfig,
4141
temperature: float = 0.0,
42-
**kwargs
42+
**kwargs,
4343
) -> AzureAIAgent:
4444
"""Create and setup an agent of the specified type.
45-
45+
4646
Args:
4747
agent_type: The type of agent to create.
4848
config: The dialect configuration for the agent.
4949
deployment_name: The model deployment to use.
5050
temperature: The temperature parameter for the model.
5151
**kwargs: Additional parameters to pass to the agent constructor.
52-
52+
5353
Returns:
5454
A configured AzureAIAgent instance.
5555
"""
@@ -61,24 +61,29 @@ async def create_agent(
6161
params = {
6262
"agent_type": agent_type,
6363
"config": config,
64-
"deployment_name": deployment_name,
6564
"temperature": temperature,
66-
**kwargs
65+
**kwargs,
6766
}
68-
67+
6968
agent = agent_class(**params)
7069
return await agent.setup()
71-
70+
7271
@classmethod
7372
def get_agent_class(cls, agent_type: AgentType) -> Type[BaseSQLAgent]:
7473
"""Get the agent class for the specified type."""
7574
agent_class = cls._agent_classes.get(agent_type)
7675
if not agent_class:
7776
raise ValueError(f"Unknown agent type: {agent_type}")
7877
return agent_class
79-
78+
8079
@classmethod
81-
def register_agent_class(cls, agent_type: AgentType, agent_class: Type[BaseSQLAgent]) -> None:
80+
def register_agent_class(
81+
cls, agent_type: AgentType, agent_class: Type[BaseSQLAgent]
82+
) -> None:
8283
"""Register a new agent class with the factory."""
8384
cls._agent_classes[agent_type] = agent_class
84-
logger.info("Registered agent class %s for type %s", agent_class.__name__, agent_type.value)
85+
logger.info(
86+
"Registered agent class %s for type %s",
87+
agent_class.__name__,
88+
agent_type.value,
89+
)

0 commit comments

Comments
 (0)