diff --git a/libs/core/langchain_core/tracers/__init__.py b/libs/core/langchain_core/tracers/__init__.py index b2f41d10f82ee..abc3a2e72e962 100644 --- a/libs/core/langchain_core/tracers/__init__.py +++ b/libs/core/langchain_core/tracers/__init__.py @@ -15,6 +15,11 @@ ) from langchain_core.tracers.schemas import Run from langchain_core.tracers.stdout import ConsoleCallbackHandler + from langchain_core.tracers.utils import ( + count_tool_calls_in_run, + get_tool_call_count_from_run, + store_tool_call_count_in_run, + ) __all__ = ( "BaseTracer", @@ -25,6 +30,9 @@ "Run", "RunLog", "RunLogPatch", + "count_tool_calls_in_run", + "get_tool_call_count_from_run", + "store_tool_call_count_in_run", ) _dynamic_imports = { @@ -36,6 +44,9 @@ "RunLogPatch": "log_stream", "Run": "schemas", "ConsoleCallbackHandler": "stdout", + "count_tool_calls_in_run": "utils", + "get_tool_call_count_from_run": "utils", + "store_tool_call_count_in_run": "utils", } diff --git a/libs/core/langchain_core/tracers/base.py b/libs/core/langchain_core/tracers/base.py index 01bc9da0aa8ad..76d2c69f207fe 100644 --- a/libs/core/langchain_core/tracers/base.py +++ b/libs/core/langchain_core/tracers/base.py @@ -37,6 +37,19 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): def _persist_run(self, run: Run) -> None: """Persist a run.""" + def _store_tool_call_metadata(self, run: Run) -> None: + """Store tool call count in run metadata automatically.""" + try: + # Avoid circular imports + from langchain_core.tracers.utils import ( # noqa: PLC0415 + store_tool_call_count_in_run, + ) + + store_tool_call_count_in_run(run) + except Exception: # noqa: S110 + # Avoid breaking existing functionality + pass + def _start_trace(self, run: Run) -> None: """Start a trace for a run.""" super()._start_trace(run) @@ -44,6 +57,8 @@ def _start_trace(self, run: Run) -> None: def _end_trace(self, run: Run) -> None: """End a trace for a run.""" + self._store_tool_call_metadata(run) + if not run.parent_run_id: self._persist_run(run) self.run_map.pop(str(run.id)) @@ -534,6 +549,19 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): async def _persist_run(self, run: Run) -> None: """Persist a run.""" + async def _store_tool_call_metadata(self, run: Run) -> None: + """Store tool call count in run metadata.""" + try: + # Avoid circular imports + from langchain_core.tracers.utils import ( # noqa: PLC0415 + store_tool_call_count_in_run, + ) + + store_tool_call_count_in_run(run) + except Exception: # noqa: S110 + # Avoid breaking existing functionality + pass + @override async def _start_trace(self, run: Run) -> None: """Start a trace for a run. @@ -551,6 +579,8 @@ async def _end_trace(self, run: Run) -> None: Ending a trace will run concurrently with each _on_[run_type]_end method. No _on_[run_type]_end callback should depend on operations in _end_trace. """ + await self._store_tool_call_metadata(run) + if not run.parent_run_id: await self._persist_run(run) self.run_map.pop(str(run.id)) diff --git a/libs/core/langchain_core/tracers/utils.py b/libs/core/langchain_core/tracers/utils.py new file mode 100644 index 0000000000000..1d4026310ae81 --- /dev/null +++ b/libs/core/langchain_core/tracers/utils.py @@ -0,0 +1,91 @@ +"""Utility functions for working with Run objects and tracers.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from langchain_core.tracers.schemas import Run + + +def count_tool_calls_in_run(run: Run) -> int: + """Count tool calls in a `Run` object by examining messages. + + Args: + run: The `Run` object to examine. + + Returns: + The total number of tool calls found in the run's messages. + """ + tool_call_count = 0 + + # Check inputs for messages containing tool calls + inputs = getattr(run, "inputs", {}) or {} + if isinstance(inputs, dict) and "messages" in inputs: + messages = inputs["messages"] + if messages: + for msg in messages: + # Handle both dict and object representations + if hasattr(msg, "tool_calls"): + tool_calls = getattr(msg, "tool_calls", []) + if tool_calls: + tool_call_count += len(tool_calls) + elif isinstance(msg, dict) and "tool_calls" in msg: + tool_calls = msg.get("tool_calls", []) + if tool_calls: + tool_call_count += len(tool_calls) + + # Also check outputs for completeness + outputs = getattr(run, "outputs", {}) or {} + if isinstance(outputs, dict) and "messages" in outputs: + messages = outputs["messages"] + if messages: + for msg in messages: + if hasattr(msg, "tool_calls"): + tool_calls = getattr(msg, "tool_calls", []) + if tool_calls: + tool_call_count += len(tool_calls) + elif isinstance(msg, dict) and "tool_calls" in msg: + tool_calls = msg.get("tool_calls", []) + if tool_calls: + tool_call_count += len(tool_calls) + + return tool_call_count + + +def store_tool_call_count_in_run(run: Run, *, always_store: bool = False) -> int: + """Count tool calls in a `Run` and store the count in run metadata. + + Args: + run: The `Run` object to analyze and modify. + always_store: If `True`, always store the count even if `0`. + If `False`, only store when there are tool calls. + + Returns: + The number of tool calls found and stored. + """ + tool_call_count = count_tool_calls_in_run(run) + + # Only store if there are tool calls or if explicitly requested + if tool_call_count > 0 or always_store: + # Store in run.extra for easy access + if not hasattr(run, "extra") or run.extra is None: + run.extra = {} + run.extra["tool_call_count"] = tool_call_count + + return tool_call_count + + +def get_tool_call_count_from_run(run: Run) -> int | None: + """Get the tool call count from run metadata if available. + + Args: + run: The `Run` object to check. + + Returns: + The tool call count if stored in metadata, otherwise `None`. + """ + extra = getattr(run, "extra", {}) or {} + if isinstance(extra, dict): + return extra.get("tool_call_count") + return None diff --git a/libs/core/tests/unit_tests/tracers/test_automatic_metadata.py b/libs/core/tests/unit_tests/tracers/test_automatic_metadata.py new file mode 100644 index 0000000000000..2fdc4b4e638d3 --- /dev/null +++ b/libs/core/tests/unit_tests/tracers/test_automatic_metadata.py @@ -0,0 +1,127 @@ +"""Test automatic tool call count storage in tracers.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, PropertyMock + +from langchain_core.messages import AIMessage +from langchain_core.messages.tool import ToolCall +from langchain_core.tracers.base import BaseTracer +from langchain_core.tracers.schemas import Run + + +class MockTracer(BaseTracer): + """Mock tracer for testing automatic metadata storage.""" + + def __init__(self) -> None: + super().__init__() + self.persisted_runs: list[Run] = [] + + def _persist_run(self, run: Run) -> None: + """Store the run for inspection.""" + self.persisted_runs.append(run) + + +def test_base_tracer_automatically_stores_tool_call_count() -> None: + """Test that `BaseTracer` automatically stores tool call count.""" + tracer = MockTracer() + + # Create a mock run with tool calls + run = MagicMock(spec=Run) + run.id = "test-run-id" + run.parent_run_id = None # Root run, will be persisted + run.extra = {} + + # Set up messages with tool calls + tool_calls = [ + ToolCall(name="search", args={"query": "test"}, id="call_1"), + ToolCall(name="calculator", args={"expression": "2+2"}, id="call_2"), + ] + messages = [AIMessage(content="Test", tool_calls=tool_calls)] + run.inputs = {"messages": messages} + run.outputs = {} + + # Add run to tracer's run_map to simulate it being tracked + tracer.run_map[str(run.id)] = run + + # End the trace (this should trigger automatic metadata storage) + tracer._end_trace(run) + + # Verify tool call count was automatically stored + assert "tool_call_count" in run.extra + assert run.extra["tool_call_count"] == 2 + + # Verify the run was persisted + assert len(tracer.persisted_runs) == 1 + assert tracer.persisted_runs[0] == run + + +def test_base_tracer_handles_no_tool_calls() -> None: + """Test that `BaseTracer` handles runs with no tool calls gracefully.""" + tracer = MockTracer() + + # Create a mock run without tool calls + run = MagicMock(spec=Run) + run.id = "test-run-id-no-tools" + run.parent_run_id = None + run.extra = {} + + # Set up messages without tool calls + messages = [AIMessage(content="No tools here")] + run.inputs = {"messages": messages} + run.outputs = {} + + # Add run to tracer's run_map + tracer.run_map[str(run.id)] = run + + # End the trace + tracer._end_trace(run) + + # Verify tool call count is not stored when there are no tool calls + assert "tool_call_count" not in run.extra + + +def test_base_tracer_handles_runs_without_messages() -> None: + """Test that `BaseTracer` handles runs without messages gracefully.""" + tracer = MockTracer() + + # Create a mock run without messages + run = MagicMock(spec=Run) + run.id = "test-run-id-no-messages" + run.parent_run_id = None + run.extra = {} + run.inputs = {} + run.outputs = {} + + # Add run to tracer's run_map + tracer.run_map[str(run.id)] = run + + # End the trace + tracer._end_trace(run) + + # Verify tool call count is not stored when there are no messages + assert "tool_call_count" not in run.extra + + +def test_base_tracer_doesnt_break_on_metadata_error() -> None: + """Test that `BaseTracer` continues working if metadata storage fails.""" + tracer = MockTracer() + + # Create a mock run that will cause an error in tool call counting + run = MagicMock(spec=Run) + run.id = "test-run-id-error" + run.parent_run_id = None + run.extra = {} + + # Make the run.inputs property raise an error when accessed + type(run).inputs = PropertyMock(side_effect=RuntimeError("Simulated error")) + + # Add run to tracer's run_map + tracer.run_map[str(run.id)] = run + + # End the trace - this should not raise an exception + tracer._end_trace(run) + + # The run should still be persisted despite the metadata error + assert len(tracer.persisted_runs) == 1 + assert tracer.persisted_runs[0] == run diff --git a/libs/core/tests/unit_tests/tracers/test_imports.py b/libs/core/tests/unit_tests/tracers/test_imports.py index 01cf3260aea25..a9421444901b0 100644 --- a/libs/core/tests/unit_tests/tracers/test_imports.py +++ b/libs/core/tests/unit_tests/tracers/test_imports.py @@ -2,13 +2,16 @@ EXPECTED_ALL = [ "BaseTracer", + "ConsoleCallbackHandler", "EvaluatorCallbackHandler", "LangChainTracer", - "ConsoleCallbackHandler", + "LogStreamCallbackHandler", "Run", "RunLog", "RunLogPatch", - "LogStreamCallbackHandler", + "count_tool_calls_in_run", + "get_tool_call_count_from_run", + "store_tool_call_count_in_run", ] diff --git a/libs/core/tests/unit_tests/tracers/test_utils.py b/libs/core/tests/unit_tests/tracers/test_utils.py new file mode 100644 index 0000000000000..0ad647a89f0eb --- /dev/null +++ b/libs/core/tests/unit_tests/tracers/test_utils.py @@ -0,0 +1,201 @@ +"""Unit tests for tracer utility functions.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.messages.tool import ToolCall +from langchain_core.tracers.utils import ( + count_tool_calls_in_run, + get_tool_call_count_from_run, + store_tool_call_count_in_run, +) + + +def create_mock_run( + inputs: dict[str, Any] | None = None, + outputs: dict[str, Any] | None = None, + extra: dict[str, Any] | None = None, +) -> MagicMock: + """Create a mock Run object for testing.""" + run = MagicMock() + run.inputs = inputs or {} + run.outputs = outputs or {} + run.extra = extra or {} + return run + + +def test_count_tool_calls_in_run_no_messages() -> None: + """Test counting tool calls when there are no messages.""" + run = create_mock_run() + + count = count_tool_calls_in_run(run) + assert count == 0 + + # Test counting tool calls when messages exist but no tool calls. + messages = [ + HumanMessage(content="Hello"), + AIMessage(content="Hi there!"), + ] + run = create_mock_run(inputs={"messages": messages}) + + count = count_tool_calls_in_run(run) + assert count == 0 + + # Test counting when `tool_calls` is empty list + messages = [AIMessage(content="No tools", tool_calls=[])] + run = create_mock_run(inputs={"messages": messages}) + + count = count_tool_calls_in_run(run) + assert count == 0 + + +def test_count_tool_calls_in_run_with_tool_calls() -> None: + """Test counting tool calls when they exist in messages.""" + tool_calls = [ + ToolCall(name="search", args={"query": "test"}, id="call_1"), + ToolCall(name="calculator", args={"expression": "2+2"}, id="call_2"), + ] + + messages = [ + HumanMessage(content="Search for test and calculate 2+2"), + AIMessage(content="I'll help you with that", tool_calls=tool_calls), + ] + run = create_mock_run(inputs={"messages": messages}) + + count = count_tool_calls_in_run(run) + assert count == 2 + + # Test counting tool calls when messages are in dict format. + messages = [ + {"role": "human", "content": "Hello"}, # type: ignore[list-item] + { # type: ignore[list-item] + "role": "assistant", + "content": "Hi!", + "tool_calls": [ + {"name": "search", "args": {"query": "test"}, "id": "call_1"}, + ], + }, + ] + run = create_mock_run(inputs={"messages": messages}) + + count = count_tool_calls_in_run(run) + assert count == 1 + + +def test_count_tool_calls_in_run_outputs_too() -> None: + """Test counting tool calls in both inputs and outputs.""" + input_tool_calls = [ToolCall(name="search", args={"query": "test"}, id="call_1")] + output_tool_calls = [ + ToolCall(name="calculator", args={"expression": "2+2"}, id="call_2") + ] + + input_messages = [AIMessage(content="Input", tool_calls=input_tool_calls)] + output_messages = [AIMessage(content="Output", tool_calls=output_tool_calls)] + + run = create_mock_run( + inputs={"messages": input_messages}, outputs={"messages": output_messages} + ) + + count = count_tool_calls_in_run(run) + assert count == 2 + + +def test_store_tool_call_count_in_run() -> None: + """Test storing tool call count in run metadata.""" + tool_calls = [ToolCall(name="search", args={"query": "test"}, id="call_1")] + messages = [AIMessage(content="Test", tool_calls=tool_calls)] + run = create_mock_run(inputs={"messages": messages}) + + count = store_tool_call_count_in_run(run) + + assert count == 1 + assert run.extra["tool_call_count"] == 1 + + +def test_store_tool_call_count_always_store() -> None: + """Test storing tool call count with `always_store=True`.""" + messages = [AIMessage(content="No tools")] + run = create_mock_run(inputs={"messages": messages}) + + count = store_tool_call_count_in_run(run, always_store=True) + + assert count == 0 + assert run.extra["tool_call_count"] == 0 + + +def test_store_tool_call_count_no_tools_no_always_store() -> None: + """Test that count is not stored when no tools and `always_store=False`.""" + messages = [AIMessage(content="No tools")] + run = create_mock_run(inputs={"messages": messages}) + + count = store_tool_call_count_in_run(run, always_store=False) + + assert count == 0 + assert "tool_call_count" not in run.extra + + +def test_store_tool_call_count_in_run_no_extra() -> None: + """Test storing when `run.extra` is `None`.""" + tool_calls = [ToolCall(name="search", args={"query": "test"}, id="call_1")] + messages = [AIMessage(content="Test", tool_calls=tool_calls)] + run = create_mock_run(inputs={"messages": messages}) + run.extra = None + + count = store_tool_call_count_in_run(run) + + assert count == 1 + assert run.extra["tool_call_count"] == 1 + + +def test_get_tool_call_count_from_run() -> None: + """Test retrieving tool call count from run metadata.""" + run = create_mock_run(extra={"tool_call_count": 5}) + + count = get_tool_call_count_from_run(run) + assert count == 5 + + +def test_get_tool_call_count_from_run_not_stored() -> None: + """Test retrieving when count is not stored.""" + run = create_mock_run() + + count = get_tool_call_count_from_run(run) + assert count is None + + # Test retrieving when run.extra is None. + run = create_mock_run() + run.extra = None + + count = get_tool_call_count_from_run(run) + assert count is None + + +def test_count_tool_calls_handles_none_inputs() -> None: + """Test counting when inputs/outputs are `None`.""" + run = create_mock_run() + run.inputs = None + run.outputs = None + + count = count_tool_calls_in_run(run) + assert count == 0 + + +def test_count_tool_calls_mixed_message_types() -> None: + """Test counting with mixed message object and `dict` types.""" + tool_calls_obj = [ToolCall(name="search", args={"query": "test"}, id="call_1")] + + messages = [ + AIMessage(content="Object message", tool_calls=tool_calls_obj), + { + "role": "assistant", + "content": "Dict message", + "tool_calls": [{"name": "calc", "args": {"expr": "1+1"}, "id": "call_2"}], + }, + ] + run = create_mock_run(inputs={"messages": messages}) + + count = count_tool_calls_in_run(run) + assert count == 2