Skip to content

Commit 869be78

Browse files
committed
fix: parallel traces
1 parent 2835af4 commit 869be78

File tree

2 files changed

+82
-15
lines changed

2 files changed

+82
-15
lines changed

posthog/ai/langchain/callbacks.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,22 @@
3232

3333

3434
class RunMetadata(TypedDict, total=False):
35-
messages: Union[List[Dict[str, Any]], List[str]]
35+
input: Any
36+
"""Input of the run: messages, prompt variables, etc."""
37+
name: str
38+
"""Name of the run: chain name, model name, etc."""
3639
provider: str
40+
"""Provider of the run: OpenAI, Anthropic"""
3741
model: str
42+
"""Model used in the run"""
3843
model_params: Dict[str, Any]
44+
"""Model parameters of the run: temperature, max_tokens, etc."""
3945
base_url: str
46+
"""Base URL of the provider's API used in the run."""
4047
start_time: float
48+
"""Start time of the run."""
4149
end_time: float
50+
"""End time of the run."""
4251

4352

4453
RunStorage = Dict[UUID, RunMetadata]
@@ -119,8 +128,7 @@ def on_chain_start(
119128
self._log_debug_event("on_chain_start", run_id, parent_run_id, inputs=inputs)
120129
self._set_parent_of_run(run_id, parent_run_id)
121130
if parent_run_id is None and self._trace_name is None:
122-
self._trace_name = self._get_langchain_run_name(serialized, **kwargs)
123-
self._trace_input = inputs
131+
self._set_span_metadata(run_id, self._get_langchain_run_name(serialized, **kwargs), inputs)
124132

125133
def on_chat_model_start(
126134
self,
@@ -134,7 +142,7 @@ def on_chat_model_start(
134142
self._log_debug_event("on_chat_model_start", run_id, parent_run_id, messages=messages)
135143
self._set_parent_of_run(run_id, parent_run_id)
136144
input = [_convert_message_to_dict(message) for row in messages for message in row]
137-
self._set_run_metadata(serialized, run_id, input, **kwargs)
145+
self._set_llm_metadata(serialized, run_id, input, **kwargs)
138146

139147
def on_llm_start(
140148
self,
@@ -147,7 +155,7 @@ def on_llm_start(
147155
):
148156
self._log_debug_event("on_llm_start", run_id, parent_run_id, prompts=prompts)
149157
self._set_parent_of_run(run_id, parent_run_id)
150-
self._set_run_metadata(serialized, run_id, prompts, **kwargs)
158+
self._set_llm_metadata(serialized, run_id, prompts, **kwargs)
151159

152160
def on_llm_new_token(
153161
self,
@@ -253,7 +261,7 @@ def on_llm_end(
253261
"$ai_provider": run.get("provider"),
254262
"$ai_model": run.get("model"),
255263
"$ai_model_parameters": run.get("model_params"),
256-
"$ai_input": with_privacy_mode(self._client, self._privacy_mode, run.get("messages")),
264+
"$ai_input": with_privacy_mode(self._client, self._privacy_mode, run.get("input")),
257265
"$ai_output_choices": with_privacy_mode(self._client, self._privacy_mode, output),
258266
"$ai_http_status": 200,
259267
"$ai_input_tokens": input_tokens,
@@ -292,7 +300,7 @@ def on_llm_error(
292300
"$ai_provider": run.get("provider"),
293301
"$ai_model": run.get("model"),
294302
"$ai_model_parameters": run.get("model_params"),
295-
"$ai_input": with_privacy_mode(self._client, self._privacy_mode, run.get("messages")),
303+
"$ai_input": with_privacy_mode(self._client, self._privacy_mode, run.get("input")),
296304
"$ai_http_status": _get_http_status(error),
297305
"$ai_latency": latency,
298306
"$ai_trace_id": trace_id,
@@ -377,7 +385,14 @@ def _find_root_run(self, run_id: UUID) -> UUID:
377385
id = self._parent_tree[id]
378386
return id
379387

380-
def _set_run_metadata(
388+
def _set_span_metadata(self, run_id: UUID, name: str, input: Any):
389+
self._runs[run_id] = {
390+
"name": name,
391+
"input": input,
392+
"start_time": time.time(),
393+
}
394+
395+
def _set_llm_metadata(
381396
self,
382397
serialized: Dict[str, Any],
383398
run_id: UUID,
@@ -387,7 +402,7 @@ def _set_run_metadata(
387402
**kwargs,
388403
):
389404
run: RunMetadata = {
390-
"messages": messages,
405+
"input": messages,
391406
"start_time": time.time(),
392407
}
393408
if isinstance(invocation_params, dict):
@@ -452,10 +467,14 @@ def _get_langchain_run_name(self, serialized: Optional[Dict[str, Any]], **kwargs
452467

453468
def _capture_trace(self, run_id: UUID, *, outputs: Optional[Dict[str, Any]]):
454469
trace_id = self._get_trace_id(run_id)
470+
run = self._pop_run_metadata(run_id)
471+
if not run:
472+
return
455473
event_properties = {
456-
"$ai_trace_name": self._trace_name,
474+
"$ai_trace_name": run.get("name"),
457475
"$ai_trace_id": trace_id,
458-
"$ai_input_state": with_privacy_mode(self._client, self._privacy_mode, self._trace_input),
476+
"$ai_input_state": with_privacy_mode(self._client, self._privacy_mode, run.get("input")),
477+
"$ai_latency": run.get("end_time", 0) - run.get("start_time", 0),
459478
**self._properties,
460479
}
461480
if outputs is not None:

posthog/test/ai/langchain/test_callbacks.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import time
55
import uuid
6+
import asyncio
67
from typing import List, Optional, TypedDict, Union
78
from unittest.mock import patch
89

@@ -12,7 +13,7 @@
1213
from langchain_community.llms.fake import FakeListLLM, FakeStreamingListLLM
1314
from langchain_core.messages import AIMessage, HumanMessage
1415
from langchain_core.prompts import ChatPromptTemplate
15-
from langchain_core.runnables import RunnableLambda
16+
from langchain_core.runnables import RunnableLambda, RunnableParallel
1617
from langchain_openai.chat_models import ChatOpenAI
1718
from langgraph.graph.state import END, START, StateGraph
1819

@@ -67,7 +68,7 @@ def test_metadata_capture(mock_client):
6768
callbacks = CallbackHandler(mock_client)
6869
run_id = uuid.uuid4()
6970
with patch("time.time", return_value=1234567890):
70-
callbacks._set_run_metadata(
71+
callbacks._set_llm_metadata(
7172
{"kwargs": {"openai_api_base": "https://us.posthog.com"}},
7273
run_id,
7374
messages=[{"role": "user", "content": "Who won the world series in 2020?"}],
@@ -76,7 +77,7 @@ def test_metadata_capture(mock_client):
7677
)
7778
expected = {
7879
"model": "hog-mini",
79-
"messages": [{"role": "user", "content": "Who won the world series in 2020?"}],
80+
"input": [{"role": "user", "content": "Who won the world series in 2020?"}],
8081
"start_time": 1234567890,
8182
"model_params": {"temperature": 0.5},
8283
"provider": "posthog",
@@ -90,6 +91,19 @@ def test_metadata_capture(mock_client):
9091
callbacks._pop_run_metadata(uuid.uuid4()) # should not raise
9192

9293

94+
def test_run_metadata_capture(mock_client):
95+
callbacks = CallbackHandler(mock_client)
96+
run_id = uuid.uuid4()
97+
with patch("time.time", return_value=1234567890):
98+
callbacks._set_span_metadata(run_id, "test", 1)
99+
expected = {
100+
"name": "test",
101+
"input": 1,
102+
"start_time": 1234567890,
103+
}
104+
assert callbacks._runs[run_id] == expected
105+
106+
93107
@pytest.mark.parametrize("stream", [True, False])
94108
def test_basic_chat_chain(mock_client, stream):
95109
prompt = ChatPromptTemplate.from_messages(
@@ -514,7 +528,11 @@ def test_callbacks_logic(mock_client):
514528
assert callbacks._parent_tree == {}
515529

516530
def assert_intermediary_run(m):
517-
assert callbacks._runs == {}
531+
assert len(callbacks._runs) != 0
532+
run = next(iter(callbacks._runs.values()))
533+
assert run["name"] == "RunnableSequence"
534+
assert run["input"] == {}
535+
assert run["start_time"] is not None
518536
assert len(callbacks._parent_tree.items()) == 1
519537
return [m]
520538

@@ -981,3 +999,33 @@ def test_tool_calls(mock_client):
981999
}
9821000
]
9831001
assert "additional_kwargs" not in generation_call["properties"]["$ai_output_choices"][0]
1002+
1003+
1004+
async def test_async_traces(mock_client):
1005+
async def sleep(x): # -> Any:
1006+
await asyncio.sleep(0.1)
1007+
return x
1008+
1009+
prompt = ChatPromptTemplate.from_messages([("user", "Foo")])
1010+
chain1 = RunnableLambda(sleep)
1011+
chain2 = prompt | FakeMessagesListChatModel(responses=[AIMessage(content="Bar")])
1012+
1013+
cb = CallbackHandler(mock_client)
1014+
1015+
start_time = time.time()
1016+
await asyncio.gather(
1017+
chain1.ainvoke({}, config={"callbacks": [cb]}),
1018+
chain2.ainvoke({}, config={"callbacks": [cb]}),
1019+
)
1020+
approximate_latency = math.floor(time.time() - start_time)
1021+
assert mock_client.capture.call_count == 3
1022+
first_call, second_call, third_call = mock_client.capture.call_args_list
1023+
print(approximate_latency, third_call[1]["properties"]["$ai_latency"])
1024+
assert first_call[1]["event"] == "$ai_generation"
1025+
assert second_call[1]["event"] == "$ai_trace"
1026+
assert second_call[1]["properties"]["$ai_trace_name"] == "RunnableSequence"
1027+
assert third_call[1]["event"] == "$ai_trace"
1028+
assert third_call[1]["properties"]["$ai_trace_name"] == "sleep"
1029+
assert (
1030+
min(approximate_latency - 1, 0) <= math.floor(third_call[1]["properties"]["$ai_latency"]) <= approximate_latency
1031+
)

0 commit comments

Comments
 (0)