Skip to content

Commit 4b32f3c

Browse files
authored
Add tests for auth_init.py (#2741)
* Add tests for auth_init.py * Remove type ignore in tests * Move values into constants and remove unneeded code
1 parent e479312 commit 4b32f3c

File tree

4 files changed

+265
-7
lines changed

4 files changed

+265
-7
lines changed

AGENTS.md

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,37 @@ When adding a new feature, add tests for it in the appropriate file.
7878
If the feature is a UI element, add an e2e test for it.
7979
If it is an API endpoint, add an app integration test for it.
8080
If it is a function or method, add a unit test for it.
81-
Use mocks from conftest.py to mock external services.
81+
Use mocks from tests/conftest.py to mock external services. Prefer mocking at the HTTP/requests level when possible.
8282

8383
When you're running tests, make sure you activate the .venv virtual environment first:
8484

85-
```bash
85+
```shell
8686
source .venv/bin/activate
8787
```
8888

8989
## Sending pull requests
9090

9191
When sending pull requests, make sure to follow the PULL_REQUEST_TEMPLATE.md format.
92+
93+
## Upgrading dependencies
94+
95+
To upgrade a particular package in the backend, use the following command, replacing `<package-name>` with the name of the package you want to upgrade:
96+
97+
```shell
98+
cd app/backend && uv pip compile requirements.in -o requirements.txt --python-version 3.9 --upgrade-package package-name
99+
```
100+
101+
## Checking Python type hints
102+
103+
To check Python type hints, use the following command:
104+
105+
```shell
106+
cd app/backend && mypy . --config-file=../pyproject.toml
107+
```
108+
109+
```shell
110+
cd scripts && mypy . --config-file=../pyproject.toml
111+
```
112+
113+
Note that we do not currently enforce type hints in the tests folder, as it would require adding a lot of `# type: ignore` comments to the existing tests.
114+
We only enforce type hints in the main application code and scripts.

app/backend/requirements.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ msal-extensions==1.3.1
199199
# via azure-identity
200200
msgraph-core==1.3.3
201201
# via msgraph-sdk
202-
msgraph-sdk==1.26.0
202+
msgraph-sdk==1.45.0
203203
# via -r requirements.in
204204
msrest==0.7.1
205205
# via azure-monitor-opentelemetry-exporter
@@ -431,7 +431,6 @@ typing-extensions==4.13.2
431431
# pypdf
432432
# quart
433433
# quart-cors
434-
# rich
435434
# taskgroup
436435
# uvicorn
437436
urllib3==2.5.0

tests/mocks.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,17 @@
3434

3535
class MockAzureCredential(AsyncTokenCredential):
3636

37-
async def get_token(self, uri):
38-
return MockToken("", 9999999999, "")
37+
async def get_token(self, *scopes, **kwargs): # accept claims, enable_cae, etc.
38+
# Return a simple mock token structure with required attributes
39+
return MockToken("mock-token", 9999999999, "mock-token")
3940

4041

4142
class MockAzureCredentialExpired(AsyncTokenCredential):
4243

4344
def __init__(self):
4445
self.access_number = 0
4546

46-
async def get_token(self, uri):
47+
async def get_token(self, *scopes, **kwargs):
4748
self.access_number += 1
4849
if self.access_number == 1:
4950
return MockToken("", 0, "")

tests/test_auth_init.py

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
import os
2+
from unittest import mock
3+
4+
import pytest
5+
from msgraph import GraphServiceClient
6+
from msgraph.generated.models.application import Application
7+
from msgraph.generated.models.password_credential import PasswordCredential
8+
from msgraph.generated.models.service_principal import ServicePrincipal
9+
10+
from .mocks import MockAzureCredential
11+
from scripts import auth_init
12+
from scripts.auth_init import (
13+
add_client_secret,
14+
client_app,
15+
create_application,
16+
create_or_update_application_with_secret,
17+
server_app_initial,
18+
server_app_permission_setup,
19+
)
20+
21+
MOCK_OBJECT_ID = "OBJ123"
22+
MOCK_APP_ID = "APP123"
23+
MOCK_SECRET = "SECRET_VALUE"
24+
EXISTING_MOCK_OBJECT_ID = "OBJ999"
25+
26+
27+
@pytest.fixture
28+
def graph_client(monkeypatch):
29+
"""GraphServiceClient whose network layer is intercepted to avoid real HTTP calls.
30+
31+
We exercise real request builders while intercepting the adapter's send_async.
32+
"""
33+
34+
client = GraphServiceClient(credentials=MockAzureCredential(), scopes=["https://graph.microsoft.com/.default"])
35+
36+
calls = {
37+
"applications.post": [],
38+
"applications.patch": [],
39+
"applications.add_password.post": [],
40+
"service_principals.post": [],
41+
}
42+
created_ids = {"object_id": MOCK_OBJECT_ID, "app_id": MOCK_APP_ID}
43+
secret_text_value = {"value": MOCK_SECRET}
44+
45+
async def fake_send_async(request_info, return_type, error_mapping=None):
46+
url = request_info.url or ""
47+
method = request_info.http_method.value
48+
if method == "POST" and url.endswith("/applications"):
49+
body = request_info.content
50+
calls["applications.post"].append(body)
51+
return Application(
52+
id=created_ids["object_id"],
53+
app_id=created_ids["app_id"],
54+
display_name=getattr(body, "display_name", None),
55+
)
56+
if method == "POST" and url.endswith("/servicePrincipals"):
57+
calls["service_principals.post"].append(request_info.content)
58+
return ServicePrincipal()
59+
if method == "PATCH" and "/applications/" in url:
60+
calls["applications.patch"].append(request_info.content)
61+
return Application()
62+
if method == "POST" and url.endswith("/addPassword"):
63+
calls["applications.add_password.post"].append(request_info.content)
64+
return PasswordCredential(secret_text=secret_text_value["value"])
65+
raise AssertionError(f"Unexpected request: {method} {url}")
66+
67+
# Patch the adapter
68+
monkeypatch.setattr(client.request_adapter, "send_async", fake_send_async)
69+
70+
client._test_calls = calls
71+
client._test_secret_text_value = secret_text_value
72+
client._test_ids = created_ids
73+
return client
74+
75+
76+
@pytest.mark.asyncio
77+
async def test_create_application_success(graph_client):
78+
graph = graph_client
79+
request = server_app_initial(42)
80+
object_id, app_id = await create_application(graph, request)
81+
assert object_id == MOCK_OBJECT_ID
82+
assert app_id == MOCK_APP_ID
83+
assert len(graph._test_calls["service_principals.post"]) == 1
84+
85+
86+
@pytest.mark.asyncio
87+
async def test_create_application_missing_ids(graph_client, monkeypatch):
88+
graph = graph_client
89+
90+
original_send_async = graph.request_adapter.send_async
91+
92+
async def bad_send_async(request_info, return_type, error_mapping=None):
93+
url = request_info.url or ""
94+
method = request_info.http_method.value
95+
if method == "POST" and url.endswith("/applications"):
96+
return Application(id=None, app_id=None)
97+
return await original_send_async(request_info, return_type, error_mapping)
98+
99+
monkeypatch.setattr(graph.request_adapter, "send_async", bad_send_async)
100+
with pytest.raises(ValueError):
101+
await create_application(graph, server_app_initial(1))
102+
103+
104+
@pytest.mark.asyncio
105+
async def test_add_client_secret_success(graph_client):
106+
graph = graph_client
107+
secret = await add_client_secret(graph, MOCK_OBJECT_ID)
108+
assert secret == MOCK_SECRET
109+
assert len(graph._test_calls["applications.add_password.post"]) == 1
110+
111+
112+
@pytest.mark.asyncio
113+
async def test_add_client_secret_missing_secret(graph_client):
114+
graph = graph_client
115+
graph._test_secret_text_value["value"] = None
116+
with pytest.raises(ValueError):
117+
await add_client_secret(graph, MOCK_OBJECT_ID)
118+
119+
120+
@pytest.mark.asyncio
121+
async def test_create_or_update_application_creates_and_adds_secret(graph_client, monkeypatch):
122+
graph = graph_client
123+
updates: list[tuple[str, str]] = []
124+
125+
def fake_update_env(name, val):
126+
updates.append((name, val))
127+
128+
# Ensure env vars not set
129+
with mock.patch.dict(os.environ, {}, clear=True):
130+
monkeypatch.setattr(auth_init, "update_azd_env", fake_update_env)
131+
132+
# Force get_application to return None (not found)
133+
async def fake_get_application(graph_client, client_id):
134+
return None
135+
136+
monkeypatch.setattr("scripts.auth_init.get_application", fake_get_application)
137+
object_id, app_id, created = await create_or_update_application_with_secret(
138+
graph,
139+
app_id_env_var="AZURE_SERVER_APP_ID",
140+
app_secret_env_var="AZURE_SERVER_APP_SECRET",
141+
request_app=server_app_initial(55),
142+
)
143+
assert created is True
144+
assert object_id == MOCK_OBJECT_ID
145+
assert app_id == MOCK_APP_ID
146+
# Two updates: app id and secret
147+
assert {u[0] for u in updates} == {"AZURE_SERVER_APP_ID", "AZURE_SERVER_APP_SECRET"}
148+
assert len(graph._test_calls["applications.add_password.post"]) == 1
149+
150+
151+
@pytest.mark.asyncio
152+
async def test_create_or_update_application_existing_adds_secret(graph_client, monkeypatch):
153+
graph = graph_client
154+
updates: list[tuple[str, str]] = []
155+
156+
def fake_update_env(name, val):
157+
updates.append((name, val))
158+
159+
with mock.patch.dict(os.environ, {"AZURE_SERVER_APP_ID": MOCK_APP_ID}, clear=True):
160+
monkeypatch.setattr(auth_init, "update_azd_env", fake_update_env)
161+
162+
async def fake_get_application(graph_client, client_id):
163+
return EXISTING_MOCK_OBJECT_ID
164+
165+
monkeypatch.setattr("scripts.auth_init.get_application", fake_get_application)
166+
object_id, app_id, created = await create_or_update_application_with_secret(
167+
graph,
168+
app_id_env_var="AZURE_SERVER_APP_ID",
169+
app_secret_env_var="AZURE_SERVER_APP_SECRET",
170+
request_app=server_app_initial(77),
171+
)
172+
assert created is False
173+
assert object_id == EXISTING_MOCK_OBJECT_ID
174+
assert app_id == MOCK_APP_ID
175+
# Secret should be added since not in env
176+
assert any(name == "AZURE_SERVER_APP_SECRET" for name, _ in updates)
177+
# Application patch should have been called
178+
# Patch captured
179+
assert len(graph._test_calls["applications.patch"]) == 1
180+
181+
182+
@pytest.mark.asyncio
183+
async def test_create_or_update_application_existing_with_secret(graph_client, monkeypatch):
184+
graph = graph_client
185+
with mock.patch.dict(
186+
os.environ, {"AZURE_SERVER_APP_ID": MOCK_APP_ID, "AZURE_SERVER_APP_SECRET": "EXISTING"}, clear=True
187+
):
188+
189+
async def fake_get_application(graph_client, client_id):
190+
return EXISTING_MOCK_OBJECT_ID
191+
192+
monkeypatch.setattr("scripts.auth_init.get_application", fake_get_application)
193+
object_id, app_id, created = await create_or_update_application_with_secret(
194+
graph,
195+
app_id_env_var="AZURE_SERVER_APP_ID",
196+
app_secret_env_var="AZURE_SERVER_APP_SECRET",
197+
request_app=server_app_initial(88),
198+
)
199+
assert created is False
200+
assert object_id == EXISTING_MOCK_OBJECT_ID
201+
assert app_id == MOCK_APP_ID
202+
# No secret added
203+
assert len(graph._test_calls["applications.add_password.post"]) == 0
204+
205+
206+
def test_client_app_validation_errors():
207+
# Server app without api
208+
server_app = server_app_initial(1)
209+
server_app.api = None
210+
with pytest.raises(ValueError):
211+
client_app("server_app_id", server_app, 2)
212+
213+
# Server app with empty scopes
214+
# attach empty api
215+
server_app_permission = server_app_permission_setup("server_app")
216+
server_app_permission.api.oauth2_permission_scopes = []
217+
with pytest.raises(ValueError):
218+
client_app("server_app_id", server_app_permission, 2)
219+
220+
221+
def test_client_app_success():
222+
server_app_permission = server_app_permission_setup("server_app")
223+
c_app = client_app("server_app", server_app_permission, 123)
224+
assert c_app.web is not None
225+
assert c_app.spa is not None
226+
assert c_app.required_resource_access is not None
227+
assert len(c_app.required_resource_access) >= 1
228+
229+
230+
def test_server_app_permission_setup():
231+
# simulate after creation we know app id
232+
app_with_permissions = server_app_permission_setup("server_app_id")
233+
assert app_with_permissions.identifier_uris == ["api://server_app_id"]
234+
assert app_with_permissions.required_resource_access is not None
235+
assert len(app_with_permissions.required_resource_access) == 1

0 commit comments

Comments
 (0)