20
20
Any ,
21
21
Awaitable ,
22
22
Callable ,
23
+ Literal ,
23
24
MutableSequence ,
25
+ Union ,
26
+ cast ,
27
+ overload ,
24
28
)
25
29
26
30
from opentelemetry ._events import EventLogger
31
+ from opentelemetry .instrumentation ._semconv import (
32
+ _StabilityMode ,
33
+ )
27
34
from opentelemetry .instrumentation .vertexai .utils import (
28
35
GenerateContentParams ,
36
+ create_operation_details_event ,
29
37
get_genai_request_attributes ,
30
38
get_genai_response_attributes ,
31
39
get_server_attributes ,
34
42
response_to_events ,
35
43
)
36
44
from opentelemetry .trace import SpanKind , Tracer
45
+ from opentelemetry .util .genai .types import ContentCapturingMode
37
46
38
47
if TYPE_CHECKING :
39
48
from google .cloud .aiplatform_v1 .services .prediction_service import client
@@ -89,17 +98,96 @@ def _extract_params(
89
98
)
90
99
91
100
101
+ # For details about GEN_AI_LATEST_EXPERIMENTAL stability mode see
102
+ # https://github.com/open-telemetry/semantic-conventions/blob/v1.37.0/docs/gen-ai/gen-ai-agent-spans.md?plain=1#L18-L37
92
103
class MethodWrappers :
104
+ @overload
105
+ def __init__ (
106
+ self ,
107
+ tracer : Tracer ,
108
+ event_logger : EventLogger ,
109
+ capture_content : ContentCapturingMode ,
110
+ sem_conv_opt_in_mode : Literal [
111
+ _StabilityMode .GEN_AI_LATEST_EXPERIMENTAL
112
+ ],
113
+ ) -> None : ...
114
+
115
+ @overload
116
+ def __init__ (
117
+ self ,
118
+ tracer : Tracer ,
119
+ event_logger : EventLogger ,
120
+ capture_content : bool ,
121
+ sem_conv_opt_in_mode : Literal [_StabilityMode .DEFAULT ],
122
+ ) -> None : ...
123
+
93
124
def __init__ (
94
- self , tracer : Tracer , event_logger : EventLogger , capture_content : bool
125
+ self ,
126
+ tracer : Tracer ,
127
+ event_logger : EventLogger ,
128
+ capture_content : Union [bool , ContentCapturingMode ],
129
+ sem_conv_opt_in_mode : Union [
130
+ Literal [_StabilityMode .DEFAULT ],
131
+ Literal [_StabilityMode .GEN_AI_LATEST_EXPERIMENTAL ],
132
+ ],
95
133
) -> None :
96
134
self .tracer = tracer
97
135
self .event_logger = event_logger
98
136
self .capture_content = capture_content
137
+ self .sem_conv_opt_in_mode = sem_conv_opt_in_mode
138
+
139
+ @contextmanager
140
+ def _with_new_instrumentation (
141
+ self ,
142
+ capture_content : ContentCapturingMode ,
143
+ instance : client .PredictionServiceClient
144
+ | client_v1beta1 .PredictionServiceClient ,
145
+ args : Any ,
146
+ kwargs : Any ,
147
+ ):
148
+ params = _extract_params (* args , ** kwargs )
149
+ api_endpoint : str = instance .api_endpoint # type: ignore[reportUnknownMemberType]
150
+ span_attributes = {
151
+ ** get_genai_request_attributes (False , params ),
152
+ ** get_server_attributes (api_endpoint ),
153
+ }
154
+
155
+ span_name = get_span_name (span_attributes )
156
+
157
+ with self .tracer .start_as_current_span (
158
+ name = span_name ,
159
+ kind = SpanKind .CLIENT ,
160
+ attributes = span_attributes ,
161
+ ) as span :
162
+
163
+ def handle_response (
164
+ response : prediction_service .GenerateContentResponse
165
+ | prediction_service_v1beta1 .GenerateContentResponse
166
+ | None ,
167
+ ) -> None :
168
+ if span .is_recording () and response :
169
+ # When streaming, this is called multiple times so attributes would be
170
+ # overwritten. In practice, it looks the API only returns the interesting
171
+ # attributes on the last streamed response. However, I couldn't find
172
+ # documentation for this and setting attributes shouldn't be too expensive.
173
+ span .set_attributes (
174
+ get_genai_response_attributes (response )
175
+ )
176
+ self .event_logger .emit (
177
+ create_operation_details_event (
178
+ api_endpoint = api_endpoint ,
179
+ params = params ,
180
+ capture_content = capture_content ,
181
+ response = response ,
182
+ )
183
+ )
184
+
185
+ yield handle_response
99
186
100
187
@contextmanager
101
- def _with_instrumentation (
188
+ def _with_default_instrumentation (
102
189
self ,
190
+ capture_content : bool ,
103
191
instance : client .PredictionServiceClient
104
192
| client_v1beta1 .PredictionServiceClient ,
105
193
args : Any ,
@@ -108,7 +196,7 @@ def _with_instrumentation(
108
196
params = _extract_params (* args , ** kwargs )
109
197
api_endpoint : str = instance .api_endpoint # type: ignore[reportUnknownMemberType]
110
198
span_attributes = {
111
- ** get_genai_request_attributes (params ),
199
+ ** get_genai_request_attributes (False , params ),
112
200
** get_server_attributes (api_endpoint ),
113
201
}
114
202
@@ -120,7 +208,7 @@ def _with_instrumentation(
120
208
attributes = span_attributes ,
121
209
) as span :
122
210
for event in request_to_events (
123
- params = params , capture_content = self . capture_content
211
+ params = params , capture_content = capture_content
124
212
):
125
213
self .event_logger .emit (event )
126
214
@@ -141,7 +229,7 @@ def handle_response(
141
229
)
142
230
143
231
for event in response_to_events (
144
- response = response , capture_content = self . capture_content
232
+ response = response , capture_content = capture_content
145
233
):
146
234
self .event_logger .emit (event )
147
235
@@ -162,12 +250,25 @@ def generate_content(
162
250
prediction_service .GenerateContentResponse
163
251
| prediction_service_v1beta1 .GenerateContentResponse
164
252
):
165
- with self ._with_instrumentation (
166
- instance , args , kwargs
167
- ) as handle_response :
168
- response = wrapped (* args , ** kwargs )
169
- handle_response (response )
170
- return response
253
+ if self .sem_conv_opt_in_mode == _StabilityMode .DEFAULT :
254
+ capture_content_bool = cast (bool , self .capture_content )
255
+ with self ._with_default_instrumentation (
256
+ capture_content_bool , instance , args , kwargs
257
+ ) as handle_response :
258
+ response = wrapped (* args , ** kwargs )
259
+ handle_response (response )
260
+ return response
261
+ else :
262
+ capture_content = cast (ContentCapturingMode , self .capture_content )
263
+ with self ._with_new_instrumentation (
264
+ capture_content , instance , args , kwargs
265
+ ) as handle_response :
266
+ response = None
267
+ try :
268
+ response = wrapped (* args , ** kwargs )
269
+ return response
270
+ finally :
271
+ handle_response (response )
171
272
172
273
async def agenerate_content (
173
274
self ,
@@ -186,9 +287,22 @@ async def agenerate_content(
186
287
prediction_service .GenerateContentResponse
187
288
| prediction_service_v1beta1 .GenerateContentResponse
188
289
):
189
- with self ._with_instrumentation (
190
- instance , args , kwargs
191
- ) as handle_response :
192
- response = await wrapped (* args , ** kwargs )
193
- handle_response (response )
194
- return response
290
+ if self .sem_conv_opt_in_mode == _StabilityMode .DEFAULT :
291
+ capture_content_bool = cast (bool , self .capture_content )
292
+ with self ._with_default_instrumentation (
293
+ capture_content_bool , instance , args , kwargs
294
+ ) as handle_response :
295
+ response = await wrapped (* args , ** kwargs )
296
+ handle_response (response )
297
+ return response
298
+ else :
299
+ capture_content = cast (ContentCapturingMode , self .capture_content )
300
+ with self ._with_new_instrumentation (
301
+ capture_content , instance , args , kwargs
302
+ ) as handle_response :
303
+ response = None
304
+ try :
305
+ response = await wrapped (* args , ** kwargs )
306
+ return response
307
+ finally :
308
+ handle_response (response )
0 commit comments