Skip to content

Commit 739c88a

Browse files
committed
fix: review comments
1 parent 9f7e094 commit 739c88a

File tree

5 files changed

+32
-37
lines changed

5 files changed

+32
-37
lines changed

posthog/ai/langchain/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .callbacks import PosthogCallbackHandler
1+
from .callbacks import CallbackHandler
22

3-
__all__ = ["PosthogCallbackHandler"]
3+
__all__ = ["CallbackHandler"]

posthog/ai/langchain/callbacks.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from posthog.ai.utils import get_model_params
2727
from posthog.client import Client
2828

29+
log = logging.getLogger("posthog")
30+
2931

3032
class RunMetadata(TypedDict, total=False):
3133
messages: Union[List[Dict[str, Any]], List[str]]
@@ -39,7 +41,7 @@ class RunMetadata(TypedDict, total=False):
3941
RunStorage = Dict[UUID, RunMetadata]
4042

4143

42-
class PosthogCallbackHandler(BaseCallbackHandler):
44+
class CallbackHandler(BaseCallbackHandler):
4345
"""
4446
A callback handler for LangChain that sends events to PostHog LLM Observability.
4547
"""
@@ -80,7 +82,6 @@ def __init__(
8082
self._properties = properties
8183
self._runs = {}
8284
self._parent_tree = {}
83-
self.log = logging.getLogger("posthog")
8485

8586
def on_chain_start(
8687
self,
@@ -274,7 +275,7 @@ def _pop_run_metadata(self, run_id: UUID) -> Optional[RunMetadata]:
274275
try:
275276
run = self._runs.pop(run_id)
276277
except KeyError:
277-
self.log.warning(f"No run metadata found for run {run_id}")
278+
log.warning(f"No run metadata found for run {run_id}")
278279
return None
279280
run["end_time"] = end_time
280281
return run
@@ -395,7 +396,7 @@ def _parse_usage(response: LLMResult):
395396

396397

397398
def _get_http_status(error: BaseException) -> int:
398-
# OpenAI: https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/_exceptions.py
399+
# OpenAI: https://github.com/openai/openai-python/blob/main/src/openai/_exceptions.py
399400
# Anthropic: https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/_exceptions.py
400401
# Google: https://github.com/googleapis/python-api-core/blob/main/google/api_core/exceptions.py
401402
status_code = getattr(error, "status_code", getattr(error, "code", 0))

posthog/ai/openai/openai.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44

55
try:
66
import openai
7+
import openai.resources
78
except ImportError:
89
raise ModuleNotFoundError("Please install the OpenAI SDK to use this feature: 'pip install openai'")
910

10-
import openai.resources
11-
1211
from posthog.ai.utils import call_llm_and_track_usage, get_model_params
1312
from posthog.client import Client as PostHogClient
1413

posthog/ai/openai/openai_async.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44

55
try:
66
import openai
7+
import openai.resources
78
except ImportError:
89
raise ModuleNotFoundError("Please install the OpenAI SDK to use this feature: 'pip install openai'")
910

10-
import openai.resources
11-
1211
from posthog.ai.utils import call_llm_and_track_usage_async, get_model_params
1312
from posthog.client import Client as PostHogClient
1413

posthog/test/ai/langchain/test_callbacks.py

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from langchain_core.runnables import RunnableLambda
1313
from langchain_openai.chat_models import ChatOpenAI
1414

15-
from posthog.ai.langchain import PosthogCallbackHandler
15+
from posthog.ai.langchain import CallbackHandler
1616

1717
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
1818

@@ -24,7 +24,7 @@ def mock_client():
2424

2525

2626
def test_parent_capture(mock_client):
27-
callbacks = PosthogCallbackHandler(mock_client)
27+
callbacks = CallbackHandler(mock_client)
2828
parent_run_id = uuid.uuid4()
2929
run_id = uuid.uuid4()
3030
callbacks._set_parent_of_run(run_id, parent_run_id)
@@ -35,7 +35,7 @@ def test_parent_capture(mock_client):
3535

3636

3737
def test_find_root_run(mock_client):
38-
callbacks = PosthogCallbackHandler(mock_client)
38+
callbacks = CallbackHandler(mock_client)
3939
root_run_id = uuid.uuid4()
4040
parent_run_id = uuid.uuid4()
4141
run_id = uuid.uuid4()
@@ -47,17 +47,17 @@ def test_find_root_run(mock_client):
4747

4848

4949
def test_trace_id_generation(mock_client):
50-
callbacks = PosthogCallbackHandler(mock_client)
50+
callbacks = CallbackHandler(mock_client)
5151
run_id = uuid.uuid4()
5252
with patch("uuid.uuid4", return_value=run_id):
5353
assert callbacks._get_trace_id(run_id) == run_id
5454
run_id = uuid.uuid4()
55-
callbacks = PosthogCallbackHandler(mock_client, trace_id=run_id)
55+
callbacks = CallbackHandler(mock_client, trace_id=run_id)
5656
assert callbacks._get_trace_id(uuid.uuid4()) == run_id
5757

5858

5959
def test_metadata_capture(mock_client):
60-
callbacks = PosthogCallbackHandler(mock_client)
60+
callbacks = CallbackHandler(mock_client)
6161
run_id = uuid.uuid4()
6262
with patch("time.time", return_value=1234567890):
6363
callbacks._set_run_metadata(
@@ -97,7 +97,7 @@ def test_basic_chat_chain(mock_client, stream):
9797
)
9898
]
9999
)
100-
callbacks = [PosthogCallbackHandler(mock_client)]
100+
callbacks = [CallbackHandler(mock_client)]
101101
chain = prompt | model
102102
if stream:
103103
result = [m for m in chain.stream({}, config={"callbacks": callbacks})][0]
@@ -143,7 +143,7 @@ async def test_async_basic_chat_chain(mock_client, stream):
143143
)
144144
]
145145
)
146-
callbacks = [PosthogCallbackHandler(mock_client)]
146+
callbacks = [CallbackHandler(mock_client)]
147147
chain = prompt | model
148148
if stream:
149149
result = [m async for m in chain.astream({}, config={"callbacks": callbacks})][0]
@@ -178,7 +178,7 @@ async def test_async_basic_chat_chain(mock_client, stream):
178178
)
179179
def test_basic_llm_chain(mock_client, Model, stream):
180180
model = Model(responses=["The Los Angeles Dodgers won the World Series in 2020."])
181-
callbacks: list[PosthogCallbackHandler] = [PosthogCallbackHandler(mock_client)]
181+
callbacks: list[CallbackHandler] = [CallbackHandler(mock_client)]
182182

183183
if stream:
184184
result = "".join(
@@ -209,7 +209,7 @@ def test_basic_llm_chain(mock_client, Model, stream):
209209
)
210210
async def test_async_basic_llm_chain(mock_client, Model, stream):
211211
model = Model(responses=["The Los Angeles Dodgers won the World Series in 2020."])
212-
callbacks: list[PosthogCallbackHandler] = [PosthogCallbackHandler(mock_client)]
212+
callbacks: list[CallbackHandler] = [CallbackHandler(mock_client)]
213213

214214
if stream:
215215
result = "".join(
@@ -241,7 +241,7 @@ def test_trace_id_for_multiple_chains(mock_client):
241241
]
242242
)
243243
model = FakeMessagesListChatModel(responses=[AIMessage(content="Bar")])
244-
callbacks = [PosthogCallbackHandler(mock_client)]
244+
callbacks = [CallbackHandler(mock_client)]
245245
chain = prompt | model | RunnableLambda(lambda x: [x]) | model
246246
result = chain.invoke({}, config={"callbacks": callbacks})
247247

@@ -279,13 +279,13 @@ def test_trace_id_for_multiple_chains(mock_client):
279279
def test_personless_mode(mock_client):
280280
prompt = ChatPromptTemplate.from_messages([("user", "Foo")])
281281
chain = prompt | FakeMessagesListChatModel(responses=[AIMessage(content="Bar")])
282-
chain.invoke({}, config={"callbacks": [PosthogCallbackHandler(mock_client)]})
282+
chain.invoke({}, config={"callbacks": [CallbackHandler(mock_client)]})
283283
assert mock_client.capture.call_count == 1
284284
args = mock_client.capture.call_args_list[0][1]
285285
assert args["properties"]["$process_person_profile"] is False
286286

287287
id = uuid.uuid4()
288-
chain.invoke({}, config={"callbacks": [PosthogCallbackHandler(mock_client, distinct_id=id)]})
288+
chain.invoke({}, config={"callbacks": [CallbackHandler(mock_client, distinct_id=id)]})
289289
assert mock_client.capture.call_count == 2
290290
args = mock_client.capture.call_args_list[1][1]
291291
assert "$process_person_profile" not in args["properties"]
@@ -295,7 +295,7 @@ def test_personless_mode(mock_client):
295295
def test_personless_mode_exception(mock_client):
296296
prompt = ChatPromptTemplate.from_messages([("user", "Foo")])
297297
chain = prompt | ChatOpenAI(api_key="test", model="gpt-4o-mini")
298-
callbacks = PosthogCallbackHandler(mock_client)
298+
callbacks = CallbackHandler(mock_client)
299299
with pytest.raises(Exception):
300300
chain.invoke({}, config={"callbacks": [callbacks]})
301301
assert mock_client.capture.call_count == 1
@@ -304,7 +304,7 @@ def test_personless_mode_exception(mock_client):
304304

305305
id = uuid.uuid4()
306306
with pytest.raises(Exception):
307-
chain.invoke({}, config={"callbacks": [PosthogCallbackHandler(mock_client, distinct_id=id)]})
307+
chain.invoke({}, config={"callbacks": [CallbackHandler(mock_client, distinct_id=id)]})
308308
assert mock_client.capture.call_count == 2
309309
args = mock_client.capture.call_args_list[1][1]
310310
assert "$process_person_profile" not in args["properties"]
@@ -319,7 +319,7 @@ def test_metadata(mock_client):
319319
)
320320
model = FakeMessagesListChatModel(responses=[AIMessage(content="Bar")])
321321
callbacks = [
322-
PosthogCallbackHandler(mock_client, trace_id="test-trace-id", distinct_id="test_id", properties={"foo": "bar"})
322+
CallbackHandler(mock_client, trace_id="test-trace-id", distinct_id="test_id", properties={"foo": "bar"})
323323
]
324324
chain = prompt | model
325325
result = chain.invoke({}, config={"callbacks": callbacks})
@@ -343,9 +343,7 @@ def test_metadata(mock_client):
343343
def test_callbacks_logic(mock_client):
344344
prompt = ChatPromptTemplate.from_messages([("user", "Foo")])
345345
model = FakeMessagesListChatModel(responses=[AIMessage(content="Bar")])
346-
callbacks = PosthogCallbackHandler(
347-
mock_client, trace_id="test-trace-id", distinct_id="test_id", properties={"foo": "bar"}
348-
)
346+
callbacks = CallbackHandler(mock_client, trace_id="test-trace-id", distinct_id="test_id", properties={"foo": "bar"})
349347
chain = prompt | model
350348

351349
chain.invoke({}, config={"callbacks": [callbacks]})
@@ -366,7 +364,7 @@ def test_exception_in_chain(mock_client):
366364
def runnable(_):
367365
raise ValueError("test")
368366

369-
callbacks = PosthogCallbackHandler(mock_client)
367+
callbacks = CallbackHandler(mock_client)
370368
with pytest.raises(ValueError):
371369
RunnableLambda(runnable).invoke({}, config={"callbacks": [callbacks]})
372370

@@ -378,7 +376,7 @@ def runnable(_):
378376
def test_openai_error(mock_client):
379377
prompt = ChatPromptTemplate.from_messages([("user", "Foo")])
380378
chain = prompt | ChatOpenAI(api_key="test", model="gpt-4o-mini")
381-
callbacks = PosthogCallbackHandler(mock_client)
379+
callbacks = CallbackHandler(mock_client)
382380

383381
# 401
384382
with pytest.raises(Exception):
@@ -408,9 +406,7 @@ def test_openai_chain(mock_client):
408406
temperature=0,
409407
max_tokens=1,
410408
)
411-
callbacks = PosthogCallbackHandler(
412-
mock_client, trace_id="test-trace-id", distinct_id="test_id", properties={"foo": "bar"}
413-
)
409+
callbacks = CallbackHandler(mock_client, trace_id="test-trace-id", distinct_id="test_id", properties={"foo": "bar"})
414410
start_time = time.time()
415411
result = chain.invoke({}, config={"callbacks": [callbacks]})
416412
approximate_latency = math.floor(time.time() - start_time)
@@ -475,7 +471,7 @@ def test_openai_captures_multiple_generations(mock_client):
475471
max_tokens=1,
476472
n=2,
477473
)
478-
callbacks = PosthogCallbackHandler(mock_client)
474+
callbacks = CallbackHandler(mock_client)
479475
result = chain.invoke({}, config={"callbacks": [callbacks]})
480476

481477
assert result.content == "Bar"
@@ -530,7 +526,7 @@ def test_openai_streaming(mock_client):
530526
chain = prompt | ChatOpenAI(
531527
api_key=OPENAI_API_KEY, model="gpt-4o-mini", temperature=0, max_tokens=1, stream=True, stream_usage=True
532528
)
533-
callbacks = PosthogCallbackHandler(mock_client)
529+
callbacks = CallbackHandler(mock_client)
534530
result = [m for m in chain.stream({}, config={"callbacks": [callbacks]})]
535531
result = sum(result[1:], result[0])
536532

@@ -562,7 +558,7 @@ async def test_async_openai_streaming(mock_client):
562558
chain = prompt | ChatOpenAI(
563559
api_key=OPENAI_API_KEY, model="gpt-4o-mini", temperature=0, max_tokens=1, stream=True, stream_usage=True
564560
)
565-
callbacks = PosthogCallbackHandler(mock_client)
561+
callbacks = CallbackHandler(mock_client)
566562
result = [m async for m in chain.astream({}, config={"callbacks": [callbacks]})]
567563
result = sum(result[1:], result[0])
568564

0 commit comments

Comments
 (0)