From 35bd561da6506f0315d47e8764ad10bb4f2a0293 Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Thu, 18 Sep 2025 11:07:41 -0700 Subject: [PATCH 1/3] Add tests for auth_init.py --- AGENTS.md | 12 +- app/backend/requirements.txt | 3 +- tests/mocks.py | 7 +- tests/test_auth_init.py | 239 +++++++++++++++++++++++++++++++++++ 4 files changed, 254 insertions(+), 7 deletions(-) create mode 100644 tests/test_auth_init.py diff --git a/AGENTS.md b/AGENTS.md index e5e8ffe217..e6250cbf5a 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -78,14 +78,22 @@ When adding a new feature, add tests for it in the appropriate file. If the feature is a UI element, add an e2e test for it. If it is an API endpoint, add an app integration test for it. If it is a function or method, add a unit test for it. -Use mocks from conftest.py to mock external services. +Use mocks from tests/conftest.py to mock external services. Prefer mocking at the HTTP/requests level when possible. When you're running tests, make sure you activate the .venv virtual environment first: -```bash +```shell source .venv/bin/activate ``` ## Sending pull requests When sending pull requests, make sure to follow the PULL_REQUEST_TEMPLATE.md format. + +## Upgrading dependencies + +To upgrade a particular package in the backend, use the following command, replacing `` with the name of the package you want to upgrade: + +```shell +cd app/backend && uv pip compile requirements.in -o requirements.txt --python-version 3.9 --upgrade-package package-name +``` diff --git a/app/backend/requirements.txt b/app/backend/requirements.txt index 80d36ceba0..7211d34f20 100644 --- a/app/backend/requirements.txt +++ b/app/backend/requirements.txt @@ -199,7 +199,7 @@ msal-extensions==1.3.1 # via azure-identity msgraph-core==1.3.3 # via msgraph-sdk -msgraph-sdk==1.26.0 +msgraph-sdk==1.45.0 # via -r requirements.in msrest==0.7.1 # via azure-monitor-opentelemetry-exporter @@ -431,7 +431,6 @@ typing-extensions==4.13.2 # pypdf # quart # quart-cors - # rich # taskgroup # uvicorn urllib3==2.5.0 diff --git a/tests/mocks.py b/tests/mocks.py index 503b5b5f99..6228d84e34 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -34,8 +34,9 @@ class MockAzureCredential(AsyncTokenCredential): - async def get_token(self, uri): - return MockToken("", 9999999999, "") + async def get_token(self, *scopes, **kwargs): # accept claims, enable_cae, etc. + # Return a simple mock token structure with required attributes + return MockToken("mock-token", 9999999999, "mock-token") class MockAzureCredentialExpired(AsyncTokenCredential): @@ -43,7 +44,7 @@ class MockAzureCredentialExpired(AsyncTokenCredential): def __init__(self): self.access_number = 0 - async def get_token(self, uri): + async def get_token(self, *scopes, **kwargs): self.access_number += 1 if self.access_number == 1: return MockToken("", 0, "") diff --git a/tests/test_auth_init.py b/tests/test_auth_init.py new file mode 100644 index 0000000000..d88f942249 --- /dev/null +++ b/tests/test_auth_init.py @@ -0,0 +1,239 @@ +import os +from unittest import mock + +import pytest +from msgraph import GraphServiceClient +from msgraph.generated.models.application import Application +from msgraph.generated.models.password_credential import PasswordCredential +from msgraph.generated.models.service_principal import ServicePrincipal + +from .mocks import MockAzureCredential +from scripts import auth_init +from scripts.auth_init import ( + add_client_secret, + client_app, + create_application, + create_or_update_application_with_secret, + server_app_initial, + server_app_permission_setup, +) + + +@pytest.fixture +def graph_client(monkeypatch): + """GraphServiceClient whose network layer is intercepted to avoid real HTTP calls. + + We exercise real request builders while intercepting the adapter's send_async. + """ + + client = GraphServiceClient(credentials=MockAzureCredential(), scopes=["https://graph.microsoft.com/.default"]) + + calls = { + "applications.post": [], + "applications.patch": [], + "applications.add_password.post": [], + "service_principals.post": [], + } + created_ids = {"object_id": "OBJ123", "client_id": "APP123"} + secret_text_value = {"value": "SECRET_VALUE"} + + async def fake_send_async(request_info, return_type, error_mapping=None): + url = request_info.url or "" + method = ( + request_info.http_method.value + if hasattr(request_info.http_method, "value") + else str(request_info.http_method) + ) + if method == "POST" and url.endswith("/applications"): + body = request_info.content + calls["applications.post"].append(body) + return Application( + id=created_ids["object_id"], + app_id=created_ids["client_id"], + display_name=getattr(body, "display_name", None), + ) + if method == "POST" and url.endswith("/servicePrincipals"): + calls["service_principals.post"].append(request_info.content) + return ServicePrincipal() + if method == "PATCH" and "/applications/" in url: + calls["applications.patch"].append(request_info.content) + return Application() + if method == "POST" and url.endswith("/addPassword"): + calls["applications.add_password.post"].append(request_info.content) + return PasswordCredential(secret_text=secret_text_value["value"]) + raise AssertionError(f"Unexpected request: {method} {url}") + + # Patch the adapter + monkeypatch.setattr(client.request_adapter, "send_async", fake_send_async) + + client._test_calls = calls # type: ignore[attr-defined] + client._test_secret_text_value = secret_text_value # type: ignore[attr-defined] + client._test_ids = created_ids # type: ignore[attr-defined] + return client + + +@pytest.mark.asyncio +async def test_create_application_success(graph_client): + graph = graph_client + request = server_app_initial(42) + object_id, client_id = await create_application(graph, request) + assert object_id == "OBJ123" + assert client_id == "APP123" + assert len(graph._test_calls["service_principals.post"]) == 1 + + +@pytest.mark.asyncio +async def test_create_application_missing_ids(graph_client, monkeypatch): + graph = graph_client + + original_send_async = graph.request_adapter.send_async + + async def bad_send_async(request_info, return_type, error_mapping=None): # type: ignore[unused-argument] + url = request_info.url or "" + method = ( + request_info.http_method.value + if hasattr(request_info.http_method, "value") + else str(request_info.http_method) + ) + if method == "POST" and url.endswith("/applications"): + return Application(id=None, app_id=None) + return await original_send_async(request_info, return_type, error_mapping) + + monkeypatch.setattr(graph.request_adapter, "send_async", bad_send_async) + with pytest.raises(ValueError): + await create_application(graph, server_app_initial(1)) + + +@pytest.mark.asyncio +async def test_add_client_secret_success(graph_client): + graph = graph_client + secret = await add_client_secret(graph, "OBJ123") + assert secret == "SECRET_VALUE" + assert len(graph._test_calls["applications.add_password.post"]) == 1 + + +@pytest.mark.asyncio +async def test_add_client_secret_missing_secret(graph_client): + graph = graph_client + graph._test_secret_text_value["value"] = None # type: ignore + with pytest.raises(ValueError): + await add_client_secret(graph, "OBJ123") + + +@pytest.mark.asyncio +async def test_create_or_update_application_creates_and_adds_secret(graph_client, monkeypatch): + graph = graph_client + updates: list[tuple[str, str]] = [] + + def fake_update_env(name, val): + updates.append((name, val)) + + # Ensure env vars not set + with mock.patch.dict(os.environ, {}, clear=True): + monkeypatch.setattr(auth_init, "update_azd_env", fake_update_env) + + # Force get_application to return None (not found) + async def fake_get_application(graph_client, client_id): + return None + + monkeypatch.setattr("scripts.auth_init.get_application", fake_get_application) + object_id, app_id, created = await create_or_update_application_with_secret( + graph, + app_id_env_var="AZURE_SERVER_APP_ID", + app_secret_env_var="AZURE_SERVER_APP_SECRET", + request_app=server_app_initial(55), + ) + assert created is True + assert object_id == "OBJ123" + assert app_id == "APP123" + # Two updates: app id and secret + assert {u[0] for u in updates} == {"AZURE_SERVER_APP_ID", "AZURE_SERVER_APP_SECRET"} + assert len(graph._test_calls["applications.add_password.post"]) == 1 + + +@pytest.mark.asyncio +async def test_create_or_update_application_existing_adds_secret(graph_client, monkeypatch): + graph = graph_client + updates: list[tuple[str, str]] = [] + + def fake_update_env(name, val): + updates.append((name, val)) + + with mock.patch.dict(os.environ, {"AZURE_SERVER_APP_ID": "APP123"}, clear=True): + monkeypatch.setattr(auth_init, "update_azd_env", fake_update_env) + + async def fake_get_application(graph_client, client_id): + # Return existing object id for provided app id + return "OBJ999" + + monkeypatch.setattr("scripts.auth_init.get_application", fake_get_application) + object_id, app_id, created = await create_or_update_application_with_secret( + graph, + app_id_env_var="AZURE_SERVER_APP_ID", + app_secret_env_var="AZURE_SERVER_APP_SECRET", + request_app=server_app_initial(77), + ) + assert created is False + assert object_id == "OBJ999" + assert app_id == "APP123" + # Secret should be added since not in env + assert any(name == "AZURE_SERVER_APP_SECRET" for name, _ in updates) + # Application patch should have been called + # Patch captured + assert len(graph._test_calls["applications.patch"]) == 1 + + +@pytest.mark.asyncio +async def test_create_or_update_application_existing_with_secret(graph_client, monkeypatch): + graph = graph_client + with mock.patch.dict( + os.environ, {"AZURE_SERVER_APP_ID": "APP123", "AZURE_SERVER_APP_SECRET": "EXISTING"}, clear=True + ): + + async def fake_get_application(graph_client, client_id): + return "OBJ999" + + monkeypatch.setattr("scripts.auth_init.get_application", fake_get_application) + object_id, app_id, created = await create_or_update_application_with_secret( + graph, + app_id_env_var="AZURE_SERVER_APP_ID", + app_secret_env_var="AZURE_SERVER_APP_SECRET", + request_app=server_app_initial(88), + ) + assert created is False + assert object_id == "OBJ999" + assert app_id == "APP123" + # No secret added + assert len(graph._test_calls["applications.add_password.post"]) == 0 + + +def test_client_app_validation_errors(): + # Server app without api + server_app = server_app_initial(1) + server_app.api = None # type: ignore + with pytest.raises(ValueError): + client_app("server_app_id", server_app, 2) + + # Server app with empty scopes + # attach empty api + server_app_permission = server_app_permission_setup("server_app") + server_app_permission.api.oauth2_permission_scopes = [] # type: ignore + with pytest.raises(ValueError): + client_app("server_app_id", server_app_permission, 2) + + +def test_client_app_success(): + server_app_permission = server_app_permission_setup("server_app") + c_app = client_app("server_app", server_app_permission, 123) + assert c_app.web is not None + assert c_app.spa is not None + assert c_app.required_resource_access is not None + assert len(c_app.required_resource_access) >= 1 + + +def test_server_app_permission_setup(): + # simulate after creation we know app id + app_with_permissions = server_app_permission_setup("server_app_id") + assert app_with_permissions.identifier_uris == ["api://server_app_id"] + assert app_with_permissions.required_resource_access is not None + assert len(app_with_permissions.required_resource_access) == 1 From a6fc0e86c720f058e3e3e75ac601bfc0bc901ba5 Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Thu, 18 Sep 2025 11:21:49 -0700 Subject: [PATCH 2/3] Remove type ignore in tests --- AGENTS.md | 15 +++++++++++++++ tests/test_auth_init.py | 14 +++++++------- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index e6250cbf5a..0021d98852 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -97,3 +97,18 @@ To upgrade a particular package in the backend, use the following command, repla ```shell cd app/backend && uv pip compile requirements.in -o requirements.txt --python-version 3.9 --upgrade-package package-name ``` + +## Checking Python type hints + +To check Python type hints, use the following command: + +```shell +cd app/backend && mypy . --config-file=../pyproject.toml +``` + +```shell +cd scripts && mypy . --config-file=../pyproject.toml +``` + +Note that we do not currently enforce type hints in the tests folder, as it would require adding a lot of `# type: ignore` comments to the existing tests. +We only enforce type hints in the main application code and scripts. diff --git a/tests/test_auth_init.py b/tests/test_auth_init.py index d88f942249..0f6276f9fb 100644 --- a/tests/test_auth_init.py +++ b/tests/test_auth_init.py @@ -66,9 +66,9 @@ async def fake_send_async(request_info, return_type, error_mapping=None): # Patch the adapter monkeypatch.setattr(client.request_adapter, "send_async", fake_send_async) - client._test_calls = calls # type: ignore[attr-defined] - client._test_secret_text_value = secret_text_value # type: ignore[attr-defined] - client._test_ids = created_ids # type: ignore[attr-defined] + client._test_calls = calls + client._test_secret_text_value = secret_text_value + client._test_ids = created_ids return client @@ -88,7 +88,7 @@ async def test_create_application_missing_ids(graph_client, monkeypatch): original_send_async = graph.request_adapter.send_async - async def bad_send_async(request_info, return_type, error_mapping=None): # type: ignore[unused-argument] + async def bad_send_async(request_info, return_type, error_mapping=None): url = request_info.url or "" method = ( request_info.http_method.value @@ -115,7 +115,7 @@ async def test_add_client_secret_success(graph_client): @pytest.mark.asyncio async def test_add_client_secret_missing_secret(graph_client): graph = graph_client - graph._test_secret_text_value["value"] = None # type: ignore + graph._test_secret_text_value["value"] = None with pytest.raises(ValueError): await add_client_secret(graph, "OBJ123") @@ -210,14 +210,14 @@ async def fake_get_application(graph_client, client_id): def test_client_app_validation_errors(): # Server app without api server_app = server_app_initial(1) - server_app.api = None # type: ignore + server_app.api = None with pytest.raises(ValueError): client_app("server_app_id", server_app, 2) # Server app with empty scopes # attach empty api server_app_permission = server_app_permission_setup("server_app") - server_app_permission.api.oauth2_permission_scopes = [] # type: ignore + server_app_permission.api.oauth2_permission_scopes = [] with pytest.raises(ValueError): client_app("server_app_id", server_app_permission, 2) From 02c1b190ed0a6660bf2c6f798f0f7207c6e5fb06 Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Thu, 18 Sep 2025 11:40:28 -0700 Subject: [PATCH 3/3] Move values into constants and remove unneeded code --- tests/test_auth_init.py | 56 +++++++++++++++++++---------------------- 1 file changed, 26 insertions(+), 30 deletions(-) diff --git a/tests/test_auth_init.py b/tests/test_auth_init.py index 0f6276f9fb..f9155f3892 100644 --- a/tests/test_auth_init.py +++ b/tests/test_auth_init.py @@ -18,6 +18,11 @@ server_app_permission_setup, ) +MOCK_OBJECT_ID = "OBJ123" +MOCK_APP_ID = "APP123" +MOCK_SECRET = "SECRET_VALUE" +EXISTING_MOCK_OBJECT_ID = "OBJ999" + @pytest.fixture def graph_client(monkeypatch): @@ -34,22 +39,18 @@ def graph_client(monkeypatch): "applications.add_password.post": [], "service_principals.post": [], } - created_ids = {"object_id": "OBJ123", "client_id": "APP123"} - secret_text_value = {"value": "SECRET_VALUE"} + created_ids = {"object_id": MOCK_OBJECT_ID, "app_id": MOCK_APP_ID} + secret_text_value = {"value": MOCK_SECRET} async def fake_send_async(request_info, return_type, error_mapping=None): url = request_info.url or "" - method = ( - request_info.http_method.value - if hasattr(request_info.http_method, "value") - else str(request_info.http_method) - ) + method = request_info.http_method.value if method == "POST" and url.endswith("/applications"): body = request_info.content calls["applications.post"].append(body) return Application( id=created_ids["object_id"], - app_id=created_ids["client_id"], + app_id=created_ids["app_id"], display_name=getattr(body, "display_name", None), ) if method == "POST" and url.endswith("/servicePrincipals"): @@ -76,9 +77,9 @@ async def fake_send_async(request_info, return_type, error_mapping=None): async def test_create_application_success(graph_client): graph = graph_client request = server_app_initial(42) - object_id, client_id = await create_application(graph, request) - assert object_id == "OBJ123" - assert client_id == "APP123" + object_id, app_id = await create_application(graph, request) + assert object_id == MOCK_OBJECT_ID + assert app_id == MOCK_APP_ID assert len(graph._test_calls["service_principals.post"]) == 1 @@ -90,11 +91,7 @@ async def test_create_application_missing_ids(graph_client, monkeypatch): async def bad_send_async(request_info, return_type, error_mapping=None): url = request_info.url or "" - method = ( - request_info.http_method.value - if hasattr(request_info.http_method, "value") - else str(request_info.http_method) - ) + method = request_info.http_method.value if method == "POST" and url.endswith("/applications"): return Application(id=None, app_id=None) return await original_send_async(request_info, return_type, error_mapping) @@ -107,8 +104,8 @@ async def bad_send_async(request_info, return_type, error_mapping=None): @pytest.mark.asyncio async def test_add_client_secret_success(graph_client): graph = graph_client - secret = await add_client_secret(graph, "OBJ123") - assert secret == "SECRET_VALUE" + secret = await add_client_secret(graph, MOCK_OBJECT_ID) + assert secret == MOCK_SECRET assert len(graph._test_calls["applications.add_password.post"]) == 1 @@ -117,7 +114,7 @@ async def test_add_client_secret_missing_secret(graph_client): graph = graph_client graph._test_secret_text_value["value"] = None with pytest.raises(ValueError): - await add_client_secret(graph, "OBJ123") + await add_client_secret(graph, MOCK_OBJECT_ID) @pytest.mark.asyncio @@ -144,8 +141,8 @@ async def fake_get_application(graph_client, client_id): request_app=server_app_initial(55), ) assert created is True - assert object_id == "OBJ123" - assert app_id == "APP123" + assert object_id == MOCK_OBJECT_ID + assert app_id == MOCK_APP_ID # Two updates: app id and secret assert {u[0] for u in updates} == {"AZURE_SERVER_APP_ID", "AZURE_SERVER_APP_SECRET"} assert len(graph._test_calls["applications.add_password.post"]) == 1 @@ -159,12 +156,11 @@ async def test_create_or_update_application_existing_adds_secret(graph_client, m def fake_update_env(name, val): updates.append((name, val)) - with mock.patch.dict(os.environ, {"AZURE_SERVER_APP_ID": "APP123"}, clear=True): + with mock.patch.dict(os.environ, {"AZURE_SERVER_APP_ID": MOCK_APP_ID}, clear=True): monkeypatch.setattr(auth_init, "update_azd_env", fake_update_env) async def fake_get_application(graph_client, client_id): - # Return existing object id for provided app id - return "OBJ999" + return EXISTING_MOCK_OBJECT_ID monkeypatch.setattr("scripts.auth_init.get_application", fake_get_application) object_id, app_id, created = await create_or_update_application_with_secret( @@ -174,8 +170,8 @@ async def fake_get_application(graph_client, client_id): request_app=server_app_initial(77), ) assert created is False - assert object_id == "OBJ999" - assert app_id == "APP123" + assert object_id == EXISTING_MOCK_OBJECT_ID + assert app_id == MOCK_APP_ID # Secret should be added since not in env assert any(name == "AZURE_SERVER_APP_SECRET" for name, _ in updates) # Application patch should have been called @@ -187,11 +183,11 @@ async def fake_get_application(graph_client, client_id): async def test_create_or_update_application_existing_with_secret(graph_client, monkeypatch): graph = graph_client with mock.patch.dict( - os.environ, {"AZURE_SERVER_APP_ID": "APP123", "AZURE_SERVER_APP_SECRET": "EXISTING"}, clear=True + os.environ, {"AZURE_SERVER_APP_ID": MOCK_APP_ID, "AZURE_SERVER_APP_SECRET": "EXISTING"}, clear=True ): async def fake_get_application(graph_client, client_id): - return "OBJ999" + return EXISTING_MOCK_OBJECT_ID monkeypatch.setattr("scripts.auth_init.get_application", fake_get_application) object_id, app_id, created = await create_or_update_application_with_secret( @@ -201,8 +197,8 @@ async def fake_get_application(graph_client, client_id): request_app=server_app_initial(88), ) assert created is False - assert object_id == "OBJ999" - assert app_id == "APP123" + assert object_id == EXISTING_MOCK_OBJECT_ID + assert app_id == MOCK_APP_ID # No secret added assert len(graph._test_calls["applications.add_password.post"]) == 0