Skip to content

Commit e55b1b3

Browse files
committed
Improve LangChain stub typing and typecheck config
1 parent 3ac649a commit e55b1b3

File tree

4 files changed

+131
-56
lines changed

4 files changed

+131
-56
lines changed

instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/__init__.py

Lines changed: 49 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -32,50 +32,71 @@
3232
result = llm.invoke(messages)
3333
LangChainInstrumentor().uninstrument()
3434
35+
# pyright: reportMissingImports=false
36+
3537
API
3638
---
3739
"""
3840

3941
import os
4042
from importlib import import_module
4143
from types import SimpleNamespace
42-
from typing import Any, Callable, Collection, cast
44+
from typing import (
45+
TYPE_CHECKING,
46+
Any,
47+
Callable,
48+
Collection,
49+
Protocol,
50+
Sequence,
51+
cast,
52+
)
4353

44-
from langchain_core.callbacks import BaseCallbackHandler # type: ignore
4554
from wrapt import wrap_function_wrapper # type: ignore
4655

56+
if TYPE_CHECKING:
57+
58+
class BaseCallbackHandler(Protocol):
59+
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
60+
61+
inheritable_handlers: Sequence[Any]
62+
63+
def add_handler(self, handler: Any, inherit: bool = False) -> None: ...
64+
65+
else:
66+
try:
67+
from langchain_core.callbacks import (
68+
BaseCallbackHandler, # type: ignore[import]
69+
)
70+
except ImportError: # pragma: no cover - optional dependency
71+
72+
class BaseCallbackHandler:
73+
def __init__(self, *args: Any, **kwargs: Any) -> None:
74+
return
75+
76+
inheritable_handlers: Sequence[Any] = ()
77+
78+
def add_handler(self, handler: Any, inherit: bool = False) -> None:
79+
raise RuntimeError(
80+
"LangChain is required for the LangChain instrumentation."
81+
)
82+
83+
84+
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
4785
from opentelemetry.instrumentation.langchain.callback_handler import (
4886
OpenTelemetryLangChainCallbackHandler,
4987
)
5088
from opentelemetry.instrumentation.langchain.package import _instruments
5189
from opentelemetry.instrumentation.langchain.version import __version__
52-
from opentelemetry.trace import get_tracer
5390

54-
_instrumentor_module = SimpleNamespace(BaseInstrumentor=object)
5591
try:
56-
_instrumentor_module = import_module(
57-
"opentelemetry.instrumentation.instrumentor"
58-
)
59-
except ModuleNotFoundError: # pragma: no cover - optional dependency
60-
pass
92+
from opentelemetry.instrumentation.utils import unwrap
93+
except ImportError: # pragma: no cover - optional dependency
6194

62-
BaseInstrumentor = cast(
63-
type,
64-
getattr(_instrumentor_module, "BaseInstrumentor", object),
65-
)
95+
def unwrap(obj: object, attr: str) -> None:
96+
return None
6697

67-
_utils_module = SimpleNamespace(
68-
unwrap=lambda *_args, **_kwargs: None,
69-
)
70-
try:
71-
_utils_module = import_module("opentelemetry.instrumentation.utils")
72-
except ModuleNotFoundError: # pragma: no cover - optional dependency
73-
pass
7498

75-
unwrap = cast(
76-
Callable[..., None],
77-
getattr(_utils_module, "unwrap", lambda *_args, **_kwargs: None),
78-
)
99+
from opentelemetry.trace import get_tracer
79100

80101
_schemas_module = SimpleNamespace()
81102
try:
@@ -95,9 +116,7 @@ class LangChainInstrumentor(BaseInstrumentor):
95116
to capture LLM telemetry.
96117
"""
97118

98-
def __init__(
99-
self,
100-
):
119+
def __init__(self) -> None:
101120
super().__init__()
102121

103122
def instrumentation_dependencies(self) -> Collection[str]:
@@ -161,14 +180,14 @@ def __init__(
161180
def __call__(
162181
self,
163182
wrapped: Callable[..., None],
164-
instance: BaseCallbackHandler, # type: ignore
183+
instance: BaseCallbackHandler,
165184
args: tuple[Any, ...],
166185
kwargs: dict[str, Any],
167186
):
168187
wrapped(*args, **kwargs)
169188
# Ensure our OTel callback is present if not already.
170-
for handler in instance.inheritable_handlers: # type: ignore
189+
for handler in instance.inheritable_handlers:
171190
if isinstance(handler, type(self._otel_handler)):
172191
break
173192
else:
174-
instance.add_handler(self._otel_handler, inherit=True) # type: ignore
193+
instance.add_handler(self._otel_handler, inherit=True)

instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/callback_handler.py

Lines changed: 80 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# pyright: reportMissingImports=false
2+
13
# Copyright The OpenTelemetry Authors
24
#
35
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,21 +17,46 @@
1517
from __future__ import annotations
1618

1719
import json
20+
from collections.abc import Mapping, Sequence
1821
from importlib import import_module
1922
from types import SimpleNamespace
20-
from typing import Any, Mapping, Protocol, Sequence, TypedDict, cast
23+
from typing import TYPE_CHECKING, Any, Protocol, TypedDict, cast
2124
from urllib.parse import urlparse
2225
from uuid import UUID
2326

2427
from opentelemetry.instrumentation.langchain.span_manager import _SpanManager
2528
from opentelemetry.trace import Span, Tracer
2629

27-
try:
28-
from langchain_core.callbacks import (
29-
BaseCallbackHandler, # type: ignore[import]
30-
)
31-
except ImportError: # pragma: no cover - optional dependency
32-
BaseCallbackHandler = object # type: ignore[assignment]
30+
31+
class _BaseCallbackHandlerProtocol(Protocol):
32+
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
33+
34+
inheritable_handlers: Sequence[Any]
35+
36+
def add_handler(self, handler: Any, inherit: bool = False) -> None: ...
37+
38+
39+
class _BaseCallbackHandlerStub:
40+
def __init__(self, *args: Any, **kwargs: Any) -> None:
41+
return
42+
43+
inheritable_handlers: Sequence[Any] = ()
44+
45+
def add_handler(self, handler: Any, inherit: bool = False) -> None:
46+
raise RuntimeError(
47+
"LangChain is required for the LangChain instrumentation."
48+
)
49+
50+
51+
if TYPE_CHECKING:
52+
BaseCallbackHandler = _BaseCallbackHandlerProtocol
53+
else:
54+
try:
55+
from langchain_core.callbacks import (
56+
BaseCallbackHandler, # type: ignore[import]
57+
)
58+
except ImportError: # pragma: no cover - optional dependency
59+
BaseCallbackHandler = _BaseCallbackHandlerStub
3360

3461

3562
class _SerializedMessage(TypedDict, total=False):
@@ -123,7 +150,7 @@ def __getattr__(self, name: str) -> Any: ...
123150
)
124151

125152

126-
class OpenTelemetryLangChainCallbackHandler(BaseCallbackHandler): # type: ignore[misc]
153+
class OpenTelemetryLangChainCallbackHandler(BaseCallbackHandler):
127154
"""
128155
A callback handler for LangChain that uses OpenTelemetry to create spans for LLM calls and chains, tools etc,. in future.
129156
"""
@@ -156,7 +183,9 @@ def __init__(
156183
tracer: Tracer,
157184
capture_messages: bool,
158185
) -> None:
159-
super().__init__() # type: ignore
186+
base_init: Any = getattr(super(), "__init__", None)
187+
if callable(base_init):
188+
base_init()
160189

161190
self.span_manager = _SpanManager(
162191
tracer=tracer,
@@ -230,17 +259,31 @@ def _resolve_provider(
230259

231260
return provider_key
232261

233-
def _extract_params(
234-
self, kwargs: Mapping[str, Any]
235-
) -> Mapping[str, Any] | None:
262+
def _extract_params(self, kwargs: Mapping[str, Any]) -> dict[str, Any]:
236263
invocation_params = kwargs.get("invocation_params")
237264
if isinstance(invocation_params, Mapping):
238-
params = invocation_params.get("params") or invocation_params
239-
if isinstance(params, Mapping):
240-
return params
241-
return None
242-
243-
return kwargs if kwargs else None
265+
invocation_mapping = cast(Mapping[str, Any], invocation_params)
266+
params_raw = cast(
267+
Mapping[Any, Any] | None, invocation_mapping.get("params")
268+
)
269+
if isinstance(params_raw, Mapping):
270+
params_mapping = params_raw
271+
extracted: dict[str, Any] = {}
272+
for key, value in params_mapping.items():
273+
key_str = key if isinstance(key, str) else str(key)
274+
extracted[key_str] = value
275+
return extracted
276+
invocation_mapping = cast(Mapping[Any, Any], invocation_params)
277+
extracted: dict[str, Any] = {}
278+
for key, value in invocation_mapping.items():
279+
key_str = key if isinstance(key, str) else str(key)
280+
extracted[key_str] = value
281+
return extracted
282+
283+
extracted: dict[str, Any] = {}
284+
for key, value in kwargs.items():
285+
extracted[key] = value
286+
return extracted
244287

245288
def _extract_request_model(
246289
self,
@@ -273,7 +316,7 @@ def _extract_request_model(
273316
def _apply_request_attributes(
274317
self,
275318
span: Span,
276-
params: Mapping[str, Any] | None,
319+
params: dict[str, Any] | None,
277320
metadata: Mapping[str, Any] | None,
278321
) -> None:
279322
if params:
@@ -373,10 +416,17 @@ def _maybe_set_server_attributes(
373416
def _extract_output_type(self, params: Mapping[str, Any]) -> str | None:
374417
response_format = params.get("response_format")
375418
output_type: str | None = None
376-
if isinstance(response_format, dict):
377-
output_type = response_format.get("type")
419+
if isinstance(response_format, Mapping):
420+
response_mapping = cast(Mapping[Any, Any], response_format)
421+
candidate: Any = response_mapping.get("type")
422+
if isinstance(candidate, str):
423+
output_type = candidate
424+
elif candidate is not None:
425+
output_type = str(candidate)
378426
elif isinstance(response_format, str):
379427
output_type = response_format
428+
elif response_format is not None:
429+
output_type = str(response_format)
380430

381431
if not output_type:
382432
return None
@@ -420,7 +470,7 @@ def _serialize_output_messages(
420470
return serialized
421471

422472
def _serialize_message(self, message: _MessageLike) -> _SerializedMessage:
423-
payload: _SerializedMessage = {
473+
payload: dict[str, Any] = {
424474
"type": getattr(message, "type", message.__class__.__name__),
425475
"content": getattr(message, "content", None),
426476
}
@@ -434,9 +484,9 @@ def _serialize_message(self, message: _MessageLike) -> _SerializedMessage:
434484
"name",
435485
):
436486
value = getattr(message, attr, None)
437-
if value:
487+
if value is not None:
438488
payload[attr] = value
439-
return payload
489+
return cast(_SerializedMessage, payload)
440490

441491
def _serialize_to_json(self, payload: Any) -> str:
442492
return json.dumps(payload, default=self._json_default)
@@ -446,9 +496,13 @@ def _json_default(value: Any) -> Any:
446496
if isinstance(value, (str, int, float, bool)) or value is None:
447497
return value
448498
if isinstance(value, dict):
449-
return value
499+
return cast(dict[str, Any], value)
450500
if isinstance(value, (list, tuple)):
451-
return list(value)
501+
seq_value = cast(Sequence[Any], value)
502+
return [
503+
OpenTelemetryLangChainCallbackHandler._json_default(item)
504+
for item in seq_value
505+
]
452506
return getattr(value, "__dict__", str(value))
453507

454508
def on_llm_end(

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,4 +220,5 @@ exclude = [
220220
"instrumentation-genai/opentelemetry-instrumentation-weaviate/tests/**/*.py",
221221
"instrumentation-genai/opentelemetry-instrumentation-weaviate/examples/**/*.py",
222222
"util/opentelemetry-util-genai/tests/**/*.py",
223+
"**/.venv",
223224
]

tox.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,6 +1075,7 @@ deps =
10751075
{toxinidir}/util/opentelemetry-util-genai[upload]
10761076
{toxinidir}/instrumentation-genai/opentelemetry-instrumentation-vertexai[instruments]
10771077
{toxinidir}/instrumentation-genai/opentelemetry-instrumentation-google-genai[instruments]
1078+
{toxinidir}/instrumentation-genai/opentelemetry-instrumentation-langchain[instruments]
10781079
{toxinidir}/instrumentation/opentelemetry-instrumentation-aiokafka[instruments]
10791080
{toxinidir}/instrumentation/opentelemetry-instrumentation-asyncclick[instruments]
10801081
{toxinidir}/exporter/opentelemetry-exporter-credential-provider-gcp

0 commit comments

Comments
 (0)