Skip to content

Commit 9cd5360

Browse files
committed
checkout parts from streaming squash me
1 parent ad29af3 commit 9cd5360

File tree

2 files changed

+86
-46
lines changed
  • instrumentation-genai/opentelemetry-instrumentation-vertexai/src/opentelemetry/instrumentation/vertexai

2 files changed

+86
-46
lines changed

instrumentation-genai/opentelemetry-instrumentation-vertexai/src/opentelemetry/instrumentation/vertexai/__init__.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -39,42 +39,50 @@
3939
---
4040
"""
4141

42+
from __future__ import annotations
43+
4244
from typing import Any, Collection
4345

4446
from wrapt import (
45-
wrap_function_wrapper, # type: ignore[reportUnknownVariableType]
47+
wrap_function_wrapper, # pyright: ignore[reportUnknownVariableType]
4648
)
4749

4850
from opentelemetry._events import get_event_logger
4951
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
5052
from opentelemetry.instrumentation.utils import unwrap
5153
from opentelemetry.instrumentation.vertexai.package import _instruments
52-
from opentelemetry.instrumentation.vertexai.patch import (
53-
generate_content_create,
54-
)
54+
from opentelemetry.instrumentation.vertexai.patch import MethodWrappers
5555
from opentelemetry.instrumentation.vertexai.utils import is_content_enabled
5656
from opentelemetry.semconv.schemas import Schemas
5757
from opentelemetry.trace import get_tracer
5858

5959

60-
def _client_classes():
60+
def _methods_to_wrap(
61+
method_wrappers: MethodWrappers,
62+
):
6163
# This import is very slow, do it lazily in case instrument() is not called
62-
6364
# pylint: disable=import-outside-toplevel
64-
from google.cloud.aiplatform_v1.services.prediction_service import (
65-
client,
66-
)
65+
from google.cloud.aiplatform_v1.services.prediction_service import client
6766
from google.cloud.aiplatform_v1beta1.services.prediction_service import (
6867
client as client_v1beta1,
6968
)
7069

71-
return (
70+
for client_class in (
7271
client.PredictionServiceClient,
7372
client_v1beta1.PredictionServiceClient,
74-
)
73+
):
74+
yield (
75+
client_class,
76+
client_class.generate_content.__name__, # pyright: ignore[reportUnknownMemberType]
77+
method_wrappers.generate_content,
78+
)
7579

7680

7781
class VertexAIInstrumentor(BaseInstrumentor):
82+
def __init__(self) -> None:
83+
super().__init__()
84+
self._methods_to_unwrap: list[tuple[Any, str]] = []
85+
7886
def instrumentation_dependencies(self) -> Collection[str]:
7987
return _instruments
8088

@@ -95,15 +103,19 @@ def _instrument(self, **kwargs: Any):
95103
event_logger_provider=event_logger_provider,
96104
)
97105

98-
for client_class in _client_classes():
106+
method_wrappers = MethodWrappers(
107+
tracer, event_logger, is_content_enabled()
108+
)
109+
for client_class, method_name, wrapper in _methods_to_wrap(
110+
method_wrappers
111+
):
99112
wrap_function_wrapper(
100113
client_class,
101-
name="generate_content",
102-
wrapper=generate_content_create(
103-
tracer, event_logger, is_content_enabled()
104-
),
114+
name=method_name,
115+
wrapper=wrapper,
105116
)
117+
self._methods_to_unwrap.append((client_class, method_name))
106118

107119
def _uninstrument(self, **kwargs: Any) -> None:
108-
for client_class in _client_classes():
109-
unwrap(client_class, "generate_content")
120+
for client_class, method_name in self._methods_to_unwrap:
121+
unwrap(client_class, method_name)

instrumentation-genai/opentelemetry-instrumentation-vertexai/src/opentelemetry/instrumentation/vertexai/patch.py

Lines changed: 56 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414

1515
from __future__ import annotations
1616

17+
from contextlib import contextmanager
1718
from typing import (
1819
TYPE_CHECKING,
1920
Any,
2021
Callable,
22+
Iterable,
2123
MutableSequence,
2224
)
2325

@@ -87,17 +89,17 @@ def _extract_params(
8789
)
8890

8991

90-
def generate_content_create(
91-
tracer: Tracer, event_logger: EventLogger, capture_content: bool
92-
):
93-
"""Wrap the `generate_content` method of the `GenerativeModel` class to trace it."""
92+
class MethodWrappers:
93+
def __init__(
94+
self, tracer: Tracer, event_logger: EventLogger, capture_content: bool
95+
) -> None:
96+
self.tracer = tracer
97+
self.event_logger = event_logger
98+
self.capture_content = capture_content
9499

95-
def traced_method(
96-
wrapped: Callable[
97-
...,
98-
prediction_service.GenerateContentResponse
99-
| prediction_service_v1beta1.GenerateContentResponse,
100-
],
100+
@contextmanager
101+
def _with_instrumentation(
102+
self,
101103
instance: client.PredictionServiceClient
102104
| client_v1beta1.PredictionServiceClient,
103105
args: Any,
@@ -111,32 +113,58 @@ def traced_method(
111113
}
112114

113115
span_name = get_span_name(span_attributes)
114-
with tracer.start_as_current_span(
116+
117+
with self.tracer.start_as_current_span(
115118
name=span_name,
116119
kind=SpanKind.CLIENT,
117120
attributes=span_attributes,
118121
) as span:
119122
for event in request_to_events(
120-
params=params, capture_content=capture_content
123+
params=params, capture_content=self.capture_content
121124
):
122-
event_logger.emit(event)
125+
self.event_logger.emit(event)
123126

124127
# TODO: set error.type attribute
125128
# https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/gen-ai-spans.md
126-
response = wrapped(*args, **kwargs)
127-
# TODO: handle streaming
128-
# if is_streaming(kwargs):
129-
# return StreamWrapper(
130-
# result, span, event_logger, capture_content
131-
# )
132-
133-
if span.is_recording():
134-
span.set_attributes(get_genai_response_attributes(response))
135-
for event in response_to_events(
136-
response=response, capture_content=capture_content
137-
):
138-
event_logger.emit(event)
139129

130+
def handle_response(
131+
response: prediction_service.GenerateContentResponse
132+
| prediction_service_v1beta1.GenerateContentResponse,
133+
) -> None:
134+
if span.is_recording():
135+
# When streaming, this is called multiple times so attributes would be
136+
# overwritten. In practice, it looks the API only returns the interesting
137+
# attributes on the last streamed response. However, I couldn't find
138+
# documentation for this and setting attributes shouldn't be too expensive.
139+
span.set_attributes(
140+
get_genai_response_attributes(response)
141+
)
142+
143+
for event in response_to_events(
144+
response=response, capture_content=self.capture_content
145+
):
146+
self.event_logger.emit(event)
147+
148+
yield handle_response
149+
150+
def generate_content(
151+
self,
152+
wrapped: Callable[
153+
...,
154+
prediction_service.GenerateContentResponse
155+
| prediction_service_v1beta1.GenerateContentResponse,
156+
],
157+
instance: client.PredictionServiceClient
158+
| client_v1beta1.PredictionServiceClient,
159+
args: Any,
160+
kwargs: Any,
161+
) -> (
162+
prediction_service.GenerateContentResponse
163+
| prediction_service_v1beta1.GenerateContentResponse
164+
):
165+
with self._with_instrumentation(
166+
instance, args, kwargs
167+
) as handle_response:
168+
response = wrapped(*args, **kwargs)
169+
handle_response(response)
140170
return response
141-
142-
return traced_method

0 commit comments

Comments
 (0)