Skip to content

Commit 6a22249

Browse files
Adds support for JSON and MARKDOWN in Redis agent memory (#6897)
Co-authored-by: Eric Zhu <[email protected]>
1 parent 27c3d3b commit 6a22249

File tree

3 files changed

+149
-26
lines changed

3 files changed

+149
-26
lines changed

python/docs/src/user-guide/agentchat-user-guide/memory.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@
225225
"\n",
226226
"- `autogen_ext.memory.chromadb.SentenceTransformerEmbeddingFunctionConfig`: A configuration class for the SentenceTransformer embedding function used by the `ChromaDBVectorMemory` store. Note that other embedding functions such as `autogen_ext.memory.openai.OpenAIEmbeddingFunctionConfig` can also be used with the `ChromaDBVectorMemory` store.\n",
227227
"\n",
228-
"- `autogen_ext.memory.redis_memory.RedisMemory`: A memory store that uses a Redis vector database to store and retrieve information.\n"
228+
"- `autogen_ext.memory.redis.RedisMemory`: A memory store that uses a Redis vector database to store and retrieve information.\n"
229229
]
230230
},
231231
{
@@ -377,7 +377,7 @@
377377
"from autogen_agentchat.agents import AssistantAgent\n",
378378
"from autogen_agentchat.ui import Console\n",
379379
"from autogen_core.memory import MemoryContent, MemoryMimeType\n",
380-
"from autogen_ext.memory.redis_memory import RedisMemory, RedisMemoryConfig\n",
380+
"from autogen_ext.memory.redis import RedisMemory, RedisMemoryConfig\n",
381381
"from autogen_ext.models.openai import OpenAIChatCompletionClient\n",
382382
"\n",
383383
"logger = getLogger()\n",

python/packages/autogen-ext/src/autogen_ext/memory/redis/_redis_memory.py

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Any, Literal
2+
from typing import Any, List, Literal
33

44
from autogen_core import CancellationToken, Component
55
from autogen_core.memory import Memory, MemoryContent, MemoryMimeType, MemoryQueryResult, UpdateContextResult
@@ -217,20 +217,31 @@ async def add(self, content: MemoryContent, cancellation_token: CancellationToke
217217
.. note::
218218
219219
To perform semantic search over stored memories RedisMemory creates a vector embedding
220-
from the content field of a MemoryContent object. This content is assumed to be text, and
221-
is passed to the vector embedding model specified in RedisMemoryConfig.
220+
from the content field of a MemoryContent object. This content is assumed to be text,
221+
JSON, or Markdown, and is passed to the vector embedding model specified in
222+
RedisMemoryConfig.
222223
223224
Args:
224225
content (MemoryContent): The memory content to store within Redis.
225226
cancellation_token (CancellationToken): Token passed to cease operation. Not used.
226227
"""
227-
if content.mime_type != MemoryMimeType.TEXT:
228+
if content.mime_type == MemoryMimeType.TEXT:
229+
memory_content = content.content
230+
mime_type = "text/plain"
231+
elif content.mime_type == MemoryMimeType.JSON:
232+
memory_content = serialize(content.content)
233+
mime_type = "application/json"
234+
elif content.mime_type == MemoryMimeType.MARKDOWN:
235+
memory_content = content.content
236+
mime_type = "text/markdown"
237+
else:
228238
raise NotImplementedError(
229-
f"Error: {content.mime_type} is not supported. Only MemoryMimeType.TEXT is currently supported."
239+
f"Error: {content.mime_type} is not supported. Only MemoryMimeType.TEXT, MemoryMimeType.JSON, and MemoryMimeType.MARKDOWN are currently supported."
230240
)
231-
241+
metadata = {"mime_type": mime_type}
242+
metadata.update(content.metadata if content.metadata else {})
232243
self.message_history.add_message(
233-
{"role": "user", "content": content.content, "tool_call_id": serialize(content.metadata)} # type: ignore[reportArgumentType]
244+
{"role": "user", "content": memory_content, "tool_call_id": serialize(metadata)} # type: ignore[reportArgumentType]
234245
)
235246

236247
async def query(
@@ -260,14 +271,19 @@ async def query(
260271
memoryQueryResult: Object containing memories relevant to the provided query.
261272
"""
262273
# get the query string, or raise an error for unsupported MemoryContent types
263-
if isinstance(query, MemoryContent):
264-
if query.mime_type != MemoryMimeType.TEXT:
274+
if isinstance(query, str):
275+
prompt = query
276+
elif isinstance(query, MemoryContent):
277+
if query.mime_type in (MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN):
278+
prompt = str(query.content)
279+
elif query.mime_type == MemoryMimeType.JSON:
280+
prompt = serialize(query.content)
281+
else:
265282
raise NotImplementedError(
266-
f"Error: {query.mime_type} is not supported. Only MemoryMimeType.TEXT is currently supported."
283+
f"Error: {query.mime_type} is not supported. Only MemoryMimeType.TEXT, MemoryMimeType.JSON, MemoryMimeType.MARKDOWN are currently supported."
267284
)
268-
prompt = query.content
269285
else:
270-
prompt = query
286+
raise TypeError("'query' must be either a string or MemoryContent")
271287

272288
top_k = kwargs.pop("top_k", self.config.top_k)
273289
distance_threshold = kwargs.pop("distance_threshold", self.config.distance_threshold)
@@ -279,12 +295,22 @@ async def query(
279295
raw=False,
280296
)
281297

282-
memories = []
298+
memories: List[MemoryContent] = []
283299
for result in results:
300+
metadata = deserialize(result["tool_call_id"]) # type: ignore[reportArgumentType]
301+
mime_type = MemoryMimeType(metadata.pop("mime_type"))
302+
if mime_type in (MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN):
303+
memory_content = result["content"] # type: ignore[reportArgumentType]
304+
elif mime_type == MemoryMimeType.JSON:
305+
memory_content = deserialize(result["content"]) # type: ignore[reportArgumentType]
306+
else:
307+
raise NotImplementedError(
308+
f"Error: {mime_type} is not supported. Only MemoryMimeType.TEXT, MemoryMimeType.JSON, and MemoryMimeType.MARKDOWN are currently supported."
309+
)
284310
memory = MemoryContent(
285-
content=result["content"], # type: ignore[reportArgumentType]
286-
mime_type=MemoryMimeType.TEXT,
287-
metadata=deserialize(result["tool_call_id"]), # type: ignore[reportArgumentType]
311+
content=memory_content, # type: ignore[reportArgumentType]
312+
mime_type=mime_type,
313+
metadata=metadata,
288314
)
289315
memories.append(memory) # type: ignore[reportUknownMemberType]
290316

python/packages/autogen-ext/tests/memory/test_redis_memory.py

Lines changed: 105 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ async def test_redis_memory_query_with_mock() -> None:
3535
config = RedisMemoryConfig()
3636
memory = RedisMemory(config=config)
3737

38-
mock_history.get_relevant.return_value = [{"content": "test content", "tool_call_id": '{"foo": "bar"}'}]
38+
mock_history.get_relevant.return_value = [
39+
{"content": "test content", "tool_call_id": '{"foo": "bar", "mime_type": "text/plain"}'}
40+
]
3941
result = await memory.query("test")
4042
assert len(result.results) == 1
4143
assert result.results[0].content == "test content"
@@ -304,8 +306,7 @@ async def test_basic_workflow(semantic_config: RedisMemoryConfig) -> None:
304306

305307
@pytest.mark.asyncio
306308
@pytest.mark.skipif(not redis_available(), reason="Redis instance not available locally")
307-
async def test_content_types(semantic_memory: RedisMemory) -> None:
308-
"""Test different content types with semantic memory."""
309+
async def test_text_memory_type(semantic_memory: RedisMemory) -> None:
309310
await semantic_memory.clear()
310311

311312
# Test text content
@@ -317,8 +318,104 @@ async def test_content_types(semantic_memory: RedisMemory) -> None:
317318
assert len(results.results) > 0
318319
assert any("Simple text content" in str(r.content) for r in results.results)
319320

320-
# Test JSON content
321-
json_data = {"key": "value", "number": 42}
322-
json_content = MemoryContent(content=json_data, mime_type=MemoryMimeType.JSON)
323-
with pytest.raises(NotImplementedError):
324-
await semantic_memory.add(json_content)
321+
322+
@pytest.mark.asyncio
323+
@pytest.mark.skipif(not redis_available(), reason="Redis instance not available locally")
324+
async def test_json_memory_type(semantic_memory: RedisMemory) -> None:
325+
await semantic_memory.clear()
326+
327+
json_data = {"title": "Hitchhiker's Guide to the Galaxy", "The answer to life, the universe and everything.": 42}
328+
await semantic_memory.add(
329+
MemoryContent(content=json_data, mime_type=MemoryMimeType.JSON, metadata={"author": "Douglas Adams"})
330+
)
331+
332+
results = await semantic_memory.query("what is the ultimate question of the universe?")
333+
assert results.results[0].content == json_data
334+
335+
# meta data should not be searched
336+
results = await semantic_memory.query("who is Douglas Adams?")
337+
assert len(results.results) == 0
338+
339+
# test we can't query with JSON also
340+
with pytest.raises(TypeError):
341+
results = await semantic_memory.query({"question": "what is the ultimate question of the universe?"}) # type: ignore[arg-type]
342+
343+
# but we can if the JSON is within a MemoryContent container
344+
results = await semantic_memory.query(
345+
MemoryContent(
346+
content={"question": "what is the ultimate question of the universe?"}, mime_type=MemoryMimeType.JSON
347+
)
348+
)
349+
assert results.results[0].content == json_data
350+
351+
352+
@pytest.mark.asyncio
353+
@pytest.mark.skipif(not redis_available(), reason="Redis instance not available locally")
354+
async def test_markdown_memory_type(semantic_memory: RedisMemory) -> None:
355+
await semantic_memory.clear()
356+
357+
markdown_data = """
358+
This is an H1 header
359+
============
360+
361+
Paragraphs are separated by a blank line.
362+
363+
*Italics are within asteriks*, **bold text is within two asterisks**,
364+
while `monospace is within back tics`.
365+
366+
Itemized lists are made with indented asterisks:
367+
368+
* this one
369+
* that one
370+
* the next one
371+
372+
> Block quotes are make with arrows
373+
> like this.
374+
>
375+
> They can span multiple paragraphs,
376+
> if you like.
377+
378+
Unicode is supported. ☺
379+
"""
380+
381+
await semantic_memory.add(
382+
MemoryContent(content=markdown_data, mime_type=MemoryMimeType.MARKDOWN, metadata={"type": "markdown example"})
383+
)
384+
385+
results = await semantic_memory.query("how can I make itemized lists, or italicize text with asterisks?")
386+
assert results.results[0].content == markdown_data
387+
388+
# test we can query with markdown interpreted as a text string also
389+
results = await semantic_memory.query("")
390+
391+
# we can also if the markdown is within a MemoryContent container
392+
results = await semantic_memory.query(
393+
MemoryContent(
394+
content="**bold text is within 2 asterisks**, and *italics are within 1 asterisk*",
395+
mime_type=MemoryMimeType.MARKDOWN,
396+
)
397+
)
398+
assert results.results[0].content == markdown_data
399+
400+
401+
@pytest.mark.asyncio
402+
@pytest.mark.skipif(not redis_available(), reason="Redis instance not available locally")
403+
async def test_query_arguments(semantic_memory: RedisMemory) -> None:
404+
# test that we can utilize the optional query arguments top_k and distance_threshold
405+
await semantic_memory.clear()
406+
407+
await semantic_memory.add(MemoryContent(content="my favorite fruit are apples", mime_type=MemoryMimeType.TEXT))
408+
await semantic_memory.add(MemoryContent(content="I also like cherries", mime_type=MemoryMimeType.TEXT))
409+
await semantic_memory.add(MemoryContent(content="I like plums as well", mime_type=MemoryMimeType.TEXT))
410+
411+
# default search
412+
results = await semantic_memory.query("what fruits do I like?")
413+
assert len(results.results) == 3
414+
415+
# limit search to 2 results
416+
results = await semantic_memory.query("what fruits do I like?", top_k=2)
417+
assert len(results.results) == 2
418+
419+
# limit search to only close matches
420+
results = await semantic_memory.query("my favorite fruit are what?", distance_threshold=0.2)
421+
assert len(results.results) == 1

0 commit comments

Comments
 (0)