diff --git a/.github/workflows/CAdeploy.yml b/.github/workflows/CAdeploy.yml index c63599e2b..73806dbbc 100644 --- a/.github/workflows/CAdeploy.yml +++ b/.github/workflows/CAdeploy.yml @@ -10,8 +10,8 @@ on: - cron: '0 6,18 * * *' # Runs at 6:00 AM and 6:00 PM GMT env: - GPT_MIN_CAPACITY: 250 - TEXT_EMBEDDING_MIN_CAPACITY: 40 + GPT_MIN_CAPACITY: 200 + TEXT_EMBEDDING_MIN_CAPACITY: 80 BRANCH_NAME: ${{ github.head_ref || github.ref_name }} jobs: diff --git a/docs/CustomizingAzdParameters.md b/docs/CustomizingAzdParameters.md index 2de3d2923..c3ac28645 100644 --- a/docs/CustomizingAzdParameters.md +++ b/docs/CustomizingAzdParameters.md @@ -19,9 +19,10 @@ By default this template will use the environment name as the prefix to prevent | `AZURE_ENV_EMBEDDING_MODEL_CAPACITY` | integer | `80` | Set the capacity for embedding model deployment. | | `AZURE_ENV_IMAGETAG` | string | `latest` | Set the image tag (allowed values: `latest`, `dev`, `hotfix`). | | `AZURE_LOCATION` | string | `japaneast` | Sets the Azure region for resource deployment. | -| `AZURE_ENV_LOG_ANALYTICS_WORKSPACE_ID` | string | `` | Reuses an existing Log Analytics Workspace instead of provisioning a new one. | +| `AZURE_ENV_LOG_ANALYTICS_WORKSPACE_ID` | string | Guide to get your [Existing Workspace ID](/docs/re-use-log-analytics.md) | Reuses an existing Log Analytics Workspace instead of provisioning a new one. | | `AZURE_EXISTING_AI_PROJECT_RESOURCE_ID` | string | `` | Reuses an existing AI Foundry Project Resource Id instead of provisioning a new one. | + ## How to Set a Parameter To customize any of the above values, run the following command **before** `azd up`: diff --git a/docs/DeploymentGuide.md b/docs/DeploymentGuide.md index fd00a54c5..7d54bbc54 100644 --- a/docs/DeploymentGuide.md +++ b/docs/DeploymentGuide.md @@ -137,6 +137,14 @@ To adjust quota settings, follow these [steps](./AzureGPTQuotaSettings.md). +
+ + Reusing an Existing Log Analytics Workspace + + Guide to get your [Existing Workspace ID](/docs/re-use-log-analytics.md) + +
+ ### Deploying with AZD Once you've opened the project in [Codespaces](#github-codespaces), [Dev Containers](#vs-code-dev-containers), or [locally](#local-environment), you can deploy it to Azure by following these steps: diff --git a/docs/QuotaCheck.md b/docs/QuotaCheck.md index 872106410..b5a9818c3 100644 --- a/docs/QuotaCheck.md +++ b/docs/QuotaCheck.md @@ -11,7 +11,7 @@ azd auth login ### 📌 Default Models & Capacities: ``` -gpt-4o-mini:30, text-embedding-ada-002:80 +gpt-4o-mini:200, text-embedding-ada-002:80 ``` ### 📌 Default Regions: ``` @@ -37,7 +37,7 @@ eastus, uksouth, eastus2, northcentralus, swedencentral, westus, westus2, southc ``` ✔️ Check specific model(s) in default regions: ``` - ./quota_check_params.sh --models gpt-4o-mini:30,text-embedding-ada-002:80 + ./quota_check_params.sh --models gpt-4o-mini:200,text-embedding-ada-002:80 ``` ✔️ Check default models in specific region(s): ``` @@ -45,11 +45,11 @@ eastus, uksouth, eastus2, northcentralus, swedencentral, westus, westus2, southc ``` ✔️ Passing Both models and regions: ``` - ./quota_check_params.sh --models gpt-4o-mini:30 --regions eastus,westus2 + ./quota_check_params.sh --models gpt-4o-mini:200 --regions eastus,westus2 ``` ✔️ All parameters combined: ``` - ./quota_check_params.sh --models gpt-4o-mini:30,text-embedding-ada-002:80 --regions eastus,westus --verbose + ./quota_check_params.sh --models gpt-4o-mini:200,text-embedding-ada-002:80 --regions eastus,westus --verbose ``` ### **Sample Output** diff --git a/docs/images/re_use_log/logAnalytics.png b/docs/images/re_use_log/logAnalytics.png new file mode 100644 index 000000000..95402f8d1 Binary files /dev/null and b/docs/images/re_use_log/logAnalytics.png differ diff --git a/docs/images/re_use_log/logAnalyticsJson.png b/docs/images/re_use_log/logAnalyticsJson.png new file mode 100644 index 000000000..3a4093bf4 Binary files /dev/null and b/docs/images/re_use_log/logAnalyticsJson.png differ diff --git a/docs/images/re_use_log/logAnalyticsList.png b/docs/images/re_use_log/logAnalyticsList.png new file mode 100644 index 000000000..6dcf4640b Binary files /dev/null and b/docs/images/re_use_log/logAnalyticsList.png differ diff --git a/docs/re-use-log-analytics.md b/docs/re-use-log-analytics.md new file mode 100644 index 000000000..9d48b0f92 --- /dev/null +++ b/docs/re-use-log-analytics.md @@ -0,0 +1,31 @@ +[← Back to *DEPLOYMENT* guide](/docs/DeploymentGuide.md#deployment-options--steps) + +# Reusing an Existing Log Analytics Workspace +To configure your environment to use an existing Log Analytics Workspace, follow these steps: +--- +### 1. Go to Azure Portal +Go to https://portal.azure.com + +### 2. Search for Log Analytics +In the search bar at the top, type "Log Analytics workspaces" and click on it and click on the workspace you want to use. + +![alt text](../docs/images/re_use_log/logAnalyticsList.png) + +### 3. Copy Resource ID +In the Overview pane, Click on JSON View + +![alt text](../docs/images/re_use_log/logAnalytics.png) + +Copy Resource ID that is your Workspace ID + +![alt text](../docs/images/re_use_log/logAnalyticsJson.png) + +### 4. Set the Workspace ID in Your Environment +Run the following command in your terminal +```bash +azd env set AZURE_ENV_LOG_ANALYTICS_WORKSPACE_ID '' +``` +Replace `` with the value obtained from Step 3. + +### 5. Continue Deployment +Proceed with the next steps in the [deployment guide](/docs/DeploymentGuide.md#deployment-options--steps). diff --git a/infra/deploy_ai_foundry.bicep b/infra/deploy_ai_foundry.bicep index 677f09d15..a8797a154 100644 --- a/infra/deploy_ai_foundry.bicep +++ b/infra/deploy_ai_foundry.bicep @@ -37,7 +37,7 @@ var aiModelDeployments = [ name: embeddingModel model: embeddingModel sku: { - name: 'Standard' + name: 'GlobalStandard' capacity: embeddingDeploymentCapacity } raiPolicyName: 'Microsoft.Default' diff --git a/infra/main.bicep b/infra/main.bicep index e8b215439..7c225c1f8 100644 --- a/infra/main.bicep +++ b/infra/main.bicep @@ -63,7 +63,7 @@ param imageTag string = 'latest' type: 'location' usageName: [ 'OpenAI.GlobalStandard.gpt-4o-mini,200' - 'OpenAI.Standard.text-embedding-ada-002,80' + 'OpenAI.GlobalStandard.text-embedding-ada-002,80' ] } }) diff --git a/infra/main.json b/infra/main.json index cf0c402ae..59b8e85a6 100644 --- a/infra/main.json +++ b/infra/main.json @@ -5,7 +5,7 @@ "_generator": { "name": "bicep", "version": "0.36.177.2456", - "templateHash": "1860841716622591379" + "templateHash": "2238194529646818649" } }, "parameters": { @@ -114,7 +114,7 @@ "type": "location", "usageName": [ "OpenAI.GlobalStandard.gpt-4o-mini,200", - "OpenAI.Standard.text-embedding-ada-002,80" + "OpenAI.GlobalStandard.text-embedding-ada-002,80" ] }, "description": "Location for AI Foundry deployment. This is the location where the AI Foundry resources will be deployed." @@ -744,7 +744,7 @@ "_generator": { "name": "bicep", "version": "0.36.177.2456", - "templateHash": "15524709584492116446" + "templateHash": "1124249040831466979" } }, "parameters": { @@ -1038,7 +1038,7 @@ "name": "[parameters('embeddingModel')]", "model": "[parameters('embeddingModel')]", "sku": { - "name": "Standard", + "name": "GlobalStandard", "capacity": "[parameters('embeddingDeploymentCapacity')]" }, "raiPolicyName": "Microsoft.Default" diff --git a/infra/scripts/checkquota.sh b/infra/scripts/checkquota.sh index 3885a57dc..136d1902c 100755 --- a/infra/scripts/checkquota.sh +++ b/infra/scripts/checkquota.sh @@ -33,7 +33,7 @@ echo "✅ Azure subscription set successfully." # Define models and their minimum required capacities declare -A MIN_CAPACITY=( ["OpenAI.GlobalStandard.gpt-4o-mini"]=$GPT_MIN_CAPACITY - ["OpenAI.Standard.text-embedding-ada-002"]=$TEXT_EMBEDDING_MIN_CAPACITY + ["OpenAI.GlobalStandard.text-embedding-ada-002"]=$TEXT_EMBEDDING_MIN_CAPACITY ) VALID_REGION="" diff --git a/infra/scripts/copy_kb_files.sh b/infra/scripts/copy_kb_files.sh index 1d94772b9..59e318bfe 100644 --- a/infra/scripts/copy_kb_files.sh +++ b/infra/scripts/copy_kb_files.sh @@ -123,10 +123,12 @@ echo "Uploading files to Azure Blob Storage" # Using az storage blob upload-batch to upload files with managed identity authentication, as the az storage fs directory upload command is not working with managed identity authentication. az storage blob upload-batch --account-name "$storageAccount" --destination data/"$extractedFolder1" --source $extractionPath1 --auth-mode login --pattern '*' --overwrite --output none if [ $? -ne 0 ]; then - retries=3 + maxRetries=5 + retries=$maxRetries sleepTime=10 - echo "Error: Failed to upload files to Azure Blob Storage. Retrying upload...($((4 - retries)) of 3)" + attempt=1 while [ $retries -gt 0 ]; do + echo "Error: Failed to upload files to Azure Blob Storage. Retrying upload...$attempt of $maxRetries in $sleepTime seconds" sleep $sleepTime az storage blob upload-batch --account-name "$storageAccount" --destination data/"$extractedFolder1" --source $extractionPath1 --auth-mode login --pattern '*' --overwrite --output none if [ $? -eq 0 ]; then @@ -134,22 +136,26 @@ if [ $? -ne 0 ]; then break else ((retries--)) - echo "Retrying upload... ($((4 - retries)) of 3)" + ((attempt++)) sleepTime=$((sleepTime * 2)) - sleep $sleepTime fi done - exit 1 + if [ $retries -eq 0 ]; then + echo "Error: Failed to upload files after all retry attempts." + exit 1 + fi else echo "Files uploaded successfully to Azure Blob Storage." fi az storage blob upload-batch --account-name "$storageAccount" --destination data/"$extractedFolder2" --source $extractionPath2 --auth-mode login --pattern '*' --overwrite --output none if [ $? -ne 0 ]; then - retries=3 + maxRetries=5 + retries=$maxRetries + attempt=1 sleepTime=10 - echo "Error: Failed to upload files to Azure Blob Storage. Retrying upload...($((4 - retries)) of 3)" while [ $retries -gt 0 ]; do + echo "Error: Failed to upload files to Azure Blob Storage. Retrying upload...$attempt of $maxRetries in $sleepTime seconds" sleep $sleepTime az storage blob upload-batch --account-name "$storageAccount" --destination data/"$extractedFolder2" --source $extractionPath2 --auth-mode login --pattern '*' --overwrite --output none if [ $? -eq 0 ]; then @@ -157,12 +163,14 @@ if [ $? -ne 0 ]; then break else ((retries--)) - echo "Retrying upload... ($((4 - retries)) of 3)" + ((attempt++)) sleepTime=$((sleepTime * 2)) - sleep $sleepTime fi done - exit 1 + if [ $retries -eq 0 ]; then + echo "Error: Failed to upload files after all retry attempts." + exit 1 + fi else echo "Files uploaded successfully to Azure Blob Storage." fi diff --git a/infra/scripts/quota_check_params.sh b/infra/scripts/quota_check_params.sh index 62a2305c8..27e659bb3 100755 --- a/infra/scripts/quota_check_params.sh +++ b/infra/scripts/quota_check_params.sh @@ -47,7 +47,7 @@ log_verbose() { } # Default Models and Capacities (Comma-separated in "model:capacity" format) -DEFAULT_MODEL_CAPACITY="gpt-4o-mini:30,text-embedding-ada-002:80" +DEFAULT_MODEL_CAPACITY="gpt-4o-mini:200,text-embedding-ada-002:80" # Convert the comma-separated string into an array IFS=',' read -r -a MODEL_CAPACITY_PAIRS <<< "$DEFAULT_MODEL_CAPACITY" @@ -165,11 +165,7 @@ for REGION in "${REGIONS[@]}"; do FOUND=false INSUFFICIENT_QUOTA=false - if [ "$MODEL_NAME" = "text-embedding-ada-002" ]; then - MODEL_TYPES=("openai.standard.$MODEL_NAME") - else - MODEL_TYPES=("openai.standard.$MODEL_NAME" "openai.globalstandard.$MODEL_NAME") - fi + MODEL_TYPES=("openai.standard.$MODEL_NAME" "openai.globalstandard.$MODEL_NAME") for MODEL_TYPE in "${MODEL_TYPES[@]}"; do FOUND=false diff --git a/src/App/backend/agents/agent_factory.py b/src/App/backend/agents/agent_factory.py index 92c291852..df81a2caf 100644 --- a/src/App/backend/agents/agent_factory.py +++ b/src/App/backend/agents/agent_factory.py @@ -41,7 +41,8 @@ async def get_wealth_advisor_agent(cls): ) agent_name = "WealthAdvisor" - agent_instructions = "You are a helpful assistant to a Wealth Advisor." + agent_instructions = '''You are a helpful assistant to a Wealth Advisor. + If the question is unrelated to data but is conversational (e.g., greetings or follow-ups), respond appropriately using context, do not use external tools or perform any web searches for these conversational inputs.''' agent_definition = await client.agents.create_agent( model=ai_agent_settings.model_deployment_name, @@ -105,3 +106,46 @@ async def delete_all_agent_instance(cls): ) cls._search_agent["client"].close() cls._search_agent = None + + @classmethod + async def get_sql_agent(cls) -> dict: + """ + Get or create a singleton SQLQueryGenerator AzureAIAgent instance. + This agent is used to generate T-SQL queries from natural language input. + """ + async with cls._lock: + if not hasattr(cls, "_sql_agent") or cls._sql_agent is None: + + agent_instructions = config.SQL_SYSTEM_PROMPT or """ + You are an expert assistant in generating T-SQL queries based on user questions. + Always use the following schema: + 1. Table: Clients (ClientId, Client, Email, Occupation, MaritalStatus, Dependents) + 2. Table: InvestmentGoals (ClientId, InvestmentGoal) + 3. Table: Assets (ClientId, AssetDate, Investment, ROI, Revenue, AssetType) + 4. Table: ClientSummaries (ClientId, ClientSummary) + 5. Table: InvestmentGoalsDetails (ClientId, InvestmentGoal, TargetAmount, Contribution) + 6. Table: Retirement (ClientId, StatusDate, RetirementGoalProgress, EducationGoalProgress) + 7. Table: ClientMeetings (ClientId, ConversationId, Title, StartTime, EndTime, Advisor, ClientEmail) + + Rules: + - Always filter by ClientId = + - Do not use client name for filtering + - Assets table contains snapshots by date; do not sum values across dates + - Use StartTime for time-based filtering (meetings) + - Only return the raw T-SQL query. No explanations or comments. + """ + + project_client = AIProjectClient( + endpoint=config.AI_PROJECT_ENDPOINT, + credential=DefaultAzureCredentialSync(), + api_version="2025-05-01", + ) + + agent = project_client.agents.create_agent( + model=config.AZURE_OPENAI_MODEL, + instructions=agent_instructions, + name="SQLQueryGeneratorAgent", + ) + + cls._sql_agent = {"agent": agent, "client": project_client} + return cls._sql_agent diff --git a/src/App/backend/plugins/chat_with_data_plugin.py b/src/App/backend/plugins/chat_with_data_plugin.py index 86740085a..f421af7ef 100644 --- a/src/App/backend/plugins/chat_with_data_plugin.py +++ b/src/App/backend/plugins/chat_with_data_plugin.py @@ -22,42 +22,6 @@ class ChatWithDataPlugin: - @kernel_function( - name="GreetingsResponse", - description="Respond to any greeting or general questions", - ) - async def greeting( - self, input: Annotated[str, "the question"] - ) -> Annotated[str, "The output is a string"]: - """ - Simple greeting handler using Azure OpenAI. - """ - try: - if config.USE_AI_PROJECT_CLIENT: - client = self.get_project_openai_client() - - else: - client = self.get_openai_client() - - completion = client.chat.completions.create( - model=config.AZURE_OPENAI_MODEL, - messages=[ - { - "role": "system", - "content": "You are a helpful assistant to respond to greetings or general questions.", - }, - {"role": "user", "content": input}, - ], - temperature=0, - top_p=1, - n=1, - ) - - answer = completion.choices[0].message.content - except Exception as e: - answer = f"Error retrieving greeting response: {str(e)}" - return answer - @kernel_function( name="ChatWithSQLDatabase", description="Given a query about client assets, investments and scheduled meetings (including upcoming or next meeting dates/times), get details from the database based on the provided question and client id", @@ -77,87 +41,61 @@ async def get_SQL_Response( if not input or not input.strip(): return "Error: Query input is required" - clientid = ClientId - query = input + try: + from backend.agents.agent_factory import AgentFactory + agent_info = await AgentFactory.get_sql_agent() + agent = agent_info["agent"] + project_client = agent_info["client"] - # Retrieve the SQL prompt from environment variables (if available) - sql_prompt = config.SQL_SYSTEM_PROMPT - if sql_prompt: - sql_prompt = sql_prompt.replace("{query}", query).replace( - "{clientid}", clientid - ) - else: - # Fallback prompt if not set in environment - sql_prompt = f"""Generate a valid T-SQL query to find {query} for tables and columns provided below: - 1. Table: Clients - Columns: ClientId, Client, Email, Occupation, MaritalStatus, Dependents - 2. Table: InvestmentGoals - Columns: ClientId, InvestmentGoal - 3. Table: Assets - Columns: ClientId, AssetDate, Investment, ROI, Revenue, AssetType - 4. Table: ClientSummaries - Columns: ClientId, ClientSummary - 5. Table: InvestmentGoalsDetails - Columns: ClientId, InvestmentGoal, TargetAmount, Contribution - 6. Table: Retirement - Columns: ClientId, StatusDate, RetirementGoalProgress, EducationGoalProgress - 7. Table: ClientMeetings - Columns: ClientId, ConversationId, Title, StartTime, EndTime, Advisor, ClientEmail - Always use the Investment column from the Assets table as the value. - Assets table has snapshots of values by date. Do not add numbers across different dates for total values. - Do not use client name in filters. - Do not include assets values unless asked for. - ALWAYS use ClientId = {clientid} in the query filter. - ALWAYS select Client Name (Column: Client) in the query. - Query filters are IMPORTANT. Add filters like AssetType, AssetDate, etc. if needed. - When answering scheduling or time-based meeting questions, always use the StartTime column from ClientMeetings table. Use correct logic to return the most recent past meeting (last/previous) or the nearest future meeting (next/upcoming), and ensure only StartTime column is used for meeting timing comparisons. - Only return the generated SQL query. Do not return anything else.""" + thread = project_client.agents.threads.create() - try: - if config.USE_AI_PROJECT_CLIENT: - client = self.get_project_openai_client() + # Send question as message + project_client.agents.messages.create( + thread_id=thread.id, + role=MessageRole.USER, + content=f"ClientId: {ClientId}\nQuestion: {input}", + ) - else: - # Initialize the Azure OpenAI client - client = self.get_openai_client() - - completion = client.chat.completions.create( - model=config.AZURE_OPENAI_MODEL, - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": sql_prompt}, - ], + # Run the agent + run = project_client.agents.runs.create_and_process( + thread_id=thread.id, + agent_id=agent.id, temperature=0, - top_p=1, - n=1, ) - sql_query = completion.choices[0].message.content + if run.status == "failed": + return f"Error: Agent run failed: {run.last_error}" - # Remove any triple backticks if present - sql_query = sql_query.replace("```sql", "").replace("```", "") + # Get SQL query from the agent's final response + message = project_client.agents.messages.get_last_message_text_by_role( + thread_id=thread.id, + role=MessageRole.AGENT + ) + sql_query = message.text.value.strip() if message else None + + if not sql_query: + return "No SQL query was generated." - # print("Generated SQL:", sql_query) + # Clean up triple backticks (if any) + sql_query = sql_query.replace("```sql", "").replace("```", "") + # Execute the query conn = get_connection() - # conn = pyodbc.connect(connectionString) cursor = conn.cursor() cursor.execute(sql_query) - rows = cursor.fetchall() + if not rows: - answer = "No data found for that client." + result = "No data found for that client." else: - answer = "" - for row in rows: - answer += str(row) + "\n" + result = "\n".join(str(row) for row in rows) conn.close() - answer = answer[:20000] if len(answer) > 20000 else answer + return result[:20000] if len(result) > 20000 else result except Exception as e: - answer = f"Error retrieving data from SQL: {str(e)}" - return answer + logging.exception("Error in get_SQL_Response") + return f"Error retrieving SQL data: {str(e)}" @kernel_function( name="ChatWithCallTranscripts", diff --git a/src/App/tests/backend/agents/test_agent_factory.py b/src/App/tests/backend/agents/test_agent_factory.py index dcae796ad..775cdab4c 100644 --- a/src/App/tests/backend/agents/test_agent_factory.py +++ b/src/App/tests/backend/agents/test_agent_factory.py @@ -51,7 +51,8 @@ async def test_get_wealth_advisor_agent_creates_agent_when_none_exists( mock_client.agents.create_agent.assert_called_once_with( model="test-model", name="WealthAdvisor", - instructions="You are a helpful assistant to a Wealth Advisor.", + instructions='''You are a helpful assistant to a Wealth Advisor. + If the question is unrelated to data but is conversational (e.g., greetings or follow-ups), respond appropriately using context, do not use external tools or perform any web searches for these conversational inputs.''', ) mock_agent.assert_called_once() diff --git a/src/App/tests/backend/plugins/test_chat_with_data_plugin.py b/src/App/tests/backend/plugins/test_chat_with_data_plugin.py index 826cf4c5f..684c947ae 100644 --- a/src/App/tests/backend/plugins/test_chat_with_data_plugin.py +++ b/src/App/tests/backend/plugins/test_chat_with_data_plugin.py @@ -12,26 +12,6 @@ def setup_method(self): """Setup method to initialize plugin instance for each test.""" self.plugin = ChatWithDataPlugin() - @pytest.mark.asyncio - @patch.object(ChatWithDataPlugin, "get_openai_client") - async def test_greeting_returns_response(self, mock_get_openai_client): - """Test that greeting method calls OpenAI and returns response.""" - # Setup mock - mock_client = MagicMock() - mock_get_openai_client.return_value = mock_client - - mock_completion = MagicMock() - mock_completion.choices = [MagicMock()] - mock_completion.choices[0].message.content = ( - "Hello! I'm your Wealth Assistant. How can I help you today?" - ) - mock_client.chat.completions.create.return_value = mock_completion - - result = await self.plugin.greeting("Hello") - - assert result == "Hello! I'm your Wealth Assistant. How can I help you today?" - mock_client.chat.completions.create.assert_called_once() - @patch("backend.plugins.chat_with_data_plugin.config") @patch("backend.plugins.chat_with_data_plugin.openai.AzureOpenAI") @patch("backend.plugins.chat_with_data_plugin.get_bearer_token_provider") @@ -102,82 +82,101 @@ def test_get_project_openai_client_success( @pytest.mark.asyncio @patch("backend.plugins.chat_with_data_plugin.get_connection") - @patch.object(ChatWithDataPlugin, "get_openai_client") - async def test_get_sql_response_success( - self, mock_get_openai_client, mock_get_connection - ): - """Test successful SQL response generation with AAD authentication.""" - # Setup mocks - mock_client = MagicMock() - mock_get_openai_client.return_value = mock_client + @patch("backend.plugins.chat_with_data_plugin.config") + @patch("backend.agents.agent_factory.AgentFactory.get_sql_agent") + async def test_get_sql_response_success(self, mock_get_sql_agent, mock_config, mock_get_connection): + mock_config.AI_PROJECT_ENDPOINT = "https://dummy.endpoint" + mock_config.AZURE_OPENAI_MODEL = "gpt-4o-mini" + mock_config.SQL_SYSTEM_PROMPT = "Test prompt" - mock_completion = MagicMock() - mock_completion.choices = [MagicMock()] - mock_completion.choices[0].message.content = ( - "SELECT * FROM Clients WHERE ClientId = 'client123';" - ) - mock_client.chat.completions.create.return_value = mock_completion + mock_agent = MagicMock() + mock_agent.id = "mock-agent-id" + mock_project_client = MagicMock() + mock_thread = MagicMock() + mock_thread.id = "thread123" + mock_project_client.agents.threads.create.return_value = mock_thread + + mock_run = MagicMock() + mock_run.status = "completed" + mock_project_client.agents.runs.create_and_process.return_value = mock_run + + mock_message = MagicMock() + mock_message.text.value = "SELECT * FROM Clients WHERE ClientId = 'client123';" + mock_project_client.agents.messages.get_last_message_text_by_role.return_value = mock_message + + mock_get_sql_agent.return_value = {"agent": mock_agent, "client": mock_project_client} + + # Mock DB execution mock_connection = MagicMock() mock_cursor = MagicMock() - mock_cursor.fetchall.return_value = [ - ("John Doe", "john@example.com", "Engineer") - ] + mock_cursor.fetchall.return_value = [("John Doe", "john@example.com", "Engineer")] mock_connection.cursor.return_value = mock_cursor mock_get_connection.return_value = mock_connection result = await self.plugin.get_SQL_Response("Find client details", "client123") - # Verify the result assert "John Doe" in result assert "john@example.com" in result assert "Engineer" in result - # Verify OpenAI was called - mock_client.chat.completions.create.assert_called_once() - - # Verify database operations using AAD authentication - mock_get_connection.assert_called_once() - mock_cursor.execute.assert_called_once() - mock_cursor.fetchall.assert_called_once() - mock_connection.close.assert_called_once() - @pytest.mark.asyncio @patch("backend.plugins.chat_with_data_plugin.get_connection") - @patch.object(ChatWithDataPlugin, "get_openai_client") - async def test_get_sql_response_database_error( - self, mock_get_openai_client, mock_get_connection - ): - """Test SQL response when database connection fails.""" - mock_client = MagicMock() - mock_get_openai_client.return_value = mock_client + @patch("backend.plugins.chat_with_data_plugin.config") + @patch("backend.agents.agent_factory.AgentFactory.get_sql_agent") + async def test_get_sql_response_database_error(self, mock_get_sql_agent, mock_config, mock_get_connection): + mock_config.AI_PROJECT_ENDPOINT = "https://dummy.endpoint" + mock_config.AZURE_OPENAI_MODEL = "gpt-4o-mini" + + mock_agent = MagicMock() + mock_agent.id = "mock-agent-id" + mock_project_client = MagicMock() - mock_completion = MagicMock() - mock_completion.choices = [MagicMock()] - mock_completion.choices[0].message.content = "SELECT * FROM Clients;" - mock_client.chat.completions.create.return_value = mock_completion + mock_thread = MagicMock() + mock_thread.id = "thread123" + mock_project_client.agents.threads.create.return_value = mock_thread + + mock_run = MagicMock() + mock_run.status = "completed" + mock_project_client.agents.runs.create_and_process.return_value = mock_run + + mock_message = MagicMock() + mock_message.text.value = "SELECT * FROM Clients;" + mock_project_client.agents.messages.get_last_message_text_by_role.return_value = mock_message + + mock_get_sql_agent.return_value = {"agent": mock_agent, "client": mock_project_client} - # Simulate database connection error mock_get_connection.side_effect = Exception("Database connection failed") result = await self.plugin.get_SQL_Response("Get all clients", "client123") - assert "Error retrieving data from SQL" in result + assert "Error retrieving SQL data" in result assert "Database connection failed" in result @pytest.mark.asyncio - @patch.object(ChatWithDataPlugin, "get_openai_client") - async def test_get_sql_response_openai_error(self, mock_get_openai_client): - """Test SQL response when OpenAI call fails.""" - mock_client = MagicMock() - mock_get_openai_client.return_value = mock_client + @patch("backend.plugins.chat_with_data_plugin.config") + @patch("backend.agents.agent_factory.AgentFactory.get_sql_agent") + async def test_get_sql_response_openai_error(self, mock_get_sql_agent, mock_config): + mock_config.AI_PROJECT_ENDPOINT = "https://dummy.endpoint" + mock_config.AZURE_OPENAI_MODEL = "gpt-4o-mini" + + mock_agent = MagicMock() + mock_agent.id = "mock-agent-id" + mock_project_client = MagicMock() + + mock_thread = MagicMock() + mock_thread.id = "thread123" + mock_project_client.agents.threads.create.return_value = mock_thread + + # Simulate error during run processing + mock_project_client.agents.runs.create_and_process.side_effect = Exception("OpenAI API error") - # Simulate OpenAI error - mock_client.chat.completions.create.side_effect = Exception("OpenAI API error") + mock_get_sql_agent.return_value = {"agent": mock_agent, "client": mock_project_client} - result = await self.plugin.get_SQL_Response("Get client data", "client123") + plugin = ChatWithDataPlugin() + result = await plugin.get_SQL_Response("Get client data", "client123") - assert "Error retrieving data from SQL" in result + assert "Error retrieving SQL data" in result assert "OpenAI API error" in result @pytest.mark.asyncio