44from opentelemetry import context as otel_context
55from typing import (
66 Any ,
7- AsyncIterable ,
7+ AsyncIterator ,
88 Awaitable ,
99 Callable ,
1010 Dict ,
3737 set_tracer ,
3838 set_tracer_context ,
3939)
40+ from guardrails .hub_telemetry .hub_tracing import async_trace
4041from guardrails .types .pydantic import ModelOrListOfModels
4142from guardrails .types .validator import UseManyValidatorSpec , UseValidatorSpec
4243from guardrails .telemetry import trace_async_guard_execution , wrap_with_otel_context
@@ -187,7 +188,7 @@ async def _execute(
187188 ) -> Union [
188189 ValidationOutcome [OT ],
189190 Awaitable [ValidationOutcome [OT ]],
190- AsyncIterable [ValidationOutcome [OT ]],
191+ AsyncIterator [ValidationOutcome [OT ]],
191192 ]:
192193 self ._fill_validator_map ()
193194 self ._fill_validators ()
@@ -219,49 +220,13 @@ async def __exec(
219220 ) -> Union [
220221 ValidationOutcome [OT ],
221222 Awaitable [ValidationOutcome [OT ]],
222- AsyncIterable [ValidationOutcome [OT ]],
223+ AsyncIterator [ValidationOutcome [OT ]],
223224 ]:
224225 prompt_params = prompt_params or {}
225226 metadata = metadata or {}
226227 if full_schema_reask is None :
227228 full_schema_reask = self ._base_model is not None
228229
229- if self ._allow_metrics_collection :
230- llm_api_str = ""
231- if llm_api :
232- llm_api_module_name = (
233- llm_api .__module__ if hasattr (llm_api , "__module__" ) else ""
234- )
235- llm_api_name = (
236- llm_api .__name__
237- if hasattr (llm_api , "__name__" )
238- else type (llm_api ).__name__
239- )
240- llm_api_str = f"{ llm_api_module_name } .{ llm_api_name } "
241- # Create a new span for this guard call
242- self ._hub_telemetry .create_new_span (
243- span_name = "/guard_call" ,
244- attributes = [
245- ("guard_id" , self .id ),
246- ("user_id" , self ._user_id ),
247- ("llm_api" , llm_api_str ),
248- (
249- "custom_reask_prompt" ,
250- self ._exec_opts .reask_prompt is not None ,
251- ),
252- (
253- "custom_reask_instructions" ,
254- self ._exec_opts .reask_instructions is not None ,
255- ),
256- (
257- "custom_reask_messages" ,
258- self ._exec_opts .reask_messages is not None ,
259- ),
260- ],
261- is_parent = True , # It will have children
262- has_parent = False , # Has no parents
263- )
264-
265230 set_call_kwargs (kwargs )
266231 set_tracer (self ._tracer )
267232 set_tracer_context (self ._tracer_context )
@@ -369,7 +334,7 @@ async def _exec(
369334 ) -> Union [
370335 ValidationOutcome [OT ],
371336 Awaitable [ValidationOutcome [OT ]],
372- AsyncIterable [ValidationOutcome [OT ]],
337+ AsyncIterator [ValidationOutcome [OT ]],
373338 ]:
374339 """Call the LLM asynchronously and validate the output.
375340
@@ -435,6 +400,7 @@ async def _exec(
435400 )
436401 return ValidationOutcome [OT ].from_guard_history (call )
437402
403+ @async_trace (name = "/guard_call" , origin = "AsyncGuard.__call__" )
438404 async def __call__ (
439405 self ,
440406 llm_api : Optional [Callable [..., Awaitable [Any ]]] = None ,
@@ -450,7 +416,7 @@ async def __call__(
450416 ) -> Union [
451417 ValidationOutcome [OT ],
452418 Awaitable [ValidationOutcome [OT ]],
453- AsyncIterable [ValidationOutcome [OT ]],
419+ AsyncIterator [ValidationOutcome [OT ]],
454420 ]:
455421 """Call the LLM and validate the output. Pass an async LLM API to
456422 return a coroutine.
@@ -501,6 +467,7 @@ async def __call__(
501467 ** kwargs ,
502468 )
503469
470+ @async_trace (name = "/guard_call" , origin = "AsyncGuard.parse" )
504471 async def parse (
505472 self ,
506473 llm_output : str ,
@@ -567,7 +534,7 @@ async def parse(
567534
568535 async def _stream_server_call (
569536 self , * , payload : Dict [str , Any ]
570- ) -> AsyncIterable [ValidationOutcome [OT ]]:
537+ ) -> AsyncIterator [ValidationOutcome [OT ]]:
571538 # TODO: Once server side supports async streaming, this function will need to
572539 # yield async generators, not generators
573540 if self ._api_client :
@@ -609,6 +576,7 @@ async def _stream_server_call(
609576 else :
610577 raise ValueError ("AsyncGuard does not have an api client!" )
611578
579+ @async_trace (name = "/guard_call" , origin = "AsyncGuard.validate" )
612580 async def validate (
613581 self , llm_output : str , * args , ** kwargs
614582 ) -> Awaitable [ValidationOutcome [OT ]]:
0 commit comments