Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions memori/llm/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,30 @@ def get_response_content(self, raw_response):

def _extract_user_query(self, kwargs: dict) -> str:
"""Extract the most recent user message from kwargs."""
# Handle Google GenAI's contents parameter
if llm_is_google(
self.config.framework.provider, self.config.llm.provider
) or agno_is_google(self.config.framework.provider, self.config.llm.provider):
contents = kwargs.get("contents", [])
if isinstance(contents, str):
return contents
elif isinstance(contents, list):
for item in reversed(contents):
if isinstance(item, str):
return item
elif isinstance(item, dict):
role = item.get("role", "user")
if role == "user":
parts = item.get("parts", [])
if parts:
first_part = parts[0]
if isinstance(first_part, str):
return first_part
elif isinstance(first_part, dict):
return first_part.get("text", "")
return ""

# Handle standard messages parameter (OpenAI, Anthropic, etc.)
if "messages" not in kwargs or not kwargs["messages"]:
return ""

Expand Down Expand Up @@ -290,6 +314,51 @@ def inject_recalled_facts(self, kwargs: dict) -> dict:
) or llm_is_bedrock(self.config.framework.provider, self.config.llm.provider):
existing_system = kwargs.get("system", "")
kwargs["system"] = existing_system + recall_context
elif llm_is_google(
self.config.framework.provider, self.config.llm.provider
) or agno_is_google(self.config.framework.provider, self.config.llm.provider):
# Google GenAI uses 'contents' instead of 'messages'
# Inject context as a system instruction via config or prepend to contents
if "config" in kwargs and hasattr(kwargs["config"], "system_instruction"):
# If using GenerateContentConfig with system_instruction
existing_instruction = kwargs["config"].system_instruction or ""
kwargs["config"].system_instruction = (
existing_instruction + recall_context
)
else:
# Prepend context as the first user message in contents
existing_contents = kwargs.get("contents", [])
if isinstance(existing_contents, str):
existing_contents = [
{"parts": [{"text": existing_contents}], "role": "user"}
]
elif isinstance(existing_contents, list):
normalized = []
for item in existing_contents:
if isinstance(item, str):
normalized.append(
{"parts": [{"text": item}], "role": "user"}
)
else:
normalized.append(item)
existing_contents = normalized

# Prepend context to the first user message's content
if existing_contents and existing_contents[0].get("role") == "user":
# Prepend context to existing first user message
first_msg = existing_contents[0]
original_text = first_msg["parts"][0].get("text", "")
first_msg["parts"][0]["text"] = (
recall_context.lstrip("\n") + "\n\n" + original_text
)
kwargs["contents"] = existing_contents
else:
# Insert context as a user message at the beginning
context_message = {
"parts": [{"text": recall_context.lstrip("\n")}],
"role": "user",
}
kwargs["contents"] = [context_message] + existing_contents
else:
messages = kwargs.get("messages", [])
if messages and messages[0].get("role") == "system":
Expand All @@ -300,6 +369,7 @@ def inject_recalled_facts(self, kwargs: dict) -> dict:
"content": recall_context.lstrip("\n"),
}
messages.insert(0, context_message)
kwargs["messages"] = messages

return kwargs

Expand Down
172 changes: 172 additions & 0 deletions tests/llm/test_llm_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from memori._config import Config
from memori.llm._base import BaseInvoke, BaseLlmAdaptor
from memori.llm._constants import (
GOOGLE_LLM_PROVIDER,
LANGCHAIN_FRAMEWORK_PROVIDER,
LANGCHAIN_OPENAI_LLM_PROVIDER,
OPENAI_LLM_PROVIDER,
Expand Down Expand Up @@ -550,3 +551,174 @@ def test_inject_conversation_messages_openai_success():
assert result["messages"][1]["content"] == "Previous answer"
assert result["messages"][2]["content"] == "New question"
assert invoke._injected_message_count == 2


# Google GenAI specific tests
def test_extract_user_query_google_contents_string():
"""Test extracting user query when contents is a simple string (Google GenAI)."""
config = Config()
config.llm.provider = GOOGLE_LLM_PROVIDER
invoke = BaseInvoke(config, "test_method")

kwargs = {"contents": "What is the weather?"}
assert invoke._extract_user_query(kwargs) == "What is the weather?"


def test_extract_user_query_google_contents_list_of_strings():
"""Test extracting user query from a list of strings (Google GenAI)."""
config = Config()
config.llm.provider = GOOGLE_LLM_PROVIDER
invoke = BaseInvoke(config, "test_method")

kwargs = {"contents": ["First message", "Second message"]}
assert invoke._extract_user_query(kwargs) == "Second message"


def test_extract_user_query_google_contents_structured():
"""Test extracting user query from structured contents format (Google GenAI)."""
config = Config()
config.llm.provider = GOOGLE_LLM_PROVIDER
invoke = BaseInvoke(config, "test_method")

kwargs = {
"contents": [
{"parts": [{"text": "Previous question"}], "role": "user"},
{"parts": [{"text": "Previous answer"}], "role": "model"},
{"parts": [{"text": "What do I like?"}], "role": "user"},
]
}
assert invoke._extract_user_query(kwargs) == "What do I like?"


def test_extract_user_query_google_contents_empty():
"""Test extracting user query with empty contents (Google GenAI)."""
config = Config()
config.llm.provider = GOOGLE_LLM_PROVIDER
invoke = BaseInvoke(config, "test_method")

assert invoke._extract_user_query({}) == ""
assert invoke._extract_user_query({"contents": []}) == ""


def test_inject_recalled_facts_google_contents_string():
"""Test injecting recalled facts when contents is a simple string (Google GenAI)."""
config = Config()
config.llm.provider = GOOGLE_LLM_PROVIDER
config.storage = Mock()
config.storage.driver = Mock()
config.storage.driver.entity.create.return_value = 1
config.entity_id = "test-entity"
invoke = BaseInvoke(config, "test_method")

kwargs = {"contents": "What do I like?"}

with patch("memori.memory.recall.Recall") as mock_recall:
mock_recall.return_value.search_facts.return_value = [
{"content": "User likes pizza", "similarity": 0.9},
]
result = invoke.inject_recalled_facts(kwargs)

# Context should be prepended to the first user message
assert "contents" in result
assert isinstance(result["contents"], list)
assert len(result["contents"]) == 1
# First message should contain both context and original query
assert result["contents"][0]["role"] == "user"
assert "User likes pizza" in result["contents"][0]["parts"][0]["text"]
assert "What do I like?" in result["contents"][0]["parts"][0]["text"]


def test_inject_recalled_facts_google_contents_structured():
"""Test injecting recalled facts with structured contents (Google GenAI)."""
config = Config()
config.llm.provider = GOOGLE_LLM_PROVIDER
config.storage = Mock()
config.storage.driver = Mock()
config.storage.driver.entity.create.return_value = 1
config.entity_id = "test-entity"
invoke = BaseInvoke(config, "test_method")

kwargs = {
"contents": [
{"parts": [{"text": "What do I like?"}], "role": "user"},
]
}

with patch("memori.memory.recall.Recall") as mock_recall:
mock_recall.return_value.search_facts.return_value = [
{"content": "User likes coding", "similarity": 0.85},
]
result = invoke.inject_recalled_facts(kwargs)

# Context should be prepended to the first user message
assert "contents" in result
assert len(result["contents"]) == 1
# First message should contain both context and original query
assert "User likes coding" in result["contents"][0]["parts"][0]["text"]
assert "What do I like?" in result["contents"][0]["parts"][0]["text"]


def test_inject_recalled_facts_google_with_config_system_instruction():
"""Test injecting recalled facts via system_instruction in config (Google GenAI)."""
config = Config()
config.llm.provider = GOOGLE_LLM_PROVIDER
config.storage = Mock()
config.storage.driver = Mock()
config.storage.driver.entity.create.return_value = 1
config.entity_id = "test-entity"
invoke = BaseInvoke(config, "test_method")

# Mock a config object with system_instruction attribute
class MockConfig:
system_instruction = "You are a helpful assistant."

kwargs = {
"contents": "What do I like?",
"config": MockConfig(),
}

with patch("memori.memory.recall.Recall") as mock_recall:
mock_recall.return_value.search_facts.return_value = [
{"content": "User likes pizza", "similarity": 0.9},
]
result = invoke.inject_recalled_facts(kwargs)

# System instruction should be updated with context
assert "You are a helpful assistant." in result["config"].system_instruction
assert "User likes pizza" in result["config"].system_instruction
assert "memori_context" in result["config"].system_instruction


def test_inject_recalled_facts_google_first_message_not_user():
"""Test injecting recalled facts when first message is not user role (Google GenAI)."""
config = Config()
config.llm.provider = GOOGLE_LLM_PROVIDER
config.storage = Mock()
config.storage.driver = Mock()
config.storage.driver.entity.create.return_value = 1
config.entity_id = "test-entity"
invoke = BaseInvoke(config, "test_method")

# First message is model role (edge case)
kwargs = {
"contents": [
{"parts": [{"text": "Hello!"}], "role": "model"},
{"parts": [{"text": "What do I like?"}], "role": "user"},
]
}

with patch("memori.memory.recall.Recall") as mock_recall:
mock_recall.return_value.search_facts.return_value = [
{"content": "User likes coding", "similarity": 0.85},
]
result = invoke.inject_recalled_facts(kwargs)

# Should insert a new context message at the beginning
assert "contents" in result
assert len(result["contents"]) == 3
# First message should be the injected context
assert result["contents"][0]["role"] == "user"
assert "User likes coding" in result["contents"][0]["parts"][0]["text"]
# Original messages should follow
assert result["contents"][1]["role"] == "model"
assert result["contents"][2]["role"] == "user"
Loading