Skip to content

Commit 60c60f7

Browse files
committed
VertexAI handle streaming requests
WIP using shared context manager Properly implement uninstrument Shared code with a contextmanager
1 parent 4397344 commit 60c60f7

File tree

4 files changed

+324
-33
lines changed

4 files changed

+324
-33
lines changed

instrumentation-genai/opentelemetry-instrumentation-vertexai/src/opentelemetry/instrumentation/vertexai/__init__.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,11 @@
4646
)
4747

4848
from opentelemetry._events import get_event_logger
49+
from opentelemetry.instrumentation.utils import unwrap
4950
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
5051
from opentelemetry.instrumentation.utils import unwrap
5152
from opentelemetry.instrumentation.vertexai.package import _instruments
52-
from opentelemetry.instrumentation.vertexai.patch import (
53-
generate_content_create,
54-
)
53+
from opentelemetry.instrumentation.vertexai.patch import PatchedMethods
5554
from opentelemetry.instrumentation.vertexai.utils import is_content_enabled
5655
from opentelemetry.semconv.schemas import Schemas
5756
from opentelemetry.trace import get_tracer
@@ -104,6 +103,23 @@ def _instrument(self, **kwargs: Any):
104103
),
105104
)
106105

106+
for module in (
107+
"google.cloud.aiplatform_v1.services.prediction_service.client",
108+
"google.cloud.aiplatform_v1beta1.services.prediction_service.client",
109+
):
110+
# non streaming
111+
wrap_function_wrapper(
112+
module=module,
113+
name="PredictionServiceClient.generate_content",
114+
wrapper=patched_methods.generate_content,
115+
)
116+
# streaming
117+
wrap_function_wrapper(
118+
module=module,
119+
name="PredictionServiceClient.stream_generate_content",
120+
wrapper=patched_methods.stream_generate_content,
121+
)
122+
107123
def _uninstrument(self, **kwargs: Any) -> None:
108124
for client_class in _client_classes():
109125
unwrap(client_class, "generate_content")

instrumentation-genai/opentelemetry-instrumentation-vertexai/src/opentelemetry/instrumentation/vertexai/patch.py

Lines changed: 83 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,22 @@
1414

1515
from __future__ import annotations
1616

17+
from contextlib import contextmanager
1718
from typing import (
1819
TYPE_CHECKING,
1920
Any,
2021
Callable,
22+
Generator,
23+
Iterable,
2124
MutableSequence,
2225
)
2326

27+
from google.cloud.aiplatform_v1.types.prediction_service import (
28+
GenerateContentResponse,
29+
)
30+
from google.cloud.aiplatform_v1beta1.types.prediction_service import (
31+
GenerateContentResponse,
32+
)
2433
from opentelemetry._events import EventLogger
2534
from opentelemetry.instrumentation.vertexai.utils import (
2635
GenerateContentParams,
@@ -87,17 +96,17 @@ def _extract_params(
8796
)
8897

8998

90-
def generate_content_create(
91-
tracer: Tracer, event_logger: EventLogger, capture_content: bool
92-
):
93-
"""Wrap the `generate_content` method of the `GenerativeModel` class to trace it."""
99+
class PatchedMethods:
100+
def __init__(
101+
self, tracer: Tracer, event_logger: EventLogger, capture_content: bool
102+
) -> None:
103+
self.tracer = tracer
104+
self.event_logger = event_logger
105+
self.capture_content = capture_content
94106

95-
def traced_method(
96-
wrapped: Callable[
97-
...,
98-
prediction_service.GenerateContentResponse
99-
| prediction_service_v1beta1.GenerateContentResponse,
100-
],
107+
@contextmanager
108+
def _start_as_current_span(
109+
self,
101110
instance: client.PredictionServiceClient
102111
| client_v1beta1.PredictionServiceClient,
103112
args: Any,
@@ -111,32 +120,76 @@ def traced_method(
111120
}
112121

113122
span_name = get_span_name(span_attributes)
114-
with tracer.start_as_current_span(
115-
name=span_name,
116-
kind=SpanKind.CLIENT,
117-
attributes=span_attributes,
123+
124+
with self.tracer.start_as_current_span(
125+
name=span_name, kind=SpanKind.CLIENT, attributes=span_attributes
118126
) as span:
119127
for event in request_to_events(
120-
params=params, capture_content=capture_content
128+
params=params, capture_content=self.capture_content
121129
):
122-
event_logger.emit(event)
130+
self.event_logger.emit(event)
123131

124132
# TODO: set error.type attribute
125133
# https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/gen-ai-spans.md
126-
response = wrapped(*args, **kwargs)
127-
# TODO: handle streaming
128-
# if is_streaming(kwargs):
129-
# return StreamWrapper(
130-
# result, span, event_logger, capture_content
131-
# )
132-
133-
if span.is_recording():
134-
span.set_attributes(get_genai_response_attributes(response))
135-
for event in response_to_events(
136-
response=response, capture_content=capture_content
137-
):
138-
event_logger.emit(event)
139134

135+
final_response = None
136+
137+
def handle_response(
138+
response: prediction_service.GenerateContentResponse
139+
| prediction_service_v1beta1.GenerateContentResponse,
140+
) -> None:
141+
nonlocal final_response
142+
final_response = response
143+
for event in response_to_events(
144+
response=response, capture_content=self.capture_content
145+
):
146+
self.event_logger.emit(event)
147+
148+
yield handle_response
149+
150+
# These attributes are only set on the final response in the case of streaming
151+
if final_response and span.is_recording():
152+
span.set_attributes(
153+
get_genai_response_attributes(final_response)
154+
)
155+
156+
def generate_content(
157+
self,
158+
wrapped: Callable[
159+
...,
160+
prediction_service.GenerateContentResponse
161+
| prediction_service_v1beta1.GenerateContentResponse,
162+
],
163+
instance: client.PredictionServiceClient
164+
| client_v1beta1.PredictionServiceClient,
165+
args: Any,
166+
kwargs: Any,
167+
) -> GenerateContentResponse | GenerateContentResponse:
168+
with self._start_as_current_span(
169+
instance, args, kwargs
170+
) as handle_response:
171+
response = wrapped(*args, **kwargs)
172+
handle_response(response)
140173
return response
141174

142-
return traced_method
175+
def stream_generate_content(
176+
self,
177+
wrapped: Callable[
178+
...,
179+
Iterable[prediction_service.GenerateContentResponse]
180+
| Iterable[prediction_service_v1beta1.GenerateContentResponse],
181+
],
182+
instance: client.PredictionServiceClient
183+
| client_v1beta1.PredictionServiceClient,
184+
args: Any,
185+
kwargs: Any,
186+
) -> Generator[
187+
GenerateContentResponse | GenerateContentResponse, Any, None
188+
]:
189+
print("stream_generate_content() starting ctxmanager")
190+
with self._start_as_current_span(
191+
instance, args, kwargs
192+
) as handle_response:
193+
for response in wrapped(*args, **kwargs):
194+
handle_response(response)
195+
yield response
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
interactions:
2+
- request:
3+
body: |-
4+
{
5+
"contents": [
6+
{
7+
"role": "user",
8+
"parts": [
9+
{
10+
"text": "Say this is a test"
11+
}
12+
]
13+
}
14+
]
15+
}
16+
headers:
17+
Accept:
18+
- '*/*'
19+
Accept-Encoding:
20+
- gzip, deflate
21+
Connection:
22+
- keep-alive
23+
Content-Length:
24+
- '141'
25+
Content-Type:
26+
- application/json
27+
User-Agent:
28+
- python-requests/2.32.3
29+
method: POST
30+
uri: https://us-central1-aiplatform.googleapis.com/v1/projects/fake-project/locations/us-central1/publishers/google/models/gemini-1.5-flash-002:streamGenerateContent?%24alt=json%3Benum-encoding%3Dint
31+
response:
32+
body:
33+
string: |-
34+
[
35+
{
36+
"candidates": [
37+
{
38+
"content": {
39+
"role": "model",
40+
"parts": [
41+
{
42+
"text": "Okay"
43+
}
44+
]
45+
}
46+
}
47+
],
48+
"usageMetadata": {},
49+
"modelVersion": "gemini-1.5-flash-002",
50+
"createTime": "2025-03-03T22:23:47.310622Z",
51+
"responseId": "8yvGZ976Eu6knvgPpOnW2Q4"
52+
},
53+
{
54+
"candidates": [
55+
{
56+
"content": {
57+
"role": "model",
58+
"parts": [
59+
{
60+
"text": ", I understand. I'm ready for your test. Please proceed"
61+
}
62+
]
63+
}
64+
}
65+
],
66+
"modelVersion": "gemini-1.5-flash-002",
67+
"createTime": "2025-03-03T22:23:47.310622Z",
68+
"responseId": "8yvGZ976Eu6knvgPpOnW2Q4"
69+
},
70+
{
71+
"candidates": [
72+
{
73+
"content": {
74+
"role": "model",
75+
"parts": [
76+
{
77+
"text": ".\n"
78+
}
79+
]
80+
},
81+
"finishReason": 1
82+
}
83+
],
84+
"usageMetadata": {
85+
"promptTokenCount": 5,
86+
"candidatesTokenCount": 19,
87+
"totalTokenCount": 24,
88+
"promptTokensDetails": [
89+
{
90+
"modality": 1,
91+
"tokenCount": 5
92+
}
93+
],
94+
"candidatesTokensDetails": [
95+
{
96+
"modality": 1,
97+
"tokenCount": 19
98+
}
99+
]
100+
},
101+
"modelVersion": "gemini-1.5-flash-002",
102+
"createTime": "2025-03-03T22:23:47.310622Z",
103+
"responseId": "8yvGZ976Eu6knvgPpOnW2Q4"
104+
}
105+
]
106+
headers:
107+
Content-Type:
108+
- application/json; charset=UTF-8
109+
Transfer-Encoding:
110+
- chunked
111+
Vary:
112+
- Origin
113+
- X-Origin
114+
- Referer
115+
content-length:
116+
- '1328'
117+
status:
118+
code: 200
119+
message: OK
120+
version: 1

0 commit comments

Comments
 (0)