diff --git a/AGENTS.md b/AGENTS.md index e5e8ffe217..0021d98852 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -78,14 +78,37 @@ When adding a new feature, add tests for it in the appropriate file. If the feature is a UI element, add an e2e test for it. If it is an API endpoint, add an app integration test for it. If it is a function or method, add a unit test for it. -Use mocks from conftest.py to mock external services. +Use mocks from tests/conftest.py to mock external services. Prefer mocking at the HTTP/requests level when possible. When you're running tests, make sure you activate the .venv virtual environment first: -```bash +```shell source .venv/bin/activate ``` ## Sending pull requests When sending pull requests, make sure to follow the PULL_REQUEST_TEMPLATE.md format. + +## Upgrading dependencies + +To upgrade a particular package in the backend, use the following command, replacing `` with the name of the package you want to upgrade: + +```shell +cd app/backend && uv pip compile requirements.in -o requirements.txt --python-version 3.9 --upgrade-package package-name +``` + +## Checking Python type hints + +To check Python type hints, use the following command: + +```shell +cd app/backend && mypy . --config-file=../pyproject.toml +``` + +```shell +cd scripts && mypy . --config-file=../pyproject.toml +``` + +Note that we do not currently enforce type hints in the tests folder, as it would require adding a lot of `# type: ignore` comments to the existing tests. +We only enforce type hints in the main application code and scripts. diff --git a/app/backend/requirements.txt b/app/backend/requirements.txt index 1af2e14c29..baa909d210 100644 --- a/app/backend/requirements.txt +++ b/app/backend/requirements.txt @@ -199,7 +199,7 @@ msal-extensions==1.3.1 # via azure-identity msgraph-core==1.3.3 # via msgraph-sdk -msgraph-sdk==1.26.0 +msgraph-sdk==1.45.0 # via -r requirements.in msrest==0.7.1 # via azure-monitor-opentelemetry-exporter @@ -431,7 +431,6 @@ typing-extensions==4.13.2 # pypdf # quart # quart-cors - # rich # taskgroup # uvicorn urllib3==2.5.0 diff --git a/tests/mocks.py b/tests/mocks.py index 503b5b5f99..6228d84e34 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -34,8 +34,9 @@ class MockAzureCredential(AsyncTokenCredential): - async def get_token(self, uri): - return MockToken("", 9999999999, "") + async def get_token(self, *scopes, **kwargs): # accept claims, enable_cae, etc. + # Return a simple mock token structure with required attributes + return MockToken("mock-token", 9999999999, "mock-token") class MockAzureCredentialExpired(AsyncTokenCredential): @@ -43,7 +44,7 @@ class MockAzureCredentialExpired(AsyncTokenCredential): def __init__(self): self.access_number = 0 - async def get_token(self, uri): + async def get_token(self, *scopes, **kwargs): self.access_number += 1 if self.access_number == 1: return MockToken("", 0, "") diff --git a/tests/test_auth_init.py b/tests/test_auth_init.py new file mode 100644 index 0000000000..f9155f3892 --- /dev/null +++ b/tests/test_auth_init.py @@ -0,0 +1,235 @@ +import os +from unittest import mock + +import pytest +from msgraph import GraphServiceClient +from msgraph.generated.models.application import Application +from msgraph.generated.models.password_credential import PasswordCredential +from msgraph.generated.models.service_principal import ServicePrincipal + +from .mocks import MockAzureCredential +from scripts import auth_init +from scripts.auth_init import ( + add_client_secret, + client_app, + create_application, + create_or_update_application_with_secret, + server_app_initial, + server_app_permission_setup, +) + +MOCK_OBJECT_ID = "OBJ123" +MOCK_APP_ID = "APP123" +MOCK_SECRET = "SECRET_VALUE" +EXISTING_MOCK_OBJECT_ID = "OBJ999" + + +@pytest.fixture +def graph_client(monkeypatch): + """GraphServiceClient whose network layer is intercepted to avoid real HTTP calls. + + We exercise real request builders while intercepting the adapter's send_async. + """ + + client = GraphServiceClient(credentials=MockAzureCredential(), scopes=["https://graph.microsoft.com/.default"]) + + calls = { + "applications.post": [], + "applications.patch": [], + "applications.add_password.post": [], + "service_principals.post": [], + } + created_ids = {"object_id": MOCK_OBJECT_ID, "app_id": MOCK_APP_ID} + secret_text_value = {"value": MOCK_SECRET} + + async def fake_send_async(request_info, return_type, error_mapping=None): + url = request_info.url or "" + method = request_info.http_method.value + if method == "POST" and url.endswith("/applications"): + body = request_info.content + calls["applications.post"].append(body) + return Application( + id=created_ids["object_id"], + app_id=created_ids["app_id"], + display_name=getattr(body, "display_name", None), + ) + if method == "POST" and url.endswith("/servicePrincipals"): + calls["service_principals.post"].append(request_info.content) + return ServicePrincipal() + if method == "PATCH" and "/applications/" in url: + calls["applications.patch"].append(request_info.content) + return Application() + if method == "POST" and url.endswith("/addPassword"): + calls["applications.add_password.post"].append(request_info.content) + return PasswordCredential(secret_text=secret_text_value["value"]) + raise AssertionError(f"Unexpected request: {method} {url}") + + # Patch the adapter + monkeypatch.setattr(client.request_adapter, "send_async", fake_send_async) + + client._test_calls = calls + client._test_secret_text_value = secret_text_value + client._test_ids = created_ids + return client + + +@pytest.mark.asyncio +async def test_create_application_success(graph_client): + graph = graph_client + request = server_app_initial(42) + object_id, app_id = await create_application(graph, request) + assert object_id == MOCK_OBJECT_ID + assert app_id == MOCK_APP_ID + assert len(graph._test_calls["service_principals.post"]) == 1 + + +@pytest.mark.asyncio +async def test_create_application_missing_ids(graph_client, monkeypatch): + graph = graph_client + + original_send_async = graph.request_adapter.send_async + + async def bad_send_async(request_info, return_type, error_mapping=None): + url = request_info.url or "" + method = request_info.http_method.value + if method == "POST" and url.endswith("/applications"): + return Application(id=None, app_id=None) + return await original_send_async(request_info, return_type, error_mapping) + + monkeypatch.setattr(graph.request_adapter, "send_async", bad_send_async) + with pytest.raises(ValueError): + await create_application(graph, server_app_initial(1)) + + +@pytest.mark.asyncio +async def test_add_client_secret_success(graph_client): + graph = graph_client + secret = await add_client_secret(graph, MOCK_OBJECT_ID) + assert secret == MOCK_SECRET + assert len(graph._test_calls["applications.add_password.post"]) == 1 + + +@pytest.mark.asyncio +async def test_add_client_secret_missing_secret(graph_client): + graph = graph_client + graph._test_secret_text_value["value"] = None + with pytest.raises(ValueError): + await add_client_secret(graph, MOCK_OBJECT_ID) + + +@pytest.mark.asyncio +async def test_create_or_update_application_creates_and_adds_secret(graph_client, monkeypatch): + graph = graph_client + updates: list[tuple[str, str]] = [] + + def fake_update_env(name, val): + updates.append((name, val)) + + # Ensure env vars not set + with mock.patch.dict(os.environ, {}, clear=True): + monkeypatch.setattr(auth_init, "update_azd_env", fake_update_env) + + # Force get_application to return None (not found) + async def fake_get_application(graph_client, client_id): + return None + + monkeypatch.setattr("scripts.auth_init.get_application", fake_get_application) + object_id, app_id, created = await create_or_update_application_with_secret( + graph, + app_id_env_var="AZURE_SERVER_APP_ID", + app_secret_env_var="AZURE_SERVER_APP_SECRET", + request_app=server_app_initial(55), + ) + assert created is True + assert object_id == MOCK_OBJECT_ID + assert app_id == MOCK_APP_ID + # Two updates: app id and secret + assert {u[0] for u in updates} == {"AZURE_SERVER_APP_ID", "AZURE_SERVER_APP_SECRET"} + assert len(graph._test_calls["applications.add_password.post"]) == 1 + + +@pytest.mark.asyncio +async def test_create_or_update_application_existing_adds_secret(graph_client, monkeypatch): + graph = graph_client + updates: list[tuple[str, str]] = [] + + def fake_update_env(name, val): + updates.append((name, val)) + + with mock.patch.dict(os.environ, {"AZURE_SERVER_APP_ID": MOCK_APP_ID}, clear=True): + monkeypatch.setattr(auth_init, "update_azd_env", fake_update_env) + + async def fake_get_application(graph_client, client_id): + return EXISTING_MOCK_OBJECT_ID + + monkeypatch.setattr("scripts.auth_init.get_application", fake_get_application) + object_id, app_id, created = await create_or_update_application_with_secret( + graph, + app_id_env_var="AZURE_SERVER_APP_ID", + app_secret_env_var="AZURE_SERVER_APP_SECRET", + request_app=server_app_initial(77), + ) + assert created is False + assert object_id == EXISTING_MOCK_OBJECT_ID + assert app_id == MOCK_APP_ID + # Secret should be added since not in env + assert any(name == "AZURE_SERVER_APP_SECRET" for name, _ in updates) + # Application patch should have been called + # Patch captured + assert len(graph._test_calls["applications.patch"]) == 1 + + +@pytest.mark.asyncio +async def test_create_or_update_application_existing_with_secret(graph_client, monkeypatch): + graph = graph_client + with mock.patch.dict( + os.environ, {"AZURE_SERVER_APP_ID": MOCK_APP_ID, "AZURE_SERVER_APP_SECRET": "EXISTING"}, clear=True + ): + + async def fake_get_application(graph_client, client_id): + return EXISTING_MOCK_OBJECT_ID + + monkeypatch.setattr("scripts.auth_init.get_application", fake_get_application) + object_id, app_id, created = await create_or_update_application_with_secret( + graph, + app_id_env_var="AZURE_SERVER_APP_ID", + app_secret_env_var="AZURE_SERVER_APP_SECRET", + request_app=server_app_initial(88), + ) + assert created is False + assert object_id == EXISTING_MOCK_OBJECT_ID + assert app_id == MOCK_APP_ID + # No secret added + assert len(graph._test_calls["applications.add_password.post"]) == 0 + + +def test_client_app_validation_errors(): + # Server app without api + server_app = server_app_initial(1) + server_app.api = None + with pytest.raises(ValueError): + client_app("server_app_id", server_app, 2) + + # Server app with empty scopes + # attach empty api + server_app_permission = server_app_permission_setup("server_app") + server_app_permission.api.oauth2_permission_scopes = [] + with pytest.raises(ValueError): + client_app("server_app_id", server_app_permission, 2) + + +def test_client_app_success(): + server_app_permission = server_app_permission_setup("server_app") + c_app = client_app("server_app", server_app_permission, 123) + assert c_app.web is not None + assert c_app.spa is not None + assert c_app.required_resource_access is not None + assert len(c_app.required_resource_access) >= 1 + + +def test_server_app_permission_setup(): + # simulate after creation we know app id + app_with_permissions = server_app_permission_setup("server_app_id") + assert app_with_permissions.identifier_uris == ["api://server_app_id"] + assert app_with_permissions.required_resource_access is not None + assert len(app_with_permissions.required_resource_access) == 1