|
| 1 | +# src/backend/tests/context/test_cosmos_memory.py |
| 2 | +# Drop-in test that self-stubs all external imports used by cosmos_memory_kernel |
| 3 | +# so we don't need to modify the repo structure or CI env. |
| 4 | + |
| 5 | +import sys |
| 6 | +import types |
1 | 7 | import pytest |
2 | | -from unittest.mock import AsyncMock, patch |
3 | | -from azure.cosmos.partition_key import PartitionKey |
4 | | -from src.backend.context.cosmos_memory import CosmosBufferedChatCompletionContext |
| 8 | +from unittest.mock import AsyncMock |
5 | 9 |
|
| 10 | +# ----------------- Preload stub modules so the SUT can import cleanly ----------------- |
6 | 11 |
|
7 | | -# Helper to create async iterable |
8 | | -async def async_iterable(mock_items): |
9 | | - """Helper to create an async iterable.""" |
10 | | - for item in mock_items: |
11 | | - yield item |
| 12 | +# 1) helpers.azure_credential_utils.get_azure_credential |
| 13 | +helpers_mod = types.ModuleType("helpers") |
| 14 | +helpers_cred_mod = types.ModuleType("helpers.azure_credential_utils") |
| 15 | +def _fake_get_azure_credential(*_a, **_k): |
| 16 | + return object() |
| 17 | +helpers_cred_mod.get_azure_credential = _fake_get_azure_credential |
| 18 | +helpers_mod.azure_credential_utils = helpers_cred_mod |
| 19 | +sys.modules.setdefault("helpers", helpers_mod) |
| 20 | +sys.modules.setdefault("helpers.azure_credential_utils", helpers_cred_mod) |
12 | 21 |
|
| 22 | +# 2) app_config.config (the SUT does: from app_config import config) |
| 23 | +app_config_mod = types.ModuleType("app_config") |
| 24 | +app_config_mod.config = types.SimpleNamespace( |
| 25 | + COSMOSDB_CONTAINER="mock-container", |
| 26 | + COSMOSDB_ENDPOINT="https://mock-endpoint", |
| 27 | + COSMOSDB_DATABASE="mock-database", |
| 28 | +) |
| 29 | +sys.modules.setdefault("app_config", app_config_mod) |
13 | 30 |
|
14 | | -@pytest.fixture |
15 | | -def mock_env_variables(monkeypatch): |
16 | | - """Mock all required environment variables.""" |
17 | | - env_vars = { |
18 | | - "COSMOSDB_ENDPOINT": "https://mock-endpoint", |
19 | | - "COSMOSDB_KEY": "mock-key", |
20 | | - "COSMOSDB_DATABASE": "mock-database", |
21 | | - "COSMOSDB_CONTAINER": "mock-container", |
22 | | - "AZURE_OPENAI_DEPLOYMENT_NAME": "mock-deployment-name", |
23 | | - "AZURE_OPENAI_API_VERSION": "2023-01-01", |
24 | | - "AZURE_OPENAI_ENDPOINT": "https://mock-openai-endpoint", |
25 | | - } |
26 | | - for key, value in env_vars.items(): |
27 | | - monkeypatch.setenv(key, value) |
| 31 | +# 3) models.messages_kernel (the SUT does: from models.messages_kernel import ...) |
| 32 | +models_mod = types.ModuleType("models") |
| 33 | +models_messages_mod = types.ModuleType("models.messages_kernel") |
| 34 | + |
| 35 | +# Minimal stand-ins so type hints/imports succeed (not used in this test path) |
| 36 | +class _Base: ... |
| 37 | +class BaseDataModel(_Base): ... |
| 38 | +class Plan(_Base): ... |
| 39 | +class Session(_Base): ... |
| 40 | +class Step(_Base): ... |
| 41 | +class AgentMessage(_Base): ... |
| 42 | + |
| 43 | +models_messages_mod.BaseDataModel = BaseDataModel |
| 44 | +models_messages_mod.Plan = Plan |
| 45 | +models_messages_mod.Session = Session |
| 46 | +models_messages_mod.Step = Step |
| 47 | +models_messages_mod.AgentMessage = AgentMessage |
| 48 | +models_mod.messages_kernel = models_messages_mod |
| 49 | +sys.modules.setdefault("models", models_mod) |
| 50 | +sys.modules.setdefault("models.messages_kernel", models_messages_mod) |
| 51 | + |
| 52 | +# 4) azure.cosmos.partition_key.PartitionKey (provide if sdk isn't installed) |
| 53 | +try: |
| 54 | + from azure.cosmos.partition_key import PartitionKey # type: ignore |
| 55 | +except Exception: # pragma: no cover |
| 56 | + azure_mod = sys.modules.setdefault("azure", types.ModuleType("azure")) |
| 57 | + azure_cosmos_mod = sys.modules.setdefault("azure.cosmos", types.ModuleType("azure.cosmos")) |
| 58 | + azure_cosmos_pk_mod = types.ModuleType("azure.cosmos.partition_key") |
| 59 | + class PartitionKey: # minimal shim |
| 60 | + def __init__(self, path: str): self.path = path |
| 61 | + azure_cosmos_pk_mod.PartitionKey = PartitionKey |
| 62 | + sys.modules.setdefault("azure.cosmos.partition_key", azure_cosmos_pk_mod) |
| 63 | + |
| 64 | +# 5) azure.cosmos.aio.CosmosClient (we’ll patch it in a fixture, but ensure import exists) |
| 65 | +try: |
| 66 | + from azure.cosmos.aio import CosmosClient # type: ignore |
| 67 | +except Exception: # pragma: no cover |
| 68 | + azure_cosmos_aio_mod = types.ModuleType("azure.cosmos.aio") |
| 69 | + class CosmosClient: # placeholder; we patch this class below |
| 70 | + def __init__(self, *a, **k): ... |
| 71 | + def get_database_client(self, *a, **k): ... |
| 72 | + azure_cosmos_aio_mod.CosmosClient = CosmosClient |
| 73 | + sys.modules.setdefault("azure.cosmos.aio", azure_cosmos_aio_mod) |
| 74 | + |
| 75 | +# ----------------- Import the SUT (after stubs are in place) ----------------- |
| 76 | +try: |
| 77 | + # If you added an alias file src/backend/context/cosmos_memory.py, this will work: |
| 78 | + from src.backend.context.cosmos_memory import CosmosMemoryContext as CosmosBufferedChatCompletionContext |
| 79 | +except Exception: |
| 80 | + # Fallback to the kernel module (your provided code) |
| 81 | + from src.backend.context.cosmos_memory_kernel import CosmosMemoryContext as CosmosBufferedChatCompletionContext # type: ignore |
| 82 | + |
| 83 | +# Import PartitionKey (either real or our shim) for assertions |
| 84 | +try: |
| 85 | + from azure.cosmos.partition_key import PartitionKey # type: ignore |
| 86 | +except Exception: # already defined above in shim |
| 87 | + pass |
28 | 88 |
|
| 89 | +# ----------------- Fixtures ----------------- |
29 | 90 |
|
30 | 91 | @pytest.fixture |
31 | | -def mock_cosmos_client(): |
32 | | - """Fixture for mocking Cosmos DB client and container.""" |
33 | | - mock_client = AsyncMock() |
| 92 | +def fake_cosmos_stack(monkeypatch): |
| 93 | + """ |
| 94 | + Patch the *SUT's* CosmosClient symbol so initialize() uses our AsyncMocks: |
| 95 | + CosmosClient(...).get_database_client() -> mock_db |
| 96 | + mock_db.create_container_if_not_exists(...) -> mock_container |
| 97 | + """ |
| 98 | + import sys |
| 99 | + |
34 | 100 | mock_container = AsyncMock() |
35 | | - mock_client.create_container_if_not_exists.return_value = mock_container |
| 101 | + mock_db = AsyncMock() |
| 102 | + mock_db.create_container_if_not_exists = AsyncMock(return_value=mock_container) |
36 | 103 |
|
37 | | - # Mocking context methods |
38 | | - mock_context = AsyncMock() |
39 | | - mock_context.store_message = AsyncMock() |
40 | | - mock_context.retrieve_messages = AsyncMock( |
41 | | - return_value=async_iterable([{"id": "test_id", "content": "test_content"}]) |
42 | | - ) |
| 104 | + def _fake_ctor(*_a, **_k): |
| 105 | + # mimic a client object with get_database_client returning our mock_db |
| 106 | + return types.SimpleNamespace( |
| 107 | + get_database_client=lambda *_a2, **_k2: mock_db |
| 108 | + ) |
43 | 109 |
|
44 | | - return mock_client, mock_container, mock_context |
| 110 | + # Find the actual module where CosmosBufferedChatCompletionContext is defined |
| 111 | + sut_module_name = CosmosBufferedChatCompletionContext.__module__ |
| 112 | + sut_module = sys.modules[sut_module_name] |
45 | 113 |
|
| 114 | + # Patch the symbol the SUT imported (its local binding), not the SDK module |
| 115 | + monkeypatch.setattr(sut_module, "CosmosClient", _fake_ctor, raising=False) |
| 116 | + |
| 117 | + return mock_db, mock_container |
46 | 118 |
|
47 | 119 | @pytest.fixture |
48 | | -def mock_config(mock_cosmos_client): |
49 | | - """Fixture to patch Config with mock Cosmos DB client.""" |
50 | | - mock_client, _, _ = mock_cosmos_client |
51 | | - with patch( |
52 | | - "src.backend.config.Config.GetCosmosDatabaseClient", return_value=mock_client |
53 | | - ), patch("src.backend.config.Config.COSMOSDB_CONTAINER", "mock-container"): |
54 | | - yield |
| 120 | +def mock_env(monkeypatch): |
| 121 | + # Optional: not strictly needed because we stubbed app_config.config above, |
| 122 | + # but keeps parity with your previous env fixture. |
| 123 | + env_vars = { |
| 124 | + "COSMOSDB_ENDPOINT": "https://mock-endpoint", |
| 125 | + "COSMOSDB_KEY": "mock-key", |
| 126 | + "COSMOSDB_DATABASE": "mock-database", |
| 127 | + "COSMOSDB_CONTAINER": "mock-container", |
| 128 | + } |
| 129 | + for k, v in env_vars.items(): |
| 130 | + monkeypatch.setenv(k, v) |
55 | 131 |
|
| 132 | +# ----------------- Test ----------------- |
56 | 133 |
|
57 | 134 | @pytest.mark.asyncio |
58 | | -async def test_initialize(mock_config, mock_cosmos_client): |
59 | | - """Test if the Cosmos DB container is initialized correctly.""" |
60 | | - mock_client, mock_container, _ = mock_cosmos_client |
61 | | - context = CosmosBufferedChatCompletionContext( |
62 | | - session_id="test_session", user_id="test_user" |
63 | | - ) |
64 | | - await context.initialize() |
65 | | - mock_client.create_container_if_not_exists.assert_called_once_with( |
66 | | - id="mock-container", partition_key=PartitionKey(path="/session_id") |
| 135 | +async def test_initialize(fake_cosmos_stack, mock_env): |
| 136 | + mock_db, mock_container = fake_cosmos_stack |
| 137 | + |
| 138 | + ctx = CosmosBufferedChatCompletionContext( |
| 139 | + session_id="test_session", |
| 140 | + user_id="test_user", |
67 | 141 | ) |
68 | | - assert context._container == mock_container |
| 142 | + await ctx.initialize() |
| 143 | + |
| 144 | + mock_db.create_container_if_not_exists.assert_called_once() |
| 145 | + # Strict arg check: |
| 146 | + args, kwargs = mock_db.create_container_if_not_exists.call_args |
| 147 | + assert kwargs.get("id") == "mock-container" |
| 148 | + pk = kwargs.get("partition_key") |
| 149 | + assert isinstance(pk, PartitionKey) and getattr(pk, "path", None) == "/session_id" |
| 150 | + |
| 151 | + assert ctx._container == mock_container |
0 commit comments