1414
1515from __future__ import annotations
1616
17+ from contextlib import contextmanager
1718from typing import (
1819 TYPE_CHECKING ,
1920 Any ,
2021 Callable ,
22+ Iterable ,
2123 MutableSequence ,
2224)
2325
@@ -87,17 +89,17 @@ def _extract_params(
8789 )
8890
8991
90- def generate_content_create (
91- tracer : Tracer , event_logger : EventLogger , capture_content : bool
92- ):
93- """Wrap the `generate_content` method of the `GenerativeModel` class to trace it."""
92+ class MethodWrappers :
93+ def __init__ (
94+ self , tracer : Tracer , event_logger : EventLogger , capture_content : bool
95+ ) -> None :
96+ self .tracer = tracer
97+ self .event_logger = event_logger
98+ self .capture_content = capture_content
9499
95- def traced_method (
96- wrapped : Callable [
97- ...,
98- prediction_service .GenerateContentResponse
99- | prediction_service_v1beta1 .GenerateContentResponse ,
100- ],
100+ @contextmanager
101+ def _with_instrumentation (
102+ self ,
101103 instance : client .PredictionServiceClient
102104 | client_v1beta1 .PredictionServiceClient ,
103105 args : Any ,
@@ -111,32 +113,58 @@ def traced_method(
111113 }
112114
113115 span_name = get_span_name (span_attributes )
114- with tracer .start_as_current_span (
116+
117+ with self .tracer .start_as_current_span (
115118 name = span_name ,
116119 kind = SpanKind .CLIENT ,
117120 attributes = span_attributes ,
118121 ) as span :
119122 for event in request_to_events (
120- params = params , capture_content = capture_content
123+ params = params , capture_content = self . capture_content
121124 ):
122- event_logger .emit (event )
125+ self . event_logger .emit (event )
123126
124127 # TODO: set error.type attribute
125128 # https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/gen-ai-spans.md
126- response = wrapped (* args , ** kwargs )
127- # TODO: handle streaming
128- # if is_streaming(kwargs):
129- # return StreamWrapper(
130- # result, span, event_logger, capture_content
131- # )
132-
133- if span .is_recording ():
134- span .set_attributes (get_genai_response_attributes (response ))
135- for event in response_to_events (
136- response = response , capture_content = capture_content
137- ):
138- event_logger .emit (event )
139129
130+ def handle_response (
131+ response : prediction_service .GenerateContentResponse
132+ | prediction_service_v1beta1 .GenerateContentResponse ,
133+ ) -> None :
134+ if span .is_recording ():
135+ # When streaming, this is called multiple times so attributes would be
136+ # overwritten. In practice, it looks the API only returns the interesting
137+ # attributes on the last streamed response. However, I couldn't find
138+ # documentation for this and setting attributes shouldn't be too expensive.
139+ span .set_attributes (
140+ get_genai_response_attributes (response )
141+ )
142+
143+ for event in response_to_events (
144+ response = response , capture_content = self .capture_content
145+ ):
146+ self .event_logger .emit (event )
147+
148+ yield handle_response
149+
150+ def generate_content (
151+ self ,
152+ wrapped : Callable [
153+ ...,
154+ prediction_service .GenerateContentResponse
155+ | prediction_service_v1beta1 .GenerateContentResponse ,
156+ ],
157+ instance : client .PredictionServiceClient
158+ | client_v1beta1 .PredictionServiceClient ,
159+ args : Any ,
160+ kwargs : Any ,
161+ ) -> (
162+ prediction_service .GenerateContentResponse
163+ | prediction_service_v1beta1 .GenerateContentResponse
164+ ):
165+ with self ._with_instrumentation (
166+ instance , args , kwargs
167+ ) as handle_response :
168+ response = wrapped (* args , ** kwargs )
169+ handle_response (response )
140170 return response
141-
142- return traced_method
0 commit comments