Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/semantic_kernel/connectors/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ def _add_key(self, key: TKey, record: dict[str, Any]) -> dict[str, Any]:

@override
async def _inner_delete(self, keys: Sequence[str], **kwargs: Any) -> None:
await asyncio.gather(*[self.redis_database.json().delete(key, **kwargs) for key in keys])
await asyncio.gather(*[self.redis_database.json().delete(self._get_redis_key(key), **kwargs) for key in keys])

@override
def _serialize_dicts_to_store_models(
Expand Down
348 changes: 348 additions & 0 deletions python/tests/integration/memory/test_redis_vector_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,348 @@
# Copyright (c) Microsoft. All rights reserved.

"""Extended Redis connector integration tests.

These supplement the single-record round-trip covered by
``test_vector_store.py`` and exercise the rest of the public surface
(`RedisStore.list_collection_names`, vector search with filters, batch
CRUD, `include_vectors`, manual index creation, and no-prefix mode)
to validate end-to-end compatibility with Redis and with Valkey +
valkey-search.

All tests require a running Redis/Valkey server reachable via
``REDIS_CONNECTION_STRING``.
"""

import asyncio
import contextlib
from dataclasses import dataclass, field
from typing import Annotated
from uuid import uuid4

import pytest

from semantic_kernel.connectors.redis import (
RedisHashsetCollection,
RedisJsonCollection,
RedisStore,
)
from semantic_kernel.data.vector import VectorStoreField, vectorstoremodel
from semantic_kernel.exceptions import MemoryConnectorConnectionException

# Vector search is broken on main due to redisvl 0.5+ API change and a
# missing guard in the hashset deserializer. See:
# https://github.com/microsoft/semantic-kernel/issues/13896
# https://github.com/microsoft/semantic-kernel/pull/13899
# Once that PR merges, remove the xfail markers below.
_SEARCH_XFAIL = pytest.mark.xfail(
reason="Vector search broken on main — process_results API mismatch + KeyError on missing vector field (#13896)",
raises=(Exception,),
strict=False,
)


@vectorstoremodel
@dataclass
class CoverageModel:
"""Shared record shape for all coverage tests."""

vector: Annotated[
list[float] | None,
VectorStoreField(
"vector",
index_kind="hnsw",
dimensions=5,
distance_function="cosine_similarity",
type="float",
),
] = None
id: Annotated[str, VectorStoreField("key", type="str")] = field(default_factory=lambda: str(uuid4()))
content: Annotated[str, VectorStoreField("data", type="str", is_full_text_indexed=True)] = "content"


def _records() -> list[CoverageModel]:
return [
CoverageModel(id="cov-1", content="alpha", vector=[0.1, 0.2, 0.3, 0.4, 0.5]),
CoverageModel(id="cov-2", content="beta", vector=[0.2, 0.3, 0.4, 0.5, 0.6]),
CoverageModel(id="cov-3", content="gamma", vector=[0.9, 0.8, 0.7, 0.6, 0.5]),
]


async def _collect(results):
"""Consume KernelSearchResults.results into a list."""
return [r async for r in results.results]


@pytest.fixture
def collection_cls(request):
"""Parametrized fixture selecting the concrete collection class."""
return request.param


@pytest.fixture
async def collection(collection_cls):
"""Yields a freshly-created collection; cleans the index up at teardown.

Uses ``prefix_collection_name_to_key_names=True`` so each parametrized
run has its own keyspace and hashset/json tests do not collide on
raw keys.
"""
name = f"sk_cov_{uuid4().hex[:8]}"
try:
col = collection_cls(
record_type=CoverageModel,
collection_name=name,
prefix_collection_name_to_key_names=True,
)
except MemoryConnectorConnectionException as exc:
pytest.xfail(f"Failed to connect to store: {exc}")

async with col:
try:
await col.ensure_collection_deleted()
await col.ensure_collection_exists()
yield col
finally:
with contextlib.suppress(Exception):
await col.ensure_collection_deleted()


_COLLECTION_CLASSES = [
pytest.param(RedisHashsetCollection, id="hashset"),
pytest.param(RedisJsonCollection, id="json"),
]


@pytest.mark.parametrize("collection_cls", _COLLECTION_CLASSES, indirect=True)
class TestRedisCoverage:
async def test_collection_exists_lifecycle(self, collection):
"""collection_exists tracks ensure_collection_exists / _deleted."""
assert await collection.collection_exists() is True
await collection.ensure_collection_deleted()
assert await collection.collection_exists() is False
await collection.ensure_collection_exists()
assert await collection.collection_exists() is True

async def test_list_collection_names_includes_created(self, collection):
"""RedisStore.list_collection_names surfaces the created index via FT._LIST."""
try:
store = RedisStore()
except MemoryConnectorConnectionException as exc:
pytest.xfail(f"Failed to connect to store: {exc}")
try:
names = await store.list_collection_names()
assert collection.collection_name in names
finally:
await store.redis_database.aclose()

async def test_batch_upsert_get_delete(self, collection):
"""Multi-record upsert, get, and delete round-trip."""
records = _records()
await collection.upsert(records)

fetched = await collection.get([r.id for r in records])
assert fetched is not None
assert {r.id for r in fetched} == {r.id for r in records}

await collection.delete([r.id for r in records])
after = await collection.get([r.id for r in records])
assert not after

async def test_get_include_vectors(self, collection):
"""get with include_vectors=True returns the vector, False hides it."""
[first, *_] = _records()
await collection.upsert([first])

with_vec = await collection.get(first.id, include_vectors=True)
without_vec = await collection.get(first.id, include_vectors=False)

assert with_vec is not None
assert without_vec is not None
assert with_vec.vector is not None
assert without_vec.vector is None

@_SEARCH_XFAIL
async def test_vector_search_basic(self, collection):
"""FT.SEARCH with an HNSW query returns results ordered by distance."""
records = _records()
await collection.upsert(records)
await asyncio.sleep(0.2)

results = await _collect(await collection.search(vector=[0.1, 0.2, 0.3, 0.4, 0.5], top=3))
assert len(results) == 3
assert results[0].record.id == "cov-1"

@_SEARCH_XFAIL
async def test_vector_search_top_skip(self, collection):
"""top/skip paging works end-to-end."""
await collection.upsert(_records())
await asyncio.sleep(0.2)

page1 = await _collect(await collection.search(vector=[0.1, 0.2, 0.3, 0.4, 0.5], top=2, skip=0))
page2 = await _collect(await collection.search(vector=[0.1, 0.2, 0.3, 0.4, 0.5], top=2, skip=2))
assert len(page1) == 2
assert len(page2) == 1
seen = {r.record.id for r in page1} | {r.record.id for r in page2}
assert seen == {"cov-1", "cov-2", "cov-3"}

@_SEARCH_XFAIL
async def test_vector_search_with_tag_filter(self, collection):
"""Lambda filter on a text field is translated and honoured."""
await collection.upsert(_records())
await asyncio.sleep(0.2)

results = await _collect(
await collection.search(
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
top=5,
filter=lambda r: r.content == "beta",
)
)
assert len(results) == 1
assert results[0].record.id == "cov-2"

@_SEARCH_XFAIL
async def test_vector_search_include_vectors(self, collection):
"""include_vectors toggles whether the vector is returned on search hits."""
await collection.upsert(_records())
await asyncio.sleep(0.2)

with_vec = await _collect(
await collection.search(vector=[0.1, 0.2, 0.3, 0.4, 0.5], top=1, include_vectors=True)
)
without_vec = await _collect(
await collection.search(vector=[0.1, 0.2, 0.3, 0.4, 0.5], top=1, include_vectors=False)
)
assert with_vec[0].record.vector is not None
assert without_vec[0].record.vector is None


class TestRedisCoverageNoPrefix:
"""prefix_collection_name_to_key_names=False should round-trip by raw key."""

@pytest.mark.parametrize(
"collection_cls",
[
pytest.param(RedisHashsetCollection, id="hashset"),
pytest.param(RedisJsonCollection, id="json"),
],
)
async def test_upsert_get_delete_without_prefix(self, collection_cls):
name = f"sk_cov_np_{uuid4().hex[:8]}"
try:
col = collection_cls(
record_type=CoverageModel,
collection_name=name,
prefix_collection_name_to_key_names=False,
)
except MemoryConnectorConnectionException as exc:
pytest.xfail(f"Failed to connect to store: {exc}")

async with col:
await col.ensure_collection_deleted()
await col.ensure_collection_exists()
try:
rec = CoverageModel(
id=f"np-{uuid4().hex[:6]}",
content="alpha",
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
)
await col.upsert([rec])
fetched = await col.get(rec.id)
assert fetched is not None
assert fetched.id == rec.id
await col.delete(rec.id)
assert not await col.get(rec.id)
finally:
await col.ensure_collection_deleted()


@pytest.mark.parametrize("collection_cls", _COLLECTION_CLASSES, indirect=True)
class TestRedisCoverageExtended:
"""Extra coverage for paths not exercised by TestRedisCoverage."""

async def test_ensure_collection_exists_with_explicit_index(self, collection):
"""ensure_collection_exists(index_definition=..., fields=...) uses the provided override."""
from redis.commands.search.field import TextField, VectorField
from redis.commands.search.index_definition import IndexDefinition, IndexType

await collection.ensure_collection_deleted()
assert await collection.collection_exists() is False

index_type = IndexType.JSON if isinstance(collection, RedisJsonCollection) else IndexType.HASH
content_field = (
TextField("$.content", as_name="content") if index_type == IndexType.JSON else TextField("content")
)
vector_field = (
VectorField(
"$.vector",
"HNSW",
{"TYPE": "FLOAT32", "DIM": 5, "DISTANCE_METRIC": "COSINE"},
as_name="vector",
)
if index_type == IndexType.JSON
else VectorField(
"vector",
"HNSW",
{"TYPE": "FLOAT32", "DIM": 5, "DISTANCE_METRIC": "COSINE"},
)
)
await collection.ensure_collection_exists(
index_definition=IndexDefinition(prefix=[f"{collection.collection_name}:"], index_type=index_type),
fields=[content_field, vector_field],
)
assert await collection.collection_exists() is True

async def test_ensure_collection_exists_invalid_index_definition(self, collection):
"""Passing a non-IndexDefinition with fields should raise."""
from semantic_kernel.exceptions import VectorStoreOperationException

with pytest.raises(VectorStoreOperationException, match="Invalid index type supplied."):
await collection.ensure_collection_exists(index_definition="not-an-IndexDefinition", fields=["content"])

@_SEARCH_XFAIL
async def test_vector_search_not_equal_filter(self, collection):
await collection.upsert(_records())
await asyncio.sleep(0.2)
results = await _collect(
await collection.search(
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
top=5,
filter=lambda r: r.content != "alpha",
)
)
ids = {r.record.id for r in results}
assert ids == {"cov-2", "cov-3"}

@_SEARCH_XFAIL
async def test_vector_search_and_filter(self, collection):
await collection.upsert(_records())
await asyncio.sleep(0.2)
results = await _collect(
await collection.search(
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
top=5,
filter=lambda r: (r.content != "alpha") and (r.content != "gamma"),
)
)
assert {r.record.id for r in results} == {"cov-2"}

@_SEARCH_XFAIL
async def test_vector_search_or_filter(self, collection):
await collection.upsert(_records())
await asyncio.sleep(0.2)
results = await _collect(
await collection.search(
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
top=5,
filter=lambda r: (r.content == "alpha") or (r.content == "gamma"),
)
)
assert {r.record.id for r in results} == {"cov-1", "cov-3"}

async def test_get_without_keys_not_implemented(self, collection):
"""get with no keys should raise NotImplementedError via the connector."""
from semantic_kernel.data.vector import GetFilteredRecordOptions

with pytest.raises(NotImplementedError):
await collection._inner_get(keys=None, options=GetFilteredRecordOptions())
12 changes: 12 additions & 0 deletions python/tests/unit/connectors/memory/test_redis_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,18 @@ async def test_delete(collection_hash, collection_json, type_):
await collection._inner_delete(["id1"])


async def test_delete_with_prefix_json(collection_with_prefix_json, mock_delete_json):
"""Verify JSON delete prefixes keys when prefix_collection_name_to_key_names=True (#13904)."""
await collection_with_prefix_json._inner_delete(["id1"])
mock_delete_json.assert_called_once_with("test:id1")


async def test_delete_with_prefix_hash(collection_with_prefix_hash, mock_delete_hash):
"""Verify hashset delete prefixes keys when prefix_collection_name_to_key_names=True."""
await collection_with_prefix_hash._inner_delete(["id1"])
mock_delete_hash.assert_called_once_with("test:id1")


async def test_collection_exists(collection_hash, mock_collection_exists):
await collection_hash.collection_exists()

Expand Down
Loading