Skip to content

Commit 72ee4a6

Browse files
committed
fix bug causing issue with tools
1 parent 9bd7169 commit 72ee4a6

File tree

3 files changed

+170
-13
lines changed

3 files changed

+170
-13
lines changed

src/examples/langchain_example/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from langtrace_python_sdk import with_langtrace_root_span
33

44
from .groq_example import groq_basic, groq_streaming
5+
from .langgraph_example_tools import basic_graph_tools
56

67

78
class LangChainRunner:
@@ -10,6 +11,7 @@ def run(self):
1011
basic_app()
1112
rag()
1213
load_and_split()
14+
basic_graph_tools()
1315

1416

1517
class GroqRunner:
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
from typing import Annotated
2+
3+
from langchain_anthropic import ChatAnthropic
4+
from langchain_core.messages import HumanMessage
5+
from langchain_core.pydantic_v1 import BaseModel
6+
from typing_extensions import TypedDict
7+
from langchain_core.pydantic_v1 import BaseModel, Field
8+
from langchain_core.tools import Tool
9+
from langgraph.checkpoint.memory import MemorySaver
10+
from langgraph.graph import StateGraph
11+
from langgraph.graph.message import add_messages
12+
from langgraph.prebuilt import ToolNode, tools_condition
13+
from langchain_core.messages import AIMessage, ToolMessage
14+
15+
from langtrace_python_sdk import langtrace
16+
17+
langtrace.init()
18+
19+
primes = {998: 7901, 999: 7907, 1000: 7919}
20+
21+
22+
class PrimeInput(BaseModel):
23+
n: int = Field()
24+
25+
26+
def is_prime(n: int) -> bool:
27+
if n <= 1 or (n % 2 == 0 and n > 2):
28+
return False
29+
for i in range(3, int(n**0.5) + 1, 2):
30+
if n % i == 0:
31+
return False
32+
return True
33+
34+
35+
def get_prime(n: int, primes: dict = primes) -> str:
36+
return str(primes.get(int(n)))
37+
38+
39+
async def aget_prime(n: int, primes: dict = primes) -> str:
40+
return str(primes.get(int(n)))
41+
42+
43+
class State(TypedDict):
44+
messages: Annotated[list, add_messages]
45+
# This flag is new
46+
ask_human: bool
47+
48+
49+
class RequestAssistance(BaseModel):
50+
"""Escalate the conversation to an expert. Use this if you are unable to assist directly or if the user requires support beyond your permissions.
51+
52+
To use this function, relay the user's 'request' so the expert can provide the right guidance.
53+
"""
54+
55+
request: str
56+
57+
58+
llm = ChatAnthropic(model="claude-3-haiku-20240307")
59+
# We can bind the llm to a tool definition, a pydantic model, or a json schema
60+
llm_with_tools = llm.bind_tools([RequestAssistance])
61+
tools = [
62+
Tool(
63+
name="GetPrime",
64+
func=get_prime,
65+
description="A tool that returns the `n`th prime number",
66+
args_schema=PrimeInput,
67+
coroutine=aget_prime,
68+
),
69+
]
70+
71+
72+
def chatbot(state: State):
73+
response = llm_with_tools.invoke(state["messages"])
74+
ask_human = False
75+
if (
76+
response.tool_calls
77+
and response.tool_calls[0]["name"] == RequestAssistance.__name__
78+
):
79+
ask_human = True
80+
return {"messages": [response], "ask_human": ask_human}
81+
82+
83+
graph_builder = StateGraph(State)
84+
85+
graph_builder.add_node("chatbot", chatbot)
86+
graph_builder.add_node("tools", ToolNode(tools=tools))
87+
88+
89+
def create_response(response: str, ai_message: AIMessage):
90+
return ToolMessage(
91+
content=response,
92+
tool_call_id=ai_message.tool_calls[0]["id"],
93+
)
94+
95+
96+
def human_node(state: State):
97+
new_messages = []
98+
if not isinstance(state["messages"][-1], ToolMessage):
99+
# Typically, the user will have updated the state during the interrupt.
100+
# If they choose not to, we will include a placeholder ToolMessage to
101+
# let the LLM continue.
102+
new_messages.append(
103+
create_response("No response from human.", state["messages"][-1])
104+
)
105+
return {
106+
# Append the new messages
107+
"messages": new_messages,
108+
# Unset the flag
109+
"ask_human": False,
110+
}
111+
112+
113+
def select_next_node(state: State):
114+
if state["ask_human"]:
115+
return "human"
116+
# Otherwise, we can route as before
117+
return tools_condition(state)
118+
119+
120+
def basic_graph_tools():
121+
graph_builder.add_node("human", human_node)
122+
graph_builder.add_conditional_edges(
123+
"chatbot",
124+
select_next_node,
125+
{"human": "human", "tools": "tools", "__end__": "__end__"},
126+
)
127+
graph_builder.add_edge("tools", "chatbot")
128+
graph_builder.add_edge("human", "chatbot")
129+
graph_builder.set_entry_point("chatbot")
130+
memory = MemorySaver()
131+
graph = graph_builder.compile(
132+
checkpointer=memory,
133+
interrupt_before=["human"],
134+
)
135+
136+
config = {"configurable": {"thread_id": "1"}}
137+
events = graph.stream(
138+
{
139+
"messages": [
140+
(
141+
"user",
142+
"I'm learning LangGraph. Could you do some research on it for me?",
143+
)
144+
]
145+
},
146+
config,
147+
stream_mode="values",
148+
)
149+
for event in events:
150+
if "messages" in event:
151+
event["messages"][-1]

src/langtrace_python_sdk/instrumentation/anthropic/patch.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def traced_method(wrapped, instance, args, kwargs):
6161
**get_llm_request_attributes(kwargs, prompts=prompts),
6262
**get_llm_url(instance),
6363
SpanAttributes.LLM_PATH: APIS["MESSAGES_CREATE"]["ENDPOINT"],
64-
**get_extra_attributes(),
64+
**get_extra_attributes(), # type: ignore
6565
}
6666

6767
attributes = LLMSpanAttributes(**span_attributes)
@@ -88,22 +88,26 @@ def traced_method(wrapped, instance, args, kwargs):
8888
@silently_fail
8989
def set_response_attributes(result, span, kwargs):
9090
if not is_streaming(kwargs):
91+
9192
if hasattr(result, "content") and result.content is not None:
9293
set_span_attribute(
9394
span, SpanAttributes.LLM_RESPONSE_MODEL, result.model
9495
)
95-
completion = [
96-
{
97-
"role": result.role if result.role else "assistant",
98-
"content": result.content[0].text,
99-
"type": result.content[0].type,
100-
}
101-
]
102-
set_event_completion(span, completion)
103-
104-
else:
105-
responses = []
106-
set_event_completion(span, responses)
96+
if hasattr(result, "content") and result.content[0] is not None:
97+
content = result.content[0]
98+
typ = content.type
99+
role = result.role if result.role else "assistant"
100+
if typ == "tool_result" or typ == "tool_use":
101+
content = content.json()
102+
set_span_attribute(
103+
span, SpanAttributes.LLM_TOOL_RESULTS, json.dumps(content)
104+
)
105+
if typ == "text":
106+
content = result.content[0].text
107+
content = content.text
108+
set_event_completion(
109+
span, [{type: typ, role: role, content: content}]
110+
)
107111

108112
if (
109113
hasattr(result, "system_fingerprint")

0 commit comments

Comments
 (0)