Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/CustomizingAzdParameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ By default this template will use the environment name as the prefix to prevent
| `AZURE_ENV_MODEL_CAPACITY` | int | `150` | Sets the GPT model capacity. |
| `AZURE_ENV_IMAGETAG` | string | `latest` | Docker image tag used for container deployments. |
| `AZURE_ENV_ENABLE_TELEMETRY` | bool | `true` | Enables telemetry for monitoring and diagnostics. |
| `AZURE_ENV_LOG_ANALYTICS_WORKSPACE_ID` | string | `<Existing Workspace Id>` | Set this if you want to reuse an existing Log Analytics Workspace instead of creating a new one. |
| `AZURE_ENV_LOG_ANALYTICS_WORKSPACE_ID` | string | Guide to get your [Existing Workspace ID](/docs/re-use-log-analytics.md) | Set this if you want to reuse an existing Log Analytics Workspace instead of creating a new one. |
---

## How to Set a Parameter
Expand Down
3 changes: 3 additions & 0 deletions infra/modules/role.bicep
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ resource aiUserAccessFoundry 'Microsoft.Authorization/roleAssignments@2022-04-01
properties: {
roleDefinitionId: aiUser.id
principalId: principalId
principalType: 'ServicePrincipal'
}
}

Expand All @@ -38,6 +39,7 @@ resource aiDeveloperAccessFoundry 'Microsoft.Authorization/roleAssignments@2022-
properties: {
roleDefinitionId: aiDeveloper.id
principalId: principalId
principalType: 'ServicePrincipal'
}
}

Expand All @@ -47,5 +49,6 @@ resource cognitiveServiceOpenAIUserAccessFoundry 'Microsoft.Authorization/roleAs
properties: {
roleDefinitionId: cognitiveServiceOpenAIUser.id
principalId: principalId
principalType: 'ServicePrincipal'
}
}
1 change: 1 addition & 0 deletions infra/old/deploy_ai_foundry.bicep
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ resource aiDevelopertoAIProject 'Microsoft.Authorization/roleAssignments@2022-04
properties: {
roleDefinitionId: aiDeveloper.id
principalId: aiHubProject.identity.principalId
principalType: 'ServicePrincipal'
}
}

Expand Down
2 changes: 2 additions & 0 deletions infra/old/main.bicep
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,7 @@ module aiFoundryStorageAccount 'br/public:avm/res/storage/storage-account:0.18.2
{
principalId: userAssignedIdentity.outputs.principalId
roleDefinitionIdOrName: 'Storage Blob Data Contributor'
principalType: 'ServicePrincipal'
}
]
}
Expand Down Expand Up @@ -760,6 +761,7 @@ module aiFoundryAiProject 'br/public:avm/res/machine-learning-services/workspace
principalId: containerApp.outputs.?systemAssignedMIPrincipalId!
// Assigning the role with the role name instead of the role ID freezes the deployment at this point
roleDefinitionIdOrName: '64702f94-c441-49e6-a78b-ef80e0188fee' //'Azure AI Developer'
principalType: 'ServicePrincipal'
}
]
}
Expand Down
6 changes: 1 addition & 5 deletions infra/scripts/quota_check_params.sh
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,7 @@ for REGION in "${REGIONS[@]}"; do
FOUND=false
INSUFFICIENT_QUOTA=false

if [ "$MODEL_NAME" = "text-embedding-ada-002" ]; then
MODEL_TYPES=("openai.standard.$MODEL_NAME")
else
MODEL_TYPES=("openai.standard.$MODEL_NAME" "openai.globalstandard.$MODEL_NAME")
fi
MODEL_TYPES=("openai.standard.$MODEL_NAME" "openai.globalstandard.$MODEL_NAME")

for MODEL_TYPE in "${MODEL_TYPES[@]}"; do
FOUND=false
Expand Down
2 changes: 1 addition & 1 deletion infra/scripts/validate_model_quota.ps1
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
param (
[string]$Location,
[string]$Model,
[string]$DeploymentType = "Standard",
[string]$DeploymentType = "GlobalStandard",
[int]$Capacity
)

Expand Down
2 changes: 1 addition & 1 deletion infra/scripts/validate_model_quota.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

LOCATION=""
MODEL=""
DEPLOYMENT_TYPE="Standard"
DEPLOYMENT_TYPE="GlobalStandard"
CAPACITY=0

ALL_REGIONS=('australiaeast' 'eastus2' 'francecentral' 'japaneast' 'norwayeast' 'swedencentral' 'uksouth' 'westus')
Expand Down
128 changes: 128 additions & 0 deletions src/backend/app_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
HumanClarification,
HumanFeedback,
InputTask,
Plan,
PlanStatus,
PlanWithSteps,
Step,
)
Expand Down Expand Up @@ -188,6 +190,132 @@ async def input_task_endpoint(input_task: InputTask, request: Request):
raise HTTPException(status_code=400, detail=f"Error creating plan: {e}")


@app.post("/api/create_plan")
async def create_plan_endpoint(input_task: InputTask, request: Request):
"""
Create a new plan without full processing.

---
tags:
- Plans
parameters:
- name: user_principal_id
in: header
type: string
required: true
description: User ID extracted from the authentication header
- name: body
in: body
required: true
schema:
type: object
properties:
session_id:
type: string
description: Session ID for the plan
description:
type: string
description: The task description to validate and create plan for
responses:
200:
description: Plan created successfully
schema:
type: object
properties:
plan_id:
type: string
description: The ID of the newly created plan
status:
type: string
description: Success message
session_id:
type: string
description: Session ID associated with the plan
400:
description: RAI check failed or invalid input
schema:
type: object
properties:
detail:
type: string
description: Error message
"""
# Perform RAI check on the description
if not await rai_success(input_task.description):
track_event_if_configured(
"RAI failed",
{
"status": "Plan not created - RAI check failed",
"description": input_task.description,
"session_id": input_task.session_id,
},
)
raise HTTPException(
status_code=400,
detail="Task description failed safety validation. Please revise your request."
)

# Get authenticated user
authenticated_user = get_authenticated_user_details(request_headers=request.headers)
user_id = authenticated_user["user_principal_id"]

if not user_id:
track_event_if_configured(
"UserIdNotFound", {"status_code": 400, "detail": "no user"}
)
raise HTTPException(status_code=400, detail="no user")

# Generate session ID if not provided
if not input_task.session_id:
input_task.session_id = str(uuid.uuid4())

try:
# Initialize memory store
kernel, memory_store = await initialize_runtime_and_context(
input_task.session_id, user_id
)

# Create a new Plan object
plan = Plan(
session_id=input_task.session_id,
user_id=user_id,
initial_goal=input_task.description,
overall_status=PlanStatus.in_progress,
source=AgentType.PLANNER.value
)

# Save the plan to the database
await memory_store.add_plan(plan)

# Log successful plan creation
track_event_if_configured(
"PlanCreated",
{
"status": f"Plan created with ID: {plan.id}",
"session_id": input_task.session_id,
"plan_id": plan.id,
"description": input_task.description,
},
)

return {
"plan_id": plan.id,
"status": "Plan created successfully",
"session_id": input_task.session_id,
}

except Exception as e:
track_event_if_configured(
"CreatePlanError",
{
"session_id": input_task.session_id,
"description": input_task.description,
"error": str(e),
},
)
raise HTTPException(status_code=400, detail=f"Error creating plan: {e}")


@app.post("/api/human_feedback")
async def human_feedback_endpoint(human_feedback: HumanFeedback, request: Request):
"""
Expand Down
126 changes: 117 additions & 9 deletions src/backend/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
sys.modules["azure.monitor"] = MagicMock()
sys.modules["azure.monitor.events.extension"] = MagicMock()
sys.modules["azure.monitor.opentelemetry"] = MagicMock()
sys.modules["azure.ai.projects"] = MagicMock()
sys.modules["azure.ai.projects.aio"] = MagicMock()

# Mock environment variables before importing app
os.environ["COSMOSDB_ENDPOINT"] = "https://mock-endpoint"
Expand All @@ -23,7 +25,7 @@

# Mock telemetry initialization to prevent errors
with patch("azure.monitor.opentelemetry.configure_azure_monitor", MagicMock()):
from src.backend.app import app
from app_kernel import app

# Initialize FastAPI test client
client = TestClient(app)
Expand All @@ -33,13 +35,9 @@
def mock_dependencies(monkeypatch):
"""Mock dependencies to simplify tests."""
monkeypatch.setattr(
"src.backend.auth.auth_utils.get_authenticated_user_details",
"auth.auth_utils.get_authenticated_user_details",
lambda headers: {"user_principal_id": "mock-user-id"},
)
monkeypatch.setattr(
"src.backend.utils.retrieve_all_agent_tools",
lambda: [{"agent": "test_agent", "function": "test_function"}],
)


def test_input_task_invalid_json():
Expand All @@ -49,9 +47,119 @@ def test_input_task_invalid_json():
headers = {"Authorization": "Bearer mock-token"}
response = client.post("/input_task", data=invalid_json, headers=headers)

# Assert response for invalid JSON
assert response.status_code == 422
assert "detail" in response.json()

def test_create_plan_endpoint_success():
"""Test the /api/create_plan endpoint with valid input."""
headers = {"Authorization": "Bearer mock-token"}

# Mock the RAI success function
with patch("app_kernel.rai_success", return_value=True), \
patch("app_kernel.initialize_runtime_and_context") as mock_init, \
patch("app_kernel.track_event_if_configured") as mock_track:

# Mock memory store
mock_memory_store = MagicMock()
mock_init.return_value = (MagicMock(), mock_memory_store)

test_input = {
"session_id": "test-session-123",
"description": "Create a marketing plan for our new product"
}

response = client.post("/api/create_plan", json=test_input, headers=headers)

# Print response details for debugging
print(f"Response status: {response.status_code}")
print(f"Response data: {response.json()}")

# Check response
assert response.status_code == 200
data = response.json()
assert "plan_id" in data
assert "status" in data
assert "session_id" in data
assert data["status"] == "Plan created successfully"
assert data["session_id"] == "test-session-123"

# Verify memory store was called to add plan
mock_memory_store.add_plan.assert_called_once()


def test_create_plan_endpoint_rai_failure():
"""Test the /api/create_plan endpoint when RAI check fails."""
headers = {"Authorization": "Bearer mock-token"}

# Mock the RAI failure
with patch("app_kernel.rai_success", return_value=False), \
patch("app_kernel.track_event_if_configured") as mock_track:

test_input = {
"session_id": "test-session-123",
"description": "This is an unsafe description"
}

response = client.post("/api/create_plan", json=test_input, headers=headers)

# Check response
assert response.status_code == 400
data = response.json()
assert "detail" in data
assert "safety validation" in data["detail"]


def test_create_plan_endpoint_harmful_content():
"""Test the /api/create_plan endpoint with harmful content that should fail RAI."""
headers = {"Authorization": "Bearer mock-token"}

# Mock the RAI failure for harmful content
with patch("app_kernel.rai_success", return_value=False), \
patch("app_kernel.track_event_if_configured") as mock_track:

test_input = {
"session_id": "test-session-456",
"description": "I want to kill my neighbors cat"
}

response = client.post("/api/create_plan", json=test_input, headers=headers)

# Print response details for debugging
print(f"Response status: {response.status_code}")
print(f"Response data: {response.json()}")

# Check response - should be 400 due to RAI failure
assert response.status_code == 400
data = response.json()
assert "detail" in data
assert "safety validation" in data["detail"]


def test_create_plan_endpoint_real_rai_check():
"""Test the /api/create_plan endpoint with real RAI check (no mocking)."""
headers = {"Authorization": "Bearer mock-token"}

# Don't mock RAI - let it run the real check
with patch("app_kernel.initialize_runtime_and_context") as mock_init, \
patch("app_kernel.track_event_if_configured") as mock_track:

# Mock memory store
mock_memory_store = MagicMock()
mock_init.return_value = (MagicMock(), mock_memory_store)

test_input = {
"session_id": "test-session-789",
"description": "I want to kill my neighbors cat"
}

response = client.post("/api/create_plan", json=test_input, headers=headers)

# Print response details for debugging
print(f"Real RAI Response status: {response.status_code}")
print(f"Real RAI Response data: {response.json()}")

# This should fail with real RAI check
assert response.status_code == 400
data = response.json()
assert "detail" in data


def test_input_task_missing_description():
Expand Down
Loading
Loading