11from  functools  import  wraps 
2- from  typing  import  Any , Callable , Dict , Optional 
2+ from  typing  import  Any , Awaitable ,  Callable , Dict , Optional 
33
44from  opentelemetry .trace  import  Span 
55
88from  guardrails .utils .hub_telemetry_utils  import  HubTelemetry 
99
1010
11- def  get_guard_attributes (attrs : Dict [str , Any ], guard_self : Any ) ->  Dict [str , Any ]:
12-     attrs ["guard_id" ] =  guard_self .id 
13-     attrs ["user_id" ] =  guard_self ._user_id 
14-     attrs ["custom_reask_prompt" ] =  guard_self ._exec_opts .reask_prompt  is  not None 
15-     attrs ["custom_reask_instructions" ] =  (
16-         guard_self ._exec_opts .reask_instructions  is  not None 
17-     )
18-     attrs ["custom_reask_messages" ] =  guard_self ._exec_opts .reask_messages  is  not None 
19-     attrs ["output_type" ] =  (
20-         "unstructured" 
21-         if  PrimitiveTypes .is_primitive (guard_self .output_schema .type .actual_instance )
22-         else  "structured" 
23-     )
24-     return  attrs 
25- 
11+ def  get_guard_call_attributes (
12+     attrs : Dict [str , Any ], origin : str , * args , ** kwargs 
13+ ) ->  Dict [str , Any ]:
14+     attrs ["stream" ] =  kwargs .get ("stream" , False )
2615
27- def  get_guard_call_attributes (attrs : Dict [str , Any ], * args , ** kwargs ) ->  Dict [str , Any ]:
2816    guard_self  =  safe_get (args , 0 )
2917    if  guard_self  is  not None :
30-         attrs  =  get_guard_attributes (attrs , guard_self )
18+         attrs ["guard_id" ] =  guard_self .id 
19+         attrs ["user_id" ] =  guard_self ._user_id 
20+         attrs ["custom_reask_prompt" ] =  guard_self ._exec_opts .reask_prompt  is  not None 
21+         attrs ["custom_reask_instructions" ] =  (
22+             guard_self ._exec_opts .reask_instructions  is  not None 
23+         )
24+         attrs ["custom_reask_messages" ] =  (
25+             guard_self ._exec_opts .reask_messages  is  not None 
26+         )
27+         attrs ["output_type" ] =  (
28+             "unstructured" 
29+             if  PrimitiveTypes .is_primitive (
30+                 guard_self .output_schema .type .actual_instance 
31+             )
32+             else  "structured" 
33+         )
34+         return  attrs 
3135
3236    llm_api_str  =  ""   # noqa 
33-     llm_api  =  safe_get (args , 1 , kwargs .get ("llm_api" ))
37+     llm_api  =  kwargs .get ("llm_api" )
38+     if  origin  in  ["Guard.__call__" , "AsyncGuard.__call__" ]:
39+         llm_api  =  safe_get (args , 1 , llm_api )
40+ 
3441    if  llm_api :
3542        llm_api_module_name  =  (
3643            llm_api .__module__  if  hasattr (llm_api , "__module__" ) else  "" 
@@ -44,16 +51,24 @@ def get_guard_call_attributes(attrs: Dict[str, Any], *args, **kwargs) -> Dict[st
4451    return  attrs 
4552
4653
47- def  add_attributes (name : str , span : Span , origin : str , * args , ** kwargs ):
48-     attrs  =  {"origin" : origin }
49-     if  origin  ==  "Guard.__call__" :
54+ def  add_attributes (
55+     span : Span , attrs : Dict [str , Any ], name : str , origin : str , * args , ** kwargs 
56+ ):
57+     attrs ["origin" ] =  origin 
58+     if  name  ==  "/guard_call" :
5059        attrs  =  get_guard_call_attributes (attrs , * args , ** kwargs )
5160
5261    for  key , value  in  attrs .items ():
5362        span .set_attribute (key , value )
5463
5564
56- def  trace (* , name : str , origin : str , is_parent : Optional [bool ] =  False ):
65+ def  trace (
66+     * ,
67+     name : str ,
68+     origin : str ,
69+     is_parent : Optional [bool ] =  False ,
70+     ** attrs ,
71+ ):
5772    def  decorator (fn : Callable [..., Any ]):
5873        @wraps (fn ) 
5974        def  wrapper (* args , ** kwargs ):
@@ -69,11 +84,41 @@ def wrapper(*args, **kwargs):
6984                        # Inject the current context 
7085                        hub_telemetry .inject_current_context ()
7186
72-                     add_attributes (name ,  span , origin , * args , ** kwargs )
87+                     add_attributes (span ,  attrs , origin , * args , ** kwargs )
7388                    return  fn (* args , ** kwargs )
7489            else :
7590                return  fn (* args , ** kwargs )
7691
7792        return  wrapper 
7893
7994    return  decorator 
95+ 
96+ 
97+ def  async_trace (
98+     * ,
99+     name : str ,
100+     origin : str ,
101+     is_parent : Optional [bool ] =  False ,
102+ ):
103+     def  decorator (fn : Callable [..., Awaitable [Any ]]):
104+         @wraps (fn ) 
105+         async  def  async_wrapper (* args , ** kwargs ):
106+             hub_telemetry  =  HubTelemetry ()
107+             if  hub_telemetry ._enabled  and  hub_telemetry ._tracer  is  not None :
108+                 context  =  (
109+                     hub_telemetry .extract_current_context () if  not  is_parent  else  None 
110+                 )
111+                 with  hub_telemetry ._tracer .start_as_current_span (
112+                     name , context = context 
113+                 ) as  span :  # noqa 
114+                     if  is_parent :
115+                         # Inject the current context 
116+                         hub_telemetry .inject_current_context ()
117+                     add_attributes (span , {"async" : True }, origin , * args , ** kwargs )
118+                     return  await  fn (* args , ** kwargs )
119+             else :
120+                 return  await  fn (* args , ** kwargs )
121+ 
122+         return  async_wrapper 
123+ 
124+     return  decorator 
0 commit comments