Skip to content

Commit b7c4e40

Browse files
committed
update agent base and agent factory
1 parent aaee20b commit b7c4e40

File tree

2 files changed

+85
-22
lines changed

2 files changed

+85
-22
lines changed

src/backend/sql_agents/agent_base.py

Lines changed: 55 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import logging
44
from abc import ABC, abstractmethod
5-
from typing import Any, Generic, TypeVar
5+
from typing import Any, Generic, List, Optional, TypeVar, Union
66

77
from azure.ai.projects.models import (
88
ResponseFormatJsonSchema,
@@ -30,26 +30,61 @@ def __init__(
3030
agent_type: AgentType,
3131
config: AgentsConfigDialect,
3232
deployment_name: AgentModelDeployment,
33-
33+
temperature: float = 0.0,
3434
):
35-
"""Initialize the base SQL agent."""
35+
"""Initialize the base SQL agent.
36+
37+
Args:
38+
agent_type: The type of agent to create.
39+
config: The dialect configuration for the agent.
40+
deployment_name: The model deployment to use.
41+
temperature: The temperature parameter for the model.
42+
"""
3643
self.agent_type = agent_type
3744
self.config = config
3845
self.deployment_name = deployment_name
46+
self.temperature = temperature
3947
self.agent: AzureAIAgent = None
4048

41-
4249
@property
4350
@abstractmethod
4451
def response_schema(self) -> type:
4552
"""Get the response schema for this agent."""
4653
pass
4754

4855
@property
49-
@abstractmethod
50-
def num_candidates(self) -> int:
51-
"""Get the number of candidates for this agent."""
52-
pass
56+
def num_candidates(self) -> Optional[int]:
57+
"""Get the number of candidates for this agent.
58+
59+
Returns:
60+
The number of candidates, or None if not applicable.
61+
"""
62+
return None
63+
64+
@property
65+
def plugins(self) -> Optional[List[Union[str, Any]]]:
66+
"""Get the plugins for this agent.
67+
68+
Returns:
69+
A list of plugins, or None if not applicable.
70+
"""
71+
return None
72+
73+
def get_kernel_arguments(self) -> KernelArguments:
74+
"""Get the kernel arguments for this agent.
75+
76+
Returns:
77+
A KernelArguments object with the necessary arguments.
78+
"""
79+
args = {
80+
"target": self.config.sql_dialect_out,
81+
"source": self.config.sql_dialect_in,
82+
}
83+
84+
if self.num_candidates is not None:
85+
args["numCandidates"] = str(self.num_candidates)
86+
87+
return KernelArguments(**args)
5388

5489
async def setup(self) -> AzureAIAgent:
5590
"""Setup the agent with Azure AI."""
@@ -62,11 +97,7 @@ async def setup(self) -> AzureAIAgent:
6297
logger.error("Prompt file for %s not found.", _name)
6398
raise ValueError(f"Prompt file for {_name} not found.") from exc
6499

65-
kernel_args = KernelArguments(
66-
target=self.config.sql_dialect_out,
67-
numCandidates=str(self.num_candidates),
68-
source=self.config.sql_dialect_in,
69-
)
100+
kernel_args = self.get_kernel_arguments()
70101

71102
# Define an agent on the Azure AI agent service
72103
agent_definition = await app_config.ai_project_client.agents.create_agent(
@@ -84,11 +115,17 @@ async def setup(self) -> AzureAIAgent:
84115
)
85116

86117
# Create a Semantic Kernel agent based on the agent definition
87-
self.agent = AzureAIAgent(
88-
client=app_config.ai_project_client,
89-
definition=agent_definition,
90-
arguments=kernel_args,
91-
)
118+
agent_kwargs = {
119+
"client": app_config.ai_project_client,
120+
"definition": agent_definition,
121+
"arguments": kernel_args,
122+
}
123+
124+
# Add plugins if specified
125+
if self.plugins:
126+
agent_kwargs["plugins"] = self.plugins
127+
128+
self.agent = AzureAIAgent(**agent_kwargs)
92129

93130
return self.agent
94131

src/backend/sql_agents/agent_factory.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"""Factory for creating SQL migration agents."""
22

33
import logging
4-
from typing import Type, TypeVar
4+
from typing import Type, TypeVar, Optional, Dict, Any
55

66
from common.models.api import AgentType
77
from semantic_kernel.agents.azure_ai.azure_ai_agent import AzureAIAgent
88
from sql_agents.agent_base import BaseSQLAgent
99
from sql_agents.agent_config import AgentModelDeployment, AgentsConfigDialect
10-
from sql_agents.specific_agents import MigratorAgent, PickerAgent
10+
from sql_agents.specific_agents import MigratorAgent, PickerAgent, SyntaxCheckerAgent
1111

1212
logger = logging.getLogger(__name__)
1313
logger.setLevel(logging.DEBUG)
@@ -22,6 +22,7 @@ class SQLAgentFactory:
2222
_agent_classes = {
2323
AgentType.PICKER: PickerAgent,
2424
AgentType.MIGRATOR: MigratorAgent,
25+
AgentType.SYNTAX_CHECKER: SyntaxCheckerAgent,
2526
}
2627

2728
@classmethod
@@ -30,13 +31,38 @@ async def create_agent(
3031
agent_type: AgentType,
3132
config: AgentsConfigDialect,
3233
deployment_name: AgentModelDeployment,
34+
temperature: float = 0.0,
35+
extra_params: Optional[Dict[str, Any]] = None,
3336
) -> AzureAIAgent:
34-
"""Create and setup an agent of the specified type."""
37+
"""Create and setup an agent of the specified type.
38+
39+
Args:
40+
agent_type: The type of agent to create.
41+
config: The dialect configuration for the agent.
42+
deployment_name: The model deployment to use.
43+
temperature: The temperature parameter for the model.
44+
extra_params: Additional parameters to pass to the agent constructor.
45+
46+
Returns:
47+
A configured AzureAIAgent instance.
48+
"""
3549
agent_class = cls._agent_classes.get(agent_type)
3650
if not agent_class:
3751
raise ValueError(f"Unknown agent type: {agent_type}")
3852

39-
agent = agent_class(agent_type, config, deployment_name)
53+
# Prepare constructor parameters
54+
params = {
55+
"agent_type": agent_type,
56+
"config": config,
57+
"deployment_name": deployment_name,
58+
"temperature": temperature,
59+
}
60+
61+
# Add any extra parameters provided
62+
if extra_params:
63+
params.update(extra_params)
64+
65+
agent = agent_class(**params)
4066
return await agent.setup()
4167

4268
@classmethod

0 commit comments

Comments
 (0)