Skip to content

Commit 9837cf4

Browse files
committed
refactor: store span and context token in LLMInvocation instead of SpanGenerator
1 parent 635b7f8 commit 9837cf4

File tree

4 files changed

+37
-79
lines changed

4 files changed

+37
-79
lines changed

util/opentelemetry-util-genai/src/opentelemetry/util/genai/generators.py

Lines changed: 13 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
follow the GenAI semantic conventions.
3232
"""
3333

34-
from contextlib import contextmanager
3534
from contextvars import Token
3635
from typing import Dict, Optional
3736
from uuid import UUID
@@ -95,56 +94,26 @@ def start(self, invocation: LLMInvocation):
9594
name=f"{GenAI.GenAiOperationNameValues.CHAT.value} {invocation.request_model}",
9695
kind=SpanKind.CLIENT,
9796
)
98-
token = otel_context.attach(set_span_in_context(span))
99-
self._active[invocation.run_id] = (span, token)
100-
101-
@contextmanager
102-
def _start_span_for_invocation(self, invocation: LLMInvocation):
103-
"""Create/register a span for the invocation and yield it.
104-
105-
The span is not ended automatically on exiting the context; callers
106-
must finalize via _finalize_invocation.
107-
"""
108-
# Create a span and attach it as current; keep the token to detach later
109-
span = self._tracer.start_span(
110-
name=f"{GenAI.GenAiOperationNameValues.CHAT.value} {invocation.request_model}",
111-
kind=SpanKind.CLIENT,
97+
invocation.span = span
98+
invocation.context_token = otel_context.attach(
99+
set_span_in_context(span)
112100
)
113-
token = otel_context.attach(set_span_in_context(span))
114-
# store active span and its context attachment token
115-
self._active[invocation.run_id] = (span, token)
116-
yield span
117101

118102
def finish(self, invocation: LLMInvocation):
119-
active = self._active.get(invocation.run_id)
120-
if active is None:
121-
# If missing, create a quick span to record attributes and end it
122-
with self._tracer.start_as_current_span(
123-
name=f"{GenAI.GenAiOperationNameValues.CHAT.value} {invocation.request_model}",
124-
kind=SpanKind.CLIENT,
125-
) as span:
126-
_apply_finish_attributes(span, invocation)
103+
if invocation.context_token is None or invocation.span is None:
127104
return
128105

129-
span, token = active
130-
_apply_finish_attributes(span, invocation)
106+
_apply_finish_attributes(invocation.span, invocation)
131107
# Detach context and end span
132-
otel_context.detach(token)
133-
span.end()
134-
del self._active[invocation.run_id]
108+
otel_context.detach(invocation.context_token)
109+
invocation.span.end()
135110

136111
def error(self, error: Error, invocation: LLMInvocation):
137-
active = self._active.get(invocation.run_id)
138-
if active is None:
139-
with self._tracer.start_as_current_span(
140-
name=f"{GenAI.GenAiOperationNameValues.CHAT.value} {invocation.request_model}",
141-
kind=SpanKind.CLIENT,
142-
) as span:
143-
_apply_error_attributes(span, error)
112+
if invocation.context_token is None or invocation.span is None:
144113
return
145114

146-
span, token = active
147-
_apply_error_attributes(span, error)
148-
otel_context.detach(token)
149-
span.end()
150-
del self._active[invocation.run_id]
115+
_apply_error_attributes(invocation.span, error)
116+
# Detach context and end span
117+
otel_context.detach(invocation.context_token)
118+
invocation.span.end()
119+
return

util/opentelemetry-util-genai/src/opentelemetry/util/genai/handler.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@
3333
"""
3434

3535
import time
36-
import uuid
3736
from typing import Any, List, Optional
38-
from uuid import UUID
3937

4038
from opentelemetry.semconv.schemas import Schemas
4139
from opentelemetry.trace import get_tracer
@@ -72,7 +70,7 @@ class TelemetryHandler:
7270
them as spans, metrics, and events.
7371
"""
7472

75-
def __init__(self, emitter_type_full: bool = True, **kwargs: Any):
73+
def __init__(self, **kwargs: Any):
7674
tracer_provider = kwargs.get("tracer_provider")
7775
self._tracer = get_tracer(
7876
__name__,
@@ -81,18 +79,14 @@ def __init__(self, emitter_type_full: bool = True, **kwargs: Any):
8179
schema_url=Schemas.V1_36_0.value,
8280
)
8381

84-
# TODO: trigger span+metric+event generation based on the full emitter flag
8582
self._generator = SpanGenerator(tracer=self._tracer)
8683

87-
self._llm_registry: dict[UUID, LLMInvocation] = {}
88-
8984
def start_llm(
9085
self,
9186
request_model: str,
9287
prompts: List[InputMessage],
93-
run_id: Optional[UUID] = None,
9488
**attributes: Any,
95-
) -> UUID:
89+
) -> LLMInvocation:
9690
"""Start an LLM invocation and create a pending span entry.
9791
9892
Known attributes provided via ``**attributes`` (``provider``,
@@ -101,29 +95,24 @@ def start_llm(
10195
``LLMInvocation``. Any remaining keys are preserved in
10296
``invocation.attributes`` for custom metadata.
10397
104-
Returns the ``run_id`` used to track the invocation lifecycle.
98+
Returns the ``LLMInvocation`` to use with `stop_llm` and `fail_llm`.
10599
"""
106-
if run_id is None:
107-
run_id = uuid.uuid4()
108100
invocation = LLMInvocation(
109101
request_model=request_model,
110102
messages=prompts,
111-
run_id=run_id,
112103
attributes=attributes,
113104
)
114105
_apply_known_attrs_to_invocation(invocation, invocation.attributes)
115-
self._llm_registry[invocation.run_id] = invocation
116106
self._generator.start(invocation)
117-
return invocation.run_id
107+
return invocation
118108

119109
def stop_llm(
120110
self,
121-
run_id: UUID,
111+
invocation: LLMInvocation,
122112
chat_generations: List[OutputMessage],
123113
**attributes: Any,
124114
) -> LLMInvocation:
125115
"""Finalize an LLM invocation successfully and end its span."""
126-
invocation = self._llm_registry.pop(run_id)
127116
invocation.end_time = time.time()
128117
invocation.chat_generations = chat_generations
129118
_apply_known_attrs_to_invocation(invocation, attributes)
@@ -132,29 +121,24 @@ def stop_llm(
132121
return invocation
133122

134123
def fail_llm(
135-
self, run_id: UUID, error: Error, **attributes: Any
124+
self, invocation: LLMInvocation, error: Error, **attributes: Any
136125
) -> LLMInvocation:
137126
"""Fail an LLM invocation and end its span with error status."""
138-
invocation = self._llm_registry.pop(run_id)
139127
invocation.end_time = time.time()
140128
_apply_known_attrs_to_invocation(invocation, attributes)
141129
invocation.attributes.update(**attributes)
142130
self._generator.error(error, invocation)
143131
return invocation
144132

145133

146-
def get_telemetry_handler(
147-
emitter_type_full: bool = True, **kwargs: Any
148-
) -> TelemetryHandler:
134+
def get_telemetry_handler(**kwargs: Any) -> TelemetryHandler:
149135
"""
150136
Returns a singleton TelemetryHandler instance.
151137
"""
152138
handler: Optional[TelemetryHandler] = getattr(
153139
get_telemetry_handler, "_default_handler", None
154140
)
155141
if handler is None:
156-
handler = TelemetryHandler(
157-
emitter_type_full=emitter_type_full, **kwargs
158-
)
142+
handler = TelemetryHandler(**kwargs)
159143
setattr(get_telemetry_handler, "_default_handler", handler)
160144
return handler

util/opentelemetry-util-genai/src/opentelemetry/util/genai/types.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,20 @@
1414

1515

1616
import time
17+
from contextvars import Token
1718
from dataclasses import dataclass, field
1819
from enum import Enum
1920
from typing import Any, Dict, List, Literal, Optional, Type, Union
2021
from uuid import UUID
2122

23+
from typing_extensions import TypeAlias
24+
25+
from opentelemetry.context import Context
26+
from opentelemetry.trace import Span
2227
from opentelemetry.util.types import AttributeValue
2328

29+
ContextToken: TypeAlias = Token[Context]
30+
2431

2532
class ContentCapturingMode(Enum):
2633
# Do not capture content (default).
@@ -81,8 +88,9 @@ class LLMInvocation:
8188
Represents a single LLM call invocation.
8289
"""
8390

84-
run_id: UUID
8591
request_model: str
92+
context_token: Optional[ContextToken] = None
93+
span: Optional[Span] = None
8694
parent_run_id: Optional[UUID] = None
8795
start_time: float = field(default_factory=time.time)
8896
end_time: Optional[float] = None

util/opentelemetry-util-genai/tests/test_utils.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,6 @@ def tearDown(self):
122122
content_capturing="SPAN_ONLY",
123123
)
124124
def test_llm_start_and_stop_creates_span(self): # pylint: disable=no-self-use
125-
run_id = uuid4()
126125
message = InputMessage(
127126
role="Human", parts=[Text(content="hello world")]
128127
)
@@ -131,15 +130,14 @@ def test_llm_start_and_stop_creates_span(self): # pylint: disable=no-self-use
131130
)
132131

133132
# Start and stop LLM invocation
134-
self.telemetry_handler.start_llm(
133+
invocation = self.telemetry_handler.start_llm(
135134
request_model="test-model",
136135
prompts=[message],
137-
run_id=run_id,
138136
custom_attr="value",
139137
provider="test-provider",
140138
)
141-
invocation = self.telemetry_handler.stop_llm(
142-
run_id, chat_generations=[chat_generation], extra="info"
139+
self.telemetry_handler.stop_llm(
140+
invocation, chat_generations=[chat_generation], extra="info"
143141
)
144142

145143
# Get the spans that were created
@@ -157,7 +155,6 @@ def test_llm_start_and_stop_creates_span(self): # pylint: disable=no-self-use
157155
assert span.start_time is not None
158156
assert span.end_time is not None
159157
assert span.end_time > span.start_time
160-
assert invocation.run_id == run_id
161158
assert invocation.attributes.get("custom_attr") == "value"
162159
assert invocation.attributes.get("extra") == "info"
163160

@@ -183,13 +180,13 @@ def test_parent_child_span_relationship(self):
183180
)
184181

185182
# Start parent and child (child references parent_run_id)
186-
self.telemetry_handler.start_llm(
183+
parent_invocation = self.telemetry_handler.start_llm(
187184
request_model="parent-model",
188185
prompts=[message],
189186
run_id=parent_id,
190187
provider="test-provider",
191188
)
192-
self.telemetry_handler.start_llm(
189+
child_invocation = self.telemetry_handler.start_llm(
193190
request_model="child-model",
194191
prompts=[message],
195192
run_id=child_id,
@@ -199,10 +196,10 @@ def test_parent_child_span_relationship(self):
199196

200197
# Stop child first, then parent (order should not matter)
201198
self.telemetry_handler.stop_llm(
202-
child_id, chat_generations=[chat_generation]
199+
child_invocation, chat_generations=[chat_generation]
203200
)
204201
self.telemetry_handler.stop_llm(
205-
parent_id, chat_generations=[chat_generation]
202+
parent_invocation, chat_generations=[chat_generation]
206203
)
207204

208205
spans = self.span_exporter.get_finished_spans()

0 commit comments

Comments
 (0)