diff --git a/camel/agents/chat_agent.py b/camel/agents/chat_agent.py index 6ef16aecda..d1575ee5cf 100644 --- a/camel/agents/chat_agent.py +++ b/camel/agents/chat_agent.py @@ -30,6 +30,7 @@ import time import uuid import warnings +from dataclasses import dataclass from datetime import datetime from pathlib import Path from typing import ( @@ -159,6 +160,53 @@ def _cleanup_temp_files(): ) +@dataclass +class _ToolOutputHistoryEntry: + tool_name: str + tool_call_id: str + result_text: str + record_uuids: List[str] + record_timestamps: List[float] + preview_text: str + cached: bool = False + cache_id: Optional[str] = None + + +class _ToolOutputCacheManager: + r"""Minimal persistent store for caching verbose tool outputs.""" + + def __init__(self, base_dir: Union[str, Path]) -> None: + self.base_dir = Path(base_dir).expanduser().resolve() + self.base_dir.mkdir(parents=True, exist_ok=True) + + def save( + self, + tool_name: str, + tool_call_id: str, + content: str, + ) -> Tuple[str, Path]: + cache_id = uuid.uuid4().hex + filename = f"{cache_id}.txt" + path = self.base_dir / filename + header = ( + f"# Cached tool output\n" + f"tool_name: {tool_name}\n" + f"tool_call_id: {tool_call_id}\n" + f"cache_id: {cache_id}\n" + f"---\n" + ) + path.write_text(f"{header}{content}", encoding="utf-8") + return cache_id, path + + def load(self, cache_id: str) -> str: + path = self.base_dir / f"{cache_id}.txt" + if not path.exists(): + raise FileNotFoundError( + f"Cached tool output {cache_id} not found at {path}" + ) + return path.read_text(encoding="utf-8") + + class StreamContentAccumulator: r"""Manages content accumulation across streaming responses to ensure all responses contain complete cumulative content.""" @@ -459,6 +507,8 @@ def __init__( retry_delay: float = 1.0, step_timeout: Optional[float] = None, stream_accumulate: bool = True, + tool_call_cache_threshold: Optional[int] = None, + tool_call_cache_dir: Optional[Union[str, Path]] = None, ) -> None: if isinstance(model, ModelManager): self.model_backend = model @@ -474,6 +524,21 @@ def __init__( # Assign unique ID self.agent_id = agent_id if agent_id else str(uuid.uuid4()) + # Used for tool call output cache + self._tool_output_cache_threshold = ( + tool_call_cache_threshold if tool_call_cache_threshold else 2000 + ) + self._tool_output_cache_dir: Path = ( + Path(tool_call_cache_dir).expanduser().resolve() + if tool_call_cache_dir + else Path("tool_cache") + ) + self._tool_output_cache_manager: Optional[_ToolOutputCacheManager] = ( + None + ) + self._tool_output_history: List[_ToolOutputHistoryEntry] = [] + self._cache_lookup_tool_name = "retrieve_cached_tool_output" + # Set up memory context_creator = ScoreBasedContextCreator( self.model_backend.token_counter, @@ -556,6 +621,7 @@ def reset(self): r"""Resets the :obj:`ChatAgent` to its initial state.""" self.terminated = False self.init_messages() + self._tool_output_history.clear() for terminator in self.response_terminators: terminator.reset() @@ -776,6 +842,243 @@ def add_tools(self, tools: List[Union[FunctionTool, Callable]]) -> None: for tool in tools: self.add_tool(tool) + def retrieve_cached_tool_output(self, cache_ids: str) -> str: + r"""Load cached tool output(s) by cache identifier(s). + + Supports both single and multiple cache ID retrieval: + - Single ID: Returns the cached content directly as a string + - Multiple IDs (comma-separated): Returns a JSON dictionary mapping + each cache_id to its content + + Args: + cache_ids (str): Single cache identifier or comma-separated list + of cache identifiers. + + Returns: + str: For single ID, returns the cached content directly. + For multiple IDs, returns a JSON-formatted dictionary mapping + cache_ids to their content or error messages. + """ + if not self._tool_output_cache_manager: + return "Tool output caching is disabled for this agent instance." + + # Parse input - check if it's comma-separated + id_list = [cid.strip() for cid in cache_ids.split(',') if cid.strip()] + + if not id_list: + return "Please provide at least one cache_id." + + # Single cache_id - return content directly + if len(id_list) == 1: + cache_id = id_list[0] + try: + return self._tool_output_cache_manager.load(cache_id) + except FileNotFoundError: + return ( + f"Cache entry '{cache_id}' was not found. " + "Verify the identifier and try again." + ) + + # Multiple cache_ids - return JSON dictionary + import json + + results = {} + for cache_id in id_list: + try: + results[cache_id] = self._tool_output_cache_manager.load( + cache_id + ) + except FileNotFoundError: + results[cache_id] = ( + f"[ERROR] Cache entry '{cache_id}' not found" + ) + + return json.dumps(results, indent=2, ensure_ascii=False) + + @property + def _tool_output_cache_enabled(self) -> bool: + """Check if tool output caching is enabled based on threshold.""" + return self._tool_output_cache_threshold > 0 + + def _ensure_tool_cache_lookup_tool(self) -> None: + if not self._tool_output_cache_enabled: + return + + # Register cache lookup tool (supports both single and multiple IDs) + lookup_name = self._cache_lookup_tool_name + if lookup_name not in self._internal_tools: + lookup_tool = convert_to_function_tool( + self.retrieve_cached_tool_output + ) + self._internal_tools[lookup_tool.get_function_name()] = lookup_tool + + def _cache_tool_calls(self) -> int: + r"""Persist eligible tool outputs to the cache store. + + This is a helper function that caches all tool outputs in + `_tool_output_history` that exceed the configured threshold. + + Returns: + int: Number of tool outputs that were newly cached. + """ + if self._tool_output_cache_threshold <= 0: + return 0 + + if self._tool_output_cache_manager is None: + self._tool_output_cache_manager = _ToolOutputCacheManager( + self._tool_output_cache_dir + ) + + cached_count = self._process_tool_output_cache() + + if any(entry.cached for entry in self._tool_output_history): + self._ensure_tool_cache_lookup_tool() + else: + self._internal_tools.pop(self._cache_lookup_tool_name, None) + + return cached_count + + def _serialize_tool_result(self, result: Any) -> str: + if isinstance(result, str): + return result + try: + return json.dumps(result, ensure_ascii=False) + except (TypeError, ValueError): + return str(result) + + def _summarize_tool_result(self, text: str, limit: int = 160) -> str: + normalized = re.sub(r"\s+", " ", text).strip() + if len(normalized) <= limit: + return normalized + return normalized[: max(0, limit - 3)].rstrip() + "..." + + def _register_tool_output_for_cache( + self, + func_name: str, + tool_call_id: str, + result_text: str, + records: List[MemoryRecord], + ) -> None: + if not records: + return + + entry = _ToolOutputHistoryEntry( + tool_name=func_name, + tool_call_id=tool_call_id, + result_text=result_text, + record_uuids=[str(record.uuid) for record in records], + record_timestamps=[record.timestamp for record in records], + preview_text=self._summarize_tool_result(result_text), + ) + self._tool_output_history.append(entry) + + def _process_tool_output_cache(self) -> int: + if ( + not self._tool_output_history + or self._tool_output_cache_manager is None + ): + return 0 + + cached_count = 0 + + # Cache all tool outputs that exceed the threshold + for entry in self._tool_output_history: + if entry.cached: + continue + if len(entry.result_text) < self._tool_output_cache_threshold: + continue + self._cache_tool_output_entry(entry) + if entry.cached: + cached_count += 1 + + return cached_count + + def _cache_tool_output_entry(self, entry: _ToolOutputHistoryEntry) -> None: + if self._tool_output_cache_manager is None or not entry.record_uuids: + return + + try: + cache_id, cache_path = self._tool_output_cache_manager.save( + entry.tool_name, + entry.tool_call_id, + entry.result_text, + ) + except Exception as exc: # pragma: no cover - defensive + logger.warning( + "Failed to persist cached tool output for %s (%s): %s", + entry.tool_name, + entry.tool_call_id, + exc, + ) + return + + timestamp = ( + entry.record_timestamps[0] + if entry.record_timestamps + else time.time_ns() / 1_000_000_000 + ) + reference_message = FunctionCallingMessage( + role_name=self.role_name, + role_type=self.role_type, + meta_dict={ + "cache_id": cache_id, + "cached_preview": entry.preview_text, + "cached_tool_output_path": str(cache_path), + }, + content="", + func_name=entry.tool_name, + result=self._build_cache_reference_text(entry, cache_id), + tool_call_id=entry.tool_call_id, + ) + + chat_history_block = getattr(self.memory, "_chat_history_block", None) + storage = getattr(chat_history_block, "storage", None) + if storage is None: + return + + existing_records = storage.load() + updated_records = [ + record + for record in existing_records + if record["uuid"] not in entry.record_uuids + ] + new_record = MemoryRecord( + message=reference_message, + role_at_backend=OpenAIBackendRole.FUNCTION, + timestamp=timestamp, + agent_id=self.agent_id, + ) + updated_records.append(new_record.to_dict()) + updated_records.sort(key=lambda record: record["timestamp"]) + storage.clear() + storage.save(updated_records) + + logger.info( + "Cached tool output '%s' (%s) to %s with cache_id=%s", + entry.tool_name, + entry.tool_call_id, + cache_path, + cache_id, + ) + + entry.cached = True + entry.cache_id = cache_id + entry.record_uuids = [str(new_record.uuid)] + entry.record_timestamps = [timestamp] + + def _build_cache_reference_text( + self, entry: _ToolOutputHistoryEntry, cache_id: str + ) -> str: + preview = entry.preview_text or "[no preview available]" + return ( + "[cached tool output]\n" + f"tool: {entry.tool_name}\n" + f"cache_id: {cache_id}\n" + f"preview: {preview}\n" + f"Use `{self._cache_lookup_tool_name}` with this cache_id to " + "retrieve the full content." + ) + def add_external_tool( self, tool: Union[FunctionTool, Callable, Dict[str, Any]] ) -> None: @@ -820,7 +1123,8 @@ def update_memory( message: BaseMessage, role: OpenAIBackendRole, timestamp: Optional[float] = None, - ) -> None: + return_records: bool = False, + ) -> Optional[List[MemoryRecord]]: r"""Updates the agent memory with a new message. If the single *message* exceeds the model's context window, it will @@ -840,21 +1144,29 @@ def update_memory( timestamp (Optional[float], optional): Custom timestamp for the memory record. If `None`, the current time will be used. (default: :obj:`None`) - (default: obj:`None`) + return_records (bool, optional): When ``True`` the method returns + the list of :class:`MemoryRecord` objects written to memory. + (default: :obj:`False`) + + Returns: + Optional[List[MemoryRecord]]: The records that were written when + ``return_records`` is ``True``; otherwise ``None``. """ + written_records: List[MemoryRecord] = [] + # 1. Helper to write a record to memory def _write_single_record( message: BaseMessage, role: OpenAIBackendRole, timestamp: float ): - self.memory.write_record( - MemoryRecord( - message=message, - role_at_backend=role, - timestamp=timestamp, - agent_id=self.agent_id, - ) + record = MemoryRecord( + message=message, + role_at_backend=role, + timestamp=timestamp, + agent_id=self.agent_id, ) + written_records.append(record) + self.memory.write_record(record) base_ts = ( timestamp @@ -869,7 +1181,7 @@ def _write_single_record( token_limit = context_creator.token_limit except AttributeError: _write_single_record(message, role, base_ts) - return + return written_records if return_records else None # 3. Check if slicing is necessary try: @@ -885,14 +1197,14 @@ def _write_single_record( if current_tokens <= remaining_budget: _write_single_record(message, role, base_ts) - return + return written_records if return_records else None except Exception as e: logger.warning( f"Token calculation failed before chunking, " f"writing message as-is. Error: {e}" ) _write_single_record(message, role, base_ts) - return + return written_records if return_records else None # 4. Perform slicing logger.warning( @@ -913,18 +1225,18 @@ def _write_single_record( if not text_to_chunk or not text_to_chunk.strip(): _write_single_record(message, role, base_ts) - return + return written_records if return_records else None # Encode the entire text to get a list of all token IDs try: all_token_ids = token_counter.encode(text_to_chunk) except Exception as e: logger.error(f"Failed to encode text for chunking: {e}") _write_single_record(message, role, base_ts) # Fallback - return + return written_records if return_records else None if not all_token_ids: _write_single_record(message, role, base_ts) # Nothing to chunk - return + return written_records if return_records else None # 1. Base chunk size: one-tenth of the smaller of (a) total token # limit and (b) current remaining budget. This prevents us from @@ -990,6 +1302,8 @@ def _write_single_record( # Increment timestamp slightly to maintain order _write_single_record(new_msg, role, base_ts + i * 1e-6) + return written_records if return_records else None + def load_memory(self, memory: AgentMemory) -> None: r"""Load the provided memory into the agent. @@ -1310,6 +1624,7 @@ def clear_memory(self) -> None: None """ self.memory.clear() + self._tool_output_history.clear() if self.system_message is not None: self.update_memory(self.system_message, OpenAIBackendRole.SYSTEM) @@ -1672,6 +1987,7 @@ def step( self, input_message: Union[BaseMessage, str], response_format: Optional[Type[BaseModel]] = None, + tool_call_history_cache: bool = False, ) -> Union[ChatAgentResponse, StreamingChatAgentResponse]: r"""Executes a single step in the chat session, generating a response to the input message. @@ -1684,6 +2000,9 @@ def step( model defining the expected structure of the response. Used to generate a structured response if provided. (default: :obj:`None`) + tool_call_history_cache (bool, optional): When ``True``, cache all + tool outputs exceeding the configured threshold after this step + completes. (default: :obj:`False`) Returns: Union[ChatAgentResponse, StreamingChatAgentResponse]: If stream is @@ -1700,7 +2019,9 @@ def step( if stream: # Return wrapped generator that has ChatAgentResponse interface - generator = self._stream(input_message, response_format) + generator = self._stream( + input_message, response_format, tool_call_history_cache + ) return StreamingChatAgentResponse(generator) # Execute with timeout if configured @@ -1709,7 +2030,10 @@ def step( max_workers=1 ) as executor: future = executor.submit( - self._step_impl, input_message, response_format + self._step_impl, + input_message, + response_format, + tool_call_history_cache, ) try: return future.result(timeout=self.step_timeout) @@ -1719,12 +2043,15 @@ def step( f"Step timed out after {self.step_timeout}s" ) else: - return self._step_impl(input_message, response_format) + return self._step_impl( + input_message, response_format, tool_call_history_cache + ) def _step_impl( self, input_message: Union[BaseMessage, str], response_format: Optional[Type[BaseModel]] = None, + tool_call_history_cache: bool = False, ) -> ChatAgentResponse: r"""Implementation of non-streaming step logic.""" # Set Langfuse session_id using agent_id for trace grouping @@ -1869,6 +2196,10 @@ def _step_impl( if self.prune_tool_calls_from_memory and tool_call_records: self.memory.clean_tool_calls() + # Cache tool outputs if requested + if tool_call_history_cache: + self._cache_tool_calls() + return self._convert_to_chatagent_response( response, tool_call_records, @@ -1889,6 +2220,7 @@ async def astep( self, input_message: Union[BaseMessage, str], response_format: Optional[Type[BaseModel]] = None, + tool_call_history_cache: bool = False, ) -> Union[ChatAgentResponse, AsyncStreamingChatAgentResponse]: r"""Performs a single step in the chat session by generating a response to the input message. This agent step can call async function calls. @@ -1905,6 +2237,9 @@ async def astep( used to generate a structured response by LLM. This schema helps in defining the expected output format. (default: :obj:`None`) + tool_call_history_cache (bool, optional): When ``True``, cache all + tool outputs exceeding the configured threshold after this step + completes. (default: :obj:`False`) Returns: Union[ChatAgentResponse, AsyncStreamingChatAgentResponse]: If stream is False, returns a ChatAgentResponse. If stream is @@ -1927,14 +2262,18 @@ async def astep( stream = self.model_backend.model_config_dict.get("stream", False) if stream: # Return wrapped async generator that is awaitable - async_generator = self._astream(input_message, response_format) + async_generator = self._astream( + input_message, response_format, tool_call_history_cache + ) return AsyncStreamingChatAgentResponse(async_generator) else: if self.step_timeout is not None: try: return await asyncio.wait_for( self._astep_non_streaming_task( - input_message, response_format + input_message, + response_format, + tool_call_history_cache, ), timeout=self.step_timeout, ) @@ -1944,13 +2283,14 @@ async def astep( ) else: return await self._astep_non_streaming_task( - input_message, response_format + input_message, response_format, tool_call_history_cache ) async def _astep_non_streaming_task( self, input_message: Union[BaseMessage, str], response_format: Optional[Type[BaseModel]] = None, + tool_call_history_cache: bool = False, ) -> ChatAgentResponse: r"""Internal async method for non-streaming astep logic.""" @@ -2091,6 +2431,10 @@ async def _astep_non_streaming_task( if self.prune_tool_calls_from_memory and tool_call_records: self.memory.clean_tool_calls() + # Cache tool outputs if requested + if tool_call_history_cache: + self._cache_tool_calls() + return self._convert_to_chatagent_response( response, tool_call_records, @@ -2740,14 +3084,18 @@ def _record_tool_calling( base_timestamp = current_time_ns / 1_000_000_000 # Convert to seconds self.update_memory( - assist_msg, OpenAIBackendRole.ASSISTANT, timestamp=base_timestamp + assist_msg, + OpenAIBackendRole.ASSISTANT, + timestamp=base_timestamp, + return_records=self._tool_output_cache_enabled, ) # Add minimal increment to ensure function message comes after - self.update_memory( + func_records = self.update_memory( func_msg, OpenAIBackendRole.FUNCTION, timestamp=base_timestamp + 1e-6, + return_records=self._tool_output_cache_enabled, ) # Record information about this tool call @@ -2758,12 +3106,26 @@ def _record_tool_calling( tool_call_id=tool_call_id, ) + if ( + self._tool_output_cache_enabled + and not mask_output + and func_records + ): + serialized_result = self._serialize_tool_result(result) + self._register_tool_output_for_cache( + func_name, + tool_call_id, + serialized_result, + cast(List[MemoryRecord], func_records), + ) + return tool_record def _stream( self, input_message: Union[BaseMessage, str], response_format: Optional[Type[BaseModel]] = None, + tool_call_history_cache: bool = False, ) -> Generator[ChatAgentResponse, None, None]: r"""Executes a streaming step in the chat session, yielding intermediate responses as they are generated. @@ -2800,6 +3162,10 @@ def _stream( openai_messages, num_tokens, response_format ) + # Cache tool outputs if requested (after streaming completes) + if tool_call_history_cache: + self._cache_tool_calls() + def _get_token_count(self, content: str) -> int: r"""Get token count for content with fallback.""" if hasattr(self.model_backend, 'token_counter'): @@ -3511,6 +3877,7 @@ async def _astream( self, input_message: Union[BaseMessage, str], response_format: Optional[Type[BaseModel]] = None, + tool_call_history_cache: bool = False, ) -> AsyncGenerator[ChatAgentResponse, None]: r"""Asynchronous version of stream method.""" @@ -3545,6 +3912,10 @@ async def _astream( if tool_calls: self.memory.clean_tool_calls() + # Cache tool outputs if requested (after streaming completes) + if tool_call_history_cache: + self._cache_tool_calls() + async def _astream_response( self, openai_messages: List[OpenAIMessage], @@ -4110,6 +4481,8 @@ def clone(self, with_memory: bool = False) -> ChatAgent: pause_event=self.pause_event, prune_tool_calls_from_memory=self.prune_tool_calls_from_memory, stream_accumulate=self.stream_accumulate, + tool_call_cache_threshold=self._tool_output_cache_threshold, + tool_call_cache_dir=self._tool_output_cache_dir, ) # Copy memory if requested diff --git a/examples/agents/agent_tool_call_cache.py b/examples/agents/agent_tool_call_cache.py new file mode 100644 index 0000000000..c97041536a --- /dev/null +++ b/examples/agents/agent_tool_call_cache.py @@ -0,0 +1,317 @@ +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +"""Real-model example for ChatAgent tool-output caching. + +This demo uses a true LLM backend (configured via the default ``ModelFactory``) +and a mock browser snapshot tool. The workflow is: + +1. Configure caching in ChatAgent initialization with threshold and cache_dir +2. Ask the agent to capture two snapshots: a long smartphone page and a short + weather widget +3. Use tool_call_history_cache=True in a step() call to cache tool outputs + exceeding the threshold +4. Ask the agent a question requiring BOTH cached snapshots - it will use the + automatically registered ``retrieve_cached_tool_output`` tool to access them +5. Verify the agent can also retrieve a single snapshot when needed + +The example demonstrates: +- Automatic caching of large tool outputs (>600 chars) +- Memory efficiency (cached references vs full content) +- Agent's ability to retrieve single or multiple cached outputs +- Seamless access to cached data without manual intervention + +Prerequisites: + - Set up the API credentials required by the default model backend + (for example, ``OPENAI_API_KEY`` if you're using OpenAI models). + - Optionally customize ``MODEL_PLATFORM`` / ``MODEL_TYPE`` via + environment variables to point to a different provider. +""" + +from __future__ import annotations + +from pathlib import Path + +from camel.agents import ChatAgent +from camel.messages import FunctionCallingMessage +from camel.models import ModelFactory +from camel.toolkits import FunctionTool +from camel.types import ( + ModelPlatformType, + ModelType, +) + +# Mock payloads ------------------------------------------------------------- +SMARTPHONE_PAGE = """ + + +
+

NovaPhone X Ultra Launch Event

+

The flagship with HDR+ Pro display, titanium frame, and satellite SOS.

+ Pre-order now +
+ +
+ + + + + +
ModelBatteryChargingStarting Price
NovaPhone X Ultra5,500 mAh120W wired / 80W wireless$1099
NovaPhone X5,000 mAh80W wired / 50W wireless$899
NovaPhone Air4,700 mAh45W wired / 25W wireless$749
+
+
+

Pre-orders open March 14, shipping starts March 28 in US, EU, and APAC.

+ +
+ + + +""" # noqa: E501 + +WEATHER_DASHBOARD = """ +
+

City Weather

+

Currently 68°F, partly cloudy.

+

Next hour: breezy with scattered clouds, no precipitation expected.

+

Sunset at 7:42 PM, UV index moderate.

+
+""" + + +# Tool implementation ------------------------------------------------------- +def cache_browser_snapshot(snapshot: str) -> str: + """Return the provided snapshot verbatim so the cache can persist it.""" + header = ( + f"[browser_snapshot length={len(snapshot)} characters]\n" + "BEGIN_SNAPSHOT\n" + ) + return header + snapshot + "\nEND_SNAPSHOT" + + +# Utility functions --------------------------------------------------------- +def _print_memory(agent: ChatAgent) -> None: + for idx, ctx_record in enumerate(agent.memory.retrieve(), start=1): + record = ctx_record.memory_record + message = record.message + role = record.role_at_backend.value + if isinstance(message, FunctionCallingMessage): + meta = message.meta_dict or {} + cache_id = meta.get("cache_id") + result = ( + message.result + if isinstance(message.result, str) + else str(message.result) + ) + result_length = len(result) + preview = result.replace("\n", " ")[:140] + if cache_id: + print( + f"{idx:02d}. role={role} tool_call_id={message.tool_call_id} " # noqa:E501 + f"(cached reference) cache_id={cache_id} " + f"result_length={result_length}" + ) + print(f" preview: {preview}") + else: + print( + f"{idx:02d}. role={role} tool_call_id={message.tool_call_id} " # noqa:E501 + f"(inline) result_length={result_length}" + ) + print(f" preview: {preview}") + else: + content = getattr(message, "content", "") or "" + print(f"{idx:02d}. role={role} content={content[:140]}") + + +def _find_cached_entry(agent: ChatAgent): + for entry in agent._tool_output_history: + if entry.cached: + return entry + return None + + +# Demo flow ----------------------------------------------------------------- +def main() -> None: + cache_dir = Path(__file__).resolve().parent / "tool_cache" + backend = ModelFactory.create( + model_platform=ModelPlatformType.AZURE, + model_type=ModelType.GPT_4_1_MINI, + ) + agent = ChatAgent( + system_message=("You are a browsing assistant."), + model=backend, + tools=[FunctionTool(cache_browser_snapshot)], + prune_tool_calls_from_memory=False, + max_iteration=3, + tool_call_cache_threshold=600, + tool_call_cache_dir=cache_dir, + ) + + print("\n>>> Step 1: Capture verbose snapshot") + prompt1 = ( + "You just browsed the NovaPhone store." + "Store the current smartphone page exactly as-is " + "so we can reference it later. Here is the full markup:\n\n" + f"{SMARTPHONE_PAGE}" + ) + response1 = agent.step(prompt1) + print(f"Assistant response: {response1.msg.content}") + + print("\n>>> Step 2: Capture weather snapshot") + prompt2 = ( + "Now you are looking at a weather dashboard." + "Save the widget below as a new snapshot " + "without paraphrasing it:\n\n" + f"{WEATHER_DASHBOARD}" + ) + # Print memory before caching + print("\n=== Memory BEFORE tool_call_history_cache ===") + response2 = agent.step(prompt2) + print(f"Assistant response: {response2.msg.content}") + _print_memory(agent) + + # Print memory after caching (using tool_call_history_cache=True) + print("\n=== Memory AFTER tool_call_history_cache ===") + response2_cached = agent.step( + "Confirm that both snapshots have been saved.", + tool_call_history_cache=True, + ) + print(f"Assistant response: {response2_cached.msg.content}") + _print_memory(agent) + + cached_entry = _find_cached_entry(agent) + if not cached_entry or not cached_entry.cache_id: + print( + "\nNo cached entry detected. Ensure the tool was executed and the threshold is high enough." # noqa:E501 + ) + return + + print("\n>>> Step 3: Ask question requiring BOTH cached snapshots") + prompt3 = ( + "Compare the information from both snapshots you saved earlier:\n" + "1. From the NovaPhone store page, tell me the battery capacity " + "of the NovaPhone X Ultra\n" + "2. From the weather dashboard, tell me the current temperature\n" + "3. Make a comparison between these two pieces of information.\n\n" + "You'll need to retrieve BOTH snapshots to answer this question." + ) + response3 = agent.step(prompt3) + print(f"Assistant response:\n{response3.msg.content}") + + print("\n>>> Step 4: Verify agent can access single snapshot") + prompt4 = "Just tell me the sunset time from the weather widget." + response4 = agent.step(prompt4) + print(f"Assistant response:\n{response4.msg.content}") + + +if __name__ == "__main__": + main() + + +''' +>>> Step 1: Capture verbose snapshot +Assistant response: The current NovaPhone X Ultra smartphone page has been stored exactly as-is. You can reference this full markup or request details from it at any time. Let me know if you need to retrieve, compare, or analyze any part of this page! + +>>> Step 2: Capture weather snapshot + +=== Memory BEFORE tool_call_history_cache === +Assistant response: The weather dashboard widget has been stored exactly as you provided it. You can reference or retrieve this snapshot any time. Let me know if you need to review, compare, or analyze the widget's contents! +01. role=system content=You are a browsing assistant. +02. role=user content=You just browsed the NovaPhone store.Store the current smartphone page exactly as-is so we can reference it later. Here is the full markup: + +03. role=assistant tool_call_id=call_AajvjXwUeugfncrhwEY9fiqy (inline) result_length=4 + preview: None +04. role=function tool_call_id=call_AajvjXwUeugfncrhwEY9fiqy (inline) result_length=1796 + preview: [browser_snapshot length=1726 characters] BEGIN_SNAPSHOT

NovaPhone X Ultra Launch Event

< +05. role=assistant content=The current NovaPhone X Ultra smartphone page has been stored exactly as-is. You can reference this full markup or request details from it a +06. role=user content=Now you are looking at a weather dashboard.Save the widget below as a new snapshot without paraphrasing it: + + +
+ +07. role=assistant tool_call_id=call_BVR80Veg57rrDVsaxSwronfh (inline) result_length=4 + preview: None +08. role=function tool_call_id=call_BVR80Veg57rrDVsaxSwronfh (inline) result_length=294 + preview: [browser_snapshot length=225 characters] BEGIN_SNAPSHOT

City Weather

Currently 68°F, partly cl +09. role=assistant content=The weather dashboard widget has been stored exactly as you provided it. You can reference or retrieve this snapshot any time. Let me know i + +=== Memory AFTER tool_call_history_cache === +Assistant response: Confirmation: Both snapshots have been successfully saved. + +1. NovaPhone smartphone page — contains the launch event details, specs, comparisons, and availability. +2. Weather dashboard widget — contains the city weather update, next hour forecast, sunset time, and UV index. + +You can request contents or analysis from either snapshot at any time. +01. role=system content=You are a browsing assistant. +02. role=user content=You just browsed the NovaPhone store.Store the current smartphone page exactly as-is so we can reference it later. Here is the full markup: + +03. role=assistant tool_call_id=call_AajvjXwUeugfncrhwEY9fiqy (inline) result_length=4 + preview: None +04. role=function tool_call_id=call_AajvjXwUeugfncrhwEY9fiqy (cached reference) cache_id=4d277c664504420eaa8ae2e5360873e7 result_length=345 + preview: [cached tool output] tool: cache_browser_snapshot cache_id: 4d277c664504420eaa8ae2e5360873e7 preview: [browser_snapshot length=1726 characte +05. role=assistant content=The current NovaPhone X Ultra smartphone page has been stored exactly as-is. You can reference this full markup or request details from it a +06. role=user content=Now you are looking at a weather dashboard.Save the widget below as a new snapshot without paraphrasing it: + + +

+ +07. role=assistant tool_call_id=call_BVR80Veg57rrDVsaxSwronfh (inline) result_length=4 + preview: None +08. role=function tool_call_id=call_BVR80Veg57rrDVsaxSwronfh (inline) result_length=294 + preview: [browser_snapshot length=225 characters] BEGIN_SNAPSHOT

City Weather

Currently 68°F, partly cl +09. role=assistant content=The weather dashboard widget has been stored exactly as you provided it. You can reference or retrieve this snapshot any time. Let me know i +10. role=user content=Confirm that both snapshots have been saved. +11. role=assistant content=Confirmation: Both snapshots have been successfully saved. + +1. NovaPhone smartphone page — contains the launch event details, specs, compari + +>>> Step 3: Ask question requiring BOTH cached snapshots +Assistant response: +Based on the two snapshots I retrieved: + +1. **NovaPhone X Ultra battery capacity:** 5,500 mAh + +2. **Current temperature:** 68°F (partly cloudy) + +3. **Creative comparison:** + Interestingly, the NovaPhone X Ultra's battery capacity (5,500 mAh) is about + 80 times larger than the current temperature in degrees Fahrenheit (68°F)! + + While the phone can power through your day with its massive 5,500 mAh battery + and 120W wired charging, the weather outside is a pleasant 68°F — perfect + conditions for testing that new phone outdoors without worrying about + overheating or cold-induced battery drain. The phone's battery is built for + performance in any weather! + +>>> Step 4: Verify agent can access single snapshot +Assistant response: +According to the weather widget snapshot, the sunset time is **7:42 PM**. +''' # noqa: E501 diff --git a/test/test_chat_agent_tool_cache.py b/test/test_chat_agent_tool_cache.py new file mode 100644 index 0000000000..02a86a6711 --- /dev/null +++ b/test/test_chat_agent_tool_cache.py @@ -0,0 +1,176 @@ +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +import pytest + +from camel.agents import ChatAgent +from camel.messages import FunctionCallingMessage +from camel.models.stub_model import StubModel +from camel.types import ModelType + + +@pytest.mark.parametrize("threshold", [20]) +def test_tool_output_caching(tmp_path, threshold): + agent = ChatAgent( + system_message="You are a tester.", + model=StubModel(model_type=ModelType.STUB), + tool_call_cache_threshold=threshold, + tool_call_cache_dir=tmp_path, + ) + + long_result = "A" * (threshold + 10) + short_result = "short" + + agent._record_tool_calling( + "dummy_tool", + args={"value": 1}, + result=long_result, + tool_call_id="call-1", + ) + + history = [ + entry + for entry in agent._tool_output_history + if entry.tool_call_id == "call-1" + ] + assert history and not history[0].cached + + agent._record_tool_calling( + "dummy_tool", + args={"value": 2}, + result=short_result, + tool_call_id="call-2", + ) + + cached_entry = next( + entry + for entry in agent._tool_output_history + if entry.tool_call_id == "call-1" + ) + assert not cached_entry.cached + + flushed = agent._cache_tool_calls() + assert flushed == 1 + + cached_entry = next( + entry + for entry in agent._tool_output_history + if entry.tool_call_id == "call-1" + ) + assert cached_entry.cached + assert cached_entry.cache_id + + cache_file = tmp_path / f"{cached_entry.cache_id}.txt" + assert cache_file.exists() + assert long_result in cache_file.read_text(encoding="utf-8") + + records = agent.memory.retrieve() + cached_message = None + for record in records: + message = record.memory_record.message + if ( + isinstance(message, FunctionCallingMessage) + and getattr(message, "tool_call_id", "") == "call-1" + and getattr(message, "result", None) + ): + cached_message = message + break + + assert cached_message is not None + assert cached_entry.cache_id in cached_message.result + assert agent._cache_lookup_tool_name in cached_message.result + assert cached_message.result != long_result + + retrieved = agent.retrieve_cached_tool_output(cached_entry.cache_id) + assert long_result in retrieved + + +def test_tool_output_history_cleared_on_reset(tmp_path): + agent = ChatAgent( + system_message="Cache reset tester.", + model=StubModel(model_type=ModelType.STUB), + tool_call_cache_threshold=10, + tool_call_cache_dir=tmp_path, + ) + + agent._record_tool_calling( + "dummy_tool", + args={"value": "a"}, + result="A" * 20, + tool_call_id="call-initial", + ) + assert agent._tool_output_history + + agent.clear_memory() + assert agent._tool_output_history == [] + + agent._record_tool_calling( + "dummy_tool", + args={"value": "b"}, + result="B" * 20, + tool_call_id="call-after-clear", + ) + assert len(agent._tool_output_history) == 1 + + agent.reset() + assert agent._tool_output_history == [] + + +def test_retrieve_multiple_cached_outputs(tmp_path): + agent = ChatAgent( + system_message="Multiple cache retrieval tester.", + model=StubModel(model_type=ModelType.STUB), + tool_call_cache_threshold=10, + tool_call_cache_dir=tmp_path, + ) + + # Record multiple tool calls + agent._record_tool_calling( + "tool_1", + args={"value": "a"}, + result="A" * 20, + tool_call_id="call-1", + ) + agent._record_tool_calling( + "tool_2", + args={"value": "b"}, + result="B" * 30, + tool_call_id="call-2", + ) + + # Cache them + cached_count = agent._cache_tool_calls() + assert cached_count == 2 + + # Get cache IDs + cache_ids = [ + entry.cache_id for entry in agent._tool_output_history if entry.cached + ] + assert len(cache_ids) == 2 + + # Test single retrieval (same tool, single ID) + result1 = agent.retrieve_cached_tool_output(cache_ids[0]) + assert "A" * 20 in result1 + + # Test multiple retrieval (same tool, multiple IDs) + import json + + result_multiple = agent.retrieve_cached_tool_output( + f"{cache_ids[0]}, {cache_ids[1]}" + ) + results_dict = json.loads(result_multiple) + + assert cache_ids[0] in results_dict + assert cache_ids[1] in results_dict + assert "A" * 20 in results_dict[cache_ids[0]] + assert "B" * 30 in results_dict[cache_ids[1]]