1+ """Base classes for SQL migration agents."""
2+
3+ import logging
4+ from abc import ABC , abstractmethod
5+ from typing import Any , Generic , TypeVar
6+
7+ from azure .ai .projects .models import (
8+ ResponseFormatJsonSchema ,
9+ ResponseFormatJsonSchemaType ,
10+ )
11+ from common .config .config import app_config
12+ from common .models .api import AgentType
13+ from semantic_kernel .agents .azure_ai .azure_ai_agent import AzureAIAgent
14+ from semantic_kernel .functions import KernelArguments
15+ from sql_agents .agent_config import AgentModelDeployment , AgentsConfigDialect
16+ from sql_agents .helpers .utils import get_prompt
17+
18+ # Type variable for response models
19+ T = TypeVar ('T' )
20+
21+ logger = logging .getLogger (__name__ )
22+ logger .setLevel (logging .DEBUG )
23+
24+
25+ class BaseSQLAgent (Generic [T ], ABC ):
26+ """Base class for all SQL migration agents."""
27+
28+ def __init__ (
29+ self ,
30+ agent_type : AgentType ,
31+ config : AgentsConfigDialect ,
32+ deployment_name : AgentModelDeployment ,
33+
34+ ):
35+ """Initialize the base SQL agent."""
36+ self .agent_type = agent_type
37+ self .config = config
38+ self .deployment_name = deployment_name
39+ self .agent : AzureAIAgent = None
40+
41+
42+ @property
43+ @abstractmethod
44+ def response_schema (self ) -> type :
45+ """Get the response schema for this agent."""
46+ pass
47+
48+ @property
49+ @abstractmethod
50+ def num_candidates (self ) -> int :
51+ """Get the number of candidates for this agent."""
52+ pass
53+
54+ async def setup (self ) -> AzureAIAgent :
55+ """Setup the agent with Azure AI."""
56+ _deployment_name = self .deployment_name .value
57+ _name = self .agent_type .value
58+
59+ try :
60+ template_content = get_prompt (_name )
61+ except FileNotFoundError as exc :
62+ logger .error ("Prompt file for %s not found." , _name )
63+ raise ValueError (f"Prompt file for { _name } not found." ) from exc
64+
65+ kernel_args = KernelArguments (
66+ target = self .config .sql_dialect_out ,
67+ numCandidates = str (self .num_candidates ),
68+ source = self .config .sql_dialect_in ,
69+ )
70+
71+ # Define an agent on the Azure AI agent service
72+ agent_definition = await app_config .ai_project_client .agents .create_agent (
73+ model = _deployment_name ,
74+ name = _name ,
75+ instructions = template_content ,
76+ temperature = self .temperature ,
77+ response_format = ResponseFormatJsonSchemaType (
78+ json_schema = ResponseFormatJsonSchema (
79+ name = self .response_schema .__name__ ,
80+ description = f"respond with { self .response_schema .__name__ .lower ()} " ,
81+ schema = self .response_schema .model_json_schema (),
82+ )
83+ ),
84+ )
85+
86+ # Create a Semantic Kernel agent based on the agent definition
87+ self .agent = AzureAIAgent (
88+ client = app_config .ai_project_client ,
89+ definition = agent_definition ,
90+ arguments = kernel_args ,
91+ )
92+
93+ return self .agent
94+
95+ async def get_agent (self ) -> AzureAIAgent :
96+ """Get the agent, setting it up if needed."""
97+ if self .agent is None :
98+ await self .setup ()
99+ return self .agent
100+
101+ async def execute (self , inputs : Any ) -> T :
102+ """Execute the agent with the given inputs."""
103+ agent = await self .get_agent ()
104+ response = await agent .invoke (inputs )
105+ return response # Type will be inferred from T
0 commit comments