22
33import logging
44from abc import ABC , abstractmethod
5- from typing import Any , Generic , TypeVar
5+ from typing import Any , Generic , List , Optional , TypeVar , Union
66
77from azure .ai .projects .models import (
88 ResponseFormatJsonSchema ,
@@ -30,26 +30,61 @@ def __init__(
3030 agent_type : AgentType ,
3131 config : AgentsConfigDialect ,
3232 deployment_name : AgentModelDeployment ,
33-
33+ temperature : float = 0.0 ,
3434 ):
35- """Initialize the base SQL agent."""
35+ """Initialize the base SQL agent.
36+
37+ Args:
38+ agent_type: The type of agent to create.
39+ config: The dialect configuration for the agent.
40+ deployment_name: The model deployment to use.
41+ temperature: The temperature parameter for the model.
42+ """
3643 self .agent_type = agent_type
3744 self .config = config
3845 self .deployment_name = deployment_name
46+ self .temperature = temperature
3947 self .agent : AzureAIAgent = None
4048
41-
4249 @property
4350 @abstractmethod
4451 def response_schema (self ) -> type :
4552 """Get the response schema for this agent."""
4653 pass
4754
4855 @property
49- @abstractmethod
50- def num_candidates (self ) -> int :
51- """Get the number of candidates for this agent."""
52- pass
56+ def num_candidates (self ) -> Optional [int ]:
57+ """Get the number of candidates for this agent.
58+
59+ Returns:
60+ The number of candidates, or None if not applicable.
61+ """
62+ return None
63+
64+ @property
65+ def plugins (self ) -> Optional [List [Union [str , Any ]]]:
66+ """Get the plugins for this agent.
67+
68+ Returns:
69+ A list of plugins, or None if not applicable.
70+ """
71+ return None
72+
73+ def get_kernel_arguments (self ) -> KernelArguments :
74+ """Get the kernel arguments for this agent.
75+
76+ Returns:
77+ A KernelArguments object with the necessary arguments.
78+ """
79+ args = {
80+ "target" : self .config .sql_dialect_out ,
81+ "source" : self .config .sql_dialect_in ,
82+ }
83+
84+ if self .num_candidates is not None :
85+ args ["numCandidates" ] = str (self .num_candidates )
86+
87+ return KernelArguments (** args )
5388
5489 async def setup (self ) -> AzureAIAgent :
5590 """Setup the agent with Azure AI."""
@@ -62,11 +97,7 @@ async def setup(self) -> AzureAIAgent:
6297 logger .error ("Prompt file for %s not found." , _name )
6398 raise ValueError (f"Prompt file for { _name } not found." ) from exc
6499
65- kernel_args = KernelArguments (
66- target = self .config .sql_dialect_out ,
67- numCandidates = str (self .num_candidates ),
68- source = self .config .sql_dialect_in ,
69- )
100+ kernel_args = self .get_kernel_arguments ()
70101
71102 # Define an agent on the Azure AI agent service
72103 agent_definition = await app_config .ai_project_client .agents .create_agent (
@@ -84,11 +115,17 @@ async def setup(self) -> AzureAIAgent:
84115 )
85116
86117 # Create a Semantic Kernel agent based on the agent definition
87- self .agent = AzureAIAgent (
88- client = app_config .ai_project_client ,
89- definition = agent_definition ,
90- arguments = kernel_args ,
91- )
118+ agent_kwargs = {
119+ "client" : app_config .ai_project_client ,
120+ "definition" : agent_definition ,
121+ "arguments" : kernel_args ,
122+ }
123+
124+ # Add plugins if specified
125+ if self .plugins :
126+ agent_kwargs ["plugins" ] = self .plugins
127+
128+ self .agent = AzureAIAgent (** agent_kwargs )
92129
93130 return self .agent
94131
0 commit comments