Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/rai_core/rai/agents/langchain/core/react_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from rai.agents.langchain.core.tool_runner import ToolRunner
from rai.initialization import get_llm_model
from rai.messages import SystemMultimodalMessage
from rai.messages import AIMultimodalMessage, SystemMultimodalMessage


class ReActAgentState(TypedDict):
Expand Down Expand Up @@ -76,6 +76,7 @@ def llm_node(
if not isinstance(state["messages"][0], SystemMessage):
state["messages"].insert(0, SystemMessage(content=system_prompt))
ai_msg = llm.invoke(state["messages"])
ai_msg = AIMultimodalMessage(content=ai_msg.content)
state["messages"].append(ai_msg)


Expand Down
18 changes: 13 additions & 5 deletions src/rai_core/rai/messages/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
# limitations under the License.


from datetime import datetime
from typing import Any, Dict, List, Literal, Optional, Union

from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.messages.base import BaseMessage, get_msg_title_repr
from pydantic import Field


class MultimodalMessage(BaseMessage):
Expand All @@ -31,6 +33,8 @@ class MultimodalMessage(BaseMessage):
List of base64 encoded audios.
"""

timestamp: datetime = Field(default_factory=datetime.now)

images: Optional[List[str]] = None
audios: Optional[Any] = None

Expand All @@ -46,7 +50,15 @@ def __init__(
_content: List[Union[str, Dict[str, Union[Dict[str, str], str]]]] = []

if isinstance(self.content, str):
_content.append({"type": "text", "text": self.content})
_content.append(
{
"type": "text",
"text": "Current time: "
+ self.timestamp.isoformat(timespec="seconds")
+ "\n"
+ self.content,
}
)
else:
raise ValueError("Content must be a string") # for now, to guarantee compat

Expand All @@ -63,10 +75,6 @@ def __init__(
_content.extend(_image_content)
self.content = _content

@property
def text(self) -> str:
return self.content[0]["text"]


class HumanMultimodalMessage(HumanMessage, MultimodalMessage):
def __repr_args__(self) -> Any:
Expand Down