Skip to content

Commit 31ec335

Browse files
Merge pull request microsoft#386 from microsoft/UnitTest
fix: unit test fixing
2 parents 1bd9a0a + d3487c1 commit 31ec335

File tree

9 files changed

+362
-128
lines changed

9 files changed

+362
-128
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ jobs:
5151
- name: Run tests with coverage
5252
if: env.skip_tests == 'false'
5353
run: |
54-
pytest --cov=. --cov-report=term-missing --cov-report=xml
54+
pytest --cov=. --cov-report=term-missing --cov-report=xml --ignore=tests/e2e-test/tests
5555
5656
- name: Skip coverage report if no tests
5757
if: env.skip_tests == 'true'

src/backend/test_utils_date_fixed.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,19 @@
44

55
import os
66
from datetime import datetime
7-
from utils_date import format_date_for_user
7+
8+
# ---- Robust import for format_date_for_user ----
9+
# Tries: root-level shim -> src package path -> package-relative (when collected as src.backend.*)
10+
try:
11+
# Works if a root-level utils_date.py shim exists or PYTHONPATH includes project root
12+
from utils_date import format_date_for_user # type: ignore
13+
except ModuleNotFoundError:
14+
try:
15+
# Works when running from project root with 'src' on the path
16+
from src.backend.utils_date import format_date_for_user # type: ignore
17+
except ModuleNotFoundError:
18+
# Works when this test is imported as 'src.backend.test_utils_date_fixed'
19+
from .utils_date import format_date_for_user # type: ignore
820

921

1022
def test_date_formatting():
Lines changed: 133 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,151 @@
1+
# src/backend/tests/context/test_cosmos_memory.py
2+
# Drop-in test that self-stubs all external imports used by cosmos_memory_kernel
3+
# so we don't need to modify the repo structure or CI env.
4+
5+
import sys
6+
import types
17
import pytest
2-
from unittest.mock import AsyncMock, patch
3-
from azure.cosmos.partition_key import PartitionKey
4-
from src.backend.context.cosmos_memory import CosmosBufferedChatCompletionContext
8+
from unittest.mock import AsyncMock
59

10+
# ----------------- Preload stub modules so the SUT can import cleanly -----------------
611

7-
# Helper to create async iterable
8-
async def async_iterable(mock_items):
9-
"""Helper to create an async iterable."""
10-
for item in mock_items:
11-
yield item
12+
# 1) helpers.azure_credential_utils.get_azure_credential
13+
helpers_mod = types.ModuleType("helpers")
14+
helpers_cred_mod = types.ModuleType("helpers.azure_credential_utils")
15+
def _fake_get_azure_credential(*_a, **_k):
16+
return object()
17+
helpers_cred_mod.get_azure_credential = _fake_get_azure_credential
18+
helpers_mod.azure_credential_utils = helpers_cred_mod
19+
sys.modules.setdefault("helpers", helpers_mod)
20+
sys.modules.setdefault("helpers.azure_credential_utils", helpers_cred_mod)
1221

22+
# 2) app_config.config (the SUT does: from app_config import config)
23+
app_config_mod = types.ModuleType("app_config")
24+
app_config_mod.config = types.SimpleNamespace(
25+
COSMOSDB_CONTAINER="mock-container",
26+
COSMOSDB_ENDPOINT="https://mock-endpoint",
27+
COSMOSDB_DATABASE="mock-database",
28+
)
29+
sys.modules.setdefault("app_config", app_config_mod)
1330

14-
@pytest.fixture
15-
def mock_env_variables(monkeypatch):
16-
"""Mock all required environment variables."""
17-
env_vars = {
18-
"COSMOSDB_ENDPOINT": "https://mock-endpoint",
19-
"COSMOSDB_KEY": "mock-key",
20-
"COSMOSDB_DATABASE": "mock-database",
21-
"COSMOSDB_CONTAINER": "mock-container",
22-
"AZURE_OPENAI_DEPLOYMENT_NAME": "mock-deployment-name",
23-
"AZURE_OPENAI_API_VERSION": "2023-01-01",
24-
"AZURE_OPENAI_ENDPOINT": "https://mock-openai-endpoint",
25-
}
26-
for key, value in env_vars.items():
27-
monkeypatch.setenv(key, value)
31+
# 3) models.messages_kernel (the SUT does: from models.messages_kernel import ...)
32+
models_mod = types.ModuleType("models")
33+
models_messages_mod = types.ModuleType("models.messages_kernel")
34+
35+
# Minimal stand-ins so type hints/imports succeed (not used in this test path)
36+
class _Base: ...
37+
class BaseDataModel(_Base): ...
38+
class Plan(_Base): ...
39+
class Session(_Base): ...
40+
class Step(_Base): ...
41+
class AgentMessage(_Base): ...
42+
43+
models_messages_mod.BaseDataModel = BaseDataModel
44+
models_messages_mod.Plan = Plan
45+
models_messages_mod.Session = Session
46+
models_messages_mod.Step = Step
47+
models_messages_mod.AgentMessage = AgentMessage
48+
models_mod.messages_kernel = models_messages_mod
49+
sys.modules.setdefault("models", models_mod)
50+
sys.modules.setdefault("models.messages_kernel", models_messages_mod)
51+
52+
# 4) azure.cosmos.partition_key.PartitionKey (provide if sdk isn't installed)
53+
try:
54+
from azure.cosmos.partition_key import PartitionKey # type: ignore
55+
except Exception: # pragma: no cover
56+
azure_mod = sys.modules.setdefault("azure", types.ModuleType("azure"))
57+
azure_cosmos_mod = sys.modules.setdefault("azure.cosmos", types.ModuleType("azure.cosmos"))
58+
azure_cosmos_pk_mod = types.ModuleType("azure.cosmos.partition_key")
59+
class PartitionKey: # minimal shim
60+
def __init__(self, path: str): self.path = path
61+
azure_cosmos_pk_mod.PartitionKey = PartitionKey
62+
sys.modules.setdefault("azure.cosmos.partition_key", azure_cosmos_pk_mod)
63+
64+
# 5) azure.cosmos.aio.CosmosClient (we’ll patch it in a fixture, but ensure import exists)
65+
try:
66+
from azure.cosmos.aio import CosmosClient # type: ignore
67+
except Exception: # pragma: no cover
68+
azure_cosmos_aio_mod = types.ModuleType("azure.cosmos.aio")
69+
class CosmosClient: # placeholder; we patch this class below
70+
def __init__(self, *a, **k): ...
71+
def get_database_client(self, *a, **k): ...
72+
azure_cosmos_aio_mod.CosmosClient = CosmosClient
73+
sys.modules.setdefault("azure.cosmos.aio", azure_cosmos_aio_mod)
74+
75+
# ----------------- Import the SUT (after stubs are in place) -----------------
76+
try:
77+
# If you added an alias file src/backend/context/cosmos_memory.py, this will work:
78+
from src.backend.context.cosmos_memory import CosmosMemoryContext as CosmosBufferedChatCompletionContext
79+
except Exception:
80+
# Fallback to the kernel module (your provided code)
81+
from src.backend.context.cosmos_memory_kernel import CosmosMemoryContext as CosmosBufferedChatCompletionContext # type: ignore
82+
83+
# Import PartitionKey (either real or our shim) for assertions
84+
try:
85+
from azure.cosmos.partition_key import PartitionKey # type: ignore
86+
except Exception: # already defined above in shim
87+
pass
2888

89+
# ----------------- Fixtures -----------------
2990

3091
@pytest.fixture
31-
def mock_cosmos_client():
32-
"""Fixture for mocking Cosmos DB client and container."""
33-
mock_client = AsyncMock()
92+
def fake_cosmos_stack(monkeypatch):
93+
"""
94+
Patch the *SUT's* CosmosClient symbol so initialize() uses our AsyncMocks:
95+
CosmosClient(...).get_database_client() -> mock_db
96+
mock_db.create_container_if_not_exists(...) -> mock_container
97+
"""
98+
import sys
99+
34100
mock_container = AsyncMock()
35-
mock_client.create_container_if_not_exists.return_value = mock_container
101+
mock_db = AsyncMock()
102+
mock_db.create_container_if_not_exists = AsyncMock(return_value=mock_container)
36103

37-
# Mocking context methods
38-
mock_context = AsyncMock()
39-
mock_context.store_message = AsyncMock()
40-
mock_context.retrieve_messages = AsyncMock(
41-
return_value=async_iterable([{"id": "test_id", "content": "test_content"}])
42-
)
104+
def _fake_ctor(*_a, **_k):
105+
# mimic a client object with get_database_client returning our mock_db
106+
return types.SimpleNamespace(
107+
get_database_client=lambda *_a2, **_k2: mock_db
108+
)
43109

44-
return mock_client, mock_container, mock_context
110+
# Find the actual module where CosmosBufferedChatCompletionContext is defined
111+
sut_module_name = CosmosBufferedChatCompletionContext.__module__
112+
sut_module = sys.modules[sut_module_name]
45113

114+
# Patch the symbol the SUT imported (its local binding), not the SDK module
115+
monkeypatch.setattr(sut_module, "CosmosClient", _fake_ctor, raising=False)
116+
117+
return mock_db, mock_container
46118

47119
@pytest.fixture
48-
def mock_config(mock_cosmos_client):
49-
"""Fixture to patch Config with mock Cosmos DB client."""
50-
mock_client, _, _ = mock_cosmos_client
51-
with patch(
52-
"src.backend.config.Config.GetCosmosDatabaseClient", return_value=mock_client
53-
), patch("src.backend.config.Config.COSMOSDB_CONTAINER", "mock-container"):
54-
yield
120+
def mock_env(monkeypatch):
121+
# Optional: not strictly needed because we stubbed app_config.config above,
122+
# but keeps parity with your previous env fixture.
123+
env_vars = {
124+
"COSMOSDB_ENDPOINT": "https://mock-endpoint",
125+
"COSMOSDB_KEY": "mock-key",
126+
"COSMOSDB_DATABASE": "mock-database",
127+
"COSMOSDB_CONTAINER": "mock-container",
128+
}
129+
for k, v in env_vars.items():
130+
monkeypatch.setenv(k, v)
55131

132+
# ----------------- Test -----------------
56133

57134
@pytest.mark.asyncio
58-
async def test_initialize(mock_config, mock_cosmos_client):
59-
"""Test if the Cosmos DB container is initialized correctly."""
60-
mock_client, mock_container, _ = mock_cosmos_client
61-
context = CosmosBufferedChatCompletionContext(
62-
session_id="test_session", user_id="test_user"
63-
)
64-
await context.initialize()
65-
mock_client.create_container_if_not_exists.assert_called_once_with(
66-
id="mock-container", partition_key=PartitionKey(path="/session_id")
135+
async def test_initialize(fake_cosmos_stack, mock_env):
136+
mock_db, mock_container = fake_cosmos_stack
137+
138+
ctx = CosmosBufferedChatCompletionContext(
139+
session_id="test_session",
140+
user_id="test_user",
67141
)
68-
assert context._container == mock_container
142+
await ctx.initialize()
143+
144+
mock_db.create_container_if_not_exists.assert_called_once()
145+
# Strict arg check:
146+
args, kwargs = mock_db.create_container_if_not_exists.call_args
147+
assert kwargs.get("id") == "mock-container"
148+
pk = kwargs.get("partition_key")
149+
assert isinstance(pk, PartitionKey) and getattr(pk, "path", None) == "/session_id"
150+
151+
assert ctx._container == mock_container

src/backend/tests/helpers/test_azure_credential_utils.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,28 @@
1-
import pytest
2-
import sys
3-
import os
4-
from unittest.mock import patch, MagicMock
1+
import os, sys, importlib
52

6-
# Ensure src/backend is on the Python path for imports
7-
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
3+
# 1) Put repo's src/backend first on sys.path so "helpers" resolves to our package
4+
HERE = os.path.dirname(__file__)
5+
SRC_BACKEND = os.path.abspath(os.path.join(HERE, "..", ".."))
6+
if SRC_BACKEND not in sys.path:
7+
sys.path.insert(0, SRC_BACKEND)
88

9+
# 2) Evict any stub/foreign modules injected by other tests or site-packages
10+
sys.modules.pop("helpers.azure_credential_utils", None)
11+
sys.modules.pop("helpers", None)
12+
13+
# 3) Now import the real module under test
914
import helpers.azure_credential_utils as azure_credential_utils
1015

16+
# src/backend/tests/helpers/test_azure_credential_utils.py
17+
18+
import pytest
19+
from unittest.mock import patch, MagicMock
20+
1121
# Synchronous tests
1222

13-
@patch("helpers.azure_credential_utils.os.getenv")
14-
@patch("helpers.azure_credential_utils.DefaultAzureCredential")
15-
@patch("helpers.azure_credential_utils.ManagedIdentityCredential")
23+
@patch("helpers.azure_credential_utils.os.getenv", create=True)
24+
@patch("helpers.azure_credential_utils.DefaultAzureCredential", create=True)
25+
@patch("helpers.azure_credential_utils.ManagedIdentityCredential", create=True)
1626
def test_get_azure_credential_dev_env(mock_managed_identity_credential, mock_default_azure_credential, mock_getenv):
1727
"""Test get_azure_credential in dev environment."""
1828
mock_getenv.return_value = "dev"
@@ -26,14 +36,15 @@ def test_get_azure_credential_dev_env(mock_managed_identity_credential, mock_def
2636
mock_managed_identity_credential.assert_not_called()
2737
assert credential == mock_default_credential
2838

29-
@patch("helpers.azure_credential_utils.os.getenv")
30-
@patch("helpers.azure_credential_utils.DefaultAzureCredential")
31-
@patch("helpers.azure_credential_utils.ManagedIdentityCredential")
39+
@patch("helpers.azure_credential_utils.os.getenv", create=True)
40+
@patch("helpers.azure_credential_utils.DefaultAzureCredential", create=True)
41+
@patch("helpers.azure_credential_utils.ManagedIdentityCredential", create=True)
3242
def test_get_azure_credential_non_dev_env(mock_managed_identity_credential, mock_default_azure_credential, mock_getenv):
3343
"""Test get_azure_credential in non-dev environment."""
3444
mock_getenv.return_value = "prod"
3545
mock_managed_credential = MagicMock()
3646
mock_managed_identity_credential.return_value = mock_managed_credential
47+
3748
credential = azure_credential_utils.get_azure_credential(client_id="test-client-id")
3849

3950
mock_getenv.assert_called_once_with("APP_ENV", "prod")
@@ -44,9 +55,9 @@ def test_get_azure_credential_non_dev_env(mock_managed_identity_credential, mock
4455
# Asynchronous tests
4556

4657
@pytest.mark.asyncio
47-
@patch("helpers.azure_credential_utils.os.getenv")
48-
@patch("helpers.azure_credential_utils.AioDefaultAzureCredential")
49-
@patch("helpers.azure_credential_utils.AioManagedIdentityCredential")
58+
@patch("helpers.azure_credential_utils.os.getenv", create=True)
59+
@patch("helpers.azure_credential_utils.AioDefaultAzureCredential", create=True)
60+
@patch("helpers.azure_credential_utils.AioManagedIdentityCredential", create=True)
5061
async def test_get_azure_credential_async_dev_env(mock_aio_managed_identity_credential, mock_aio_default_azure_credential, mock_getenv):
5162
"""Test get_azure_credential_async in dev environment."""
5263
mock_getenv.return_value = "dev"
@@ -61,9 +72,9 @@ async def test_get_azure_credential_async_dev_env(mock_aio_managed_identity_cred
6172
assert credential == mock_aio_default_credential
6273

6374
@pytest.mark.asyncio
64-
@patch("helpers.azure_credential_utils.os.getenv")
65-
@patch("helpers.azure_credential_utils.AioDefaultAzureCredential")
66-
@patch("helpers.azure_credential_utils.AioManagedIdentityCredential")
75+
@patch("helpers.azure_credential_utils.os.getenv", create=True)
76+
@patch("helpers.azure_credential_utils.AioDefaultAzureCredential", create=True)
77+
@patch("helpers.azure_credential_utils.AioManagedIdentityCredential", create=True)
6778
async def test_get_azure_credential_async_non_dev_env(mock_aio_managed_identity_credential, mock_aio_default_azure_credential, mock_getenv):
6879
"""Test get_azure_credential_async in non-dev environment."""
6980
mock_getenv.return_value = "prod"
@@ -75,4 +86,4 @@ async def test_get_azure_credential_async_non_dev_env(mock_aio_managed_identity_
7586
mock_getenv.assert_called_once_with("APP_ENV", "prod")
7687
mock_aio_managed_identity_credential.assert_called_once_with(client_id="test-client-id")
7788
mock_aio_default_azure_credential.assert_not_called()
78-
assert credential == mock_aio_managed_credential
89+
assert credential == mock_aio_managed_credential

src/backend/tests/models/test_messages.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# File: test_message.py
22

33
import uuid
4-
from src.backend.models.messages import (
4+
from src.backend.models.messages_kernel import (
55
DataType,
6-
BAgentType,
6+
AgentType as BAgentType, # map to your enum
77
StepStatus,
88
PlanStatus,
99
HumanFeedbackStatus,
@@ -20,7 +20,7 @@ def test_enum_values():
2020
"""Test enumeration values for consistency."""
2121
assert DataType.session == "session"
2222
assert DataType.plan == "plan"
23-
assert BAgentType.human_agent == "HumanAgent"
23+
assert BAgentType.HUMAN == "Human_Agent" # was human_agent / "HumanAgent"
2424
assert StepStatus.completed == "completed"
2525
assert PlanStatus.in_progress == "in_progress"
2626
assert HumanFeedbackStatus.requested == "requested"
@@ -31,15 +31,15 @@ def test_plan_with_steps_update_counts():
3131
step1 = Step(
3232
plan_id=str(uuid.uuid4()),
3333
action="Review document",
34-
agent=BAgentType.human_agent,
34+
agent=BAgentType.HUMAN,
3535
status=StepStatus.completed,
3636
session_id=str(uuid.uuid4()),
3737
user_id=str(uuid.uuid4()),
3838
)
3939
step2 = Step(
4040
plan_id=str(uuid.uuid4()),
4141
action="Approve document",
42-
agent=BAgentType.hr_agent,
42+
agent=BAgentType.HR,
4343
status=StepStatus.failed,
4444
session_id=str(uuid.uuid4()),
4545
user_id=str(uuid.uuid4()),
@@ -78,10 +78,10 @@ def test_action_request_creation():
7878
plan_id=str(uuid.uuid4()),
7979
session_id=str(uuid.uuid4()),
8080
action="Review and approve",
81-
agent=BAgentType.procurement_agent,
81+
agent=BAgentType.PROCUREMENT,
8282
)
8383
assert action_request.action == "Review and approve"
84-
assert action_request.agent == BAgentType.procurement_agent
84+
assert action_request.agent == BAgentType.PROCUREMENT
8585

8686

8787
def test_human_feedback_creation():
@@ -114,7 +114,7 @@ def test_step_defaults():
114114
step = Step(
115115
plan_id=str(uuid.uuid4()),
116116
action="Prepare report",
117-
agent=BAgentType.generic_agent,
117+
agent=BAgentType.GENERIC,
118118
session_id=str(uuid.uuid4()),
119119
user_id=str(uuid.uuid4()),
120120
)

0 commit comments

Comments
 (0)