Skip to content

Commit 9656c51

Browse files
committed
test: fix part 1
1 parent 0ab22fb commit 9656c51

File tree

2 files changed

+81
-24
lines changed

2 files changed

+81
-24
lines changed

posthog/ai/langchain/callbacks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ def _capture_generation(
469469
trace_id: Any,
470470
run_id: UUID,
471471
run: GenerationMetadata,
472-
output: Union[LLMResult | BaseException],
472+
output: Union[LLMResult, BaseException],
473473
parent_run_id: Optional[UUID] = None,
474474
):
475475
event_properties = {

posthog/test/ai/langchain/test_callbacks.py

Lines changed: 80 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from langgraph.graph.state import END, START, StateGraph
1919

2020
from posthog.ai.langchain import CallbackHandler
21+
from posthog.ai.langchain.callbacks import GenerationMetadata, SpanMetadata
2122

2223
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
2324
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY")
@@ -74,19 +75,23 @@ def test_metadata_capture(mock_client):
7475
messages=[{"role": "user", "content": "Who won the world series in 2020?"}],
7576
invocation_params={"temperature": 0.5},
7677
metadata={"ls_model_name": "hog-mini", "ls_provider": "posthog"},
78+
name="test",
7779
)
78-
expected = {
79-
"model": "hog-mini",
80-
"input": [{"role": "user", "content": "Who won the world series in 2020?"}],
81-
"start_time": 1234567890,
82-
"model_params": {"temperature": 0.5},
83-
"provider": "posthog",
84-
"base_url": "https://us.posthog.com",
85-
}
80+
expected = GenerationMetadata(
81+
model="hog-mini",
82+
input=[{"role": "user", "content": "Who won the world series in 2020?"}],
83+
start_time=1234567890,
84+
model_params={"temperature": 0.5},
85+
provider="posthog",
86+
base_url="https://us.posthog.com",
87+
name="test",
88+
end_time=None,
89+
)
8690
assert callbacks._runs[run_id] == expected
8791
with patch("time.time", return_value=1234567891):
8892
run = callbacks._pop_run_metadata(run_id)
89-
assert run == {**expected, "end_time": 1234567891}
93+
expected.end_time = 1234567891
94+
assert run == expected
9095
assert callbacks._runs == {}
9196
callbacks._pop_run_metadata(uuid.uuid4()) # should not raise
9297

@@ -95,12 +100,32 @@ def test_run_metadata_capture(mock_client):
95100
callbacks = CallbackHandler(mock_client)
96101
run_id = uuid.uuid4()
97102
with patch("time.time", return_value=1234567890):
98-
callbacks._set_span_metadata(run_id, "test", 1)
99-
expected = {
100-
"name": "test",
101-
"input": 1,
102-
"start_time": 1234567890,
103-
}
103+
callbacks._set_trace_or_span_metadata(None, 1, run_id)
104+
expected = SpanMetadata(
105+
name="trace",
106+
input=1,
107+
start_time=1234567890,
108+
end_time=None,
109+
)
110+
assert callbacks._runs[run_id] == expected
111+
with patch("time.time", return_value=1234567890):
112+
callbacks._set_trace_or_span_metadata(None, 1, run_id, uuid.uuid4())
113+
expected = SpanMetadata(
114+
name="span",
115+
input=1,
116+
start_time=1234567890,
117+
end_time=None,
118+
)
119+
assert callbacks._runs[run_id] == expected
120+
121+
with patch("time.time", return_value=1234567890):
122+
callbacks._set_trace_or_span_metadata({"name": "test"}, 1, run_id)
123+
expected = SpanMetadata(
124+
name="test",
125+
input=1,
126+
start_time=1234567890,
127+
end_time=None,
128+
)
104129
assert callbacks._runs[run_id] == expected
105130

106131

@@ -132,11 +157,24 @@ def test_basic_chat_chain(mock_client, stream):
132157
result = chain.invoke({}, config={"callbacks": callbacks})
133158

134159
assert result.content == "The Los Angeles Dodgers won the World Series in 2020."
135-
assert mock_client.capture.call_count == 2
136-
generation_args = mock_client.capture.call_args_list[0][1]
160+
assert mock_client.capture.call_count == 3
161+
162+
span_args = mock_client.capture.call_args_list[0][1]
163+
span_props = span_args["properties"]
164+
165+
generation_args = mock_client.capture.call_args_list[1][1]
137166
generation_props = generation_args["properties"]
138-
trace_args = mock_client.capture.call_args_list[1][1]
139167

168+
trace_args = mock_client.capture.call_args_list[2][1]
169+
trace_props = trace_args["properties"]
170+
171+
# Span is first
172+
assert span_args["event"] == "$ai_span"
173+
assert span_props["$ai_trace_id"] == generation_props["$ai_trace_id"]
174+
assert span_props["$ai_parent_id"] == trace_props["$ai_trace_id"]
175+
assert "$ai_span_id" in span_props
176+
177+
# Generation is second
140178
assert generation_args["event"] == "$ai_generation"
141179
assert "distinct_id" in generation_args
142180
assert "$ai_model" in generation_props
@@ -154,9 +192,15 @@ def test_basic_chat_chain(mock_client, stream):
154192
assert generation_props["$ai_input_tokens"] == 10
155193
assert generation_props["$ai_output_tokens"] == 10
156194
assert generation_props["$ai_http_status"] == 200
157-
assert generation_props["$ai_trace_id"] is not None
158195
assert isinstance(generation_props["$ai_latency"], float)
196+
assert "$ai_generation_id" in generation_props
197+
assert generation_props["$ai_parent_id"] == trace_props["$ai_trace_id"]
198+
assert generation_props["$ai_trace_id"] == trace_props["$ai_trace_id"]
199+
200+
# Trace is last
159201
assert trace_args["event"] == "$ai_trace"
202+
assert "$ai_trace_id" in trace_props
203+
assert "$ai_parent_id" not in trace_props
160204

161205

162206
@pytest.mark.parametrize("stream", [True, False])
@@ -186,13 +230,22 @@ async def test_async_basic_chat_chain(mock_client, stream):
186230
else:
187231
result = await chain.ainvoke({}, config={"callbacks": callbacks})
188232
assert result.content == "The Los Angeles Dodgers won the World Series in 2020."
189-
assert mock_client.capture.call_count == 2
233+
assert mock_client.capture.call_count == 3
190234

191-
generation_args = mock_client.capture.call_args_list[0][1]
235+
span_args = mock_client.capture.call_args_list[0][1]
236+
span_props = span_args["properties"]
237+
generation_args = mock_client.capture.call_args_list[1][1]
192238
generation_props = generation_args["properties"]
193-
trace_args = mock_client.capture.call_args_list[1][1]
239+
trace_args = mock_client.capture.call_args_list[2][1]
194240
trace_props = trace_args["properties"]
195241

242+
# Span is first
243+
assert span_args["event"] == "$ai_span"
244+
assert span_props["$ai_trace_id"] == generation_props["$ai_trace_id"]
245+
assert span_props["$ai_parent_id"] == trace_props["$ai_trace_id"]
246+
assert "$ai_span_id" in span_props
247+
248+
# Generation is second
196249
assert generation_args["event"] == "$ai_generation"
197250
assert "distinct_id" in generation_args
198251
assert "$ai_model" in generation_props
@@ -210,12 +263,16 @@ async def test_async_basic_chat_chain(mock_client, stream):
210263
assert generation_props["$ai_input_tokens"] == 10
211264
assert generation_props["$ai_output_tokens"] == 10
212265
assert generation_props["$ai_http_status"] == 200
213-
assert generation_props["$ai_trace_id"] is not None
214266
assert isinstance(generation_props["$ai_latency"], float)
267+
assert "$ai_generation_id" in generation_props
268+
assert generation_props["$ai_parent_id"] == trace_props["$ai_trace_id"]
269+
assert generation_props["$ai_trace_id"] == trace_props["$ai_trace_id"]
215270

271+
# Trace is last
216272
assert trace_args["event"] == "$ai_trace"
217273
assert "distinct_id" in generation_args
218274
assert trace_props["$ai_trace_id"] == generation_props["$ai_trace_id"]
275+
assert "$ai_parent_id" not in trace_props
219276

220277

221278
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)