Skip to content

Commit 49a869e

Browse files
authored
feat: Add tracing to LLM invoke (#650)
1 parent f4dc5b2 commit 49a869e

File tree

4 files changed

+96
-5
lines changed

4 files changed

+96
-5
lines changed

src/rai_core/rai/agents/langchain/core/conversational_agent.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@
2222
BaseMessage,
2323
SystemMessage,
2424
)
25+
from langchain_core.runnables import RunnableConfig
2526
from langchain_core.tools import BaseTool
2627
from langgraph.graph import START, StateGraph
2728
from langgraph.graph.state import CompiledStateGraph
2829
from langgraph.prebuilt.tool_node import tools_condition
2930

3031
from rai.agents.langchain.core.tool_runner import ToolRunner
32+
from rai.agents.langchain.invocation_helpers import invoke_llm_with_tracing
3133

3234

3335
class State(TypedDict):
@@ -39,6 +41,7 @@ def agent(
3941
logger: logging.Logger,
4042
system_prompt: str | SystemMessage,
4143
state: State,
44+
config: RunnableConfig,
4245
):
4346
logger.info("Running thinker")
4447

@@ -54,7 +57,9 @@ def agent(
5457
else system_prompt
5558
)
5659
state["messages"].insert(0, system_msg)
57-
ai_msg = llm.invoke(state["messages"])
60+
61+
# Invoke LLM with tracing if it is configured and available
62+
ai_msg = invoke_llm_with_tracing(llm, state["messages"], config)
5863
state["messages"].append(ai_msg)
5964
return state
6065

src/rai_core/rai/agents/langchain/core/react_agent.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@
2222

2323
from langchain_core.language_models import BaseChatModel
2424
from langchain_core.messages import BaseMessage, SystemMessage
25-
from langchain_core.runnables import Runnable
25+
from langchain_core.runnables import Runnable, RunnableConfig
2626
from langchain_core.tools import BaseTool
2727
from langgraph.graph import START, StateGraph
2828
from langgraph.prebuilt.tool_node import tools_condition
2929

3030
from rai.agents.langchain.core.tool_runner import ToolRunner
31+
from rai.agents.langchain.invocation_helpers import invoke_llm_with_tracing
3132
from rai.initialization import get_llm_model
3233
from rai.messages import SystemMultimodalMessage
3334

@@ -48,6 +49,7 @@ def llm_node(
4849
llm: BaseChatModel,
4950
system_prompt: Optional[str | SystemMultimodalMessage],
5051
state: ReActAgentState,
52+
config: RunnableConfig,
5153
):
5254
"""Process messages using the LLM.
5355
@@ -57,6 +59,8 @@ def llm_node(
5759
The language model to use for processing
5860
state : ReActAgentState
5961
Current state containing messages
62+
config : RunnableConfig
63+
Configuration including callbacks for tracing
6064
6165
Returns
6266
-------
@@ -75,7 +79,9 @@ def llm_node(
7579
# at this point, state['messages'] length should at least be 1
7680
if not isinstance(state["messages"][0], SystemMessage):
7781
state["messages"].insert(0, SystemMessage(content=system_prompt))
78-
ai_msg = llm.invoke(state["messages"])
82+
83+
# Invoke LLM with tracing if it is configured and available
84+
ai_msg = invoke_llm_with_tracing(llm, state["messages"], config)
7985
state["messages"].append(ai_msg)
8086

8187

src/rai_core/rai/agents/langchain/core/state_based_agent.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,13 @@
2626

2727
from langchain_core.language_models import BaseChatModel
2828
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
29-
from langchain_core.runnables import Runnable
29+
from langchain_core.runnables import Runnable, RunnableConfig
3030
from langchain_core.tools import BaseTool
3131
from langgraph.graph import START, StateGraph
3232
from langgraph.prebuilt.tool_node import tools_condition
3333

3434
from rai.agents.langchain.core.tool_runner import ToolRunner
35+
from rai.agents.langchain.invocation_helpers import invoke_llm_with_tracing
3536
from rai.initialization import get_llm_model
3637
from rai.messages import HumanMultimodalMessage, SystemMultimodalMessage
3738

@@ -52,6 +53,7 @@ def llm_node(
5253
llm: BaseChatModel,
5354
system_prompt: Optional[str | SystemMultimodalMessage],
5455
state: ReActAgentState,
56+
config: RunnableConfig,
5557
):
5658
"""Process messages using the LLM.
5759
@@ -61,6 +63,8 @@ def llm_node(
6163
The language model to use for processing
6264
state : ReActAgentState
6365
Current state containing messages
66+
config : RunnableConfig
67+
Configuration including callbacks for tracing
6468
6569
Returns
6670
-------
@@ -79,7 +83,9 @@ def llm_node(
7983
# at this point, state['messages'] length should at least be 1
8084
if not isinstance(state["messages"][0], SystemMessage):
8185
state["messages"].insert(0, SystemMessage(content=system_prompt))
82-
ai_msg = llm.invoke(state["messages"])
86+
87+
# Invoke LLM with tracing if it is configured and available
88+
ai_msg = invoke_llm_with_tracing(llm, state["messages"], config)
8389
state["messages"].append(ai_msg)
8490

8591

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright (C) 2025 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import logging
16+
from typing import Any, List, Optional
17+
18+
from langchain_core.language_models import BaseChatModel
19+
from langchain_core.messages import BaseMessage
20+
from langchain_core.runnables import RunnableConfig
21+
22+
from rai.initialization import get_tracing_callbacks
23+
24+
logger = logging.getLogger(__name__)
25+
26+
27+
def invoke_llm_with_tracing(
28+
llm: BaseChatModel,
29+
messages: List[BaseMessage],
30+
config: Optional[RunnableConfig] = None,
31+
) -> Any:
32+
"""
33+
Invoke an LLM with enhanced tracing callbacks.
34+
35+
This function automatically adds tracing callbacks (like Langfuse) to LLM calls
36+
within LangGraph nodes, solving the callback propagation issue.
37+
38+
Parameters
39+
----------
40+
llm : BaseChatModel
41+
The language model to invoke
42+
messages : List[BaseMessage]
43+
Messages to send to the LLM
44+
config : Optional[RunnableConfig]
45+
Existing configuration (may contain some callbacks)
46+
47+
Returns
48+
-------
49+
Any
50+
The LLM response
51+
"""
52+
tracing_callbacks = get_tracing_callbacks()
53+
54+
if len(tracing_callbacks) == 0:
55+
# No tracing callbacks available, use config as-is
56+
return llm.invoke(messages, config=config)
57+
58+
# Create enhanced config with tracing callbacks
59+
enhanced_config = config.copy() if config else {}
60+
61+
# Add tracing callbacks to existing callbacks
62+
existing_callbacks = config.get("callbacks", []) if config else []
63+
64+
if hasattr(existing_callbacks, "handlers"):
65+
# Merge with existing CallbackManager
66+
all_callbacks = existing_callbacks.handlers + tracing_callbacks
67+
elif isinstance(existing_callbacks, list):
68+
all_callbacks = existing_callbacks + tracing_callbacks
69+
else:
70+
all_callbacks = tracing_callbacks
71+
72+
enhanced_config["callbacks"] = all_callbacks
73+
74+
return llm.invoke(messages, config=enhanced_config)

0 commit comments

Comments
 (0)