@@ -215,6 +215,28 @@ def _warn(msg: str):
215215 _warn ._LOGGER .warning (msg ) # pyright: ignore[reportFunctionMemberAccess]
216216
217217
218+ def _force_flush_traces ():
219+ try :
220+ import opentelemetry .trace
221+ except (ImportError , AttributeError ):
222+ _warn (
223+ "Could not force flush traces. opentelemetry-api is not installed. Please call 'pip install google-cloud-aiplatform[agent_engines]'."
224+ )
225+ return None
226+
227+ try :
228+ import opentelemetry .sdk .trace
229+ except (ImportError , AttributeError ):
230+ _warn (
231+ "Could not force flush traces. opentelemetry-sdk is not installed. Please call 'pip install google-cloud-aiplatform[agent_engines]'."
232+ )
233+ return None
234+
235+ provider = opentelemetry .trace .get_tracer_provider ()
236+ if isinstance (provider , opentelemetry .sdk .trace .TracerProvider ):
237+ _ = provider .force_flush ()
238+
239+
218240def _default_instrumentor_builder (
219241 project_id : str ,
220242 * ,
@@ -288,28 +310,23 @@ def _detect_cloud_resource_id(project_id: str) -> Optional[str]:
288310
289311 if enable_tracing :
290312 try :
291- import opentelemetry .exporter .cloud_trace
313+ import opentelemetry .exporter .otlp .proto .http .trace_exporter
314+ import google .auth .transport .requests
292315 except (ImportError , AttributeError ):
293316 return _warn_missing_dependency (
294- "opentelemetry-exporter-gcp-trace" , needed_for_tracing = True
295- )
296-
297- try :
298- import google .cloud .trace_v2
299- except (ImportError , AttributeError ):
300- return _warn_missing_dependency (
301- "google-cloud-trace" , needed_for_tracing = True
317+ "opentelemetry-exporter-otlp-proto-http" , needed_for_tracing = True
302318 )
303319
304320 import google .auth
305321
306322 credentials , _ = google .auth .default ()
307- span_exporter = opentelemetry .exporter .cloud_trace .CloudTraceSpanExporter (
308- project_id = project_id ,
309- client = google .cloud .trace_v2 .TraceServiceClient (
310- credentials = credentials .with_quota_project (project_id ),
311- ),
312- resource_regex = "|" .join (resource .attributes .keys ()),
323+ span_exporter = (
324+ opentelemetry .exporter .otlp .proto .http .trace_exporter .OTLPSpanExporter (
325+ session = google .auth .transport .requests .AuthorizedSession (
326+ credentials = credentials
327+ ),
328+ endpoint = "https://telemetry.googleapis.com/v1/traces" ,
329+ )
313330 )
314331 span_processor = opentelemetry .sdk .trace .export .BatchSpanProcessor (
315332 span_exporter = span_exporter ,
@@ -646,54 +663,17 @@ def set_up(self):
646663 else :
647664 os .environ ["ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS" ] = "false"
648665
649- GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY = (
650- "GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY"
651- )
652-
653- def telemetry_enabled () -> Optional [bool ]:
654- return (
655- os .getenv (GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY , "0" ).lower ()
656- in ("true" , "1" )
657- if GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY in os .environ
658- else None
659- )
660-
661- # Tracing enablement follows truth table:
662- def tracing_enabled () -> bool :
663- """Tracing enablement follows true table:
664-
665- | enable_tracing | enable_telemetry(env) | tracing_actually_enabled |
666- |----------------|-----------------------|--------------------------|
667- | false | false | false |
668- | false | true | false |
669- | false | None | false |
670- | true | false | false |
671- | true | true | true |
672- | true | None | true |
673- | None(default) | false | false |
674- | None(default) | true | adk_version >= 1.17 |
675- | None(default) | None | false |
676- """
677- enable_tracing : Optional [bool ] = self ._tmpl_attrs .get ("enable_tracing" )
678- enable_telemetry : Optional [bool ] = telemetry_enabled ()
679-
680- return (enable_tracing is True and enable_telemetry is not False ) or (
681- enable_tracing is None
682- and enable_telemetry is True
683- and is_version_sufficient ("1.17.0" )
684- )
685-
686- enable_logging = bool (telemetry_enabled ())
666+ enable_logging = bool (self ._telemetry_enabled ())
687667
688668 custom_instrumentor = self ._tmpl_attrs .get ("instrumentor_builder" )
689669
690- if custom_instrumentor and tracing_enabled ():
670+ if custom_instrumentor and self . _tracing_enabled ():
691671 self ._tmpl_attrs ["instrumentor" ] = custom_instrumentor (project )
692672
693673 if not custom_instrumentor :
694674 self ._tmpl_attrs ["instrumentor" ] = _default_instrumentor_builder (
695675 project ,
696- enable_tracing = tracing_enabled (),
676+ enable_tracing = self . _tracing_enabled (),
697677 enable_logging = enable_logging ,
698678 )
699679
@@ -847,9 +827,13 @@ async def async_stream_query(
847827 ** kwargs ,
848828 )
849829
850- async for event in events_async :
851- # Yield the event data as a dictionary
852- yield _utils .dump_event_for_json (event )
830+ try :
831+ async for event in events_async :
832+ # Yield the event data as a dictionary
833+ yield _utils .dump_event_for_json (event )
834+ finally :
835+ if self ._tracing_enabled ():
836+ _force_flush_traces ()
853837
854838 async def streaming_agent_run_with_events (self , request_json : str ):
855839 """Streams responses asynchronously from the ADK application.
@@ -920,6 +904,8 @@ async def streaming_agent_run_with_events(self, request_json: str):
920904 user_id = request .user_id ,
921905 session_id = session .id ,
922906 )
907+ if self ._tracing_enabled ():
908+ _force_flush_traces ()
923909
924910 async def async_get_session (
925911 self ,
@@ -1105,3 +1091,42 @@ def register_operations(self) -> Dict[str, List[str]]:
11051091 "streaming_agent_run_with_events" ,
11061092 ],
11071093 }
1094+
1095+ def _telemetry_enabled (self ) -> Optional [bool ]:
1096+ import os
1097+
1098+ GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY = (
1099+ "GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY"
1100+ )
1101+
1102+ return (
1103+ os .getenv (GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY , "0" ).lower ()
1104+ in ("true" , "1" )
1105+ if GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY in os .environ
1106+ else None
1107+ )
1108+
1109+ # Tracing enablement follows truth table:
1110+ def _tracing_enabled (self ) -> bool :
1111+ """Tracing enablement follows true table:
1112+
1113+ | enable_tracing | enable_telemetry(env) | tracing_actually_enabled |
1114+ |----------------|-----------------------|--------------------------|
1115+ | false | false | false |
1116+ | false | true | false |
1117+ | false | None | false |
1118+ | true | false | false |
1119+ | true | true | true |
1120+ | true | None | true |
1121+ | None(default) | false | false |
1122+ | None(default) | true | adk_version >= 1.17 |
1123+ | None(default) | None | false |
1124+ """
1125+ enable_tracing : Optional [bool ] = self ._tmpl_attrs .get ("enable_tracing" )
1126+ enable_telemetry : Optional [bool ] = self ._telemetry_enabled ()
1127+
1128+ return (enable_tracing is True and enable_telemetry is not False ) or (
1129+ enable_tracing is None
1130+ and enable_telemetry is True
1131+ and is_version_sufficient ("1.17.0" )
1132+ )
0 commit comments