Skip to content

Commit 7d31c8b

Browse files
authored
INTPYTHON-660 Inline removed semantic kernel test file (#81)
1 parent 8ccda91 commit 7d31c8b

File tree

2 files changed

+284
-2
lines changed

2 files changed

+284
-2
lines changed

semantic-kernel-python/run.sh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ make install-python
2424
make install-sk
2525
make install-pre-commit
2626

27+
cp $SCRIPT_DIR/test_mongodb_atlas_memory_store.py .
28+
2729
# shellcheck disable=SC2154
2830
OPENAI_API_KEY="$OPENAI_API_KEY" \
2931
OPENAI_ORG_ID="" \
@@ -32,7 +34,7 @@ OPENAI_API_KEY="$OPENAI_API_KEY" \
3234
AZURE_OPENAI_API_KEY="" \
3335
MONGODB_ATLAS_CONNECTION_STRING=$MONGODB_URI \
3436
Python_Integration_Tests=1 \
35-
uv run pytest tests/integration/memory/memory_stores/test_mongodb_atlas_memory_store.py -k test_collection_knn
37+
uv run pytest test_mongodb_atlas_memory_store.py -k test_collection_knn
3638

3739
# shellcheck disable=SC2154
3840
OPENAI_API_KEY="$OPENAI_API_KEY" \
@@ -42,4 +44,4 @@ OPENAI_API_KEY="$OPENAI_API_KEY" \
4244
AZURE_OPENAI_API_KEY="" \
4345
MONGODB_ATLAS_CONNECTION_STRING=$MONGODB_URI \
4446
Python_Integration_Tests=1 \
45-
uv run pytest tests/integration/memory/memory_stores/test_mongodb_atlas_memory_store.py
47+
uv run pytest test_mongodb_atlas_memory_store.py
Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
# Copied from removed file: https://github.com/microsoft/semantic-kernel/pull/12271/files#diff-f3c976da0b5944eb15db3c36c8b98120e8f47197ed05c39ee63e5d50c174758a
3+
import asyncio
4+
import random
5+
6+
import numpy as np
7+
import pytest
8+
import pytest_asyncio
9+
10+
from semantic_kernel.connectors.memory_stores.mongodb_atlas.mongodb_atlas_memory_store import (
11+
MongoDBAtlasMemoryStore,
12+
)
13+
from semantic_kernel.exceptions import MemoryConnectorInitializationError
14+
from semantic_kernel.memory.memory_record import MemoryRecord
15+
16+
import motor # noqa: F401
17+
from pymongo import errors
18+
19+
DUPLICATE_INDEX_ERR_CODE = 68
20+
READ_ONLY_COLLECTION = "nearestSearch"
21+
DIMENSIONS = 3
22+
23+
24+
def is_equal_memory_record(
25+
mem1: MemoryRecord, mem2: MemoryRecord, with_embeddings: bool
26+
):
27+
"""Comparator for two memory records"""
28+
29+
def dictify_memory_record(mem):
30+
return {k: v for k, v in mem.__dict__.items() if k != "_embedding"}
31+
32+
assert dictify_memory_record(mem1) == dictify_memory_record(mem2)
33+
if with_embeddings:
34+
assert mem1._embedding.tolist() == mem2._embedding.tolist()
35+
36+
37+
@pytest.fixture
38+
def memory_record_gen():
39+
def memory_record(_id):
40+
return MemoryRecord(
41+
id=str(_id),
42+
text=f"{_id} text",
43+
is_reference=False,
44+
embedding=np.array([1 / (_id + val) for val in range(0, DIMENSIONS)]),
45+
description=f"{_id} description",
46+
external_source_name=f"{_id} external source",
47+
additional_metadata=f"{_id} additional metadata",
48+
timestamp=None,
49+
key=str(_id),
50+
)
51+
52+
return memory_record
53+
54+
55+
@pytest.fixture
56+
def test_collection():
57+
return f"AVSTest-{random.randint(0, 9999)}"
58+
59+
60+
@pytest.fixture
61+
def memory():
62+
try:
63+
return MongoDBAtlasMemoryStore(database_name="pyMSKTest")
64+
except MemoryConnectorInitializationError:
65+
pytest.skip("MongoDB Atlas connection string not found in env vars.")
66+
67+
68+
@pytest_asyncio.fixture
69+
async def vector_search_store(memory):
70+
await memory.__aenter__()
71+
try:
72+
# Delete all collections before and after
73+
for cname in await memory.get_collections():
74+
await memory.delete_collection(cname)
75+
76+
def patch_index_exception(fn):
77+
"""Function patch for collection creation call to retry
78+
on duplicate index errors
79+
"""
80+
81+
async def _patch(collection_name):
82+
while True:
83+
try:
84+
await fn(collection_name)
85+
break
86+
except errors.OperationFailure as e:
87+
# In this test instance, this error code is indicative
88+
# of a previous index not completing teardown
89+
if e.code != DUPLICATE_INDEX_ERR_CODE:
90+
raise
91+
await asyncio.sleep(1)
92+
93+
return _patch
94+
95+
memory.create_collection = patch_index_exception(memory.create_collection)
96+
97+
try:
98+
yield memory
99+
finally:
100+
pass
101+
for cname in await memory.get_collections():
102+
await memory.delete_collection(cname)
103+
except Exception:
104+
pass
105+
finally:
106+
await memory.__aexit__(None, None, None)
107+
108+
109+
@pytest_asyncio.fixture
110+
async def nearest_match_store(memory):
111+
"""Fixture for read only vector store; the URI for test needs atlas configured"""
112+
await memory.__aenter__()
113+
try:
114+
if not await memory.does_collection_exist("nearestSearch"):
115+
pytest.skip(
116+
reason="db: readOnly collection: nearestSearch not found, "
117+
"please ensure your Atlas Test Cluster has this collection configured"
118+
)
119+
yield memory
120+
except Exception:
121+
pass
122+
finally:
123+
await memory.__aexit__(None, None, None)
124+
125+
126+
async def test_constructor(memory):
127+
assert isinstance(memory, MongoDBAtlasMemoryStore)
128+
129+
130+
async def test_collection_create_and_delete(vector_search_store, test_collection):
131+
await vector_search_store.create_collection(test_collection)
132+
assert await vector_search_store.does_collection_exist(test_collection)
133+
await vector_search_store.delete_collection(test_collection)
134+
assert not await vector_search_store.does_collection_exist(test_collection)
135+
136+
137+
async def test_collection_upsert(
138+
vector_search_store, test_collection, memory_record_gen
139+
):
140+
mems = [memory_record_gen(i) for i in range(1, 4)]
141+
mem1 = await vector_search_store.upsert(test_collection, mems[0])
142+
assert mem1 == mems[0]._id
143+
144+
145+
async def test_collection_batch_upsert(
146+
vector_search_store, test_collection, memory_record_gen
147+
):
148+
mems = [memory_record_gen(i) for i in range(1, 4)]
149+
mems_check = await vector_search_store.upsert_batch(test_collection, mems)
150+
assert [m._id for m in mems] == mems_check
151+
152+
153+
async def test_collection_deletion(
154+
vector_search_store, test_collection, memory_record_gen
155+
):
156+
mem = memory_record_gen(1)
157+
await vector_search_store.upsert(test_collection, mem)
158+
insertion_val = await vector_search_store.get(test_collection, mem._id, True)
159+
assert mem._id == insertion_val._id
160+
assert mem._embedding.tolist() == insertion_val._embedding.tolist()
161+
assert insertion_val is not None
162+
await vector_search_store.remove(test_collection, mem._id)
163+
val = await vector_search_store.get(test_collection, mem._id, False)
164+
assert val is None
165+
166+
167+
async def test_collection_batch_deletion(
168+
vector_search_store, test_collection, memory_record_gen
169+
):
170+
mems = [memory_record_gen(i) for i in range(1, 4)]
171+
await vector_search_store.upsert_batch(test_collection, mems)
172+
ids = [mem._id for mem in mems]
173+
insertion_val = await vector_search_store.get_batch(test_collection, ids, True)
174+
assert len(insertion_val) == len(mems)
175+
await vector_search_store.remove_batch(test_collection, ids)
176+
assert not await vector_search_store.get_batch(test_collection, ids, False)
177+
178+
179+
async def test_collection_get(vector_search_store, test_collection, memory_record_gen):
180+
mem = memory_record_gen(1)
181+
await vector_search_store.upsert(test_collection, mem)
182+
insertion_val = await vector_search_store.get(test_collection, mem._id, False)
183+
is_equal_memory_record(mem, insertion_val, False)
184+
185+
refetched_record = await vector_search_store.get(test_collection, mem._id, True)
186+
is_equal_memory_record(mem, refetched_record, True)
187+
188+
189+
async def test_collection_batch_get(
190+
vector_search_store, test_collection, memory_record_gen
191+
):
192+
mems = {str(i): memory_record_gen(i) for i in range(1, 4)}
193+
await vector_search_store.upsert_batch(test_collection, list(mems.values()))
194+
insertion_val = await vector_search_store.get_batch(
195+
test_collection, list(mems.keys()), False
196+
)
197+
assert len(insertion_val) == len(mems)
198+
for val in insertion_val:
199+
is_equal_memory_record(mems[val._id], val, False)
200+
201+
refetched_vals = await vector_search_store.get_batch(
202+
test_collection, list(mems.keys()), True
203+
)
204+
for ref in refetched_vals:
205+
is_equal_memory_record(mems[ref._id], ref, True)
206+
207+
208+
async def test_collection_knn_match(nearest_match_store, memory_record_gen):
209+
mem = memory_record_gen(7)
210+
await nearest_match_store.upsert(READ_ONLY_COLLECTION, mem)
211+
result, score = await nearest_match_store.get_nearest_match(
212+
collection_name=READ_ONLY_COLLECTION,
213+
embedding=mem._embedding,
214+
with_embedding=True,
215+
)
216+
is_equal_memory_record(mem, result, True)
217+
assert score
218+
219+
220+
async def test_collection_knn_match_with_score(nearest_match_store, memory_record_gen):
221+
mem = memory_record_gen(7)
222+
await nearest_match_store.upsert(READ_ONLY_COLLECTION, mem)
223+
result, score = await nearest_match_store.get_nearest_match(
224+
collection_name=READ_ONLY_COLLECTION,
225+
embedding=mem._embedding,
226+
with_embedding=True,
227+
min_relevance_score=0.0,
228+
)
229+
is_equal_memory_record(mem, result, True)
230+
assert score
231+
232+
233+
async def knn_matcher(
234+
nearest_match_store,
235+
test_collection,
236+
mems,
237+
query_limit,
238+
expected_limit,
239+
min_relevance_score,
240+
):
241+
results_and_scores = await nearest_match_store.get_nearest_matches(
242+
collection_name=test_collection,
243+
embedding=mems["2"]._embedding,
244+
limit=query_limit,
245+
with_embeddings=True,
246+
min_relevance_score=min_relevance_score,
247+
)
248+
assert len(results_and_scores) == expected_limit
249+
scores = [score for _, score in results_and_scores]
250+
assert scores == sorted(scores, reverse=True)
251+
for result, _ in results_and_scores:
252+
is_equal_memory_record(mems[result._id], result, True)
253+
254+
255+
async def test_collection_knn_matches(nearest_match_store, memory_record_gen):
256+
mems = {str(i): memory_record_gen(i) for i in range(1, 4)}
257+
await nearest_match_store.upsert_batch(READ_ONLY_COLLECTION, mems.values())
258+
await knn_matcher(
259+
nearest_match_store,
260+
READ_ONLY_COLLECTION,
261+
mems,
262+
query_limit=2,
263+
expected_limit=2,
264+
min_relevance_score=None,
265+
)
266+
267+
268+
async def test_collection_knn_matches_with_score(
269+
nearest_match_store, memory_record_gen
270+
):
271+
mems = {str(i): memory_record_gen(i) for i in range(1, 4)}
272+
await nearest_match_store.upsert_batch(READ_ONLY_COLLECTION, mems.values())
273+
await knn_matcher(
274+
nearest_match_store,
275+
READ_ONLY_COLLECTION,
276+
mems,
277+
query_limit=2,
278+
expected_limit=2,
279+
min_relevance_score=0.0,
280+
)

0 commit comments

Comments
 (0)