Skip to content

Commit b007681

Browse files
committed
moved span generation code and added test coverage
1 parent 8903433 commit b007681

File tree

4 files changed

+200
-81
lines changed

4 files changed

+200
-81
lines changed

instrumentation-genai/opentelemetry-instrumentation-langchain/pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ classifiers = [
2525
"Programming Language :: Python :: 3.13",
2626
]
2727
dependencies = [
28-
"opentelemetry-api == 1.36.0",
29-
"opentelemetry-instrumentation == 0.57b0",
30-
"opentelemetry-semantic-conventions == 0.57b0"
28+
"opentelemetry-api >= 1.36.0",
29+
"opentelemetry-instrumentation >= 0.57b0",
30+
"opentelemetry-semantic-conventions >= 0.57b0"
3131
]
3232

3333
[project.optional-dependencies]

instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/callback_handler.py

Lines changed: 23 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,36 @@
1-
import time
2-
from dataclasses import dataclass, field
1+
# Copyright The OpenTelemetry Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
315
from typing import Any, Dict, List, Optional
416
from uuid import UUID
517

618
from langchain_core.callbacks import BaseCallbackHandler # type: ignore
719
from langchain_core.messages import BaseMessage # type: ignore
820
from langchain_core.outputs import LLMResult # type: ignore
921

10-
from opentelemetry.context import Context, get_current
22+
from opentelemetry.instrumentation.langchain.span_manager import SpanManager
1123
from opentelemetry.instrumentation.langchain.utils import dont_throw
1224
from opentelemetry.semconv._incubating.attributes import (
1325
gen_ai_attributes as GenAI,
1426
)
1527
from opentelemetry.semconv.attributes import (
1628
error_attributes as ErrorAttributes,
1729
)
18-
from opentelemetry.trace import Span, SpanKind, Tracer, set_span_in_context
30+
from opentelemetry.trace import Tracer
1931
from opentelemetry.trace.status import Status, StatusCode
2032

2133

22-
@dataclass
23-
class _SpanState:
24-
span: Span
25-
span_context: Context
26-
start_time: float = field(default_factory=time.time)
27-
children: List[UUID] = field(default_factory=list)
28-
29-
3034
class OpenTelemetryLangChainCallbackHandler(BaseCallbackHandler): # type: ignore[misc]
3135
"""
3236
A callback handler for LangChain that uses OpenTelemetry to create spans for LLM calls and chains, tools etc,. in future.
@@ -37,69 +41,10 @@ def __init__(
3741
tracer: Tracer,
3842
) -> None:
3943
super().__init__() # type: ignore
40-
self._tracer = tracer
41-
42-
# Map from run_id -> _SpanState, to keep track of spans and parent/child relationships
43-
self.spans: Dict[UUID, _SpanState] = {}
44-
self.run_inline = True # Whether to run the callback inline.
45-
46-
def _create_span(
47-
self,
48-
run_id: UUID,
49-
parent_run_id: Optional[UUID],
50-
span_name: str,
51-
kind: SpanKind = SpanKind.INTERNAL,
52-
) -> Span:
53-
if parent_run_id is not None and parent_run_id in self.spans:
54-
parent_span = self.spans[parent_run_id].span
55-
ctx = set_span_in_context(parent_span)
56-
span = self._tracer.start_span(
57-
name=span_name, kind=kind, context=ctx
58-
)
59-
else:
60-
# top-level or missing parent
61-
span = self._tracer.start_span(name=span_name, kind=kind)
62-
63-
span_state = _SpanState(span=span, span_context=get_current())
64-
self.spans[run_id] = span_state
65-
66-
if parent_run_id is not None and parent_run_id in self.spans:
67-
self.spans[parent_run_id].children.append(run_id)
6844

69-
return span
70-
71-
def _create_llm_span(
72-
self,
73-
run_id: UUID,
74-
parent_run_id: Optional[UUID],
75-
name: str,
76-
) -> Span:
77-
span = self._create_span(
78-
run_id=run_id,
79-
parent_run_id=parent_run_id,
80-
span_name=f"{name}.{GenAI.GenAiOperationNameValues.CHAT.value}",
81-
kind=SpanKind.CLIENT,
45+
self.span_manager = SpanManager(
46+
tracer=tracer,
8247
)
83-
span.set_attribute(
84-
GenAI.GEN_AI_OPERATION_NAME,
85-
GenAI.GenAiOperationNameValues.CHAT.value,
86-
)
87-
span.set_attribute(GenAI.GEN_AI_SYSTEM, name)
88-
89-
return span
90-
91-
def _end_span(self, run_id: UUID) -> None:
92-
state = self.spans[run_id]
93-
for child_id in state.children:
94-
child_state = self.spans.get(child_id)
95-
if child_state:
96-
# Always end child spans as OpenTelemetry spans don't expose end_time directly
97-
child_state.span.end()
98-
# Always end the span as OpenTelemetry spans don't expose end_time directly
99-
state.span.end()
100-
101-
def _get_span(self, run_id: UUID) -> Span:
102-
return self.spans[run_id].span
10348

10449
@dont_throw
10550
def on_chat_model_start(
@@ -114,7 +59,7 @@ def on_chat_model_start(
11459
**kwargs: Any,
11560
) -> None:
11661
name = serialized.get("name") or kwargs.get("name") or "ChatLLM"
117-
span = self._create_llm_span(
62+
span = self.span_manager.create_llm_span(
11863
run_id=run_id,
11964
parent_run_id=parent_run_id,
12065
name=name,
@@ -170,7 +115,7 @@ def on_llm_end(
170115
parent_run_id: Optional[UUID] = None,
171116
**kwargs: Any,
172117
) -> None:
173-
span = self._get_span(run_id)
118+
span = self.span_manager.get_span(run_id)
174119

175120
finish_reasons: List[str] = []
176121
for generation in getattr(response, "generations", []): # type: ignore
@@ -218,7 +163,7 @@ def on_llm_end(
218163
)
219164

220165
# End the LLM span
221-
self._end_span(run_id)
166+
self.span_manager.end_span(run_id)
222167

223168
@dont_throw
224169
def on_llm_error(
@@ -232,9 +177,9 @@ def on_llm_error(
232177
self._handle_error(error, run_id)
233178

234179
def _handle_error(self, error: BaseException, run_id: UUID):
235-
span = self._get_span(run_id)
180+
span = self.span_manager.get_span(run_id)
236181
span.set_status(Status(StatusCode.ERROR, str(error)))
237182
span.set_attribute(
238183
ErrorAttributes.ERROR_TYPE, type(error).__qualname__
239184
)
240-
self._end_span(run_id)
185+
self.span_manager.end_span(run_id)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright The OpenTelemetry Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import time
16+
from dataclasses import dataclass, field
17+
from typing import Dict, List, Optional
18+
from uuid import UUID
19+
20+
from langchain_core.callbacks import BaseCallbackHandler # type: ignore
21+
from langchain_core.messages import BaseMessage # type: ignore
22+
from langchain_core.outputs import LLMResult # type: ignore
23+
24+
from opentelemetry.context import Context, get_current
25+
from opentelemetry.semconv._incubating.attributes import (
26+
gen_ai_attributes as GenAI,
27+
)
28+
29+
from opentelemetry.trace import Span, SpanKind, Tracer, set_span_in_context
30+
31+
32+
@dataclass
33+
class _SpanState:
34+
span: Span
35+
context: Context
36+
start_time: float = field(default_factory=time.time)
37+
children: List[UUID] = field(default_factory=list)
38+
39+
class SpanManager:
40+
def __init__(
41+
self,
42+
tracer: Tracer,
43+
) -> None:
44+
self._tracer = tracer
45+
46+
# Map from run_id -> _SpanState, to keep track of spans and parent/child relationships
47+
self.spans: Dict[UUID, _SpanState] = {}
48+
49+
def create_span(
50+
self,
51+
run_id: UUID,
52+
parent_run_id: Optional[UUID],
53+
span_name: str,
54+
kind: SpanKind = SpanKind.INTERNAL,
55+
) -> Span:
56+
if parent_run_id is not None and parent_run_id in self.spans:
57+
parent_span = self.spans[parent_run_id].span
58+
ctx = set_span_in_context(parent_span)
59+
span = self._tracer.start_span(
60+
name=span_name, kind=kind, context=ctx
61+
)
62+
else:
63+
# top-level or missing parent
64+
span = self._tracer.start_span(name=span_name, kind=kind)
65+
66+
span_state = _SpanState(span=span, context=get_current())
67+
self.spans[run_id] = span_state
68+
69+
if parent_run_id is not None and parent_run_id in self.spans:
70+
self.spans[parent_run_id].children.append(run_id)
71+
72+
return span
73+
74+
def create_llm_span(
75+
self,
76+
run_id: UUID,
77+
parent_run_id: Optional[UUID],
78+
name: str,
79+
) -> Span:
80+
span = self.create_span(
81+
run_id=run_id,
82+
parent_run_id=parent_run_id,
83+
span_name=f"{name}.{GenAI.GenAiOperationNameValues.CHAT.value}",
84+
kind=SpanKind.CLIENT,
85+
)
86+
span.set_attribute(
87+
GenAI.GEN_AI_OPERATION_NAME,
88+
GenAI.GenAiOperationNameValues.CHAT.value,
89+
)
90+
span.set_attribute(GenAI.GEN_AI_SYSTEM, name)
91+
92+
return span
93+
94+
def end_span(self, run_id: UUID) -> None:
95+
state = self.spans[run_id]
96+
for child_id in state.children:
97+
child_state = self.spans.get(child_id)
98+
if child_state:
99+
# Always end child spans as OpenTelemetry spans don't expose end_time directly
100+
child_state.span.end()
101+
# Always end the span as OpenTelemetry spans don't expose end_time directly
102+
state.span.end()
103+
104+
def get_span(self, run_id: UUID) -> Span:
105+
return self.spans[run_id].span
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import unittest.mock
2+
import uuid
3+
4+
import pytest
5+
from opentelemetry.trace import SpanKind, get_tracer
6+
from opentelemetry.trace.span import Span
7+
8+
from opentelemetry.instrumentation.langchain.span_manager import (
9+
SpanManager,
10+
_SpanState,
11+
)
12+
13+
14+
class TestSpanManager:
15+
@pytest.fixture
16+
def tracer(self):
17+
return get_tracer("test_tracer")
18+
19+
@pytest.fixture
20+
def handler(self, tracer):
21+
return SpanManager(tracer=tracer)
22+
23+
@pytest.mark.parametrize(
24+
"parent_run_id,parent_in_spans",
25+
[
26+
(None, False), # No parent
27+
(uuid.uuid4(), False), # Parent not in spans
28+
(uuid.uuid4(), True), # Parent in spans
29+
],
30+
)
31+
32+
def test_create_span(self, handler, tracer, parent_run_id, parent_in_spans):
33+
# Arrange
34+
run_id = uuid.uuid4()
35+
span_name = "test_span"
36+
kind = SpanKind.INTERNAL
37+
38+
mock_span = unittest.mock.Mock(spec=Span)
39+
40+
# Setup parent if needed
41+
if parent_run_id is not None and parent_in_spans:
42+
parent_mock_span = unittest.mock.Mock(spec=Span)
43+
handler.spans[parent_run_id] = _SpanState(span=parent_mock_span, context=None)
44+
45+
with unittest.mock.patch.object(
46+
tracer, "start_span", return_value=mock_span
47+
) as mock_start_span, unittest.mock.patch(
48+
"opentelemetry.instrumentation.langchain.span_manager.set_span_in_context"
49+
) as mock_set_span_in_context, unittest.mock.patch(
50+
"opentelemetry.instrumentation.langchain.span_manager.get_current"
51+
) as mock_get_current:
52+
# Act
53+
result = handler.create_span(run_id, parent_run_id, span_name, kind)
54+
55+
# Assert
56+
assert result == mock_span
57+
assert run_id in handler.spans
58+
assert handler.spans[run_id].span == mock_span
59+
60+
# Verify parent-child relationship
61+
if parent_run_id is not None and parent_in_spans:
62+
mock_set_span_in_context.assert_called_once_with(handler.spans[parent_run_id].span)
63+
mock_start_span.assert_called_once_with(
64+
name=span_name, kind=kind, context=mock_set_span_in_context.return_value
65+
)
66+
assert run_id in handler.spans[parent_run_id].children
67+
else:
68+
mock_start_span.assert_called_once_with(name=span_name, kind=kind)
69+
mock_set_span_in_context.assert_not_called()

0 commit comments

Comments
 (0)