1818from langgraph .graph .state import END , START , StateGraph
1919
2020from posthog .ai .langchain import CallbackHandler
21+ from posthog .ai .langchain .callbacks import GenerationMetadata , SpanMetadata
2122
2223OPENAI_API_KEY = os .getenv ("OPENAI_API_KEY" )
2324ANTHROPIC_API_KEY = os .getenv ("ANTHROPIC_API_KEY" )
@@ -74,19 +75,23 @@ def test_metadata_capture(mock_client):
7475 messages = [{"role" : "user" , "content" : "Who won the world series in 2020?" }],
7576 invocation_params = {"temperature" : 0.5 },
7677 metadata = {"ls_model_name" : "hog-mini" , "ls_provider" : "posthog" },
78+ name = "test" ,
7779 )
78- expected = {
79- "model" : "hog-mini" ,
80- "input" : [{"role" : "user" , "content" : "Who won the world series in 2020?" }],
81- "start_time" : 1234567890 ,
82- "model_params" : {"temperature" : 0.5 },
83- "provider" : "posthog" ,
84- "base_url" : "https://us.posthog.com" ,
85- }
80+ expected = GenerationMetadata (
81+ model = "hog-mini" ,
82+ input = [{"role" : "user" , "content" : "Who won the world series in 2020?" }],
83+ start_time = 1234567890 ,
84+ model_params = {"temperature" : 0.5 },
85+ provider = "posthog" ,
86+ base_url = "https://us.posthog.com" ,
87+ name = "test" ,
88+ end_time = None ,
89+ )
8690 assert callbacks ._runs [run_id ] == expected
8791 with patch ("time.time" , return_value = 1234567891 ):
8892 run = callbacks ._pop_run_metadata (run_id )
89- assert run == {** expected , "end_time" : 1234567891 }
93+ expected .end_time = 1234567891
94+ assert run == expected
9095 assert callbacks ._runs == {}
9196 callbacks ._pop_run_metadata (uuid .uuid4 ()) # should not raise
9297
@@ -95,12 +100,32 @@ def test_run_metadata_capture(mock_client):
95100 callbacks = CallbackHandler (mock_client )
96101 run_id = uuid .uuid4 ()
97102 with patch ("time.time" , return_value = 1234567890 ):
98- callbacks ._set_span_metadata (run_id , "test" , 1 )
99- expected = {
100- "name" : "test" ,
101- "input" : 1 ,
102- "start_time" : 1234567890 ,
103- }
103+ callbacks ._set_trace_or_span_metadata (None , 1 , run_id )
104+ expected = SpanMetadata (
105+ name = "trace" ,
106+ input = 1 ,
107+ start_time = 1234567890 ,
108+ end_time = None ,
109+ )
110+ assert callbacks ._runs [run_id ] == expected
111+ with patch ("time.time" , return_value = 1234567890 ):
112+ callbacks ._set_trace_or_span_metadata (None , 1 , run_id , uuid .uuid4 ())
113+ expected = SpanMetadata (
114+ name = "span" ,
115+ input = 1 ,
116+ start_time = 1234567890 ,
117+ end_time = None ,
118+ )
119+ assert callbacks ._runs [run_id ] == expected
120+
121+ with patch ("time.time" , return_value = 1234567890 ):
122+ callbacks ._set_trace_or_span_metadata ({"name" : "test" }, 1 , run_id )
123+ expected = SpanMetadata (
124+ name = "test" ,
125+ input = 1 ,
126+ start_time = 1234567890 ,
127+ end_time = None ,
128+ )
104129 assert callbacks ._runs [run_id ] == expected
105130
106131
@@ -132,11 +157,24 @@ def test_basic_chat_chain(mock_client, stream):
132157 result = chain .invoke ({}, config = {"callbacks" : callbacks })
133158
134159 assert result .content == "The Los Angeles Dodgers won the World Series in 2020."
135- assert mock_client .capture .call_count == 2
136- generation_args = mock_client .capture .call_args_list [0 ][1 ]
160+ assert mock_client .capture .call_count == 3
161+
162+ span_args = mock_client .capture .call_args_list [0 ][1 ]
163+ span_props = span_args ["properties" ]
164+
165+ generation_args = mock_client .capture .call_args_list [1 ][1 ]
137166 generation_props = generation_args ["properties" ]
138- trace_args = mock_client .capture .call_args_list [1 ][1 ]
139167
168+ trace_args = mock_client .capture .call_args_list [2 ][1 ]
169+ trace_props = trace_args ["properties" ]
170+
171+ # Span is first
172+ assert span_args ["event" ] == "$ai_span"
173+ assert span_props ["$ai_trace_id" ] == generation_props ["$ai_trace_id" ]
174+ assert span_props ["$ai_parent_id" ] == trace_props ["$ai_trace_id" ]
175+ assert "$ai_span_id" in span_props
176+
177+ # Generation is second
140178 assert generation_args ["event" ] == "$ai_generation"
141179 assert "distinct_id" in generation_args
142180 assert "$ai_model" in generation_props
@@ -154,9 +192,15 @@ def test_basic_chat_chain(mock_client, stream):
154192 assert generation_props ["$ai_input_tokens" ] == 10
155193 assert generation_props ["$ai_output_tokens" ] == 10
156194 assert generation_props ["$ai_http_status" ] == 200
157- assert generation_props ["$ai_trace_id" ] is not None
158195 assert isinstance (generation_props ["$ai_latency" ], float )
196+ assert "$ai_generation_id" in generation_props
197+ assert generation_props ["$ai_parent_id" ] == trace_props ["$ai_trace_id" ]
198+ assert generation_props ["$ai_trace_id" ] == trace_props ["$ai_trace_id" ]
199+
200+ # Trace is last
159201 assert trace_args ["event" ] == "$ai_trace"
202+ assert "$ai_trace_id" in trace_props
203+ assert "$ai_parent_id" not in trace_props
160204
161205
162206@pytest .mark .parametrize ("stream" , [True , False ])
@@ -186,13 +230,22 @@ async def test_async_basic_chat_chain(mock_client, stream):
186230 else :
187231 result = await chain .ainvoke ({}, config = {"callbacks" : callbacks })
188232 assert result .content == "The Los Angeles Dodgers won the World Series in 2020."
189- assert mock_client .capture .call_count == 2
233+ assert mock_client .capture .call_count == 3
190234
191- generation_args = mock_client .capture .call_args_list [0 ][1 ]
235+ span_args = mock_client .capture .call_args_list [0 ][1 ]
236+ span_props = span_args ["properties" ]
237+ generation_args = mock_client .capture .call_args_list [1 ][1 ]
192238 generation_props = generation_args ["properties" ]
193- trace_args = mock_client .capture .call_args_list [1 ][1 ]
239+ trace_args = mock_client .capture .call_args_list [2 ][1 ]
194240 trace_props = trace_args ["properties" ]
195241
242+ # Span is first
243+ assert span_args ["event" ] == "$ai_span"
244+ assert span_props ["$ai_trace_id" ] == generation_props ["$ai_trace_id" ]
245+ assert span_props ["$ai_parent_id" ] == trace_props ["$ai_trace_id" ]
246+ assert "$ai_span_id" in span_props
247+
248+ # Generation is second
196249 assert generation_args ["event" ] == "$ai_generation"
197250 assert "distinct_id" in generation_args
198251 assert "$ai_model" in generation_props
@@ -210,12 +263,16 @@ async def test_async_basic_chat_chain(mock_client, stream):
210263 assert generation_props ["$ai_input_tokens" ] == 10
211264 assert generation_props ["$ai_output_tokens" ] == 10
212265 assert generation_props ["$ai_http_status" ] == 200
213- assert generation_props ["$ai_trace_id" ] is not None
214266 assert isinstance (generation_props ["$ai_latency" ], float )
267+ assert "$ai_generation_id" in generation_props
268+ assert generation_props ["$ai_parent_id" ] == trace_props ["$ai_trace_id" ]
269+ assert generation_props ["$ai_trace_id" ] == trace_props ["$ai_trace_id" ]
215270
271+ # Trace is last
216272 assert trace_args ["event" ] == "$ai_trace"
217273 assert "distinct_id" in generation_args
218274 assert trace_props ["$ai_trace_id" ] == generation_props ["$ai_trace_id" ]
275+ assert "$ai_parent_id" not in trace_props
219276
220277
221278@pytest .mark .parametrize (
0 commit comments