Skip to content

Commit a0d54f3

Browse files
committed
add agent class and agent factory update
1 parent b7c4e40 commit a0d54f3

File tree

6 files changed

+363
-199
lines changed

6 files changed

+363
-199
lines changed

src/backend/sql_agents/agent_factory.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@
77
from semantic_kernel.agents.azure_ai.azure_ai_agent import AzureAIAgent
88
from sql_agents.agent_base import BaseSQLAgent
99
from sql_agents.agent_config import AgentModelDeployment, AgentsConfigDialect
10-
from sql_agents.specific_agents import MigratorAgent, PickerAgent, SyntaxCheckerAgent
10+
from sql_agents.migrator.agent import MigratorAgent
11+
from sql_agents.picker.agent import PickerAgent
12+
from sql_agents.syntax_checker.agent import SyntaxCheckerAgent
13+
from sql_agents.fixer.agent import FixerAgent
14+
from sql_agents.semantic_verifier.agent import SemanticVerifierAgent
15+
from sql_agents.helpers.utils import get_prompt
1116

1217
logger = logging.getLogger(__name__)
1318
logger.setLevel(logging.DEBUG)
@@ -23,6 +28,8 @@ class SQLAgentFactory:
2328
AgentType.PICKER: PickerAgent,
2429
AgentType.MIGRATOR: MigratorAgent,
2530
AgentType.SYNTAX_CHECKER: SyntaxCheckerAgent,
31+
AgentType.FIXER: FixerAgent,
32+
AgentType.SEMANTIC_VERIFIER: SemanticVerifierAgent,
2633
}
2734

2835
@classmethod
@@ -32,7 +39,7 @@ async def create_agent(
3239
config: AgentsConfigDialect,
3340
deployment_name: AgentModelDeployment,
3441
temperature: float = 0.0,
35-
extra_params: Optional[Dict[str, Any]] = None,
42+
**kwargs
3643
) -> AzureAIAgent:
3744
"""Create and setup an agent of the specified type.
3845
@@ -41,7 +48,7 @@ async def create_agent(
4148
config: The dialect configuration for the agent.
4249
deployment_name: The model deployment to use.
4350
temperature: The temperature parameter for the model.
44-
extra_params: Additional parameters to pass to the agent constructor.
51+
**kwargs: Additional parameters to pass to the agent constructor.
4552
4653
Returns:
4754
A configured AzureAIAgent instance.
@@ -56,11 +63,8 @@ async def create_agent(
5663
"config": config,
5764
"deployment_name": deployment_name,
5865
"temperature": temperature,
66+
**kwargs
5967
}
60-
61-
# Add any extra parameters provided
62-
if extra_params:
63-
params.update(extra_params)
6468

6569
agent = agent_class(**params)
6670
return await agent.setup()

src/backend/sql_agents/fixer/agent.py

Lines changed: 56 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
ResponseFormatJsonSchema,
77
ResponseFormatJsonSchemaType,
88
)
9+
from backend.sql_agents.agent_base import (
10+
BaseSQLAgent,
11+
) # Ensure this import is correct and the module exists
12+
from backend.sql_agents.agent_factory import SQLAgentFactory
913
from common.config.config import app_config
1014
from common.models.api import AgentType
1115
from semantic_kernel.agents.azure_ai.azure_ai_agent import AzureAIAgent
@@ -18,42 +22,58 @@
1822
logger.setLevel(logging.DEBUG)
1923

2024

25+
class FixerAgent(BaseSQLAgent[FixerResponse]):
26+
"""Fixer agent for correcting SQL syntax errors."""
27+
28+
@property
29+
def response_schema(self) -> type:
30+
"""Get the response schema for the fixer agent."""
31+
return FixerResponse
32+
33+
34+
# async def setup_fixer_agent(
35+
# name: AgentType, config: AgentsConfigDialect, deployment_name: AgentModelDeployment
36+
# ) -> AzureAIAgent:
37+
# """Setup the fixer agent."""
38+
# _deployment_name = deployment_name.value
39+
# _name = name.value
40+
41+
# try:
42+
# template_content = get_prompt(_name)
43+
# except FileNotFoundError as exc:
44+
# logger.error("Prompt file for %s not found.", _name)
45+
# raise ValueError(f"Prompt file for {_name} not found.") from exc
46+
47+
# kernel_args = KernelArguments(target=config.sql_dialect_out)
48+
49+
# # Define an agent on the Azure AI agent service
50+
# agent_definition = await app_config.ai_project_client.agents.create_agent(
51+
# model=_deployment_name,
52+
# name=_name,
53+
# instructions=template_content,
54+
# temperature=0.0,
55+
# response_format=ResponseFormatJsonSchemaType(
56+
# json_schema=ResponseFormatJsonSchema(
57+
# name="FixerResponse",
58+
# description="respond with fixer response",
59+
# schema=FixerResponse.model_json_schema(),
60+
# )
61+
# ),
62+
# )
63+
64+
# # Create a Semantic Kernel agent based on the agent definition.
65+
# # Add RAG with docs programmatically for this one
66+
# fixer_agent = AzureAIAgent(
67+
# client=app_config.ai_project_client,
68+
# definition=agent_definition,
69+
# arguments=kernel_args,
70+
# )
71+
72+
# return fixer_agent
73+
74+
2175
async def setup_fixer_agent(
2276
name: AgentType, config: AgentsConfigDialect, deployment_name: AgentModelDeployment
2377
) -> AzureAIAgent:
24-
"""Setup the fixer agent."""
25-
_deployment_name = deployment_name.value
26-
_name = name.value
27-
28-
try:
29-
template_content = get_prompt(_name)
30-
except FileNotFoundError as exc:
31-
logger.error("Prompt file for %s not found.", _name)
32-
raise ValueError(f"Prompt file for {_name} not found.") from exc
33-
34-
kernel_args = KernelArguments(target=config.sql_dialect_out)
35-
36-
# Define an agent on the Azure AI agent service
37-
agent_definition = await app_config.ai_project_client.agents.create_agent(
38-
model=_deployment_name,
39-
name=_name,
40-
instructions=template_content,
41-
temperature=0.0,
42-
response_format=ResponseFormatJsonSchemaType(
43-
json_schema=ResponseFormatJsonSchema(
44-
name="FixerResponse",
45-
description="respond with fixer response",
46-
schema=FixerResponse.model_json_schema(),
47-
)
48-
),
49-
)
50-
51-
# Create a Semantic Kernel agent based on the agent definition.
52-
# Add RAG with docs programmatically for this one
53-
fixer_agent = AzureAIAgent(
54-
client=app_config.ai_project_client,
55-
definition=agent_definition,
56-
arguments=kernel_args,
57-
)
58-
59-
return fixer_agent
78+
"""Setup the fixer agent using the factory."""
79+
return await SQLAgentFactory.create_agent(name, config, deployment_name)

src/backend/sql_agents/migrator/agent.py

Lines changed: 64 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
ResponseFormatJsonSchema,
77
ResponseFormatJsonSchemaType,
88
)
9+
from backend.sql_agents.agent_base import BaseSQLAgent
10+
from backend.sql_agents.agent_factory import SQLAgentFactory
911
from common.config.config import app_config
1012
from common.models.api import AgentType
1113
from semantic_kernel.agents.azure_ai.azure_ai_agent import AzureAIAgent
@@ -18,47 +20,68 @@
1820
logger.setLevel(logging.DEBUG)
1921

2022

23+
class MigratorAgent(BaseSQLAgent[MigratorResponse]):
24+
"""Migrator agent for translating SQL from one dialect to another."""
25+
26+
@property
27+
def response_schema(self) -> type:
28+
"""Get the response schema for the migrator agent."""
29+
return MigratorResponse
30+
31+
@property
32+
def num_candidates(self) -> int:
33+
"""Get the number of candidates for the migrator agent."""
34+
return 3
35+
36+
2137
async def setup_migrator_agent(
2238
name: AgentType, config: AgentsConfigDialect, deployment_name: AgentModelDeployment
2339
) -> AzureAIAgent:
24-
"""Setup the migrator agent."""
25-
_deployment_name = deployment_name.value
26-
_name = name.value
27-
num_candidates = 3
28-
29-
try:
30-
template_content = get_prompt(_name)
31-
except FileNotFoundError as exc:
32-
logger.error("Prompt file for %s not found.", _name)
33-
raise ValueError(f"Prompt file for {_name} not found.") from exc
34-
35-
kernel_args = KernelArguments(
36-
target=config.sql_dialect_out,
37-
numCandidates=str(num_candidates),
38-
source=config.sql_dialect_in,
39-
)
40-
41-
# Define an agent on the Azure AI agent service
42-
agent_definition = await app_config.ai_project_client.agents.create_agent(
43-
model=_deployment_name,
44-
name=_name,
45-
instructions=template_content,
46-
temperature=0.0,
47-
response_format=ResponseFormatJsonSchemaType(
48-
json_schema=ResponseFormatJsonSchema(
49-
name="MigratorResponse",
50-
description="respond with migrator response",
51-
schema=MigratorResponse.model_json_schema(),
52-
)
53-
),
54-
)
55-
56-
# Create a Semantic Kernel agent based on the agent definition.
57-
# Add RAG with docs programmatically for this one
58-
migrator_agent = AzureAIAgent(
59-
client=app_config.ai_project_client,
60-
definition=agent_definition,
61-
arguments=kernel_args,
62-
)
63-
64-
return migrator_agent
40+
"""Setup the migrator agent using the factory."""
41+
return await SQLAgentFactory.create_agent(name, config, deployment_name)
42+
43+
44+
# async def setup_migrator_agent(
45+
# name: AgentType, config: AgentsConfigDialect, deployment_name: AgentModelDeployment
46+
# ) -> AzureAIAgent:
47+
# """Setup the migrator agent."""
48+
# _deployment_name = deployment_name.value
49+
# _name = name.value
50+
# num_candidates = 3
51+
52+
# try:
53+
# template_content = get_prompt(_name)
54+
# except FileNotFoundError as exc:
55+
# logger.error("Prompt file for %s not found.", _name)
56+
# raise ValueError(f"Prompt file for {_name} not found.") from exc
57+
58+
# kernel_args = KernelArguments(
59+
# target=config.sql_dialect_out,
60+
# numCandidates=str(num_candidates),
61+
# source=config.sql_dialect_in,
62+
# )
63+
64+
# # Define an agent on the Azure AI agent service
65+
# agent_definition = await app_config.ai_project_client.agents.create_agent(
66+
# model=_deployment_name,
67+
# name=_name,
68+
# instructions=template_content,
69+
# temperature=0.0,
70+
# response_format=ResponseFormatJsonSchemaType(
71+
# json_schema=ResponseFormatJsonSchema(
72+
# name="MigratorResponse",
73+
# description="respond with migrator response",
74+
# schema=MigratorResponse.model_json_schema(),
75+
# )
76+
# ),
77+
# )
78+
79+
# # Create a Semantic Kernel agent based on the agent definition.
80+
# # Add RAG with docs programmatically for this one
81+
# migrator_agent = AzureAIAgent(
82+
# client=app_config.ai_project_client,
83+
# definition=agent_definition,
84+
# arguments=kernel_args,
85+
# )
86+
87+
# return migrator_agent

src/backend/sql_agents/picker/agent.py

Lines changed: 64 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
ResponseFormatJsonSchema,
77
ResponseFormatJsonSchemaType,
88
)
9+
from backend.sql_agents.agent_base import BaseSQLAgent
10+
from backend.sql_agents.agent_factory import SQLAgentFactory
911
from common.config.config import app_config
1012
from common.models.api import AgentType
1113
from semantic_kernel.agents.azure_ai.azure_ai_agent import AzureAIAgent
@@ -17,49 +19,70 @@
1719
logger = logging.getLogger(__name__)
1820
logger.setLevel(logging.DEBUG)
1921

20-
NUM_CANDIDATES = 3
22+
# NUM_CANDIDATES = 3
23+
24+
25+
class PickerAgent(BaseSQLAgent[PickerResponse]):
26+
"""Picker agent for selecting the best SQL translation candidate."""
27+
28+
@property
29+
def response_schema(self) -> type:
30+
"""Get the response schema for the picker agent."""
31+
return PickerResponse
32+
33+
@property
34+
def num_candidates(self) -> int:
35+
"""Get the number of candidates for the picker agent."""
36+
return 3
2137

2238

2339
async def setup_picker_agent(
2440
name: AgentType, config: AgentsConfigDialect, deployment_name: AgentModelDeployment
2541
) -> AzureAIAgent:
26-
"""Setup the picker agent."""
27-
_deployment_name = deployment_name.value
28-
_name = name.value
29-
30-
try:
31-
template_content = get_prompt(_name)
32-
except FileNotFoundError as exc:
33-
logger.error("Prompt file for %s not found.", _name)
34-
raise ValueError(f"Prompt file for {_name} not found.") from exc
35-
36-
kernel_args = KernelArguments(
37-
target=config.sql_dialect_out,
38-
numCandidates=str(NUM_CANDIDATES),
39-
source=config.sql_dialect_in,
40-
)
41-
42-
# Define an agent on the Azure AI agent service
43-
agent_definition = await app_config.ai_project_client.agents.create_agent(
44-
model=_deployment_name,
45-
name=_name,
46-
instructions=template_content,
47-
temperature=0.0,
48-
response_format=ResponseFormatJsonSchemaType(
49-
json_schema=ResponseFormatJsonSchema(
50-
name="PickerResponse",
51-
description="respond with picker response",
52-
schema=PickerResponse.model_json_schema(),
53-
)
54-
),
55-
)
56-
57-
# Create a Semantic Kernel agent based on the agent definition.
58-
# Add RAG with docs programmatically for this one
59-
picker_agent = AzureAIAgent(
60-
client=app_config.ai_project_client,
61-
definition=agent_definition,
62-
arguments=kernel_args,
63-
)
64-
65-
return picker_agent
42+
"""Setup the picker agent using the factory."""
43+
return await SQLAgentFactory.create_agent(name, config, deployment_name)
44+
45+
46+
# async def setup_picker_agent(
47+
# name: AgentType, config: AgentsConfigDialect, deployment_name: AgentModelDeployment
48+
# ) -> AzureAIAgent:
49+
# """Setup the picker agent."""
50+
# _deployment_name = deployment_name.value
51+
# _name = name.value
52+
53+
# try:
54+
# template_content = get_prompt(_name)
55+
# except FileNotFoundError as exc:
56+
# logger.error("Prompt file for %s not found.", _name)
57+
# raise ValueError(f"Prompt file for {_name} not found.") from exc
58+
59+
# kernel_args = KernelArguments(
60+
# target=config.sql_dialect_out,
61+
# numCandidates=str(NUM_CANDIDATES),
62+
# source=config.sql_dialect_in,
63+
# )
64+
65+
# # Define an agent on the Azure AI agent service
66+
# agent_definition = await app_config.ai_project_client.agents.create_agent(
67+
# model=_deployment_name,
68+
# name=_name,
69+
# instructions=template_content,
70+
# temperature=0.0,
71+
# response_format=ResponseFormatJsonSchemaType(
72+
# json_schema=ResponseFormatJsonSchema(
73+
# name="PickerResponse",
74+
# description="respond with picker response",
75+
# schema=PickerResponse.model_json_schema(),
76+
# )
77+
# ),
78+
# )
79+
80+
# # Create a Semantic Kernel agent based on the agent definition.
81+
# # Add RAG with docs programmatically for this one
82+
# picker_agent = AzureAIAgent(
83+
# client=app_config.ai_project_client,
84+
# definition=agent_definition,
85+
# arguments=kernel_args,
86+
# )
87+
88+
# return picker_agent

0 commit comments

Comments
 (0)