|
4 | 4 | import os |
5 | 5 | import time |
6 | 6 | import uuid |
7 | | -from typing import List, Optional, TypedDict, Union |
| 7 | +from typing import List, Optional, TypedDict, Union, Literal |
8 | 8 | from unittest.mock import patch |
9 | 9 |
|
10 | 10 | import pytest |
11 | 11 | from langchain_anthropic.chat_models import ChatAnthropic |
12 | 12 | from langchain_community.chat_models.fake import FakeMessagesListChatModel |
13 | 13 | from langchain_community.llms.fake import FakeListLLM, FakeStreamingListLLM |
14 | | -from langchain_core.messages import AIMessage, HumanMessage |
| 14 | +from langchain_core.messages import AIMessage, HumanMessage, ToolCall |
15 | 15 | from langchain_core.prompts import ChatPromptTemplate |
16 | 16 | from langchain_core.runnables import RunnableLambda |
17 | 17 | from langchain_openai.chat_models import ChatOpenAI |
18 | 18 | from langgraph.graph.state import END, START, StateGraph |
| 19 | +from langgraph.prebuilt import create_react_agent |
| 20 | +from langchain_core.tools import tool |
19 | 21 |
|
20 | 22 | from posthog.ai.langchain import CallbackHandler |
21 | 23 | from posthog.ai.langchain.callbacks import GenerationMetadata, SpanMetadata |
@@ -1234,3 +1236,32 @@ async def sleep(x): # -> Any: |
1234 | 1236 | assert ( |
1235 | 1237 | min(approximate_latency - 1, 0) <= math.floor(third_call[1]["properties"]["$ai_latency"]) <= approximate_latency |
1236 | 1238 | ) |
| 1239 | + |
| 1240 | + |
| 1241 | +@pytest.mark.skipif(not OPENAI_API_KEY, reason="OPENAI_API_KEY is not set") |
| 1242 | +def test_langgraph_agent(mock_client): |
| 1243 | + @tool |
| 1244 | + def get_weather(city: Literal["nyc", "sf"]): |
| 1245 | + """ |
| 1246 | + Use this to get weather information. |
| 1247 | +
|
| 1248 | + Args: |
| 1249 | + city: The city to get weather information for. |
| 1250 | + """ |
| 1251 | + if city == "sf": |
| 1252 | + return "It's always sunny in sf" |
| 1253 | + return "No info" |
| 1254 | + |
| 1255 | + tools = [get_weather] |
| 1256 | + model = ChatOpenAI(api_key=OPENAI_API_KEY, model="gpt-4o-mini", temperature=0) |
| 1257 | + graph = create_react_agent(model, tools=tools) |
| 1258 | + inputs = {"messages": [("user", "what is the weather in sf")]} |
| 1259 | + cb = CallbackHandler(mock_client, trace_id="test-trace-id", distinct_id="test-distinct-id") |
| 1260 | + res = graph.invoke(inputs, config={"callbacks": [cb]}) |
| 1261 | + calls = [call[1] for call in mock_client.capture.call_args_list] |
| 1262 | + assert len(calls) == 21 |
| 1263 | + for call in calls: |
| 1264 | + assert call["properties"]["$ai_trace_id"] == "test-trace-id" |
| 1265 | + assert len([call for call in calls if call["event"] == "$ai_generation"]) == 2 |
| 1266 | + assert len([call for call in calls if call["event"] == "$ai_span"]) == 18 |
| 1267 | + assert len([call for call in calls if call["event"] == "$ai_trace"]) == 1 |
0 commit comments