Skip to content

Commit a2c4ca4

Browse files
committed
feat: add groups and privacy mode support to langchain callback
1 parent 471a8b7 commit a2c4ca4

File tree

2 files changed

+63
-4
lines changed

2 files changed

+63
-4
lines changed

posthog/ai/langchain/callbacks.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from langchain_core.outputs import ChatGeneration, LLMResult
2424
from pydantic import BaseModel
2525

26-
from posthog.ai.utils import get_model_params
26+
from posthog.ai.utils import get_model_params, with_privacy_mode
2727
from posthog.client import Client
2828

2929
log = logging.getLogger("posthog")
@@ -69,18 +69,24 @@ def __init__(
6969
distinct_id: Optional[Union[str, int, float, UUID]] = None,
7070
trace_id: Optional[Union[str, int, float, UUID]] = None,
7171
properties: Optional[Dict[str, Any]] = None,
72+
privacy_mode: bool = False,
73+
groups: Optional[Dict[str, Any]] = None,
7274
):
7375
"""
7476
Args:
7577
client: PostHog client instance.
7678
distinct_id: Optional distinct ID of the user to associate the trace with.
7779
trace_id: Optional trace ID to use for the event.
7880
properties: Optional additional metadata to use for the trace.
81+
privacy_mode: Whether to redact the input and output of the trace.
82+
groups: Optional additional PostHog groups to use for the trace.
7983
"""
8084
self._client = client
8185
self._distinct_id = distinct_id
8286
self._trace_id = trace_id
8387
self._properties = properties or {}
88+
self._privacy_mode = privacy_mode
89+
self._groups = groups or {}
8490
self._runs = {}
8591
self._parent_tree = {}
8692

@@ -164,8 +170,8 @@ def on_llm_end(
164170
"$ai_provider": run.get("provider"),
165171
"$ai_model": run.get("model"),
166172
"$ai_model_parameters": run.get("model_params"),
167-
"$ai_input": run.get("messages"),
168-
"$ai_output": {"choices": output},
173+
"$ai_input": with_privacy_mode(self._client, self._privacy_mode, run.get("messages")),
174+
"$ai_output": with_privacy_mode(self._client, self._privacy_mode, {"choices": output}),
169175
"$ai_http_status": 200,
170176
"$ai_input_tokens": input_tokens,
171177
"$ai_output_tokens": output_tokens,
@@ -180,6 +186,7 @@ def on_llm_end(
180186
distinct_id=self._distinct_id or trace_id,
181187
event="$ai_generation",
182188
properties=event_properties,
189+
groups=self._groups,
183190
)
184191

185192
def on_chain_error(
@@ -212,7 +219,7 @@ def on_llm_error(
212219
"$ai_provider": run.get("provider"),
213220
"$ai_model": run.get("model"),
214221
"$ai_model_parameters": run.get("model_params"),
215-
"$ai_input": run.get("messages"),
222+
"$ai_input": with_privacy_mode(self._client, self._privacy_mode, run.get("messages")),
216223
"$ai_http_status": _get_http_status(error),
217224
"$ai_latency": latency,
218225
"$ai_trace_id": trace_id,
@@ -225,6 +232,7 @@ def on_llm_error(
225232
distinct_id=self._distinct_id or trace_id,
226233
event="$ai_generation",
227234
properties=event_properties,
235+
groups=self._groups,
228236
)
229237

230238
def _set_parent_of_run(self, run_id: UUID, parent_run_id: Optional[UUID] = None):

posthog/test/ai/langchain/test_callbacks.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,3 +595,54 @@ def test_base_url_retrieval(mock_client):
595595
assert mock_client.capture.call_count == 1
596596
call = mock_client.capture.call_args[1]
597597
assert call["properties"]["$ai_base_url"] == "https://test.posthog.com"
598+
599+
600+
def test_groups(mock_client):
601+
prompt = ChatPromptTemplate.from_messages(
602+
[
603+
("system", 'You must always answer with "Bar".'),
604+
("user", "Foo"),
605+
]
606+
)
607+
chain = prompt | ChatOpenAI(api_key="test", model="gpt-4o-mini")
608+
callbacks = CallbackHandler(mock_client, groups={"company": "test_company"})
609+
chain.invoke({}, config={"callbacks": [callbacks]})
610+
611+
assert mock_client.capture.call_count == 1
612+
call = mock_client.capture.call_args[1]
613+
assert call["groups"] == {"company": "test_company"}
614+
615+
616+
def test_privacy_mode_local(mock_client):
617+
prompt = ChatPromptTemplate.from_messages(
618+
[
619+
("system", 'You must always answer with "Bar".'),
620+
("user", "Foo"),
621+
]
622+
)
623+
chain = prompt | ChatOpenAI(api_key="test", model="gpt-4o-mini")
624+
callbacks = CallbackHandler(mock_client, privacy_mode=True)
625+
chain.invoke({}, config={"callbacks": [callbacks]})
626+
627+
assert mock_client.capture.call_count == 1
628+
call = mock_client.capture.call_args[1]
629+
assert call["properties"]["$ai_input"] is None
630+
assert call["properties"]["$ai_output"] is None
631+
632+
633+
def test_privacy_mode_global(mock_client):
634+
mock_client.privacy_mode = True
635+
prompt = ChatPromptTemplate.from_messages(
636+
[
637+
("system", 'You must always answer with "Bar".'),
638+
("user", "Foo"),
639+
]
640+
)
641+
chain = prompt | ChatOpenAI(api_key="test", model="gpt-4o-mini")
642+
callbacks = CallbackHandler(mock_client)
643+
chain.invoke({}, config={"callbacks": [callbacks]})
644+
645+
assert mock_client.capture.call_count == 1
646+
call = mock_client.capture.call_args[1]
647+
assert call["properties"]["$ai_input"] is None
648+
assert call["properties"]["$ai_output"] is None

0 commit comments

Comments
 (0)