diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 6239ea12..318dd04c 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -1,31 +1,31 @@ -{ - "name": "Multi Agent Custom Automation Engine Solution Accelerator", - "image": "mcr.microsoft.com/devcontainers/python:3.10", - "features": { - "ghcr.io/devcontainers/features/azure-cli:1.0.8": {}, - "ghcr.io/azure/azure-dev/azd:latest": {}, - "ghcr.io/rchaganti/vsc-devcontainer-features/azurebicep:1.0.5": {} - }, - - "postCreateCommand": "sudo chmod +x .devcontainer/setupEnv.sh && ./.devcontainer/setupEnv.sh", - - "customizations": { - "vscode": { - "extensions": [ - "ms-azuretools.azure-dev", - "ms-azuretools.vscode-bicep", - "ms-python.python" - ] - }, - "codespaces": { - "openFiles": [ - "README.md" - ] - } - }, - - "remoteUser": "vscode", - "hostRequirements": { - "memory": "8gb" - } -} +{ + "name": "Multi Agent Custom Automation Engine Solution Accelerator", + "image": "mcr.microsoft.com/devcontainers/python:3.10", + "features": { + "ghcr.io/devcontainers/features/azure-cli:1.0.8": {}, + "ghcr.io/azure/azure-dev/azd:latest": {}, + "ghcr.io/rchaganti/vsc-devcontainer-features/azurebicep:1.0.5": {} + }, + + "postCreateCommand": "sudo chmod +x .devcontainer/setupEnv.sh && ./.devcontainer/setupEnv.sh", + + "customizations": { + "vscode": { + "extensions": [ + "ms-azuretools.azure-dev", + "ms-azuretools.vscode-bicep", + "ms-python.python" + ] + }, + "codespaces": { + "openFiles": [ + "README.md" + ] + } + }, + + "remoteUser": "vscode", + "hostRequirements": { + "memory": "8gb" + } +} diff --git a/.devcontainer/setupEnv.sh b/.devcontainer/setupEnv.sh index 1b99cdb7..da381991 100644 --- a/.devcontainer/setupEnv.sh +++ b/.devcontainer/setupEnv.sh @@ -5,4 +5,7 @@ pip install --upgrade pip (cd ./src/frontend; pip install -r requirements.txt) -(cd ./src/backend; pip install -r requirements.txt) \ No newline at end of file + +(cd ./src/backend; pip install -r requirements.txt) + + diff --git a/.flake8 b/.flake8 index 93f63e5d..08367ecd 100644 --- a/.flake8 +++ b/.flake8 @@ -2,4 +2,4 @@ max-line-length = 88 extend-ignore = E501 exclude = .venv, frontend -ignore = E203, W503, G004, G200 \ No newline at end of file +ignore = E203, W503, G004, G200, E402 \ No newline at end of file diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 32d1c60a..6392f559 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -37,8 +37,6 @@ jobs: run: | python -m pip install --upgrade pip pip install -r src/backend/requirements.txt - pip install pytest-cov - pip install pytest-asyncio - name: Check if test files exist id: check_tests @@ -50,7 +48,6 @@ jobs: echo "Test files found, running tests." echo "skip_tests=false" >> $GITHUB_ENV fi - - name: Run tests with coverage if: env.skip_tests == 'false' run: | @@ -59,4 +56,4 @@ jobs: - name: Skip coverage report if no tests if: env.skip_tests == 'true' run: | - echo "Skipping coverage report because no tests were found." + echo "Skipping coverage report because no tests were found." \ No newline at end of file diff --git a/README.md b/README.md index c84c56a7..3da325ea 100644 --- a/README.md +++ b/README.md @@ -177,8 +177,8 @@ To add your newly created backend image: name: 'FRONTEND_SITE_NAME' value: 'https://.azurewebsites.net' - name: 'APPLICATIONINSIGHTS_INSTRUMENTATION_KEY' - value: + name: 'APPLICATIONINSIGHTS_CONNECTION_STRING' + value: - Click 'Save' and deploy your new revision diff --git a/deploy/macae-continer-oc.json b/deploy/macae-continer-oc.json index f394cd91..2efe8653 100644 --- a/deploy/macae-continer-oc.json +++ b/deploy/macae-continer-oc.json @@ -6,7 +6,7 @@ "_generator": { "name": "bicep", "version": "0.32.4.45862", - "templateHash": "13282901028774763433" + "templateHash": "14272651486454797588" } }, "parameters": { @@ -368,7 +368,7 @@ "value": "[format('https://{0}.azurewebsites.net', format(variables('uniqueNameFormat'), 'frontend'))]" }, { - "name": "APPLICATIONINSIGHTS_INSTRUMENTATION_KEY", + "name": "APPLICATIONINSIGHTS_CONNECTION_STRING", "value": "[reference('appInsights').ConnectionString]" } ] diff --git a/deploy/macae-continer.bicep b/deploy/macae-continer.bicep index 965b111d..35ba347f 100644 --- a/deploy/macae-continer.bicep +++ b/deploy/macae-continer.bicep @@ -280,7 +280,7 @@ resource containerApp 'Microsoft.App/containerApps@2024-03-01' = { value: 'https://${format(uniqueNameFormat, 'frontend')}.azurewebsites.net' } { - name: 'APPLICATIONINSIGHTS_INSTRUMENTATION_KEY' + name: 'APPLICATIONINSIGHTS_CONNECTION_STRING' value: appInsights.properties.ConnectionString } ] diff --git a/documentation/LocalDeployment.md b/documentation/LocalDeployment.md index 4cbc799e..a34ba583 100644 --- a/documentation/LocalDeployment.md +++ b/documentation/LocalDeployment.md @@ -9,6 +9,7 @@ # Local setup > **Note for macOS Developers**: If you are using macOS on Apple Silicon (ARM64) the DevContainer will **not** work. This is due to a limitation with the Azure Functions Core Tools (see [here](https://github.com/Azure/azure-functions-core-tools/issues/3112)). We recommend using the [Non DevContainer Setup](./NON_DEVCONTAINER_SETUP.md) instructions to run the accelerator locally. + The easiest way to run this accelerator is in a VS Code Dev Containers, which will open the project in your local VS Code using the [Dev Containers extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers): 1. Start Docker Desktop (install it if not already installed) @@ -96,8 +97,8 @@ The files for the dev container are located in `/.devcontainer/` folder. **Using a Different Database in Cosmos:** You can set the solution up to use a different database in Cosmos. For example, you can name it something like autogen-dev. To do this: - 1. Change the environment variable **COSMOSDB_DATABASE** to the new database name. - 2. You will need to create the database in the Cosmos DB account. You can do this from the Data Explorer pane in the portal, click on the drop down labeled “_+ New Container_” and provide all the necessary details. + 1. Change the environment variable **COSMOSDB_DATABASE** to the new database name. + 2. You will need to create the database in the Cosmos DB account. You can do this from the Data Explorer pane in the portal, click on the drop down labeled “_+ New Container_” and provide all the necessary details. 6. **Create a `.env` file:** diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/backend/.env.sample b/src/backend/.env.sample index 64102ab7..e92f7346 100644 --- a/src/backend/.env.sample +++ b/src/backend/.env.sample @@ -6,7 +6,7 @@ AZURE_OPENAI_ENDPOINT= AZURE_OPENAI_MODEL_NAME=gpt-4o AZURE_OPENAI_DEPLOYMENT_NAME=gpt-4o AZURE_OPENAI_API_VERSION=2024-08-01-preview -APPLICATIONINSIGHTS_INSTRUMENTATION_KEY= +APPLICATIONINSIGHTS_CONNECTION_STRING= BACKEND_API_URL='http://localhost:8000' FRONTEND_SITE_NAME='http://127.0.0.1:3000' \ No newline at end of file diff --git a/src/backend/Dockerfile b/src/backend/Dockerfile index 46333fbf..607d65f9 100644 --- a/src/backend/Dockerfile +++ b/src/backend/Dockerfile @@ -3,7 +3,7 @@ FROM python:3.11-slim # Backend app setup -WORKDIR /app/backend +WORKDIR /src/backend COPY . . # Install dependencies RUN pip install --no-cache-dir -r requirements.txt diff --git a/src/backend/__init__.py b/src/backend/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/backend/agents/__init__.py b/src/backend/agents/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/backend/agents/agentutils.py b/src/backend/agents/agentutils.py index 72a6928d..6b117566 100644 --- a/src/backend/agents/agentutils.py +++ b/src/backend/agents/agentutils.py @@ -6,8 +6,8 @@ ) from pydantic import BaseModel -from context.cosmos_memory import CosmosBufferedChatCompletionContext -from models.messages import Step +from src.backend.context.cosmos_memory import CosmosBufferedChatCompletionContext +from src.backend.models.messages import Step common_agent_system_message = "If you do not have the information for the arguments of the function you need to call, do not call the function. Instead, respond back to the user requesting further information. You must not hallucinate or invent any of the information used as arguments in the function. For example, if you need to call a function that requires a delivery address, you must not generate 123 Example St. You must skip calling functions and return a clarification message along the lines of: Sorry, I'm missing some information I need to help you with that. Could you please provide the delivery address so I can do that for you?" diff --git a/src/backend/agents/base_agent.py b/src/backend/agents/base_agent.py index 23541f83..46b34960 100644 --- a/src/backend/agents/base_agent.py +++ b/src/backend/agents/base_agent.py @@ -13,15 +13,15 @@ from autogen_core.components.tool_agent import tool_agent_caller_loop from autogen_core.components.tools import Tool -from context.cosmos_memory import CosmosBufferedChatCompletionContext -from models.messages import ( +from src.backend.context.cosmos_memory import CosmosBufferedChatCompletionContext +from src.backend.models.messages import ( ActionRequest, ActionResponse, AgentMessage, Step, StepStatus, ) -from event_utils import track_event_if_configured +from src.backend.event_utils import track_event_if_configured class BaseAgent(RoutedAgent): diff --git a/src/backend/agents/generic.py b/src/backend/agents/generic.py index 209ee277..fff73a56 100644 --- a/src/backend/agents/generic.py +++ b/src/backend/agents/generic.py @@ -5,8 +5,8 @@ from autogen_core.components.models import AzureOpenAIChatCompletionClient from autogen_core.components.tools import FunctionTool, Tool -from agents.base_agent import BaseAgent -from context.cosmos_memory import CosmosBufferedChatCompletionContext +from src.backend.agents.base_agent import BaseAgent +from src.backend.context.cosmos_memory import CosmosBufferedChatCompletionContext async def dummy_function() -> str: diff --git a/src/backend/agents/group_chat_manager.py b/src/backend/agents/group_chat_manager.py index 3591f0ef..32d7f238 100644 --- a/src/backend/agents/group_chat_manager.py +++ b/src/backend/agents/group_chat_manager.py @@ -9,8 +9,8 @@ from autogen_core.components import RoutedAgent, default_subscription, message_handler from autogen_core.components.models import AzureOpenAIChatCompletionClient -from context.cosmos_memory import CosmosBufferedChatCompletionContext -from models.messages import ( +from src.backend.context.cosmos_memory import CosmosBufferedChatCompletionContext +from src.backend.models.messages import ( ActionRequest, AgentMessage, BAgentType, @@ -22,7 +22,7 @@ StepStatus, ) -from event_utils import track_event_if_configured +from src.backend.event_utils import track_event_if_configured @default_subscription diff --git a/src/backend/agents/hr.py b/src/backend/agents/hr.py index 1c0f8b06..4060ae9a 100644 --- a/src/backend/agents/hr.py +++ b/src/backend/agents/hr.py @@ -6,8 +6,8 @@ from autogen_core.components.tools import FunctionTool, Tool from typing_extensions import Annotated -from agents.base_agent import BaseAgent -from context.cosmos_memory import CosmosBufferedChatCompletionContext +from src.backend.agents.base_agent import BaseAgent +from src.backend.context.cosmos_memory import CosmosBufferedChatCompletionContext formatting_instructions = "Instructions: returning the output of this function call verbatim to the user in markdown. Then write AGENT SUMMARY: and then include a summary of what you did." diff --git a/src/backend/agents/human.py b/src/backend/agents/human.py index 6292fef7..5d1a72d8 100644 --- a/src/backend/agents/human.py +++ b/src/backend/agents/human.py @@ -4,15 +4,15 @@ from autogen_core.base import AgentId, MessageContext from autogen_core.components import RoutedAgent, default_subscription, message_handler -from context.cosmos_memory import CosmosBufferedChatCompletionContext -from models.messages import ( +from src.backend.context.cosmos_memory import CosmosBufferedChatCompletionContext +from src.backend.models.messages import ( ApprovalRequest, HumanFeedback, StepStatus, AgentMessage, Step, ) -from event_utils import track_event_if_configured +from src.backend.event_utils import track_event_if_configured @default_subscription diff --git a/src/backend/agents/marketing.py b/src/backend/agents/marketing.py index 348e6a81..5cf11c97 100644 --- a/src/backend/agents/marketing.py +++ b/src/backend/agents/marketing.py @@ -5,8 +5,8 @@ from autogen_core.components.models import AzureOpenAIChatCompletionClient from autogen_core.components.tools import FunctionTool, Tool -from agents.base_agent import BaseAgent -from context.cosmos_memory import CosmosBufferedChatCompletionContext +from src.backend.agents.base_agent import BaseAgent +from src.backend.context.cosmos_memory import CosmosBufferedChatCompletionContext # Define new Marketing tools (functions) diff --git a/src/backend/agents/planner.py b/src/backend/agents/planner.py index 83768443..e7975be3 100644 --- a/src/backend/agents/planner.py +++ b/src/backend/agents/planner.py @@ -13,8 +13,8 @@ ) from pydantic import BaseModel -from context.cosmos_memory import CosmosBufferedChatCompletionContext -from models.messages import ( +from src.backend.context.cosmos_memory import CosmosBufferedChatCompletionContext +from src.backend.models.messages import ( AgentMessage, HumanClarification, BAgentType, @@ -26,7 +26,7 @@ HumanFeedbackStatus, ) -from event_utils import track_event_if_configured +from src.backend.event_utils import track_event_if_configured @default_subscription diff --git a/src/backend/agents/procurement.py b/src/backend/agents/procurement.py index 2c8b677b..6c657a71 100644 --- a/src/backend/agents/procurement.py +++ b/src/backend/agents/procurement.py @@ -6,8 +6,8 @@ from autogen_core.components.tools import FunctionTool, Tool from typing_extensions import Annotated -from agents.base_agent import BaseAgent -from context.cosmos_memory import CosmosBufferedChatCompletionContext +from src.backend.agents.base_agent import BaseAgent +from src.backend.context.cosmos_memory import CosmosBufferedChatCompletionContext # Define new Procurement tools (functions) diff --git a/src/backend/agents/product.py b/src/backend/agents/product.py index ab2b88fa..2956a977 100644 --- a/src/backend/agents/product.py +++ b/src/backend/agents/product.py @@ -8,8 +8,8 @@ from autogen_core.components.tools import FunctionTool, Tool from typing_extensions import Annotated -from agents.base_agent import BaseAgent -from context.cosmos_memory import CosmosBufferedChatCompletionContext +from src.backend.agents.base_agent import BaseAgent +from src.backend.context.cosmos_memory import CosmosBufferedChatCompletionContext formatting_instructions = "Instructions: returning the output of this function call verbatim to the user in markdown. Then write AGENT SUMMARY: and then include a summary of what you did." diff --git a/src/backend/agents/tech_support.py b/src/backend/agents/tech_support.py index c8613643..5c0cb088 100644 --- a/src/backend/agents/tech_support.py +++ b/src/backend/agents/tech_support.py @@ -6,8 +6,8 @@ from autogen_core.components.tools import FunctionTool, Tool from typing_extensions import Annotated -from agents.base_agent import BaseAgent -from context.cosmos_memory import CosmosBufferedChatCompletionContext +from src.backend.agents.base_agent import BaseAgent +from src.backend.context.cosmos_memory import CosmosBufferedChatCompletionContext formatting_instructions = "Instructions: returning the output of this function call verbatim to the user in markdown. Then write AGENT SUMMARY: and then include a summary of what you did." @@ -523,7 +523,6 @@ async def get_tech_information( Document Name: Contoso's IT Policy and Procedure Manual Domain: IT Policy Description: A comprehensive guide detailing the IT policies and procedures at Contoso, including acceptable use, security protocols, and incident reporting. - At Contoso, we prioritize the security and efficiency of our IT infrastructure. All employees are required to adhere to the following policies: - Use strong passwords and change them every 90 days. - Report any suspicious emails to the IT department immediately. diff --git a/src/backend/app.py b/src/backend/app.py index 1e96822d..801d8f3a 100644 --- a/src/backend/app.py +++ b/src/backend/app.py @@ -1,16 +1,21 @@ -# app.py +#!/usr/bin/env python +import os +import sys + +# Add the parent directory (the one that contains the "src" folder) to sys.path. +# This allows absolute imports such as "from src.backend.middleware.health_check" to work +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) import asyncio import logging -import os import uuid from typing import List, Optional -from middleware.health_check import HealthCheckMiddleware +from src.backend.middleware.health_check import HealthCheckMiddleware from autogen_core.base import AgentId from fastapi import FastAPI, HTTPException, Query, Request -from auth.auth_utils import get_authenticated_user_details -from config import Config -from context.cosmos_memory import CosmosBufferedChatCompletionContext -from models.messages import ( +from src.backend.auth.auth_utils import get_authenticated_user_details +from src.backend.config import Config +from src.backend.context.cosmos_memory import CosmosBufferedChatCompletionContext +from src.backend.models.messages import ( HumanFeedback, HumanClarification, InputTask, @@ -19,15 +24,15 @@ AgentMessage, PlanWithSteps, ) -from utils import initialize_runtime_and_context, retrieve_all_agent_tools, rai_success -from event_utils import track_event_if_configured +from src.backend.utils import initialize_runtime_and_context, retrieve_all_agent_tools, rai_success +from src.backend.event_utils import track_event_if_configured from fastapi.middleware.cors import CORSMiddleware from azure.monitor.opentelemetry import configure_azure_monitor from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor # Check if the Application Insights Instrumentation Key is set in the environment variables -instrumentation_key = os.getenv("APPLICATIONINSIGHTS_INSTRUMENTATION_KEY") +instrumentation_key = os.getenv("APPLICATIONINSIGHTS_CONNECTION_STRING") if instrumentation_key: # Configure Application Insights if the Instrumentation Key is found configure_azure_monitor(connection_string=instrumentation_key) diff --git a/src/backend/context/cosmos_memory.py b/src/backend/context/cosmos_memory.py index b9271e1f..1261f65b 100644 --- a/src/backend/context/cosmos_memory.py +++ b/src/backend/context/cosmos_memory.py @@ -15,8 +15,8 @@ ) from azure.cosmos.partition_key import PartitionKey -from config import Config -from models.messages import BaseDataModel, Plan, Session, Step, AgentMessage +from src.backend.config import Config +from src.backend.models.messages import BaseDataModel, Plan, Session, Step, AgentMessage class CosmosBufferedChatCompletionContext(BufferedChatCompletionContext): diff --git a/src/backend/event_utils.py b/src/backend/event_utils.py index 9b9e5bbf..eb86a530 100644 --- a/src/backend/event_utils.py +++ b/src/backend/event_utils.py @@ -4,7 +4,7 @@ def track_event_if_configured(event_name: str, event_data: dict): - instrumentation_key = os.getenv("APPLICATIONINSIGHTS_INSTRUMENTATION_KEY") + instrumentation_key = os.getenv("APPLICATIONINSIGHTS_CONNECTION_STRING") if instrumentation_key: track_event(event_name, event_data) else: diff --git a/src/backend/handlers/__init__.py b/src/backend/handlers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/backend/handlers/runtime_interrupt.py b/src/backend/handlers/runtime_interrupt.py index 7ed1848b..58e75eff 100644 --- a/src/backend/handlers/runtime_interrupt.py +++ b/src/backend/handlers/runtime_interrupt.py @@ -3,7 +3,9 @@ from autogen_core.base import AgentId from autogen_core.base.intervention import DefaultInterventionHandler -from models.messages import GetHumanInputMessage, GroupChatMessage +from src.backend.models.messages import GroupChatMessage + +from src.backend.models.messages import GetHumanInputMessage class NeedsUserInputHandler(DefaultInterventionHandler): diff --git a/src/backend/middleware/__init__.py b/src/backend/middleware/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/backend/models/messages.py b/src/backend/models/messages.py index e4ea6a59..60453cb5 100644 --- a/src/backend/models/messages.py +++ b/src/backend/models/messages.py @@ -293,3 +293,11 @@ class RequestToSpeak(BaseModel): def to_dict(self): return self.model_dump() + + +class GetHumanInputMessage: + def __init__(self, message): + self.message = message + + def __str__(self): + return f"GetHumanInputMessage: {self.message}" diff --git a/src/backend/requirements.txt b/src/backend/requirements.txt index c4bfa64e..24ccf580 100644 --- a/src/backend/requirements.txt +++ b/src/backend/requirements.txt @@ -14,3 +14,8 @@ opentelemetry-instrumentation-fastapi opentelemetry-instrumentation-openai opentelemetry-exporter-otlp-proto-http opentelemetry-exporter-otlp-proto-grpc + +# Testing tools +pytest>=8.2,<9 # Compatible version for pytest-asyncio +pytest-asyncio==0.24.0 +pytest-cov==5.0.0 \ No newline at end of file diff --git a/src/backend/tests/__init__.py b/src/backend/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/backend/tests/agents/__init__.py b/src/backend/tests/agents/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/backend/tests/agents/test_agentutils.py b/src/backend/tests/agents/test_agentutils.py new file mode 100644 index 00000000..c5131815 --- /dev/null +++ b/src/backend/tests/agents/test_agentutils.py @@ -0,0 +1,54 @@ +# pylint: disable=import-error, wrong-import-position, missing-module-docstring +import os +import sys +from unittest.mock import MagicMock +import pytest +from pydantic import ValidationError + +# Environment and module setup +sys.modules["azure.monitor.events.extension"] = MagicMock() + +os.environ["COSMOSDB_ENDPOINT"] = "https://mock-endpoint" +os.environ["COSMOSDB_KEY"] = "mock-key" +os.environ["COSMOSDB_DATABASE"] = "mock-database" +os.environ["COSMOSDB_CONTAINER"] = "mock-container" +os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"] = "mock-deployment-name" +os.environ["AZURE_OPENAI_API_VERSION"] = "2023-01-01" +os.environ["AZURE_OPENAI_ENDPOINT"] = "https://mock-openai-endpoint" + +from src.backend.agents.agentutils import extract_and_update_transition_states # noqa: F401, C0413 +from src.backend.models.messages import Step # noqa: F401, C0413 + + +def test_step_initialization(): + """Test Step initialization with valid data.""" + step = Step( + data_type="step", + plan_id="test_plan", + action="test_action", + agent="HumanAgent", + session_id="test_session", + user_id="test_user", + agent_reply="test_reply", + ) + + assert step.data_type == "step" + assert step.plan_id == "test_plan" + assert step.action == "test_action" + assert step.agent == "HumanAgent" + assert step.session_id == "test_session" + assert step.user_id == "test_user" + assert step.agent_reply == "test_reply" + assert step.status == "planned" + assert step.human_approval_status == "requested" + + +def test_step_missing_required_fields(): + """Test Step initialization with missing required fields.""" + with pytest.raises(ValidationError): + Step( + data_type="step", + action="test_action", + agent="test_agent", + session_id="test_session", + ) diff --git a/src/backend/tests/agents/test_base_agent.py b/src/backend/tests/agents/test_base_agent.py new file mode 100644 index 00000000..9ecbf258 --- /dev/null +++ b/src/backend/tests/agents/test_base_agent.py @@ -0,0 +1,151 @@ +# pylint: disable=import-error, wrong-import-position, missing-module-docstring +import os +import sys +from unittest.mock import MagicMock, AsyncMock, patch +import pytest +from contextlib import contextmanager + +# Mocking necessary modules and environment variables +sys.modules["azure.monitor.events.extension"] = MagicMock() + +# Mocking environment variables +os.environ["COSMOSDB_ENDPOINT"] = "https://mock-endpoint" +os.environ["COSMOSDB_KEY"] = "mock-key" +os.environ["COSMOSDB_DATABASE"] = "mock-database" +os.environ["COSMOSDB_CONTAINER"] = "mock-container" +os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"] = "mock-deployment-name" +os.environ["AZURE_OPENAI_API_VERSION"] = "2023-01-01" +os.environ["AZURE_OPENAI_ENDPOINT"] = "https://mock-openai-endpoint" + +# Importing the module to test +from src.backend.agents.base_agent import BaseAgent +from src.backend.models.messages import ActionRequest, Step, StepStatus +from autogen_core.base import AgentId + + +# Context manager for setting up mocks +@contextmanager +def mock_context(): + mock_runtime = MagicMock() + with patch("autogen_core.base._agent_instantiation.AgentInstantiationContext.AGENT_INSTANTIATION_CONTEXT_VAR") as mock_context_var: + mock_context_instance = MagicMock() + mock_context_var.get.return_value = mock_context_instance + mock_context_instance.set.return_value = None + yield mock_runtime + + +@pytest.fixture +def mock_dependencies(): + model_client = MagicMock() + model_context = MagicMock() + tools = [MagicMock(schema="tool_schema")] + tool_agent_id = MagicMock() + return { + "model_client": model_client, + "model_context": model_context, + "tools": tools, + "tool_agent_id": tool_agent_id, + } + + +@pytest.fixture +def base_agent(mock_dependencies): + with mock_context(): + return BaseAgent( + agent_name="test_agent", + model_client=mock_dependencies["model_client"], + session_id="test_session", + user_id="test_user", + model_context=mock_dependencies["model_context"], + tools=mock_dependencies["tools"], + tool_agent_id=mock_dependencies["tool_agent_id"], + system_message="This is a system message.", + ) + + +def test_save_state(base_agent, mock_dependencies): + mock_dependencies["model_context"].save_state = MagicMock(return_value={"state_key": "state_value"}) + state = base_agent.save_state() + assert state == {"memory": {"state_key": "state_value"}} + + +def test_load_state(base_agent, mock_dependencies): + mock_dependencies["model_context"].load_state = MagicMock() + state = {"memory": {"state_key": "state_value"}} + base_agent.load_state(state) + mock_dependencies["model_context"].load_state.assert_called_once_with({"state_key": "state_value"}) + + +@pytest.mark.asyncio +async def test_handle_action_request_error(base_agent, mock_dependencies): + """Test handle_action_request when tool_agent_caller_loop raises an error.""" + step = Step( + id="step_1", + status=StepStatus.approved, + human_feedback="feedback", + agent_reply="", + plan_id="plan_id", + action="action", + agent="HumanAgent", + session_id="session_id", + user_id="user_id", + ) + mock_dependencies["model_context"].get_step = AsyncMock(return_value=step) + mock_dependencies["model_context"].add_item = AsyncMock() + + with patch("src.backend.agents.base_agent.tool_agent_caller_loop", AsyncMock(side_effect=Exception("Mock error"))): + message = ActionRequest( + step_id="step_1", + session_id="test_session", + action="test_action", + plan_id="plan_id", + agent="HumanAgent", + ) + ctx = MagicMock() + with pytest.raises(ValueError) as excinfo: + await base_agent.handle_action_request(message, ctx) + assert "Return type not in return types" in str(excinfo.value) + + +@pytest.mark.asyncio +async def test_handle_action_request_success(base_agent, mock_dependencies): + """Test handle_action_request with a successful tool_agent_caller_loop.""" + step = Step( + id="step_1", + status=StepStatus.approved, + human_feedback="feedback", + agent_reply="", + plan_id="plan_id", + action="action", + agent="HumanAgent", + session_id="session_id", + user_id="user_id" + ) + mock_dependencies["model_context"].get_step = AsyncMock(return_value=step) + mock_dependencies["model_context"].update_step = AsyncMock() + mock_dependencies["model_context"].add_item = AsyncMock() + + with patch("src.backend.agents.base_agent.tool_agent_caller_loop", new=AsyncMock(return_value=[MagicMock(content="result")])): + base_agent._runtime.publish_message = AsyncMock() + message = ActionRequest( + step_id="step_1", + session_id="test_session", + action="test_action", + plan_id="plan_id", + agent="HumanAgent" + ) + ctx = MagicMock() + response = await base_agent.handle_action_request(message, ctx) + + assert response.status == StepStatus.completed + assert response.result == "result" + assert response.plan_id == "plan_id" + assert response.session_id == "test_session" + + base_agent._runtime.publish_message.assert_awaited_once_with( + response, + AgentId(type="group_chat_manager", key="test_session"), + sender=base_agent.id, + cancellation_token=None + ) + mock_dependencies["model_context"].update_step.assert_called_once_with(step) diff --git a/src/backend/tests/agents/test_generic.py b/src/backend/tests/agents/test_generic.py new file mode 100644 index 00000000..e69de29b diff --git a/src/backend/tests/agents/test_group_chat_manager.py b/src/backend/tests/agents/test_group_chat_manager.py new file mode 100644 index 00000000..60c775d2 --- /dev/null +++ b/src/backend/tests/agents/test_group_chat_manager.py @@ -0,0 +1,128 @@ +""" +Combined Test cases for GroupChatManager class in the backend agents module. +""" + +import os +import sys +from unittest.mock import AsyncMock, patch, MagicMock +import pytest + +# Set mock environment variables for Azure and CosmosDB before importing anything else +os.environ["COSMOSDB_ENDPOINT"] = "https://mock-endpoint" +os.environ["COSMOSDB_KEY"] = "mock-key" +os.environ["COSMOSDB_DATABASE"] = "mock-database" +os.environ["COSMOSDB_CONTAINER"] = "mock-container" +os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"] = "mock-deployment-name" +os.environ["AZURE_OPENAI_API_VERSION"] = "2023-01-01" +os.environ["AZURE_OPENAI_ENDPOINT"] = "https://mock-openai-endpoint" + +# Mock Azure dependencies +sys.modules["azure.monitor.events.extension"] = MagicMock() + +# Import after setting environment variables +from src.backend.agents.group_chat_manager import GroupChatManager +from src.backend.models.messages import ( + Step, + StepStatus, + BAgentType, +) +from autogen_core.base import AgentInstantiationContext, AgentRuntime +from autogen_core.components.models import AzureOpenAIChatCompletionClient +from src.backend.context.cosmos_memory import CosmosBufferedChatCompletionContext +from autogen_core.base import AgentId + + +@pytest.fixture +def setup_group_chat_manager(): + """ + Fixture to set up a GroupChatManager and its dependencies. + """ + # Mock dependencies + mock_model_client = MagicMock(spec=AzureOpenAIChatCompletionClient) + session_id = "test_session_id" + user_id = "test_user_id" + mock_memory = AsyncMock(spec=CosmosBufferedChatCompletionContext) + mock_agent_ids = {BAgentType.planner_agent: AgentId("planner_agent", session_id)} + + # Mock AgentInstantiationContext + mock_runtime = MagicMock(spec=AgentRuntime) + mock_agent_id = "test_agent_id" + + with patch.object(AgentInstantiationContext, "current_runtime", return_value=mock_runtime): + with patch.object(AgentInstantiationContext, "current_agent_id", return_value=mock_agent_id): + # Instantiate GroupChatManager + group_chat_manager = GroupChatManager( + model_client=mock_model_client, + session_id=session_id, + user_id=user_id, + memory=mock_memory, + agent_ids=mock_agent_ids, + ) + + return group_chat_manager, mock_memory, session_id, user_id, mock_agent_ids + + +@pytest.mark.asyncio +@patch("src.backend.agents.group_chat_manager.track_event_if_configured") +async def test_update_step_status(mock_track_event, setup_group_chat_manager): + """ + Test the `_update_step_status` method. + """ + group_chat_manager, mock_memory, session_id, user_id, mock_agent_ids = setup_group_chat_manager + + # Create a mock Step + step = Step( + id="test_step_id", + session_id=session_id, + plan_id="test_plan_id", + user_id=user_id, + action="Test Action", + agent=BAgentType.human_agent, + status=StepStatus.planned, + ) + + # Call the method + await group_chat_manager._update_step_status(step, True, "Feedback message") + + # Assertions + step.status = StepStatus.completed + step.human_feedback = "Feedback message" + mock_memory.update_step.assert_called_once_with(step) + mock_track_event.assert_called_once_with( + "Group Chat Manager - Received human feedback, Updating step and updated into the cosmos", + { + "status": StepStatus.completed, + "session_id": step.session_id, + "user_id": step.user_id, + "human_feedback": "Feedback message", + "source": step.agent, + }, + ) + + +@pytest.mark.asyncio +async def test_update_step_invalid_feedback_status(setup_group_chat_manager): + """ + Test `_update_step_status` with invalid feedback status. + Covers lines 210-211. + """ + group_chat_manager, mock_memory, session_id, user_id, mock_agent_ids = setup_group_chat_manager + + # Create a mock Step + step = Step( + id="test_step_id", + session_id=session_id, + plan_id="test_plan_id", + user_id=user_id, + action="Test Action", + agent=BAgentType.human_agent, + status=StepStatus.planned, + ) + + # Call the method with invalid feedback status + await group_chat_manager._update_step_status(step, None, "Feedback message") + + # Assertions + step.status = StepStatus.planned # Status should remain unchanged + step.human_feedback = "Feedback message" + mock_memory.update_step.assert_called_once_with(step) diff --git a/src/backend/tests/agents/test_hr.py b/src/backend/tests/agents/test_hr.py new file mode 100644 index 00000000..aa89fb0e --- /dev/null +++ b/src/backend/tests/agents/test_hr.py @@ -0,0 +1,254 @@ +""" +Test suite for HR-related functions in the backend agents module. + +This module contains asynchronous test cases for various HR functions, +including employee orientation, benefits registration, payroll setup, and more. +""" + +import os +import sys +from unittest.mock import MagicMock +import pytest + +# Set mock environment variables for Azure and CosmosDB +os.environ["COSMOSDB_ENDPOINT"] = "https://mock-endpoint" +os.environ["COSMOSDB_KEY"] = "mock-key" +os.environ["COSMOSDB_DATABASE"] = "mock-database" +os.environ["COSMOSDB_CONTAINER"] = "mock-container" +os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"] = "mock-deployment-name" +os.environ["AZURE_OPENAI_API_VERSION"] = "2023-01-01" +os.environ["AZURE_OPENAI_ENDPOINT"] = "https://mock-openai-endpoint" + +# Mock Azure dependencies +sys.modules["azure.monitor.events.extension"] = MagicMock() + +# pylint: disable=C0413 +from src.backend.agents.hr import ( + schedule_orientation_session, + assign_mentor, + register_for_benefits, + enroll_in_training_program, + provide_employee_handbook, + update_employee_record, + request_id_card, + set_up_payroll, + add_emergency_contact, + process_leave_request, + update_policies, + conduct_exit_interview, + verify_employment, + schedule_performance_review, + approve_expense_claim, + send_company_announcement, + fetch_employee_directory, + initiate_background_check, + organize_team_building_activity, + manage_employee_transfer, + track_employee_attendance, + organize_health_and_wellness_program, + facilitate_remote_work_setup, + manage_retirement_plan, +) +# pylint: enable=C0413 + + +@pytest.mark.asyncio +async def test_schedule_orientation_session(): + """Test scheduling an orientation session.""" + result = await schedule_orientation_session("John Doe", "2025-02-01") + assert "##### Orientation Session Scheduled" in result + assert "**Employee Name:** John Doe" in result + assert "**Date:** 2025-02-01" in result + + +@pytest.mark.asyncio +async def test_assign_mentor(): + """Test assigning a mentor to an employee.""" + result = await assign_mentor("John Doe") + assert "##### Mentor Assigned" in result + assert "**Employee Name:** John Doe" in result + + +@pytest.mark.asyncio +async def test_register_for_benefits(): + """Test registering an employee for benefits.""" + result = await register_for_benefits("John Doe") + assert "##### Benefits Registration" in result + assert "**Employee Name:** John Doe" in result + + +@pytest.mark.asyncio +async def test_enroll_in_training_program(): + """Test enrolling an employee in a training program.""" + result = await enroll_in_training_program("John Doe", "Leadership 101") + assert "##### Training Program Enrollment" in result + assert "**Employee Name:** John Doe" in result + assert "**Program Name:** Leadership 101" in result + + +@pytest.mark.asyncio +async def test_provide_employee_handbook(): + """Test providing the employee handbook.""" + result = await provide_employee_handbook("John Doe") + assert "##### Employee Handbook Provided" in result + assert "**Employee Name:** John Doe" in result + + +@pytest.mark.asyncio +async def test_update_employee_record(): + """Test updating an employee record.""" + result = await update_employee_record("John Doe", "Email", "john.doe@example.com") + assert "##### Employee Record Updated" in result + assert "**Field Updated:** Email" in result + assert "**New Value:** john.doe@example.com" in result + + +@pytest.mark.asyncio +async def test_request_id_card(): + """Test requesting an ID card for an employee.""" + result = await request_id_card("John Doe") + assert "##### ID Card Request" in result + assert "**Employee Name:** John Doe" in result + + +@pytest.mark.asyncio +async def test_set_up_payroll(): + """Test setting up payroll for an employee.""" + result = await set_up_payroll("John Doe") + assert "##### Payroll Setup" in result + assert "**Employee Name:** John Doe" in result + + +@pytest.mark.asyncio +async def test_add_emergency_contact(): + """Test adding an emergency contact for an employee.""" + result = await add_emergency_contact("John Doe", "Jane Doe", "123-456-7890") + assert "##### Emergency Contact Added" in result + assert "**Contact Name:** Jane Doe" in result + assert "**Contact Phone:** 123-456-7890" in result + + +@pytest.mark.asyncio +async def test_process_leave_request(): + """Test processing a leave request for an employee.""" + result = await process_leave_request( + "John Doe", "Vacation", "2025-03-01", "2025-03-10" + ) + assert "##### Leave Request Processed" in result + assert "**Leave Type:** Vacation" in result + assert "**Start Date:** 2025-03-01" in result + assert "**End Date:** 2025-03-10" in result + + +@pytest.mark.asyncio +async def test_update_policies(): + """Test updating company policies.""" + result = await update_policies("Work From Home Policy", "Updated content") + assert "##### Policy Updated" in result + assert "**Policy Name:** Work From Home Policy" in result + assert "Updated content" in result + + +@pytest.mark.asyncio +async def test_conduct_exit_interview(): + """Test conducting an exit interview.""" + result = await conduct_exit_interview("John Doe") + assert "##### Exit Interview Conducted" in result + assert "**Employee Name:** John Doe" in result + + +@pytest.mark.asyncio +async def test_verify_employment(): + """Test verifying employment.""" + result = await verify_employment("John Doe") + assert "##### Employment Verification" in result + assert "**Employee Name:** John Doe" in result + + +@pytest.mark.asyncio +async def test_schedule_performance_review(): + """Test scheduling a performance review.""" + result = await schedule_performance_review("John Doe", "2025-04-15") + assert "##### Performance Review Scheduled" in result + assert "**Date:** 2025-04-15" in result + + +@pytest.mark.asyncio +async def test_approve_expense_claim(): + """Test approving an expense claim.""" + result = await approve_expense_claim("John Doe", 500.75) + assert "##### Expense Claim Approved" in result + assert "**Claim Amount:** $500.75" in result + + +@pytest.mark.asyncio +async def test_send_company_announcement(): + """Test sending a company-wide announcement.""" + result = await send_company_announcement( + "Holiday Schedule", "We will be closed on Christmas." + ) + assert "##### Company Announcement" in result + assert "**Subject:** Holiday Schedule" in result + assert "We will be closed on Christmas." in result + + +@pytest.mark.asyncio +async def test_fetch_employee_directory(): + """Test fetching the employee directory.""" + result = await fetch_employee_directory() + assert "##### Employee Directory" in result + + +@pytest.mark.asyncio +async def test_initiate_background_check(): + """Test initiating a background check.""" + result = await initiate_background_check("John Doe") + assert "##### Background Check Initiated" in result + assert "**Employee Name:** John Doe" in result + + +@pytest.mark.asyncio +async def test_organize_team_building_activity(): + """Test organizing a team-building activity.""" + result = await organize_team_building_activity("Escape Room", "2025-05-01") + assert "##### Team-Building Activity Organized" in result + assert "**Activity Name:** Escape Room" in result + + +@pytest.mark.asyncio +async def test_manage_employee_transfer(): + """Test managing an employee transfer.""" + result = await manage_employee_transfer("John Doe", "Marketing") + assert "##### Employee Transfer" in result + assert "**New Department:** Marketing" in result + + +@pytest.mark.asyncio +async def test_track_employee_attendance(): + """Test tracking employee attendance.""" + result = await track_employee_attendance("John Doe") + assert "##### Attendance Tracked" in result + + +@pytest.mark.asyncio +async def test_organize_health_and_wellness_program(): + """Test organizing a health and wellness program.""" + result = await organize_health_and_wellness_program("Yoga Session", "2025-06-01") + assert "##### Health and Wellness Program Organized" in result + assert "**Program Name:** Yoga Session" in result + + +@pytest.mark.asyncio +async def test_facilitate_remote_work_setup(): + """Test facilitating remote work setup.""" + result = await facilitate_remote_work_setup("John Doe") + assert "##### Remote Work Setup Facilitated" in result + assert "**Employee Name:** John Doe" in result + + +@pytest.mark.asyncio +async def test_manage_retirement_plan(): + """Test managing a retirement plan.""" + result = await manage_retirement_plan("John Doe") + assert "##### Retirement Plan Managed" in result + assert "**Employee Name:** John Doe" in result diff --git a/src/backend/tests/agents/test_human.py b/src/backend/tests/agents/test_human.py new file mode 100644 index 00000000..2980e1fb --- /dev/null +++ b/src/backend/tests/agents/test_human.py @@ -0,0 +1,121 @@ +""" +Test cases for HumanAgent class in the backend agents module. +""" + +# Standard library imports +import os +import sys +from unittest.mock import AsyncMock, MagicMock, patch +import pytest + + +# Function to set environment variables +def setup_environment_variables(): + """Set environment variables required for the tests.""" + os.environ["COSMOSDB_ENDPOINT"] = "https://mock-endpoint" + os.environ["COSMOSDB_KEY"] = "mock-key" + os.environ["COSMOSDB_DATABASE"] = "mock-database" + os.environ["COSMOSDB_CONTAINER"] = "mock-container" + os.environ["APPLICATIONINSIGHTS_CONNECTION_STRING"] = "mock-instrumentation-key" + os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"] = "mock-deployment-name" + os.environ["AZURE_OPENAI_API_VERSION"] = "2023-01-01" + os.environ["AZURE_OPENAI_ENDPOINT"] = "https://mock-openai-endpoint" + + +# Call the function to set environment variables +setup_environment_variables() + +# Mock Azure and event_utils dependencies globally +sys.modules["azure.monitor.events.extension"] = MagicMock() +sys.modules["src.backend.event_utils"] = MagicMock() + +# Project-specific imports (must come after environment setup) +from autogen_core.base import AgentInstantiationContext, AgentRuntime +from src.backend.agents.human import HumanAgent +from src.backend.models.messages import HumanFeedback, Step, StepStatus, BAgentType + + +@pytest.fixture(autouse=True) +def ensure_env_variables(monkeypatch): + """ + Fixture to ensure environment variables are set for all tests. + This overrides any modifications made by individual tests. + """ + env_vars = { + "COSMOSDB_ENDPOINT": "https://mock-endpoint", + "COSMOSDB_KEY": "mock-key", + "COSMOSDB_DATABASE": "mock-database", + "COSMOSDB_CONTAINER": "mock-container", + "APPLICATIONINSIGHTS_CONNECTION_STRING": "mock-instrumentation-key", + "AZURE_OPENAI_DEPLOYMENT_NAME": "mock-deployment-name", + "AZURE_OPENAI_API_VERSION": "2023-01-01", + "AZURE_OPENAI_ENDPOINT": "https://mock-openai-endpoint", + } + for key, value in env_vars.items(): + monkeypatch.setenv(key, value) + + +@pytest.fixture +def setup_agent(): + """ + Fixture to set up a HumanAgent and its dependencies. + """ + memory = AsyncMock() + user_id = "test_user" + group_chat_manager_id = "group_chat_manager" + + # Mock runtime and agent ID + mock_runtime = MagicMock(spec=AgentRuntime) + mock_agent_id = "test_agent_id" + + # Set up the context + with patch.object(AgentInstantiationContext, "current_runtime", return_value=mock_runtime): + with patch.object(AgentInstantiationContext, "current_agent_id", return_value=mock_agent_id): + agent = HumanAgent(memory, user_id, group_chat_manager_id) + + session_id = "session123" + step_id = "step123" + plan_id = "plan123" + + # Mock HumanFeedback message + feedback_message = HumanFeedback( + session_id=session_id, + step_id=step_id, + plan_id=plan_id, + approved=True, + human_feedback="Great job!", + ) + + # Mock Step with all required fields + step = Step( + plan_id=plan_id, + action="Test Action", + agent=BAgentType.human_agent, + status=StepStatus.planned, + session_id=session_id, + user_id=user_id, + human_feedback=None, + ) + + return agent, memory, feedback_message, step, session_id, step_id, plan_id + + +@patch("src.backend.agents.human.logging.info") +@patch("src.backend.agents.human.track_event_if_configured") +@pytest.mark.asyncio +async def test_handle_step_feedback_step_not_found(mock_track_event, mock_logging, setup_agent): + """ + Test scenario where the step is not found in memory. + """ + agent, memory, feedback_message, _, _, step_id, _ = setup_agent + + # Mock no step found + memory.get_step.return_value = None + + # Run the method + await agent.handle_step_feedback(feedback_message, MagicMock()) + + # Check if log and return were called correctly + mock_logging.assert_called_with(f"No step found with id: {step_id}") + memory.update_step.assert_not_called() + mock_track_event.assert_not_called() diff --git a/src/backend/tests/agents/test_marketing.py b/src/backend/tests/agents/test_marketing.py new file mode 100644 index 00000000..48562bc1 --- /dev/null +++ b/src/backend/tests/agents/test_marketing.py @@ -0,0 +1,585 @@ +import os +import sys +import pytest +from unittest.mock import MagicMock +from autogen_core.components.tools import FunctionTool + +# Import marketing functions for testing +from src.backend.agents.marketing import ( + create_marketing_campaign, + analyze_market_trends, + develop_brand_strategy, + generate_social_media_posts, + get_marketing_tools, + manage_loyalty_program, + plan_advertising_budget, + conduct_customer_survey, + generate_marketing_report, + perform_competitor_analysis, + optimize_seo_strategy, + run_influencer_marketing_campaign, + schedule_marketing_event, + design_promotional_material, + manage_email_marketing, + track_campaign_performance, + create_content_calendar, + update_website_content, + plan_product_launch, + handle_customer_feedback, + generate_press_release, + run_ppc_campaign, + create_infographic +) + + +# Set mock environment variables for Azure and CosmosDB +os.environ["COSMOSDB_ENDPOINT"] = "https://mock-endpoint" +os.environ["COSMOSDB_KEY"] = "mock-key" +os.environ["COSMOSDB_DATABASE"] = "mock-database" +os.environ["COSMOSDB_CONTAINER"] = "mock-container" +os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"] = "mock-deployment-name" +os.environ["AZURE_OPENAI_API_VERSION"] = "2023-01-01" +os.environ["AZURE_OPENAI_ENDPOINT"] = "https://mock-openai-endpoint" + +# Mock Azure dependencies +sys.modules["azure.monitor.events.extension"] = MagicMock() + + +# Test cases +@pytest.mark.asyncio +async def test_create_marketing_campaign(): + result = await create_marketing_campaign("Holiday Sale", "Millennials", 10000) + assert "Marketing campaign 'Holiday Sale' created targeting 'Millennials' with a budget of $10000.00." in result + + +@pytest.mark.asyncio +async def test_analyze_market_trends(): + result = await analyze_market_trends("Technology") + assert "Market trends analyzed for the 'Technology' industry." in result + + +@pytest.mark.asyncio +async def test_generate_social_media_posts(): + result = await generate_social_media_posts("Black Friday", ["Facebook", "Instagram"]) + assert "Social media posts for campaign 'Black Friday' generated for platforms: Facebook, Instagram." in result + + +@pytest.mark.asyncio +async def test_plan_advertising_budget(): + result = await plan_advertising_budget("New Year Sale", 20000) + assert "Advertising budget planned for campaign 'New Year Sale' with a total budget of $20000.00." in result + + +@pytest.mark.asyncio +async def test_conduct_customer_survey(): + result = await conduct_customer_survey("Customer Satisfaction", "Frequent Buyers") + assert "Customer survey on 'Customer Satisfaction' conducted targeting 'Frequent Buyers'." in result + + +@pytest.mark.asyncio +async def test_generate_marketing_report(): + result = await generate_marketing_report("Winter Campaign") + assert "Marketing report generated for campaign 'Winter Campaign'." in result + + +@pytest.mark.asyncio +async def test_perform_competitor_analysis(): + result = await perform_competitor_analysis("Competitor A") + assert "Competitor analysis performed on 'Competitor A'." in result + + +@pytest.mark.asyncio +async def test_perform_competitor_analysis_empty_input(): + result = await perform_competitor_analysis("") + assert "Competitor analysis performed on ''." in result + + +@pytest.mark.asyncio +async def test_optimize_seo_strategy(): + result = await optimize_seo_strategy(["keyword1", "keyword2"]) + assert "SEO strategy optimized with keywords: keyword1, keyword2." in result + + +@pytest.mark.asyncio +async def test_optimize_seo_strategy_empty_keywords(): + result = await optimize_seo_strategy([]) + assert "SEO strategy optimized with keywords: ." in result + + +@pytest.mark.asyncio +async def test_schedule_marketing_event(): + result = await schedule_marketing_event("Product Launch", "2025-01-30", "Main Hall") + assert "Marketing event 'Product Launch' scheduled on 2025-01-30 at Main Hall." in result + + +@pytest.mark.asyncio +async def test_schedule_marketing_event_empty_details(): + result = await schedule_marketing_event("", "", "") + assert "Marketing event '' scheduled on at ." in result + + +@pytest.mark.asyncio +async def test_design_promotional_material(): + result = await design_promotional_material("Spring Sale", "poster") + assert "Poster for campaign 'Spring Sale' designed." in result + + +@pytest.mark.asyncio +async def test_design_promotional_material_empty_input(): + result = await design_promotional_material("", "") + assert " for campaign '' designed." in result + + +@pytest.mark.asyncio +async def test_manage_email_marketing_large_email_list(): + result = await manage_email_marketing("Holiday Offers", 100000) + assert "Email marketing managed for campaign 'Holiday Offers' targeting 100000 recipients." in result + + +@pytest.mark.asyncio +async def test_manage_email_marketing_zero_recipients(): + result = await manage_email_marketing("Holiday Offers", 0) + assert "Email marketing managed for campaign 'Holiday Offers' targeting 0 recipients." in result + + +@pytest.mark.asyncio +async def test_track_campaign_performance(): + result = await track_campaign_performance("Fall Promo") + assert "Performance of campaign 'Fall Promo' tracked." in result + + +@pytest.mark.asyncio +async def test_track_campaign_performance_empty_name(): + result = await track_campaign_performance("") + assert "Performance of campaign '' tracked." in result + + +@pytest.mark.asyncio +async def test_create_content_calendar(): + result = await create_content_calendar("March") + assert "Content calendar for 'March' created." in result + + +@pytest.mark.asyncio +async def test_create_content_calendar_empty_month(): + result = await create_content_calendar("") + assert "Content calendar for '' created." in result + + +@pytest.mark.asyncio +async def test_update_website_content(): + result = await update_website_content("Homepage") + assert "Website content on page 'Homepage' updated." in result + + +@pytest.mark.asyncio +async def test_update_website_content_empty_page(): + result = await update_website_content("") + assert "Website content on page '' updated." in result + + +@pytest.mark.asyncio +async def test_plan_product_launch(): + result = await plan_product_launch("Smartwatch", "2025-02-15") + assert "Product launch for 'Smartwatch' planned on 2025-02-15." in result + + +@pytest.mark.asyncio +async def test_plan_product_launch_empty_input(): + result = await plan_product_launch("", "") + assert "Product launch for '' planned on ." in result + + +@pytest.mark.asyncio +async def test_handle_customer_feedback(): + result = await handle_customer_feedback("Great service!") + assert "Customer feedback handled: Great service!" in result + + +@pytest.mark.asyncio +async def test_handle_customer_feedback_empty_feedback(): + result = await handle_customer_feedback("") + assert "Customer feedback handled: " in result + + +@pytest.mark.asyncio +async def test_generate_press_release(): + result = await generate_press_release("Key updates for the press release.") + assert "Identify the content." in result + assert "generate a press release based on this content Key updates for the press release." in result + + +@pytest.mark.asyncio +async def test_generate_press_release_empty_content(): + result = await generate_press_release("") + assert "generate a press release based on this content " in result + + +@pytest.mark.asyncio +async def test_generate_marketing_report_empty_name(): + result = await generate_marketing_report("") + assert "Marketing report generated for campaign ''." in result + + +@pytest.mark.asyncio +async def test_run_ppc_campaign(): + result = await run_ppc_campaign("Spring PPC", 10000.00) + assert "PPC campaign 'Spring PPC' run with a budget of $10000.00." in result + + +@pytest.mark.asyncio +async def test_run_ppc_campaign_zero_budget(): + result = await run_ppc_campaign("Spring PPC", 0.00) + assert "PPC campaign 'Spring PPC' run with a budget of $0.00." in result + + +@pytest.mark.asyncio +async def test_run_ppc_campaign_large_budget(): + result = await run_ppc_campaign("Spring PPC", 1e7) + assert "PPC campaign 'Spring PPC' run with a budget of $10000000.00." in result + + +@pytest.mark.asyncio +async def test_generate_social_media_posts_no_campaign_name(): + """Test generating social media posts with no campaign name.""" + result = await generate_social_media_posts("", ["Twitter", "LinkedIn"]) + assert "Social media posts for campaign '' generated for platforms: Twitter, LinkedIn." in result + + +@pytest.mark.asyncio +async def test_plan_advertising_budget_negative_value(): + """Test planning an advertising budget with a negative value.""" + result = await plan_advertising_budget("Summer Sale", -10000) + assert "Advertising budget planned for campaign 'Summer Sale' with a total budget of $-10000.00." in result + + +@pytest.mark.asyncio +async def test_conduct_customer_survey_invalid_target_group(): + """Test conducting a survey with an invalid target group.""" + result = await conduct_customer_survey("Product Feedback", None) + assert "Customer survey on 'Product Feedback' conducted targeting 'None'." in result + + +@pytest.mark.asyncio +async def test_manage_email_marketing_boundary(): + """Test managing email marketing with boundary cases.""" + result = await manage_email_marketing("Year-End Deals", 1) + assert "Email marketing managed for campaign 'Year-End Deals' targeting 1 recipients." in result + + +@pytest.mark.asyncio +async def test_create_marketing_campaign_no_audience(): + """Test creating a marketing campaign with no specified audience.""" + result = await create_marketing_campaign("Holiday Sale", "", 10000) + assert "Marketing campaign 'Holiday Sale' created targeting '' with a budget of $10000.00." in result + + +@pytest.mark.asyncio +async def test_analyze_market_trends_no_industry(): + """Test analyzing market trends with no specified industry.""" + result = await analyze_market_trends("") + assert "Market trends analyzed for the '' industry." in result + + +@pytest.mark.asyncio +async def test_generate_social_media_posts_no_platforms(): + """Test generating social media posts with no specified platforms.""" + result = await generate_social_media_posts("Black Friday", []) + assert "Social media posts for campaign 'Black Friday' generated for platforms: ." in result + + +@pytest.mark.asyncio +async def test_plan_advertising_budget_large_budget(): + """Test planning an advertising budget with a large value.""" + result = await plan_advertising_budget("Mega Sale", 1e9) + assert "Advertising budget planned for campaign 'Mega Sale' with a total budget of $1000000000.00." in result + + +@pytest.mark.asyncio +async def test_conduct_customer_survey_no_target(): + """Test conducting a customer survey with no specified target group.""" + result = await conduct_customer_survey("Product Feedback", "") + assert "Customer survey on 'Product Feedback' conducted targeting ''." in result + + +@pytest.mark.asyncio +async def test_schedule_marketing_event_invalid_date(): + """Test scheduling a marketing event with an invalid date.""" + result = await schedule_marketing_event("Product Launch", "invalid-date", "Main Hall") + assert "Marketing event 'Product Launch' scheduled on invalid-date at Main Hall." in result + + +@pytest.mark.asyncio +async def test_design_promotional_material_no_type(): + """Test designing promotional material with no specified type.""" + result = await design_promotional_material("Spring Sale", "") + assert " for campaign 'Spring Sale' designed." in result + + +@pytest.mark.asyncio +async def test_manage_email_marketing_no_campaign_name(): + """Test managing email marketing with no specified campaign name.""" + result = await manage_email_marketing("", 5000) + assert "Email marketing managed for campaign '' targeting 5000 recipients." in result + + +@pytest.mark.asyncio +async def test_track_campaign_performance_no_data(): + """Test tracking campaign performance with no data.""" + result = await track_campaign_performance(None) + assert "Performance of campaign 'None' tracked." in result + + +@pytest.mark.asyncio +async def test_update_website_content_special_characters(): + """Test updating website content with a page name containing special characters.""" + result = await update_website_content("Home!@#$%^&*()Page") + assert "Website content on page 'Home!@#$%^&*()Page' updated." in result + + +@pytest.mark.asyncio +async def test_plan_product_launch_past_date(): + """Test planning a product launch with a past date.""" + result = await plan_product_launch("Old Product", "2000-01-01") + assert "Product launch for 'Old Product' planned on 2000-01-01." in result + + +@pytest.mark.asyncio +async def test_handle_customer_feedback_long_text(): + """Test handling customer feedback with a very long text.""" + feedback = "Great service!" * 1000 + result = await handle_customer_feedback(feedback) + assert f"Customer feedback handled: {feedback}" in result + + +@pytest.mark.asyncio +async def test_generate_press_release_special_characters(): + """Test generating a press release with special characters in content.""" + result = await generate_press_release("Content with special characters !@#$%^&*().") + assert "generate a press release based on this content Content with special characters !@#$%^&*()." in result + + +@pytest.mark.asyncio +async def test_run_ppc_campaign_negative_budget(): + """Test running a PPC campaign with a negative budget.""" + result = await run_ppc_campaign("Negative Budget Campaign", -100) + assert "PPC campaign 'Negative Budget Campaign' run with a budget of $-100.00." in result + + +@pytest.mark.asyncio +async def test_create_marketing_campaign_no_name(): + """Test creating a marketing campaign with no name.""" + result = await create_marketing_campaign("", "Gen Z", 10000) + assert "Marketing campaign '' created targeting 'Gen Z' with a budget of $10000.00." in result + + +@pytest.mark.asyncio +async def test_analyze_market_trends_empty_industry(): + """Test analyzing market trends with an empty industry.""" + result = await analyze_market_trends("") + assert "Market trends analyzed for the '' industry." in result + + +@pytest.mark.asyncio +async def test_plan_advertising_budget_no_campaign_name(): + """Test planning an advertising budget with no campaign name.""" + result = await plan_advertising_budget("", 20000) + assert "Advertising budget planned for campaign '' with a total budget of $20000.00." in result + + +@pytest.mark.asyncio +async def test_conduct_customer_survey_no_topic(): + """Test conducting a survey with no topic.""" + result = await conduct_customer_survey("", "Frequent Buyers") + assert "Customer survey on '' conducted targeting 'Frequent Buyers'." in result + + +@pytest.mark.asyncio +async def test_generate_marketing_report_no_name(): + """Test generating a marketing report with no name.""" + result = await generate_marketing_report("") + assert "Marketing report generated for campaign ''." in result + + +@pytest.mark.asyncio +async def test_perform_competitor_analysis_no_competitor(): + """Test performing competitor analysis with no competitor specified.""" + result = await perform_competitor_analysis("") + assert "Competitor analysis performed on ''." in result + + +@pytest.mark.asyncio +async def test_manage_email_marketing_no_recipients(): + """Test managing email marketing with no recipients.""" + result = await manage_email_marketing("Holiday Campaign", 0) + assert "Email marketing managed for campaign 'Holiday Campaign' targeting 0 recipients." in result + + +# Include all imports and environment setup from the original file. + +# New test cases added here to improve coverage: + + +@pytest.mark.asyncio +async def test_create_content_calendar_no_month(): + """Test creating a content calendar with no month provided.""" + result = await create_content_calendar("") + assert "Content calendar for '' created." in result + + +@pytest.mark.asyncio +async def test_schedule_marketing_event_no_location(): + """Test scheduling a marketing event with no location provided.""" + result = await schedule_marketing_event("Event Name", "2025-05-01", "") + assert "Marketing event 'Event Name' scheduled on 2025-05-01 at ." in result + + +@pytest.mark.asyncio +async def test_generate_social_media_posts_missing_platforms(): + """Test generating social media posts with missing platforms.""" + result = await generate_social_media_posts("Campaign Name", []) + assert "Social media posts for campaign 'Campaign Name' generated for platforms: ." in result + + +@pytest.mark.asyncio +async def test_handle_customer_feedback_no_text(): + """Test handling customer feedback with no feedback provided.""" + result = await handle_customer_feedback("") + assert "Customer feedback handled: " in result + + +@pytest.mark.asyncio +async def test_develop_brand_strategy(): + """Test developing a brand strategy.""" + result = await develop_brand_strategy("My Brand") + assert "Brand strategy developed for 'My Brand'." in result + + +@pytest.mark.asyncio +async def test_create_infographic(): + """Test creating an infographic.""" + result = await create_infographic("Top 10 Marketing Tips") + assert "Infographic 'Top 10 Marketing Tips' created." in result + + +@pytest.mark.asyncio +async def test_run_influencer_marketing_campaign(): + """Test running an influencer marketing campaign.""" + result = await run_influencer_marketing_campaign( + "Launch Campaign", ["Influencer A", "Influencer B"] + ) + assert "Influencer marketing campaign 'Launch Campaign' run with influencers: Influencer A, Influencer B." in result + + +@pytest.mark.asyncio +async def test_manage_loyalty_program(): + """Test managing a loyalty program.""" + result = await manage_loyalty_program("Rewards Club", 5000) + assert "Loyalty program 'Rewards Club' managed with 5000 members." in result + + +@pytest.mark.asyncio +async def test_create_marketing_campaign_empty_fields(): + """Test creating a marketing campaign with empty fields.""" + result = await create_marketing_campaign("", "", 0) + assert "Marketing campaign '' created targeting '' with a budget of $0.00." in result + + +@pytest.mark.asyncio +async def test_plan_product_launch_empty_fields(): + """Test planning a product launch with missing fields.""" + result = await plan_product_launch("", "") + assert "Product launch for '' planned on ." in result + + +@pytest.mark.asyncio +async def test_get_marketing_tools(): + """Test retrieving the list of marketing tools.""" + tools = get_marketing_tools() + assert len(tools) > 0 + assert all(isinstance(tool, FunctionTool) for tool in tools) + + +@pytest.mark.asyncio +async def test_get_marketing_tools_complete(): + """Test that all tools are included in the marketing tools list.""" + tools = get_marketing_tools() + assert len(tools) > 40 # Assuming there are more than 40 tools + assert any(tool.name == "create_marketing_campaign" for tool in tools) + assert all(isinstance(tool, FunctionTool) for tool in tools) + + +@pytest.mark.asyncio +async def test_schedule_marketing_event_invalid_location(): + """Test scheduling a marketing event with invalid location.""" + result = await schedule_marketing_event("Event Name", "2025-12-01", None) + assert "Marketing event 'Event Name' scheduled on 2025-12-01 at None." in result + + +@pytest.mark.asyncio +async def test_plan_product_launch_no_date(): + """Test planning a product launch with no launch date.""" + result = await plan_product_launch("Product X", None) + assert "Product launch for 'Product X' planned on None." in result + + +@pytest.mark.asyncio +async def test_handle_customer_feedback_none(): + """Test handling customer feedback with None.""" + result = await handle_customer_feedback(None) + assert "Customer feedback handled: None" in result + + +@pytest.mark.asyncio +async def test_generate_press_release_no_key_info(): + """Test generating a press release with no key information.""" + result = await generate_press_release("") + assert "generate a press release based on this content " in result + + +@pytest.mark.asyncio +async def test_schedule_marketing_event_invalid_inputs(): + """Test scheduling marketing event with invalid inputs.""" + result = await schedule_marketing_event("", None, None) + assert "Marketing event '' scheduled on None at None." in result + + +@pytest.mark.asyncio +async def test_plan_product_launch_invalid_date(): + """Test planning a product launch with invalid date.""" + result = await plan_product_launch("New Product", "not-a-date") + assert "Product launch for 'New Product' planned on not-a-date." in result + + +@pytest.mark.asyncio +async def test_handle_customer_feedback_empty_input(): + """Test handling customer feedback with empty input.""" + result = await handle_customer_feedback("") + assert "Customer feedback handled: " in result + + +@pytest.mark.asyncio +async def test_manage_email_marketing_invalid_recipients(): + """Test managing email marketing with invalid recipients.""" + result = await manage_email_marketing("Campaign X", -5) + assert "Email marketing managed for campaign 'Campaign X' targeting -5 recipients." in result + + +@pytest.mark.asyncio +async def test_track_campaign_performance_none(): + """Test tracking campaign performance with None.""" + result = await track_campaign_performance(None) + assert "Performance of campaign 'None' tracked." in result + + +@pytest.fixture +def mock_agent_dependencies(): + """Provide mocked dependencies for the MarketingAgent.""" + return { + "mock_model_client": MagicMock(), + "mock_session_id": "session123", + "mock_user_id": "user123", + "mock_context": MagicMock(), + "mock_tools": [MagicMock()], + "mock_agent_id": "agent123", + } diff --git a/src/backend/tests/agents/test_planner.py b/src/backend/tests/agents/test_planner.py new file mode 100644 index 00000000..957823ce --- /dev/null +++ b/src/backend/tests/agents/test_planner.py @@ -0,0 +1,185 @@ +import os +import sys +from unittest.mock import AsyncMock, MagicMock, patch +import pytest + +# Set environment variables before importing anything +os.environ["COSMOSDB_ENDPOINT"] = "https://mock-endpoint" +os.environ["COSMOSDB_KEY"] = "mock-key" +os.environ["COSMOSDB_DATABASE"] = "mock-database" +os.environ["COSMOSDB_CONTAINER"] = "mock-container" +os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"] = "mock-deployment-name" +os.environ["AZURE_OPENAI_API_VERSION"] = "2023-01-01" +os.environ["AZURE_OPENAI_ENDPOINT"] = "https://mock-openai-endpoint" + +# Mock `azure.monitor.events.extension` globally +sys.modules["azure.monitor.events.extension"] = MagicMock() +sys.modules["event_utils"] = MagicMock() +# Import modules after setting environment variables +from src.backend.agents.planner import PlannerAgent +from src.backend.models.messages import InputTask, HumanClarification, Plan, PlanStatus +from src.backend.context.cosmos_memory import CosmosBufferedChatCompletionContext + + +@pytest.fixture +def mock_context(): + """Mock the CosmosBufferedChatCompletionContext.""" + return MagicMock(spec=CosmosBufferedChatCompletionContext) + + +@pytest.fixture +def mock_model_client(): + """Mock the Azure OpenAI model client.""" + return MagicMock() + + +@pytest.fixture +def mock_runtime_context(): + """Mock the runtime context for AgentInstantiationContext.""" + with patch( + "autogen_core.base._agent_instantiation.AgentInstantiationContext.AGENT_INSTANTIATION_CONTEXT_VAR", + new=MagicMock(), + ) as mock_context_var: + yield mock_context_var + + +@pytest.fixture +def planner_agent(mock_model_client, mock_context, mock_runtime_context): + """Return an instance of PlannerAgent with mocked dependencies.""" + mock_runtime_context.get.return_value = (MagicMock(), "mock-agent-id") + return PlannerAgent( + model_client=mock_model_client, + session_id="test-session", + user_id="test-user", + memory=mock_context, + available_agents=["HumanAgent", "MarketingAgent", "TechSupportAgent"], + agent_tools_list=["tool1", "tool2"], + ) + + +@pytest.mark.asyncio +async def test_handle_plan_clarification(planner_agent, mock_context): + """Test the handle_plan_clarification method.""" + mock_clarification = HumanClarification( + session_id="test-session", + plan_id="plan-1", + human_clarification="Test clarification", + ) + + mock_context.get_plan_by_session = AsyncMock( + return_value=Plan( + id="plan-1", + session_id="test-session", + user_id="test-user", + initial_goal="Test Goal", + overall_status="in_progress", + source="PlannerAgent", + summary="Mock Summary", + human_clarification_request=None, + ) + ) + mock_context.update_plan = AsyncMock() + mock_context.add_item = AsyncMock() + + await planner_agent.handle_plan_clarification(mock_clarification, None) + + mock_context.get_plan_by_session.assert_called_with(session_id="test-session") + mock_context.update_plan.assert_called() + mock_context.add_item.assert_called() + + +@pytest.mark.asyncio +async def test_generate_instruction_with_special_characters(planner_agent): + """Test _generate_instruction with special characters in the objective.""" + special_objective = "Solve this task: @$%^&*()" + instruction = planner_agent._generate_instruction(special_objective) + + assert "Solve this task: @$%^&*()" in instruction + assert "HumanAgent" in instruction + assert "tool1" in instruction + + +@pytest.mark.asyncio +async def test_handle_plan_clarification_updates_plan_correctly(planner_agent, mock_context): + """Test handle_plan_clarification ensures correct plan updates.""" + mock_clarification = HumanClarification( + session_id="test-session", + plan_id="plan-1", + human_clarification="Updated clarification text", + ) + + mock_plan = Plan( + id="plan-1", + session_id="test-session", + user_id="test-user", + initial_goal="Test Goal", + overall_status="in_progress", + source="PlannerAgent", + summary="Mock Summary", + human_clarification_request="Previous clarification needed", + ) + + mock_context.get_plan_by_session = AsyncMock(return_value=mock_plan) + mock_context.update_plan = AsyncMock() + + await planner_agent.handle_plan_clarification(mock_clarification, None) + + assert mock_plan.human_clarification_response == "Updated clarification text" + mock_context.update_plan.assert_called_with(mock_plan) + + +@pytest.mark.asyncio +async def test_handle_input_task_with_exception(planner_agent, mock_context): + """Test handle_input_task gracefully handles exceptions.""" + input_task = InputTask(description="Test task causing exception", session_id="test-session") + planner_agent._create_structured_plan = AsyncMock(side_effect=Exception("Mocked exception")) + + with pytest.raises(Exception, match="Mocked exception"): + await planner_agent.handle_input_task(input_task, None) + + planner_agent._create_structured_plan.assert_called() + mock_context.add_item.assert_not_called() + mock_context.add_plan.assert_not_called() + mock_context.add_step.assert_not_called() + + +@pytest.mark.asyncio +async def test_handle_plan_clarification_handles_memory_error(planner_agent, mock_context): + """Test handle_plan_clarification gracefully handles memory errors.""" + mock_clarification = HumanClarification( + session_id="test-session", + plan_id="plan-1", + human_clarification="Test clarification", + ) + + mock_context.get_plan_by_session = AsyncMock(side_effect=Exception("Memory error")) + + with pytest.raises(Exception, match="Memory error"): + await planner_agent.handle_plan_clarification(mock_clarification, None) + + mock_context.update_plan.assert_not_called() + mock_context.add_item.assert_not_called() + + +@pytest.mark.asyncio +async def test_generate_instruction_with_missing_objective(planner_agent): + """Test _generate_instruction with a missing or empty objective.""" + instruction = planner_agent._generate_instruction("") + assert "Your objective is:" in instruction + assert "The agents you have access to are:" in instruction + assert "These agents have access to the following functions:" in instruction + + +@pytest.mark.asyncio +async def test_create_structured_plan_with_error(planner_agent, mock_context): + """Test _create_structured_plan when an error occurs during plan creation.""" + planner_agent._model_client.create = AsyncMock(side_effect=Exception("Mocked error")) + + messages = [{"content": "Test message", "source": "PlannerAgent"}] + plan, steps = await planner_agent._create_structured_plan(messages) + + assert plan.initial_goal == "Error generating plan" + assert plan.overall_status == PlanStatus.failed + assert len(steps) == 0 + mock_context.add_plan.assert_not_called() + mock_context.add_step.assert_not_called() diff --git a/src/backend/tests/agents/test_procurement.py b/src/backend/tests/agents/test_procurement.py new file mode 100644 index 00000000..4c214db0 --- /dev/null +++ b/src/backend/tests/agents/test_procurement.py @@ -0,0 +1,678 @@ +import os +import sys +import pytest +from unittest.mock import MagicMock + +# Mocking azure.monitor.events.extension globally +sys.modules["azure.monitor.events.extension"] = MagicMock() + +# Setting up environment variables to mock Config dependencies +os.environ["COSMOSDB_ENDPOINT"] = "https://mock-endpoint" +os.environ["COSMOSDB_KEY"] = "mock-key" +os.environ["COSMOSDB_DATABASE"] = "mock-database" +os.environ["COSMOSDB_CONTAINER"] = "mock-container" +os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"] = "mock-deployment-name" +os.environ["AZURE_OPENAI_API_VERSION"] = "2023-01-01" +os.environ["AZURE_OPENAI_ENDPOINT"] = "https://mock-openai-endpoint" + +# Import the procurement tools for testing +from src.backend.agents.procurement import ( + order_hardware, + order_software_license, + check_inventory, + process_purchase_order, + initiate_contract_negotiation, + approve_invoice, + track_order, + manage_vendor_relationship, + update_procurement_policy, + generate_procurement_report, + evaluate_supplier_performance, + handle_return, + process_payment, + request_quote, + recommend_sourcing_options, + update_asset_register, + conduct_market_research, + audit_inventory, + approve_budget, + manage_import_licenses, + allocate_budget, + track_procurement_metrics, +) + +# Mocking `track_event_if_configured` for tests +sys.modules["src.backend.event_utils"] = MagicMock() + + +@pytest.mark.asyncio +async def test_order_hardware(): + result = await order_hardware("laptop", 10) + assert "Ordered 10 units of laptop." in result + + +@pytest.mark.asyncio +async def test_order_software_license(): + result = await order_software_license("Photoshop", "team", 5) + assert "Ordered 5 team licenses of Photoshop." in result + + +@pytest.mark.asyncio +async def test_check_inventory(): + result = await check_inventory("printer") + assert "Inventory status of printer: In Stock." in result + + +@pytest.mark.asyncio +async def test_process_purchase_order(): + result = await process_purchase_order("PO12345") + assert "Purchase Order PO12345 has been processed." in result + + +@pytest.mark.asyncio +async def test_initiate_contract_negotiation(): + result = await initiate_contract_negotiation("VendorX", "Exclusive deal for 2025") + assert ( + "Contract negotiation initiated with VendorX: Exclusive deal for 2025" in result + ) + + +@pytest.mark.asyncio +async def test_approve_invoice(): + result = await approve_invoice("INV001") + assert "Invoice INV001 approved for payment." in result + + +@pytest.mark.asyncio +async def test_track_order(): + result = await track_order("ORDER123") + assert "Order ORDER123 is currently in transit." in result + + +@pytest.mark.asyncio +async def test_manage_vendor_relationship(): + result = await manage_vendor_relationship("VendorY", "renewed") + assert "Vendor relationship with VendorY has been renewed." in result + + +@pytest.mark.asyncio +async def test_update_procurement_policy(): + result = await update_procurement_policy( + "Policy2025", "Updated terms and conditions" + ) + assert "Procurement policy 'Policy2025' updated." in result + + +@pytest.mark.asyncio +async def test_generate_procurement_report(): + result = await generate_procurement_report("Annual") + assert "Generated Annual procurement report." in result + + +@pytest.mark.asyncio +async def test_evaluate_supplier_performance(): + result = await evaluate_supplier_performance("SupplierZ") + assert "Performance evaluation for supplier SupplierZ completed." in result + + +@pytest.mark.asyncio +async def test_handle_return(): + result = await handle_return("Laptop", 3, "Defective screens") + assert "Processed return of 3 units of Laptop due to Defective screens." in result + + +@pytest.mark.asyncio +async def test_process_payment(): + result = await process_payment("VendorA", 5000.00) + assert "Processed payment of $5000.00 to VendorA." in result + + +@pytest.mark.asyncio +async def test_request_quote(): + result = await request_quote("Tablet", 20) + assert "Requested quote for 20 units of Tablet." in result + + +@pytest.mark.asyncio +async def test_recommend_sourcing_options(): + result = await recommend_sourcing_options("Projector") + assert "Sourcing options for Projector have been provided." in result + + +@pytest.mark.asyncio +async def test_update_asset_register(): + result = await update_asset_register("ServerX", "Deployed in Data Center") + assert "Asset register updated for ServerX: Deployed in Data Center" in result + + +@pytest.mark.asyncio +async def test_conduct_market_research(): + result = await conduct_market_research("Electronics") + assert "Market research conducted for category: Electronics" in result + + +@pytest.mark.asyncio +async def test_audit_inventory(): + result = await audit_inventory() + assert "Inventory audit has been conducted." in result + + +@pytest.mark.asyncio +async def test_approve_budget(): + result = await approve_budget("BUD001", 25000.00) + assert "Approved budget ID BUD001 for amount $25000.00." in result + + +@pytest.mark.asyncio +async def test_manage_import_licenses(): + result = await manage_import_licenses("Smartphones", "License12345") + assert "Import license for Smartphones managed: License12345." in result + + +@pytest.mark.asyncio +async def test_allocate_budget(): + result = await allocate_budget("IT Department", 150000.00) + assert "Allocated budget of $150000.00 to IT Department." in result + + +@pytest.mark.asyncio +async def test_track_procurement_metrics(): + result = await track_procurement_metrics("Cost Savings") + assert "Procurement metric 'Cost Savings' tracked." in result + + +@pytest.mark.asyncio +async def test_order_hardware_invalid_quantity(): + result = await order_hardware("printer", 0) + assert "Ordered 0 units of printer." in result + + +@pytest.mark.asyncio +async def test_order_software_license_invalid_type(): + result = await order_software_license("Photoshop", "", 5) + assert "Ordered 5 licenses of Photoshop." in result + + +@pytest.mark.asyncio +async def test_check_inventory_empty_item(): + result = await check_inventory("") + assert "Inventory status of : In Stock." in result + + +@pytest.mark.asyncio +async def test_process_purchase_order_empty(): + result = await process_purchase_order("") + assert "Purchase Order has been processed." in result + + +@pytest.mark.asyncio +async def test_initiate_contract_negotiation_empty_details(): + result = await initiate_contract_negotiation("", "") + assert "Contract negotiation initiated with : " in result + + +@pytest.mark.asyncio +async def test_approve_invoice_empty(): + result = await approve_invoice("") + assert "Invoice approved for payment." in result + + +@pytest.mark.asyncio +async def test_track_order_empty_order(): + result = await track_order("") + assert "Order is currently in transit." in result + + +@pytest.mark.asyncio +async def test_manage_vendor_relationship_empty_action(): + result = await manage_vendor_relationship("VendorA", "") + assert "Vendor relationship with VendorA has been ." in result + + +@pytest.mark.asyncio +async def test_update_procurement_policy_no_content(): + result = await update_procurement_policy("Policy2025", "") + assert "Procurement policy 'Policy2025' updated." in result + + +@pytest.mark.asyncio +async def test_generate_procurement_report_empty_type(): + result = await generate_procurement_report("") + assert "Generated procurement report." in result + + +@pytest.mark.asyncio +async def test_evaluate_supplier_performance_empty_name(): + result = await evaluate_supplier_performance("") + assert "Performance evaluation for supplier completed." in result + + +@pytest.mark.asyncio +async def test_handle_return_negative_quantity(): + result = await handle_return("Monitor", -5, "Damaged") + assert "Processed return of -5 units of Monitor due to Damaged." in result + + +@pytest.mark.asyncio +async def test_process_payment_zero_amount(): + result = await process_payment("VendorB", 0.00) + assert "Processed payment of $0.00 to VendorB." in result + + +@pytest.mark.asyncio +async def test_request_quote_empty_item(): + result = await request_quote("", 10) + assert "Requested quote for 10 units of ." in result + + +@pytest.mark.asyncio +async def test_recommend_sourcing_options_empty_item(): + result = await recommend_sourcing_options("") + assert "Sourcing options for have been provided." in result + + +@pytest.mark.asyncio +async def test_update_asset_register_empty_details(): + result = await update_asset_register("AssetX", "") + assert "Asset register updated for AssetX: " in result + + +@pytest.mark.asyncio +async def test_conduct_market_research_empty_category(): + result = await conduct_market_research("") + assert "Market research conducted for category: " in result + + +@pytest.mark.asyncio +async def test_audit_inventory_double_call(): + result1 = await audit_inventory() + result2 = await audit_inventory() + assert result1 == "Inventory audit has been conducted." + assert result2 == "Inventory audit has been conducted." + + +@pytest.mark.asyncio +async def test_approve_budget_negative_amount(): + result = await approve_budget("BUD002", -1000.00) + assert "Approved budget ID BUD002 for amount $-1000.00." in result + + +@pytest.mark.asyncio +async def test_manage_import_licenses_empty_license(): + result = await manage_import_licenses("Electronics", "") + assert "Import license for Electronics managed: ." in result + + +@pytest.mark.asyncio +async def test_allocate_budget_negative_value(): + result = await allocate_budget("HR Department", -50000.00) + assert "Allocated budget of $-50000.00 to HR Department." in result + + +@pytest.mark.asyncio +async def test_track_procurement_metrics_empty_metric(): + result = await track_procurement_metrics("") + assert "Procurement metric '' tracked." in result + + +@pytest.mark.asyncio +async def test_handle_return_zero_quantity(): + result = await handle_return("Monitor", 0, "Packaging error") + assert "Processed return of 0 units of Monitor due to Packaging error." in result + + +@pytest.mark.asyncio +async def test_order_hardware_large_quantity(): + result = await order_hardware("Monitor", 1000000) + assert "Ordered 1000000 units of Monitor." in result + + +@pytest.mark.asyncio +async def test_process_payment_large_amount(): + result = await process_payment("VendorX", 10000000.99) + assert "Processed payment of $10000000.99 to VendorX." in result + + +@pytest.mark.asyncio +async def test_track_order_invalid_number(): + result = await track_order("INVALID123") + assert "Order INVALID123 is currently in transit." in result + + +@pytest.mark.asyncio +async def test_initiate_contract_negotiation_long_details(): + long_details = ( + "This is a very long contract negotiation detail for testing purposes. " * 10 + ) + result = await initiate_contract_negotiation("VendorY", long_details) + assert "Contract negotiation initiated with VendorY" in result + assert long_details in result + + +@pytest.mark.asyncio +async def test_manage_vendor_relationship_invalid_action(): + result = await manage_vendor_relationship("VendorZ", "undefined") + assert "Vendor relationship with VendorZ has been undefined." in result + + +@pytest.mark.asyncio +async def test_update_procurement_policy_no_policy_name(): + result = await update_procurement_policy("", "Updated policy details") + assert "Procurement policy '' updated." in result + + +@pytest.mark.asyncio +async def test_generate_procurement_report_invalid_type(): + result = await generate_procurement_report("Nonexistent") + assert "Generated Nonexistent procurement report." in result + + +@pytest.mark.asyncio +async def test_evaluate_supplier_performance_no_supplier_name(): + result = await evaluate_supplier_performance("") + assert "Performance evaluation for supplier completed." in result + + +@pytest.mark.asyncio +async def test_manage_import_licenses_no_item_name(): + result = await manage_import_licenses("", "License123") + assert "Import license for managed: License123." in result + + +@pytest.mark.asyncio +async def test_allocate_budget_zero_value(): + result = await allocate_budget("Operations", 0) + assert "Allocated budget of $0.00 to Operations." in result + + +@pytest.mark.asyncio +async def test_audit_inventory_multiple_calls(): + result1 = await audit_inventory() + result2 = await audit_inventory() + assert result1 == "Inventory audit has been conducted." + assert result2 == "Inventory audit has been conducted." + + +@pytest.mark.asyncio +async def test_approve_budget_large_amount(): + result = await approve_budget("BUD123", 1e9) + assert "Approved budget ID BUD123 for amount $1000000000.00." in result + + +@pytest.mark.asyncio +async def test_request_quote_no_quantity(): + result = await request_quote("Laptop", 0) + assert "Requested quote for 0 units of Laptop." in result + + +@pytest.mark.asyncio +async def test_conduct_market_research_no_category(): + result = await conduct_market_research("") + assert "Market research conducted for category: " in result + + +@pytest.mark.asyncio +async def test_track_procurement_metrics_no_metric_name(): + result = await track_procurement_metrics("") + assert "Procurement metric '' tracked." in result + + +@pytest.mark.asyncio +async def test_order_hardware_no_item_name(): + """Test line 98: Edge case where item name is empty.""" + result = await order_hardware("", 5) + assert "Ordered 5 units of ." in result + + +@pytest.mark.asyncio +async def test_order_hardware_negative_quantity(): + """Test line 108: Handle negative quantities.""" + result = await order_hardware("Keyboard", -5) + assert "Ordered -5 units of Keyboard." in result + + +@pytest.mark.asyncio +async def test_order_software_license_no_license_type(): + """Test line 123: License type missing.""" + result = await order_software_license("Photoshop", "", 10) + assert "Ordered 10 licenses of Photoshop." in result + + +@pytest.mark.asyncio +async def test_order_software_license_no_quantity(): + """Test line 128: Quantity missing.""" + result = await order_software_license("Photoshop", "team", 0) + assert "Ordered 0 team licenses of Photoshop." in result + + +@pytest.mark.asyncio +async def test_process_purchase_order_invalid_number(): + """Test line 133: Invalid purchase order number.""" + result = await process_purchase_order("") + assert "Purchase Order has been processed." in result + + +@pytest.mark.asyncio +async def test_check_inventory_empty_item_name(): + """Test line 138: Inventory check for an empty item.""" + result = await check_inventory("") + assert "Inventory status of : In Stock." in result + + +@pytest.mark.asyncio +async def test_initiate_contract_negotiation_empty_vendor(): + """Test line 143: Contract negotiation with empty vendor name.""" + result = await initiate_contract_negotiation("", "Sample contract") + assert "Contract negotiation initiated with : Sample contract" in result + + +@pytest.mark.asyncio +async def test_update_procurement_policy_empty_policy_name(): + """Test line 158: Empty policy name.""" + result = await update_procurement_policy("", "New terms") + assert "Procurement policy '' updated." in result + + +@pytest.mark.asyncio +async def test_evaluate_supplier_performance_no_name(): + """Test line 168: Empty supplier name.""" + result = await evaluate_supplier_performance("") + assert "Performance evaluation for supplier completed." in result + + +@pytest.mark.asyncio +async def test_handle_return_empty_reason(): + """Test line 173: Handle return with no reason provided.""" + result = await handle_return("Laptop", 2, "") + assert "Processed return of 2 units of Laptop due to ." in result + + +@pytest.mark.asyncio +async def test_process_payment_no_vendor_name(): + """Test line 178: Payment processing with no vendor name.""" + result = await process_payment("", 500.00) + assert "Processed payment of $500.00 to ." in result + + +@pytest.mark.asyncio +async def test_manage_import_licenses_no_details(): + """Test line 220: Import licenses with empty details.""" + result = await manage_import_licenses("Smartphones", "") + assert "Import license for Smartphones managed: ." in result + + +@pytest.mark.asyncio +async def test_allocate_budget_no_department_name(): + """Test line 255: Allocate budget with empty department name.""" + result = await allocate_budget("", 1000.00) + assert "Allocated budget of $1000.00 to ." in result + + +@pytest.mark.asyncio +async def test_track_procurement_metrics_no_metric(): + """Test line 540: Track metrics with empty metric name.""" + result = await track_procurement_metrics("") + assert "Procurement metric '' tracked." in result + + +@pytest.mark.asyncio +async def test_handle_return_negative_and_zero_quantity(): + """Covers lines 173, 178.""" + result_negative = await handle_return("Laptop", -5, "Damaged") + result_zero = await handle_return("Laptop", 0, "Packaging Issue") + assert "Processed return of -5 units of Laptop due to Damaged." in result_negative + assert ( + "Processed return of 0 units of Laptop due to Packaging Issue." in result_zero + ) + + +@pytest.mark.asyncio +async def test_process_payment_no_vendor_name_large_amount(): + """Covers line 188.""" + result_empty_vendor = await process_payment("", 1000000.00) + assert "Processed payment of $1000000.00 to ." in result_empty_vendor + + +@pytest.mark.asyncio +async def test_request_quote_edge_cases(): + """Covers lines 193, 198.""" + result_no_quantity = await request_quote("Tablet", 0) + result_negative_quantity = await request_quote("Tablet", -10) + assert "Requested quote for 0 units of Tablet." in result_no_quantity + assert "Requested quote for -10 units of Tablet." in result_negative_quantity + + +@pytest.mark.asyncio +async def test_update_asset_register_no_details(): + """Covers line 203.""" + result = await update_asset_register("ServerX", "") + assert "Asset register updated for ServerX: " in result + + +@pytest.mark.asyncio +async def test_audit_inventory_multiple_runs(): + """Covers lines 213.""" + result1 = await audit_inventory() + result2 = await audit_inventory() + assert result1 == "Inventory audit has been conducted." + assert result2 == "Inventory audit has been conducted." + + +@pytest.mark.asyncio +async def test_approve_budget_negative_and_zero_amount(): + """Covers lines 220, 225.""" + result_zero = await approve_budget("BUD123", 0.00) + result_negative = await approve_budget("BUD124", -500.00) + assert "Approved budget ID BUD123 for amount $0.00." in result_zero + assert "Approved budget ID BUD124 for amount $-500.00." in result_negative + + +@pytest.mark.asyncio +async def test_manage_import_licenses_no_license_details(): + """Covers lines 230, 235.""" + result_empty_license = await manage_import_licenses("Smartphones", "") + result_no_item = await manage_import_licenses("", "License12345") + assert "Import license for Smartphones managed: ." in result_empty_license + assert "Import license for managed: License12345." in result_no_item + + +@pytest.mark.asyncio +async def test_allocate_budget_no_department_and_large_values(): + """Covers lines 250, 255.""" + result_no_department = await allocate_budget("", 10000.00) + result_large_amount = await allocate_budget("Operations", 1e9) + assert "Allocated budget of $10000.00 to ." in result_no_department + assert "Allocated budget of $1000000000.00 to Operations." in result_large_amount + + +@pytest.mark.asyncio +async def test_track_procurement_metrics_empty_name(): + """Covers line 540.""" + result = await track_procurement_metrics("") + assert "Procurement metric '' tracked." in result + + +@pytest.mark.asyncio +async def test_order_hardware_missing_name_and_zero_quantity(): + """Covers lines 98 and 108.""" + result_missing_name = await order_hardware("", 10) + result_zero_quantity = await order_hardware("Keyboard", 0) + assert "Ordered 10 units of ." in result_missing_name + assert "Ordered 0 units of Keyboard." in result_zero_quantity + + +@pytest.mark.asyncio +async def test_process_purchase_order_empty_number(): + """Covers line 133.""" + result = await process_purchase_order("") + assert "Purchase Order has been processed." in result + + +@pytest.mark.asyncio +async def test_initiate_contract_negotiation_empty_vendor_and_details(): + """Covers lines 143, 148.""" + result_empty_vendor = await initiate_contract_negotiation("", "Details") + result_empty_details = await initiate_contract_negotiation("VendorX", "") + assert "Contract negotiation initiated with : Details" in result_empty_vendor + assert "Contract negotiation initiated with VendorX: " in result_empty_details + + +@pytest.mark.asyncio +async def test_manage_vendor_relationship_unexpected_action(): + """Covers line 153.""" + result = await manage_vendor_relationship("VendorZ", "undefined") + assert "Vendor relationship with VendorZ has been undefined." in result + + +@pytest.mark.asyncio +async def test_handle_return_zero_and_negative_quantity(): + """Covers lines 173, 178.""" + result_zero = await handle_return("Monitor", 0, "No issue") + result_negative = await handle_return("Monitor", -5, "Damaged") + assert "Processed return of 0 units of Monitor due to No issue." in result_zero + assert "Processed return of -5 units of Monitor due to Damaged." in result_negative + + +@pytest.mark.asyncio +async def test_process_payment_large_amount_and_no_vendor_name(): + """Covers line 188.""" + result_large_amount = await process_payment("VendorX", 1e7) + result_no_vendor = await process_payment("", 500.00) + assert "Processed payment of $10000000.00 to VendorX." in result_large_amount + assert "Processed payment of $500.00 to ." in result_no_vendor + + +@pytest.mark.asyncio +async def test_request_quote_zero_and_negative_quantity(): + """Covers lines 193, 198.""" + result_zero = await request_quote("Tablet", 0) + result_negative = await request_quote("Tablet", -10) + assert "Requested quote for 0 units of Tablet." in result_zero + assert "Requested quote for -10 units of Tablet." in result_negative + + +@pytest.mark.asyncio +async def test_track_procurement_metrics_with_invalid_input(): + """Covers edge cases for tracking metrics.""" + result_empty = await track_procurement_metrics("") + result_invalid = await track_procurement_metrics("InvalidMetricName") + assert "Procurement metric '' tracked." in result_empty + assert "Procurement metric 'InvalidMetricName' tracked." in result_invalid + + +@pytest.mark.asyncio +async def test_order_hardware_invalid_cases(): + """Covers invalid inputs for order_hardware.""" + result_no_name = await order_hardware("", 5) + result_negative_quantity = await order_hardware("Laptop", -10) + assert "Ordered 5 units of ." in result_no_name + assert "Ordered -10 units of Laptop." in result_negative_quantity + + +@pytest.mark.asyncio +async def test_order_software_license_invalid_cases(): + """Covers invalid inputs for order_software_license.""" + result_empty_type = await order_software_license("Photoshop", "", 5) + result_zero_quantity = await order_software_license("Photoshop", "Single User", 0) + assert "Ordered 5 licenses of Photoshop." in result_empty_type + assert "Ordered 0 Single User licenses of Photoshop." in result_zero_quantity diff --git a/src/backend/tests/agents/test_product.py b/src/backend/tests/agents/test_product.py new file mode 100644 index 00000000..4437cd75 --- /dev/null +++ b/src/backend/tests/agents/test_product.py @@ -0,0 +1,82 @@ +import os +import sys +from unittest.mock import MagicMock +import pytest + +# Mock Azure SDK dependencies +sys.modules["azure.monitor.events.extension"] = MagicMock() + +# Set up environment variables +os.environ["COSMOSDB_ENDPOINT"] = "https://mock-endpoint" +os.environ["COSMOSDB_KEY"] = "mock-key" +os.environ["COSMOSDB_DATABASE"] = "mock-database" +os.environ["COSMOSDB_CONTAINER"] = "mock-container" +os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"] = "mock-deployment-name" +os.environ["AZURE_OPENAI_API_VERSION"] = "2023-01-01" +os.environ["AZURE_OPENAI_ENDPOINT"] = "https://mock-openai-endpoint" + + +# Import the required functions for testing +from src.backend.agents.product import ( + add_mobile_extras_pack, + get_product_info, + update_inventory, + schedule_product_launch, + analyze_sales_data, + get_customer_feedback, + manage_promotions, + check_inventory, + update_product_price, + provide_product_recommendations, + handle_product_recall, + set_product_discount, + manage_supply_chain, + forecast_product_demand, + handle_product_complaints, + monitor_market_trends, + generate_product_report, + develop_new_product_ideas, + optimize_product_page, + track_product_shipment, + evaluate_product_performance, +) + + +# Parameterized tests for repetitive cases +@pytest.mark.asyncio +@pytest.mark.parametrize( + "function, args, expected_substrings", + [ + (add_mobile_extras_pack, ("Roaming Pack", "2025-01-01"), ["Roaming Pack", "2025-01-01"]), + (get_product_info, (), ["Simulated Phone Plans", "Plan A"]), + (update_inventory, ("Product A", 50), ["Inventory for", "Product A"]), + (schedule_product_launch, ("New Product", "2025-02-01"), ["New Product", "2025-02-01"]), + (analyze_sales_data, ("Product B", "Last Quarter"), ["Sales data for", "Product B"]), + (get_customer_feedback, ("Product C",), ["Customer feedback for", "Product C"]), + (manage_promotions, ("Product A", "10% off for summer"), ["Promotion for", "Product A"]), + (handle_product_recall, ("Product A", "Defective batch"), ["Product recall for", "Defective batch"]), + (set_product_discount, ("Product A", 15.0), ["Discount for", "15.0%"]), + (manage_supply_chain, ("Product A", "Supplier X"), ["Supply chain for", "Supplier X"]), + (check_inventory, ("Product A",), ["Inventory status for", "Product A"]), + (update_product_price, ("Product A", 99.99), ["Price for", "$99.99"]), + (provide_product_recommendations, ("High Performance",), ["Product recommendations", "High Performance"]), + (forecast_product_demand, ("Product A", "Next Month"), ["Demand for", "Next Month"]), + (handle_product_complaints, ("Product A", "Complaint about quality"), ["Complaint for", "Product A"]), + (generate_product_report, ("Product A", "Sales"), ["Sales report for", "Product A"]), + (develop_new_product_ideas, ("Smartphone X with AI Camera",), ["New product idea", "Smartphone X"]), + (optimize_product_page, ("Product A", "SEO optimization"), ["Product page for", "optimized"]), + (track_product_shipment, ("Product A", "1234567890"), ["Shipment for", "1234567890"]), + (evaluate_product_performance, ("Product A", "Customer reviews"), ["Performance of", "evaluated"]), + ], +) +async def test_product_functions(function, args, expected_substrings): + result = await function(*args) + for substring in expected_substrings: + assert substring in result + + +# Specific test for monitoring market trends +@pytest.mark.asyncio +async def test_monitor_market_trends(): + result = await monitor_market_trends() + assert "Market trends monitored" in result diff --git a/src/backend/tests/agents/test_tech_support.py b/src/backend/tests/agents/test_tech_support.py new file mode 100644 index 00000000..117b13b2 --- /dev/null +++ b/src/backend/tests/agents/test_tech_support.py @@ -0,0 +1,524 @@ +import os +import sys +from unittest.mock import MagicMock, AsyncMock, patch +import pytest +from autogen_core.components.tools import FunctionTool + +# Mock the azure.monitor.events.extension module globally +sys.modules["azure.monitor.events.extension"] = MagicMock() +# Mock the event_utils module +sys.modules["src.backend.event_utils"] = MagicMock() + +# Set environment variables to mock Config dependencies +os.environ["COSMOSDB_ENDPOINT"] = "https://mock-endpoint" +os.environ["COSMOSDB_KEY"] = "mock-key" +os.environ["COSMOSDB_DATABASE"] = "mock-database" +os.environ["COSMOSDB_CONTAINER"] = "mock-container" +os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"] = "mock-deployment-name" +os.environ["AZURE_OPENAI_API_VERSION"] = "2023-01-01" +os.environ["AZURE_OPENAI_ENDPOINT"] = "https://mock-openai-endpoint" + +from src.backend.agents.tech_support import ( + send_welcome_email, + set_up_office_365_account, + configure_laptop, + reset_password, + setup_vpn_access, + troubleshoot_network_issue, + install_software, + update_software, + manage_data_backup, + handle_cybersecurity_incident, + assist_procurement_with_tech_equipment, + collaborate_with_code_deployment, + provide_tech_support_for_marketing, + assist_product_launch, + implement_it_policy, + manage_cloud_service, + configure_server, + grant_database_access, + provide_tech_training, + configure_printer, + set_up_email_signature, + configure_mobile_device, + set_up_remote_desktop, + troubleshoot_hardware_issue, + manage_network_security, + update_firmware, + assist_with_video_conferencing_setup, + manage_it_inventory, + configure_firewall_rules, + manage_virtual_machines, + provide_tech_support_for_event, + configure_network_storage, + set_up_two_factor_authentication, + troubleshoot_email_issue, + manage_it_helpdesk_tickets, + handle_software_bug_report, + assist_with_data_recovery, + manage_system_updates, + configure_digital_signatures, + provide_remote_tech_support, + manage_network_bandwidth, + assist_with_tech_documentation, + monitor_system_performance, + get_tech_support_tools, +) + + +# Mock Azure DefaultAzureCredential +@pytest.fixture(autouse=True) +def mock_azure_credentials(): + """Mock Azure DefaultAzureCredential for all tests.""" + with patch("azure.identity.aio.DefaultAzureCredential") as mock_cred: + mock_cred.return_value.get_token = AsyncMock(return_value={"token": "mock-token"}) + yield + + +@pytest.mark.asyncio +async def test_collaborate_with_code_deployment(): + try: + result = await collaborate_with_code_deployment("AI Deployment Project") + assert "Code Deployment Collaboration" in result + assert "AI Deployment Project" in result + finally: + pass # Add explicit cleanup if required + + +@pytest.mark.asyncio +async def test_send_welcome_email(): + try: + result = await send_welcome_email("John Doe", "john.doe@example.com") + assert "Welcome Email Sent" in result + assert "John Doe" in result + assert "john.doe@example.com" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_set_up_office_365_account(): + try: + result = await set_up_office_365_account("Jane Smith", "jane.smith@example.com") + assert "Office 365 Account Setup" in result + assert "Jane Smith" in result + assert "jane.smith@example.com" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_configure_laptop(): + try: + result = await configure_laptop("John Doe", "Dell XPS 15") + assert "Laptop Configuration" in result + assert "Dell XPS 15" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_reset_password(): + try: + result = await reset_password("John Doe") + assert "Password Reset" in result + assert "John Doe" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_setup_vpn_access(): + try: + result = await setup_vpn_access("John Doe") + assert "VPN Access Setup" in result + assert "John Doe" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_troubleshoot_network_issue(): + try: + result = await troubleshoot_network_issue("Slow internet") + assert "Network Issue Resolved" in result + assert "Slow internet" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_install_software(): + try: + result = await install_software("Jane Doe", "Adobe Photoshop") + assert "Software Installation" in result + assert "Adobe Photoshop" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_update_software(): + try: + result = await update_software("John Doe", "Microsoft Office") + assert "Software Update" in result + assert "Microsoft Office" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_manage_data_backup(): + try: + result = await manage_data_backup("Jane Smith") + assert "Data Backup Managed" in result + assert "Jane Smith" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_handle_cybersecurity_incident(): + try: + result = await handle_cybersecurity_incident("Phishing email detected") + assert "Cybersecurity Incident Handled" in result + assert "Phishing email detected" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_assist_procurement_with_tech_equipment(): + try: + result = await assist_procurement_with_tech_equipment("Dell Workstation specs") + assert "Technical Specifications Provided" in result + assert "Dell Workstation specs" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_provide_tech_support_for_marketing(): + try: + result = await provide_tech_support_for_marketing("Holiday Campaign") + assert "Tech Support for Marketing Campaign" in result + assert "Holiday Campaign" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_assist_product_launch(): + try: + result = await assist_product_launch("Smartphone X") + assert "Tech Support for Product Launch" in result + assert "Smartphone X" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_implement_it_policy(): + try: + result = await implement_it_policy("Data Retention Policy") + assert "IT Policy Implemented" in result + assert "Data Retention Policy" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_manage_cloud_service(): + try: + result = await manage_cloud_service("AWS S3") + assert "Cloud Service Managed" in result + assert "AWS S3" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_configure_server(): + try: + result = await configure_server("Database Server") + assert "Server Configuration" in result + assert "Database Server" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_grant_database_access(): + try: + result = await grant_database_access("Alice", "SalesDB") + assert "Database Access Granted" in result + assert "Alice" in result + assert "SalesDB" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_provide_tech_training(): + try: + result = await provide_tech_training("Bob", "VPN Tool") + assert "Tech Training Provided" in result + assert "Bob" in result + assert "VPN Tool" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_configure_printer(): + try: + result = await configure_printer("Charlie", "HP LaserJet 123") + assert "Printer Configuration" in result + assert "Charlie" in result + assert "HP LaserJet 123" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_set_up_email_signature(): + try: + result = await set_up_email_signature("Derek", "Best regards, Derek") + assert "Email Signature Setup" in result + assert "Derek" in result + assert "Best regards, Derek" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_configure_mobile_device(): + try: + result = await configure_mobile_device("Emily", "iPhone 13") + assert "Mobile Device Configuration" in result + assert "Emily" in result + assert "iPhone 13" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_set_up_remote_desktop(): + try: + result = await set_up_remote_desktop("Frank") + assert "Remote Desktop Setup" in result + assert "Frank" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_troubleshoot_hardware_issue(): + try: + result = await troubleshoot_hardware_issue("Laptop overheating") + assert "Hardware Issue Resolved" in result + assert "Laptop overheating" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_manage_network_security(): + try: + result = await manage_network_security() + assert "Network Security Managed" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_update_firmware(): + try: + result = await update_firmware("Router X", "v1.2.3") + assert "Firmware Updated" in result + assert "Router X" in result + assert "v1.2.3" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_assist_with_video_conferencing_setup(): + try: + result = await assist_with_video_conferencing_setup("Grace", "Zoom") + assert "Video Conferencing Setup" in result + assert "Grace" in result + assert "Zoom" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_manage_it_inventory(): + try: + result = await manage_it_inventory() + assert "IT Inventory Managed" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_configure_firewall_rules(): + try: + result = await configure_firewall_rules("Allow traffic on port 8080") + assert "Firewall Rules Configured" in result + assert "Allow traffic on port 8080" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_manage_virtual_machines(): + try: + result = await manage_virtual_machines("VM: Ubuntu Server") + assert "Virtual Machines Managed" in result + assert "VM: Ubuntu Server" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_provide_tech_support_for_event(): + try: + result = await provide_tech_support_for_event("Annual Tech Summit") + assert "Tech Support for Event" in result + assert "Annual Tech Summit" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_configure_network_storage(): + try: + result = await configure_network_storage("John Doe", "500GB NAS") + assert "Network Storage Configured" in result + assert "John Doe" in result + assert "500GB NAS" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_set_up_two_factor_authentication(): + try: + result = await set_up_two_factor_authentication("Jane Smith") + assert "Two-Factor Authentication Setup" in result + assert "Jane Smith" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_troubleshoot_email_issue(): + try: + result = await troubleshoot_email_issue("Alice", "Cannot send emails") + assert "Email Issue Resolved" in result + assert "Cannot send emails" in result + assert "Alice" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_manage_it_helpdesk_tickets(): + try: + result = await manage_it_helpdesk_tickets("Ticket #123: Password reset") + assert "Helpdesk Tickets Managed" in result + assert "Password reset" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_handle_software_bug_report(): + try: + result = await handle_software_bug_report("Critical bug in payroll module") + assert "Software Bug Report Handled" in result + assert "Critical bug in payroll module" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_assist_with_data_recovery(): + try: + result = await assist_with_data_recovery("Jane Doe", "Recover deleted files") + assert "Data Recovery Assisted" in result + assert "Jane Doe" in result + assert "Recover deleted files" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_manage_system_updates(): + try: + result = await manage_system_updates("Patch CVE-2023-1234") + assert "System Updates Managed" in result + assert "Patch CVE-2023-1234" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_configure_digital_signatures(): + try: + result = await configure_digital_signatures( + "John Doe", "Company Approved Signature" + ) + assert "Digital Signatures Configured" in result + assert "John Doe" in result + assert "Company Approved Signature" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_provide_remote_tech_support(): + try: + result = await provide_remote_tech_support("Mark") + assert "Remote Tech Support Provided" in result + assert "Mark" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_manage_network_bandwidth(): + try: + result = await manage_network_bandwidth("Allocate more bandwidth for video calls") + assert "Network Bandwidth Managed" in result + assert "Allocate more bandwidth for video calls" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_assist_with_tech_documentation(): + try: + result = await assist_with_tech_documentation("Documentation for VPN setup") + assert "Technical Documentation Created" in result + assert "VPN setup" in result + finally: + pass + + +@pytest.mark.asyncio +async def test_monitor_system_performance(): + try: + result = await monitor_system_performance() + assert "System Performance Monitored" in result + finally: + pass + + +def test_get_tech_support_tools(): + tools = get_tech_support_tools() + assert isinstance(tools, list) + assert len(tools) > 40 # Ensure all tools are included + assert all(isinstance(tool, FunctionTool) for tool in tools) diff --git a/src/backend/tests/auth/__init__.py b/src/backend/tests/auth/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/backend/tests/auth/test_auth_utils.py b/src/backend/tests/auth/test_auth_utils.py new file mode 100644 index 00000000..59753b56 --- /dev/null +++ b/src/backend/tests/auth/test_auth_utils.py @@ -0,0 +1,53 @@ +from unittest.mock import patch, Mock +import base64 +import json + +from src.backend.auth.auth_utils import get_authenticated_user_details, get_tenantid + + +def test_get_authenticated_user_details_with_headers(): + """Test get_authenticated_user_details with valid headers.""" + request_headers = { + "x-ms-client-principal-id": "test-user-id", + "x-ms-client-principal-name": "test-user-name", + "x-ms-client-principal-idp": "test-auth-provider", + "x-ms-token-aad-id-token": "test-auth-token", + "x-ms-client-principal": "test-client-principal-b64", + } + + result = get_authenticated_user_details(request_headers) + + assert result["user_principal_id"] == "test-user-id" + assert result["user_name"] == "test-user-name" + assert result["auth_provider"] == "test-auth-provider" + assert result["auth_token"] == "test-auth-token" + assert result["client_principal_b64"] == "test-client-principal-b64" + assert result["aad_id_token"] == "test-auth-token" + + +def test_get_tenantid_with_valid_b64(): + """Test get_tenantid with a valid base64-encoded JSON string.""" + valid_b64 = base64.b64encode( + json.dumps({"tid": "test-tenant-id"}).encode("utf-8") + ).decode("utf-8") + + tenant_id = get_tenantid(valid_b64) + + assert tenant_id == "test-tenant-id" + + +def test_get_tenantid_with_empty_b64(): + """Test get_tenantid with an empty base64 string.""" + tenant_id = get_tenantid("") + assert tenant_id == "" + + +@patch("src.backend.auth.auth_utils.logging.getLogger", return_value=Mock()) +def test_get_tenantid_with_invalid_b64(mock_logger): + """Test get_tenantid with an invalid base64-encoded string.""" + invalid_b64 = "invalid-base64" + + tenant_id = get_tenantid(invalid_b64) + + assert tenant_id == "" + mock_logger().exception.assert_called_once() diff --git a/src/backend/tests/auth/test_sample_user.py b/src/backend/tests/auth/test_sample_user.py new file mode 100644 index 00000000..730a8a60 --- /dev/null +++ b/src/backend/tests/auth/test_sample_user.py @@ -0,0 +1,84 @@ +from src.backend.auth.sample_user import sample_user # Adjust path as necessary + + +def test_sample_user_keys(): + """Verify that all expected keys are present in the sample_user dictionary.""" + expected_keys = [ + "Accept", + "Accept-Encoding", + "Accept-Language", + "Client-Ip", + "Content-Length", + "Content-Type", + "Cookie", + "Disguised-Host", + "Host", + "Max-Forwards", + "Origin", + "Referer", + "Sec-Ch-Ua", + "Sec-Ch-Ua-Mobile", + "Sec-Ch-Ua-Platform", + "Sec-Fetch-Dest", + "Sec-Fetch-Mode", + "Sec-Fetch-Site", + "Traceparent", + "User-Agent", + "Was-Default-Hostname", + "X-Appservice-Proto", + "X-Arr-Log-Id", + "X-Arr-Ssl", + "X-Client-Ip", + "X-Client-Port", + "X-Forwarded-For", + "X-Forwarded-Proto", + "X-Forwarded-Tlsversion", + "X-Ms-Client-Principal", + "X-Ms-Client-Principal-Id", + "X-Ms-Client-Principal-Idp", + "X-Ms-Client-Principal-Name", + "X-Ms-Token-Aad-Id-Token", + "X-Original-Url", + "X-Site-Deployment-Id", + "X-Waws-Unencoded-Url", + ] + assert set(expected_keys) == set(sample_user.keys()) + + +def test_sample_user_values(): + # Proceed with assertions + assert sample_user["Accept"].strip() == "*/*" # Ensure no hidden characters + assert sample_user["Content-Type"] == "application/json" + assert sample_user["Disguised-Host"] == "your_app_service.azurewebsites.net" + assert ( + sample_user["X-Ms-Client-Principal-Id"] + == "00000000-0000-0000-0000-000000000000" + ) + assert sample_user["X-Ms-Client-Principal-Name"] == "testusername@constoso.com" + assert sample_user["X-Forwarded-Proto"] == "https" + + +def test_sample_user_cookie(): + """Check if the Cookie key is present and contains an expected substring.""" + assert "AppServiceAuthSession" in sample_user["Cookie"] + + +def test_sample_user_protocol(): + """Verify protocol-related keys.""" + assert sample_user["X-Appservice-Proto"] == "https" + assert sample_user["X-Forwarded-Proto"] == "https" + assert sample_user["Sec-Fetch-Mode"] == "cors" + + +def test_sample_user_client_ip(): + """Verify the Client-Ip key.""" + assert sample_user["Client-Ip"] == "22.222.222.2222:64379" + assert sample_user["X-Client-Ip"] == "22.222.222.222" + + +def test_sample_user_user_agent(): + """Verify the User-Agent key.""" + user_agent = sample_user["User-Agent"] + assert "Mozilla/5.0" in user_agent + assert "Windows NT 10.0" in user_agent + assert "Edg/" in user_agent # Matches Edge's identifier more accurately diff --git a/src/backend/tests/context/__init__.py b/src/backend/tests/context/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/backend/tests/context/test_cosmos_memory.py b/src/backend/tests/context/test_cosmos_memory.py new file mode 100644 index 00000000..441bb1ef --- /dev/null +++ b/src/backend/tests/context/test_cosmos_memory.py @@ -0,0 +1,68 @@ +import pytest +from unittest.mock import AsyncMock, patch +from azure.cosmos.partition_key import PartitionKey +from src.backend.context.cosmos_memory import CosmosBufferedChatCompletionContext + + +# Helper to create async iterable +async def async_iterable(mock_items): + """Helper to create an async iterable.""" + for item in mock_items: + yield item + + +@pytest.fixture +def mock_env_variables(monkeypatch): + """Mock all required environment variables.""" + env_vars = { + "COSMOSDB_ENDPOINT": "https://mock-endpoint", + "COSMOSDB_KEY": "mock-key", + "COSMOSDB_DATABASE": "mock-database", + "COSMOSDB_CONTAINER": "mock-container", + "AZURE_OPENAI_DEPLOYMENT_NAME": "mock-deployment-name", + "AZURE_OPENAI_API_VERSION": "2023-01-01", + "AZURE_OPENAI_ENDPOINT": "https://mock-openai-endpoint", + } + for key, value in env_vars.items(): + monkeypatch.setenv(key, value) + + +@pytest.fixture +def mock_cosmos_client(): + """Fixture for mocking Cosmos DB client and container.""" + mock_client = AsyncMock() + mock_container = AsyncMock() + mock_client.create_container_if_not_exists.return_value = mock_container + + # Mocking context methods + mock_context = AsyncMock() + mock_context.store_message = AsyncMock() + mock_context.retrieve_messages = AsyncMock( + return_value=async_iterable([{"id": "test_id", "content": "test_content"}]) + ) + + return mock_client, mock_container, mock_context + + +@pytest.fixture +def mock_config(mock_cosmos_client): + """Fixture to patch Config with mock Cosmos DB client.""" + mock_client, _, _ = mock_cosmos_client + with patch( + "src.backend.config.Config.GetCosmosDatabaseClient", return_value=mock_client + ), patch("src.backend.config.Config.COSMOSDB_CONTAINER", "mock-container"): + yield + + +@pytest.mark.asyncio +async def test_initialize(mock_config, mock_cosmos_client): + """Test if the Cosmos DB container is initialized correctly.""" + mock_client, mock_container, _ = mock_cosmos_client + context = CosmosBufferedChatCompletionContext( + session_id="test_session", user_id="test_user" + ) + await context.initialize() + mock_client.create_container_if_not_exists.assert_called_once_with( + id="mock-container", partition_key=PartitionKey(path="/session_id") + ) + assert context._container == mock_container diff --git a/src/backend/tests/handlers/__init__.py b/src/backend/tests/handlers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/backend/tests/handlers/test_runtime_interrupt.py b/src/backend/tests/handlers/test_runtime_interrupt.py new file mode 100644 index 00000000..d2008415 --- /dev/null +++ b/src/backend/tests/handlers/test_runtime_interrupt.py @@ -0,0 +1,124 @@ +import pytest +from unittest.mock import Mock +from src.backend.handlers.runtime_interrupt import ( + NeedsUserInputHandler, + AssistantResponseHandler, +) +from src.backend.models.messages import GetHumanInputMessage, GroupChatMessage +from autogen_core.base import AgentId + + +@pytest.mark.asyncio +async def test_needs_user_input_handler_on_publish_human_input(): + """Test on_publish with GetHumanInputMessage.""" + handler = NeedsUserInputHandler() + + mock_message = Mock(spec=GetHumanInputMessage) + mock_message.content = "This is a question for the human." + + mock_sender = Mock(spec=AgentId) + mock_sender.type = "human_agent" + mock_sender.key = "human_key" + + await handler.on_publish(mock_message, sender=mock_sender) + + assert handler.needs_human_input is True + assert handler.question_content == "This is a question for the human." + assert len(handler.messages) == 1 + assert handler.messages[0]["agent"]["type"] == "human_agent" + assert handler.messages[0]["agent"]["key"] == "human_key" + assert handler.messages[0]["content"] == "This is a question for the human." + + +@pytest.mark.asyncio +async def test_needs_user_input_handler_on_publish_group_chat(): + """Test on_publish with GroupChatMessage.""" + handler = NeedsUserInputHandler() + + mock_message = Mock(spec=GroupChatMessage) + mock_message.body = Mock(content="This is a group chat message.") + + mock_sender = Mock(spec=AgentId) + mock_sender.type = "group_agent" + mock_sender.key = "group_key" + + await handler.on_publish(mock_message, sender=mock_sender) + + assert len(handler.messages) == 1 + assert handler.messages[0]["agent"]["type"] == "group_agent" + assert handler.messages[0]["agent"]["key"] == "group_key" + assert handler.messages[0]["content"] == "This is a group chat message." + + +@pytest.mark.asyncio +async def test_needs_user_input_handler_get_messages(): + """Test get_messages method.""" + handler = NeedsUserInputHandler() + + # Add mock messages + mock_message = Mock(spec=GroupChatMessage) + mock_message.body = Mock(content="Group chat content.") + mock_sender = Mock(spec=AgentId) + mock_sender.type = "group_agent" + mock_sender.key = "group_key" + + await handler.on_publish(mock_message, sender=mock_sender) + + # Retrieve messages + messages = handler.get_messages() + + assert len(messages) == 1 + assert messages[0]["agent"]["type"] == "group_agent" + assert messages[0]["agent"]["key"] == "group_key" + assert messages[0]["content"] == "Group chat content." + assert len(handler.messages) == 0 # Ensure messages are cleared + + +def test_needs_user_input_handler_properties(): + """Test properties of NeedsUserInputHandler.""" + handler = NeedsUserInputHandler() + + # Initially no human input + assert handler.needs_human_input is False + assert handler.question_content is None + + # Add a question + mock_message = Mock(spec=GetHumanInputMessage) + mock_message.content = "Human question?" + handler.question_for_human = mock_message + + assert handler.needs_human_input is True + assert handler.question_content == "Human question?" + + +@pytest.mark.asyncio +async def test_assistant_response_handler_on_publish(): + """Test on_publish in AssistantResponseHandler.""" + handler = AssistantResponseHandler() + + mock_message = Mock() + mock_message.body = Mock(content="Assistant response content.") + + mock_sender = Mock(spec=AgentId) + mock_sender.type = "writer" + mock_sender.key = "assistant_key" + + await handler.on_publish(mock_message, sender=mock_sender) + + assert handler.has_response is True + assert handler.get_response() == "Assistant response content." + + +def test_assistant_response_handler_properties(): + """Test properties of AssistantResponseHandler.""" + handler = AssistantResponseHandler() + + # Initially no response + assert handler.has_response is False + assert handler.get_response() is None + + # Set a response + handler.assistant_response = "Assistant response" + + assert handler.has_response is True + assert handler.get_response() == "Assistant response" diff --git a/src/backend/tests/middleware/__init__.py b/src/backend/tests/middleware/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/backend/tests/middleware/test_health_check.py b/src/backend/tests/middleware/test_health_check.py new file mode 100644 index 00000000..52a5a985 --- /dev/null +++ b/src/backend/tests/middleware/test_health_check.py @@ -0,0 +1,72 @@ +from src.backend.middleware.health_check import ( + HealthCheckMiddleware, + HealthCheckResult, +) +from fastapi import FastAPI +from starlette.testclient import TestClient +from asyncio import sleep + + +# Updated helper functions for test health checks +async def successful_check(): + """Simulates a successful check.""" + await sleep(0.1) # Simulate async operation + return HealthCheckResult(status=True, message="Successful check") + + +async def failing_check(): + """Simulates a failing check.""" + await sleep(0.1) # Simulate async operation + return HealthCheckResult(status=False, message="Failing check") + + +# Test application setup +app = FastAPI() + +checks = { + "success": successful_check, + "failure": failing_check, +} + +app.add_middleware(HealthCheckMiddleware, checks=checks, password="test123") + + +@app.get("/") +async def root(): + return {"message": "Hello, World!"} + + +def test_health_check_success(): + """Test the health check endpoint with successful checks.""" + client = TestClient(app) + response = client.get("/healthz") + + assert response.status_code == 503 # Because one check is failing + assert response.text == "Service Unavailable" + + +def test_root_endpoint(): + """Test the root endpoint to ensure the app is functioning.""" + client = TestClient(app) + response = client.get("/") + + assert response.status_code == 200 + assert response.json() == {"message": "Hello, World!"} + + +def test_health_check_missing_password(): + """Test the health check endpoint without a password.""" + client = TestClient(app) + response = client.get("/healthz") + + assert response.status_code == 503 # Unauthorized access without correct password + assert response.text == "Service Unavailable" + + +def test_health_check_incorrect_password(): + """Test the health check endpoint with an incorrect password.""" + client = TestClient(app) + response = client.get("/healthz?code=wrongpassword") + + assert response.status_code == 503 # Because one check is failing + assert response.text == "Service Unavailable" diff --git a/src/backend/tests/models/__init__.py b/src/backend/tests/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/backend/tests/models/test_messages.py b/src/backend/tests/models/test_messages.py new file mode 100644 index 00000000..49fb1b7f --- /dev/null +++ b/src/backend/tests/models/test_messages.py @@ -0,0 +1,122 @@ +# File: test_message.py + +import uuid +from src.backend.models.messages import ( + DataType, + BAgentType, + StepStatus, + PlanStatus, + HumanFeedbackStatus, + PlanWithSteps, + Step, + Plan, + AgentMessage, + ActionRequest, + HumanFeedback, +) + + +def test_enum_values(): + """Test enumeration values for consistency.""" + assert DataType.session == "session" + assert DataType.plan == "plan" + assert BAgentType.human_agent == "HumanAgent" + assert StepStatus.completed == "completed" + assert PlanStatus.in_progress == "in_progress" + assert HumanFeedbackStatus.requested == "requested" + + +def test_plan_with_steps_update_counts(): + """Test the update_step_counts method in PlanWithSteps.""" + step1 = Step( + plan_id=str(uuid.uuid4()), + action="Review document", + agent=BAgentType.human_agent, + status=StepStatus.completed, + session_id=str(uuid.uuid4()), + user_id=str(uuid.uuid4()), + ) + step2 = Step( + plan_id=str(uuid.uuid4()), + action="Approve document", + agent=BAgentType.hr_agent, + status=StepStatus.failed, + session_id=str(uuid.uuid4()), + user_id=str(uuid.uuid4()), + ) + plan = PlanWithSteps( + steps=[step1, step2], + session_id=str(uuid.uuid4()), + user_id=str(uuid.uuid4()), + initial_goal="Test plan goal", + ) + plan.update_step_counts() + + assert plan.total_steps == 2 + assert plan.completed == 1 + assert plan.failed == 1 + assert plan.overall_status == PlanStatus.completed + + +def test_agent_message_creation(): + """Test creation of an AgentMessage.""" + agent_message = AgentMessage( + session_id=str(uuid.uuid4()), + user_id=str(uuid.uuid4()), + plan_id=str(uuid.uuid4()), + content="Test message content", + source="System", + ) + assert agent_message.data_type == "agent_message" + assert agent_message.content == "Test message content" + + +def test_action_request_creation(): + """Test the creation of ActionRequest.""" + action_request = ActionRequest( + step_id=str(uuid.uuid4()), + plan_id=str(uuid.uuid4()), + session_id=str(uuid.uuid4()), + action="Review and approve", + agent=BAgentType.procurement_agent, + ) + assert action_request.action == "Review and approve" + assert action_request.agent == BAgentType.procurement_agent + + +def test_human_feedback_creation(): + """Test HumanFeedback creation.""" + human_feedback = HumanFeedback( + step_id=str(uuid.uuid4()), + plan_id=str(uuid.uuid4()), + session_id=str(uuid.uuid4()), + approved=True, + human_feedback="Looks good!", + ) + assert human_feedback.approved is True + assert human_feedback.human_feedback == "Looks good!" + + +def test_plan_initialization(): + """Test Plan model initialization.""" + plan = Plan( + session_id=str(uuid.uuid4()), + user_id=str(uuid.uuid4()), + initial_goal="Complete document processing", + ) + assert plan.data_type == "plan" + assert plan.initial_goal == "Complete document processing" + assert plan.overall_status == PlanStatus.in_progress + + +def test_step_defaults(): + """Test default values for Step model.""" + step = Step( + plan_id=str(uuid.uuid4()), + action="Prepare report", + agent=BAgentType.generic_agent, + session_id=str(uuid.uuid4()), + user_id=str(uuid.uuid4()), + ) + assert step.status == StepStatus.planned + assert step.human_approval_status == HumanFeedbackStatus.requested diff --git a/src/backend/tests/test_app.py b/src/backend/tests/test_app.py new file mode 100644 index 00000000..0e9f0d1e --- /dev/null +++ b/src/backend/tests/test_app.py @@ -0,0 +1,89 @@ +import os +import sys +from unittest.mock import MagicMock, patch +import pytest +from fastapi.testclient import TestClient + +# Mock Azure dependencies to prevent import errors +sys.modules["azure.monitor"] = MagicMock() +sys.modules["azure.monitor.events.extension"] = MagicMock() +sys.modules["azure.monitor.opentelemetry"] = MagicMock() + +# Mock environment variables before importing app +os.environ["COSMOSDB_ENDPOINT"] = "https://mock-endpoint" +os.environ["COSMOSDB_KEY"] = "mock-key" +os.environ["COSMOSDB_DATABASE"] = "mock-database" +os.environ["COSMOSDB_CONTAINER"] = "mock-container" +os.environ[ + "APPLICATIONINSIGHTS_CONNECTION_STRING" +] = "InstrumentationKey=mock-instrumentation-key;IngestionEndpoint=https://mock-ingestion-endpoint" +os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"] = "mock-deployment-name" +os.environ["AZURE_OPENAI_API_VERSION"] = "2023-01-01" +os.environ["AZURE_OPENAI_ENDPOINT"] = "https://mock-openai-endpoint" + +# Mock telemetry initialization to prevent errors +with patch("azure.monitor.opentelemetry.configure_azure_monitor", MagicMock()): + from src.backend.app import app + +# Initialize FastAPI test client +client = TestClient(app) + + +@pytest.fixture(autouse=True) +def mock_dependencies(monkeypatch): + """Mock dependencies to simplify tests.""" + monkeypatch.setattr( + "src.backend.auth.auth_utils.get_authenticated_user_details", + lambda headers: {"user_principal_id": "mock-user-id"}, + ) + monkeypatch.setattr( + "src.backend.utils.retrieve_all_agent_tools", + lambda: [{"agent": "test_agent", "function": "test_function"}], + ) + + +def test_input_task_invalid_json(): + """Test the case where the input JSON is invalid.""" + invalid_json = "Invalid JSON data" + + headers = {"Authorization": "Bearer mock-token"} + response = client.post("/input_task", data=invalid_json, headers=headers) + + # Assert response for invalid JSON + assert response.status_code == 422 + assert "detail" in response.json() + + +def test_input_task_missing_description(): + """Test the case where the input task description is missing.""" + input_task = { + "session_id": None, + "user_id": "mock-user-id", + } + + headers = {"Authorization": "Bearer mock-token"} + response = client.post("/input_task", json=input_task, headers=headers) + + # Assert response for missing description + assert response.status_code == 422 + assert "detail" in response.json() + + +def test_basic_endpoint(): + """Test a basic endpoint to ensure the app runs.""" + response = client.get("/") + assert response.status_code == 404 # The root endpoint is not defined + + +def test_input_task_empty_description(): + """Tests if /input_task handles an empty description.""" + empty_task = {"session_id": None, "user_id": "mock-user-id", "description": ""} + headers = {"Authorization": "Bearer mock-token"} + response = client.post("/input_task", json=empty_task, headers=headers) + + assert response.status_code == 422 + assert "detail" in response.json() # Assert error message for missing description + + +if __name__ == "__main__": + pytest.main() diff --git a/src/backend/tests/test_config.py b/src/backend/tests/test_config.py new file mode 100644 index 00000000..3c4b0efe --- /dev/null +++ b/src/backend/tests/test_config.py @@ -0,0 +1,62 @@ +# tests/test_config.py +from unittest.mock import patch +import os + +# Mock environment variables globally +MOCK_ENV_VARS = { + "COSMOSDB_ENDPOINT": "https://mock-cosmosdb.documents.azure.com:443/", + "COSMOSDB_DATABASE": "mock_database", + "COSMOSDB_CONTAINER": "mock_container", + "AZURE_OPENAI_DEPLOYMENT_NAME": "mock-deployment", + "AZURE_OPENAI_API_VERSION": "2024-05-01-preview", + "AZURE_OPENAI_ENDPOINT": "https://mock-openai-endpoint.azure.com/", + "AZURE_OPENAI_API_KEY": "mock-api-key", + "AZURE_TENANT_ID": "mock-tenant-id", + "AZURE_CLIENT_ID": "mock-client-id", + "AZURE_CLIENT_SECRET": "mock-client-secret", +} + +with patch.dict(os.environ, MOCK_ENV_VARS): + from src.backend.config import ( + Config, + GetRequiredConfig, + GetOptionalConfig, + GetBoolConfig, + ) + + +@patch.dict(os.environ, MOCK_ENV_VARS) +def test_get_required_config(): + """Test GetRequiredConfig.""" + assert GetRequiredConfig("COSMOSDB_ENDPOINT") == MOCK_ENV_VARS["COSMOSDB_ENDPOINT"] + + +@patch.dict(os.environ, MOCK_ENV_VARS) +def test_get_optional_config(): + """Test GetOptionalConfig.""" + assert GetOptionalConfig("NON_EXISTENT_VAR", "default_value") == "default_value" + assert ( + GetOptionalConfig("COSMOSDB_DATABASE", "default_db") + == MOCK_ENV_VARS["COSMOSDB_DATABASE"] + ) + + +@patch.dict(os.environ, MOCK_ENV_VARS) +def test_get_bool_config(): + """Test GetBoolConfig.""" + with patch.dict("os.environ", {"FEATURE_ENABLED": "true"}): + assert GetBoolConfig("FEATURE_ENABLED") is True + with patch.dict("os.environ", {"FEATURE_ENABLED": "false"}): + assert GetBoolConfig("FEATURE_ENABLED") is False + with patch.dict("os.environ", {"FEATURE_ENABLED": "1"}): + assert GetBoolConfig("FEATURE_ENABLED") is True + with patch.dict("os.environ", {"FEATURE_ENABLED": "0"}): + assert GetBoolConfig("FEATURE_ENABLED") is False + + +@patch("config.DefaultAzureCredential") +def test_get_azure_credentials_with_env_vars(mock_default_cred): + """Test Config.GetAzureCredentials with explicit credentials.""" + with patch.dict(os.environ, MOCK_ENV_VARS): + creds = Config.GetAzureCredentials() + assert creds is not None diff --git a/src/backend/tests/test_otlp_tracing.py b/src/backend/tests/test_otlp_tracing.py new file mode 100644 index 00000000..1b6da903 --- /dev/null +++ b/src/backend/tests/test_otlp_tracing.py @@ -0,0 +1,38 @@ +import sys +import os +from unittest.mock import patch, MagicMock +from src.backend.otlp_tracing import configure_oltp_tracing # Import directly since it's in backend + +# Add the backend directory to the Python path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + + +@patch("src.backend.otlp_tracing.TracerProvider") +@patch("src.backend.otlp_tracing.OTLPSpanExporter") +@patch("src.backend.otlp_tracing.Resource") +def test_configure_oltp_tracing( + mock_resource, + mock_otlp_exporter, + mock_tracer_provider, +): + # Mock the Resource + mock_resource_instance = MagicMock() + mock_resource.return_value = mock_resource_instance + + # Mock TracerProvider + mock_tracer_provider_instance = MagicMock() + mock_tracer_provider.return_value = mock_tracer_provider_instance + + # Mock OTLPSpanExporter + mock_otlp_exporter_instance = MagicMock() + mock_otlp_exporter.return_value = mock_otlp_exporter_instance + + # Call the function + endpoint = "mock-endpoint" + tracer_provider = configure_oltp_tracing(endpoint=endpoint) + + # Assertions + mock_tracer_provider.assert_called_once_with(resource=mock_resource_instance) + mock_otlp_exporter.assert_called_once_with() + mock_tracer_provider_instance.add_span_processor.assert_called_once() + assert tracer_provider == mock_tracer_provider_instance diff --git a/src/backend/tests/test_utils.py b/src/backend/tests/test_utils.py new file mode 100644 index 00000000..e5f4734e --- /dev/null +++ b/src/backend/tests/test_utils.py @@ -0,0 +1,81 @@ +from unittest.mock import patch, MagicMock +import pytest +from src.backend.utils import ( + initialize_runtime_and_context, + retrieve_all_agent_tools, + rai_success, + runtime_dict, +) +from autogen_core.application import SingleThreadedAgentRuntime +from src.backend.context.cosmos_memory import CosmosBufferedChatCompletionContext + + +@pytest.fixture(scope="function", autouse=True) +def mock_telemetry(): + """Mock telemetry and threading-related components to prevent access violations.""" + with patch("opentelemetry.sdk.trace.export.BatchSpanProcessor", MagicMock()): + yield + + +@patch("src.backend.utils.get_hr_tools", MagicMock(return_value=[])) +@patch("src.backend.utils.get_marketing_tools", MagicMock(return_value=[])) +@patch("src.backend.utils.get_procurement_tools", MagicMock(return_value=[])) +@patch("src.backend.utils.get_product_tools", MagicMock(return_value=[])) +@patch("src.backend.utils.get_tech_support_tools", MagicMock(return_value=[])) +def test_retrieve_all_agent_tools(): + """Test retrieval of all agent tools with mocked dependencies.""" + tools = retrieve_all_agent_tools() + assert isinstance(tools, list) + assert len(tools) == 0 # Mocked to return no tools + + +@pytest.mark.asyncio +@patch("src.backend.utils.Config.GetAzureOpenAIChatCompletionClient", MagicMock()) +async def test_initialize_runtime_and_context(): + """Test initialization of runtime and context with mocked Azure client.""" + session_id = "test-session-id" + user_id = "test-user-id" + + runtime, context = await initialize_runtime_and_context(session_id, user_id) + + # Validate runtime and context types + assert isinstance(runtime, SingleThreadedAgentRuntime) + assert isinstance(context, CosmosBufferedChatCompletionContext) + + # Validate caching + assert session_id in runtime_dict + assert runtime_dict[session_id] == (runtime, context) + + +@pytest.mark.asyncio +async def test_initialize_runtime_and_context_missing_user_id(): + """Test ValueError when user_id is missing.""" + with pytest.raises(ValueError, match="The 'user_id' parameter cannot be None"): + await initialize_runtime_and_context(session_id="test-session-id", user_id=None) + + +@patch("src.backend.utils.requests.post") +@patch("src.backend.utils.DefaultAzureCredential") +def test_rai_success(mock_credential, mock_post): + """Test successful RAI response with mocked requests and credentials.""" + mock_credential.return_value.get_token.return_value.token = "mock-token" + mock_post.return_value.json.return_value = { + "choices": [{"message": {"content": "FALSE"}}] + } + + description = "Test RAI success" + result = rai_success(description) + assert result is True + mock_post.assert_called_once() + + +@patch("src.backend.utils.requests.post") +@patch("src.backend.utils.DefaultAzureCredential") +def test_rai_success_invalid_response(mock_credential, mock_post): + """Test RAI response with an invalid format.""" + mock_credential.return_value.get_token.return_value.token = "mock-token" + mock_post.return_value.json.return_value = {"unexpected_key": "value"} + + description = "Test invalid response" + result = rai_success(description) + assert result is False diff --git a/src/backend/utils.py b/src/backend/utils.py index 70c5f9a5..7d4fa19e 100644 --- a/src/backend/utils.py +++ b/src/backend/utils.py @@ -10,20 +10,21 @@ from autogen_core.components.tool_agent import ToolAgent from autogen_core.components.tools import Tool -from agents.group_chat_manager import GroupChatManager -from agents.hr import HrAgent, get_hr_tools -from agents.human import HumanAgent -from agents.marketing import MarketingAgent, get_marketing_tools -from agents.planner import PlannerAgent -from agents.procurement import ProcurementAgent, get_procurement_tools -from agents.product import ProductAgent, get_product_tools -from agents.generic import GenericAgent, get_generic_tools -from agents.tech_support import TechSupportAgent, get_tech_support_tools +from src.backend.agents.group_chat_manager import GroupChatManager +from src.backend.agents.hr import HrAgent, get_hr_tools +from src.backend.agents.human import HumanAgent +from src.backend.agents.marketing import MarketingAgent, get_marketing_tools +from src.backend.agents.planner import PlannerAgent +from src.backend.agents.procurement import ProcurementAgent, get_procurement_tools +from src.backend.agents.product import ProductAgent, get_product_tools +from src.backend.agents.generic import GenericAgent, get_generic_tools +from src.backend.agents.tech_support import TechSupportAgent, get_tech_support_tools # from agents.misc import MiscAgent -from config import Config -from context.cosmos_memory import CosmosBufferedChatCompletionContext -from models.messages import BAgentType +from src.backend.config import Config +from src.backend.context.cosmos_memory import CosmosBufferedChatCompletionContext +from src.backend.models.messages import BAgentType +# from collections import defaultdict # Initialize logging # from otlp_tracing import configure_oltp_tracing @@ -68,8 +69,6 @@ async def initialize_runtime_and_context( Returns: Tuple[SingleThreadedAgentRuntime, CosmosBufferedChatCompletionContext]: The runtime and context for the session. """ - global runtime_dict - global aoai_model_client if user_id is None: raise ValueError( diff --git a/src/frontend/wwwroot/app.js b/src/frontend/wwwroot/app.js index 028ed4ba..32acefbc 100644 --- a/src/frontend/wwwroot/app.js +++ b/src/frontend/wwwroot/app.js @@ -1,19 +1,19 @@ (() => { window.headers = GetAuthDetails(); - const apiEndpoint = sessionStorage.getItem('apiEndpoint') || BACKEND_API_URL; + const apiEndpoint = getStoredData('apiEndpoint') || BACKEND_API_URL; const goHomeButton = document.getElementById("goHomeButton"); const newTaskButton = document.getElementById("newTaskButton"); const closeModalButtons = document.querySelectorAll(".modal-close-button"); const myTasksMenu = document.getElementById("myTasksMenu"); const tasksStats = document.getElementById("tasksStats"); - //if (!sessionStorage.getItem('apiEndpoint')) sessionStorage.setItem('apiEndpoint', apiEndpoint); + //if (!getStoredData('apiEndpoint'))setStoredData('apiEndpoint', apiEndpoint); // Force rewrite of apiEndpoint - sessionStorage.setItem('apiEndpoint', apiEndpoint); - sessionStorage.setItem('context', 'employee'); + setStoredData('apiEndpoint', apiEndpoint); + setStoredData('context', 'employee'); // Refresh rate is set - if (!sessionStorage.getItem('apiRefreshRate')) sessionStorage.setItem('apiRefreshRate', 5000); - if (!sessionStorage.getItem('actionStagesRun')) sessionStorage.setItem('actionStagesRun', []); + if (!getStoredData('apiRefreshRate'))setStoredData('apiRefreshRate', 5000); + if (!getStoredData('actionStagesRun'))setStoredData('actionStagesRun', []); const getQueryParam = (param) => { const urlParams = new URLSearchParams(window.location.search); @@ -30,7 +30,7 @@ const viewIframe = document.getElementById('viewIframe'); if (viewIframe) { const viewRoute = getQueryParam('v'); - const viewContext = sessionStorage.getItem('context'); + const viewContext = getStoredData('context'); const noCache = '?nocache=' + new Date().getTime(); switch (viewRoute) { case 'home': @@ -112,7 +112,7 @@ } const fetchTasksIfNeeded = async () => { - const taskStore = JSON.parse(sessionStorage.getItem('task')); + const taskStore = JSON.parse(getStoredData('task')); window.headers .then(headers => { fetch(apiEndpoint + '/plans', { @@ -164,7 +164,7 @@ setQueryParam('v', 'task'); switchView(); - sessionStorage.setItem('task', JSON.stringify({ + setStoredData('task', JSON.stringify({ id: sessionId, name: taskName })); @@ -222,7 +222,7 @@ if (!userInfo) { console.error("Authentication failed. Access to tasks is restricted."); } else { - sessionStorage.setItem('userInfo', userInfo); + setStoredData('userInfo', userInfo); await fetchTasksIfNeeded(); // Fetch tasks after initialization if needed } }; diff --git a/src/frontend/wwwroot/home/home.js b/src/frontend/wwwroot/home/home.js index 18d5336d..00cd0475 100644 --- a/src/frontend/wwwroot/home/home.js +++ b/src/frontend/wwwroot/home/home.js @@ -4,7 +4,7 @@ ripple: false, duration: 3000, }); - const apiEndpoint = sessionStorage.getItem("apiEndpoint"); + const apiEndpoint = getStoredData("apiEndpoint"); const newTaskPrompt = document.getElementById("newTaskPrompt"); const startTaskButton = document.getElementById("startTaskButton"); const startTaskButtonContainer = document.querySelector(".send-button"); diff --git a/src/frontend/wwwroot/task/task.js b/src/frontend/wwwroot/task/task.js index acd716cb..f7203ce2 100644 --- a/src/frontend/wwwroot/task/task.js +++ b/src/frontend/wwwroot/task/task.js @@ -1,13 +1,14 @@ (() => { const markdownConverter = new showdown.Converter(); - const apiEndpoint = sessionStorage.getItem("apiEndpoint"); - const taskStore = JSON.parse(sessionStorage.getItem("task")); + const apiEndpoint = getStoredData("apiEndpoint"); + const taskStore = JSON.parse(getStoredData("task")); const taskName = document.getElementById("taskName"); const taskStatusTag = document.getElementById("taskStatusTag"); const taskStagesMenu = document.getElementById("taskStagesMenu"); const taskPauseButton = document.getElementById("taskPauseButton"); const taskAgentsButton = document.getElementById("taskAgentsButton"); const taskWokFlowButton = document.getElementById("taskWokFlowButton"); + const taskMessageTextarea=document.getElementById("taskMessageTextarea"); const taskMessageAddButton = document.getElementById("taskMessageAddButton"); const taskMessages = document.getElementById("taskMessages"); const taskDetailsAgents = document.getElementById("taskDetailsAgents"); @@ -91,12 +92,12 @@ agentIcon = "manager"; break; case "HumanAgent": - let userNumber = sessionStorage.getItem("userNumber"); + let userNumber = getStoredData("userNumber"); if (userNumber == null) { // Generate a random number between 0 and 6 userNumber = Math.floor(Math.random() * 6); // Create the icon name by concatenating 'user' with the random number - sessionStorage.setItem("userNumber", userNumber); + setStoredData("userNumber", userNumber); } let iconName = "user" + userNumber; agentIcon = iconName; @@ -160,8 +161,23 @@ if (taskCancelButton) { taskCancelButton.addEventListener("click", (event) => { - const apiTaskStore = JSON.parse(sessionStorage.getItem("apiTask")); + const apiTaskStore = JSON.parse(getStoredData("apiTask")); handleDisableOfActions("completed") + + // Explicitly disable chatbox and message button + taskMessageTextarea.disabled = true; + taskMessageTextarea.style.backgroundColor = "#efefef"; + taskMessageTextarea.style.cursor = 'not-allowed'; + + taskMessageAddButton.disabled = true; + taskMessageAddButton.style.cursor = 'not-allowed'; + + const textInputContainer = document.getElementsByClassName("text-input-container"); + if (textInputContainer[0]) { + textInputContainer[0].style.backgroundColor = '#efefef'; + textInputContainer[0].style.cursor = 'not-allowed'; + } + actionStages(apiTaskStore, false); }); } @@ -216,37 +232,37 @@ updateTaskProgress(data[0]); fetchTaskStages(data[0]); - sessionStorage.setItem("apiTask", JSON.stringify(data[0])); - const isHumanClarificationRequestNull = data?.[0]?.human_clarification_request === null + setStoredData("apiTask", JSON.stringify(data[0])); + //const isHumanClarificationRequestNull = data?.[0]?.human_clarification_request === null + const isHumanClarificationResponseNotNull = data?.[0]?.human_clarification_response !== null; const taskMessageTextareaElement =document.getElementById("taskMessageTextarea"); const taskMessageAddButton = document.getElementById("taskMessageAddButton"); const textInputContainer = document.getElementsByClassName("text-input-container"); + if (isHumanClarificationResponseNotNull) { + // Update the local state to set human_clarification_request to null + data[0].human_clarification_request = null; + console.log("Human clarification request set to null locally."); + } + + const isHumanClarificationRequestNull = data?.[0]?.human_clarification_request === null + if(isHumanClarificationRequestNull && taskMessageTextareaElement){ taskMessageTextareaElement.setAttribute('disabled', true) taskMessageTextareaElement.style.backgroundColor = "#efefef"; taskMessageTextareaElement.style.cursor = 'not-allowed'; - } else { - taskMessageTextareaElement.removeAttribute('disabled') - taskMessageTextareaElement.style.backgroundColor = "white" - taskMessageTextareaElement.style.cursor = ''; - } + } + if(isHumanClarificationRequestNull && taskMessageAddButton){ taskMessageAddButton.setAttribute('disabled', true) taskMessageAddButton.style.cursor = 'not-allowed'; - } else { - taskMessageAddButton.removeAttribute('disabled') - taskMessageAddButton.style.cursor = 'pointer'; - } - + } + if(isHumanClarificationRequestNull && textInputContainer[0]){ textInputContainer[0].style.backgroundColor = '#efefef'; textInputContainer[0].style.cursor = 'not-allowed'; - } else { - textInputContainer[0].style.backgroundColor = 'white'; - textInputContainer[0].style.cursor = ''; - } - + } + }) .catch((error) => { console.error("Error:", error); @@ -359,15 +375,15 @@ updateTaskDetailsAgents([...new Set(taskAgents)]); - sessionStorage.setItem("showApproveAll", false); + setStoredData("showApproveAll", false); // Feature approve all removed for this version // if (isHumanFeedbackPending()) { - // sessionStorage.setItem('showApproveAll', false); + // setStoredData('showApproveAll', false); // console.log('showApproveAll status', "showApproveAll is false"); // } else { - // sessionStorage.setItem('showApproveAll', taskStageApprovalStatus === taskStageCount); + // setStoredData('showApproveAll', taskStageApprovalStatus === taskStageCount); // console.log('showApproveAll status', taskStageApprovalStatus === taskStageCount); // } @@ -434,8 +450,8 @@ // console.log(groupByStepId(data)); if ( - sessionStorage.getItem("context") && - sessionStorage.getItem("context") === "customer" + getStoredData("context") && + getStoredData("context") === "customer" ) { data = contextFilter(data); } @@ -461,7 +477,7 @@ messages.forEach((message) => { const messageItem = document.createElement("div"); const showApproveAll = - sessionStorage.getItem("showApproveAll") === "true" && + getStoredData("showApproveAll") === "true" && data.length === messageCount; let approveAllStagesButton = ""; @@ -475,8 +491,8 @@ : "has-status-active"; if ( - sessionStorage.getItem("context") && - sessionStorage.getItem("context") !== "customer" + getStoredData("context") && + getStoredData("context") !== "customer" ) { if (showApproveAll) { console.log("Creating approveAllStagesButton"); @@ -534,8 +550,8 @@ taskMessages.appendChild(messageItem); if ( - sessionStorage.getItem("context") && - sessionStorage.getItem("context") !== "customer" + getStoredData("context") && + getStoredData("context") !== "customer" ) { if (showApproveAll) { document @@ -560,20 +576,19 @@ } if ( - sessionStorage.getItem("context") && - sessionStorage.getItem("context") === "customer" && - !sessionStorage - .getItem("actionStagesRun") + getStoredData("context") && + getStoredData("context") === "customer" && + !getStoredData("actionStagesRun") .includes(task.session_id) ) { actionStages(task, true); let actionStagesRun = JSON.parse( - sessionStorage.getItem("actionStagesRun") || "[]" + getStoredData("actionStagesRun") || "[]" ); actionStagesRun.push(task.session_id); - sessionStorage.setItem( + setStoredData( "actionStagesRun", JSON.stringify(actionStagesRun) ); @@ -618,12 +633,18 @@ } else if (task.overall_status === "in_progress") { removeClassesExcept(taskStatusTag, "tag"); taskStatusTag.classList.add("is-info"); + const iconElement = taskPauseButton.querySelector("i"); + if (iconElement.classList.contains("fa-circle-play")) { + iconElement.classList.remove("fa-circle-play"); + iconElement.classList.add("fa-circle-pause"); + } + } handleDisableOfActions(task.overall_status) }; const isHumanFeedbackPending = () => { - const storedData = sessionStorage.getItem("apiTask"); + const storedData = getStoredData("apiTask"); const planDetails = JSON.parse(storedData); return ( planDetails.human_clarification_request !== null && @@ -767,11 +788,7 @@ // Update the lastDataHash to the new hash lastDataHash = newDataHash; - // Continue polling by calling fetchLoop again - setTimeout( - () => fetchLoop(id), - Number(sessionStorage.getItem("apiRefreshRate")) - ); + } catch (error) { console.error("Error in fetchLoop:", error); } diff --git a/src/frontend/wwwroot/utils.js b/src/frontend/wwwroot/utils.js index 501e534b..ef816dc7 100644 --- a/src/frontend/wwwroot/utils.js +++ b/src/frontend/wwwroot/utils.js @@ -68,3 +68,22 @@ window.GetAuthDetails = async () => { return headers; } }; + +window.getStoredData = (key)=> { + let data = localStorage.getItem(key); + + // If not found in localStorage, check sessionStorage + if (!data) { + data = sessionStorage.getItem(key); + if (data) { + // Move data from sessionStorage to localStorage + setStoredData(key, data); + sessionStorage.removeItem(key); // Optional cleanup + } + } + return data; +} + +window.setStoredData = (key, value)=> { + localStorage.setItem(key, value) +} \ No newline at end of file