11from functools import wraps
2- from typing import Any , Awaitable , Callable , Dict , Optional
2+ from typing import (
3+ Any ,
4+ AsyncIterator ,
5+ Awaitable ,
6+ Callable ,
7+ Dict ,
8+ Iterator ,
9+ Optional ,
10+ TypeVar ,
11+ )
312
413from opentelemetry .trace import Span
514
918from guardrails .utils .safe_get import safe_get
1019from guardrails .utils .hub_telemetry_utils import HubTelemetry
1120
21+ R = TypeVar ("R" , covariant = True )
22+
1223
1324def get_guard_call_attributes (
1425 attrs : Dict [str , Any ], origin : str , * args , ** kwargs
@@ -76,7 +87,9 @@ def get_validator_inference_attributes(
7687def get_validator_usage_attributes (
7788 attrs : Dict [str , Any ], response , * args , ** kwargs
7889) -> Dict [str , Any ]:
79- validator_self = safe_get (args , 0 )
90+ # We're wrapping a wrapped function,
91+ # so the first arg is the validator service
92+ validator_self = safe_get (args , 1 )
8093 if validator_self is not None :
8194 attrs ["validator_name" ] = validator_self .rail_alias
8295 attrs ["validator_on_fail" ] = validator_self .on_fail_descriptor
@@ -90,11 +103,17 @@ def get_validator_usage_attributes(
90103
91104
92105def add_attributes (
93- span : Span , attrs : Dict [str , Any ], name : str , origin : str , response , * args , ** kwargs
106+ span : Span ,
107+ attrs : Dict [str , Any ],
108+ name : str ,
109+ origin : str ,
110+ * args ,
111+ response = None ,
112+ ** kwargs ,
94113):
95114 attrs ["origin" ] = origin
96115 if name == "/guard_call" :
97- attrs = get_guard_call_attributes (attrs , * args , ** kwargs )
116+ attrs = get_guard_call_attributes (attrs , origin , * args , ** kwargs )
98117 elif name == "/reasks" :
99118 if response is not None and hasattr (response , "iterations" ):
100119 attrs ["reask_count" ] = len (response .iterations ) - 1
@@ -103,7 +122,7 @@ def add_attributes(
103122 elif name == "/validator_inference" :
104123 attrs = get_validator_inference_attributes (attrs , * args , ** kwargs )
105124 elif name == "/validator_usage" :
106- attrs = get_validator_usage_attributes (attrs , response * args , ** kwargs )
125+ attrs = get_validator_usage_attributes (attrs , response , * args , ** kwargs )
107126
108127 for key , value in attrs .items ():
109128 if value is not None :
@@ -117,9 +136,9 @@ def trace(
117136 is_parent : Optional [bool ] = False ,
118137 ** attrs ,
119138):
120- def decorator (fn : Callable [..., Any ]):
139+ def decorator (fn : Callable [..., R ]):
121140 @wraps (fn )
122- def wrapper (* args , ** kwargs ):
141+ def wrapper (* args , ** kwargs ) -> R :
123142 hub_telemetry = HubTelemetry ()
124143 if hub_telemetry ._enabled and hub_telemetry ._tracer is not None :
125144 context = (
@@ -137,7 +156,9 @@ def wrapper(*args, **kwargs):
137156 origin = origin if origin is not None else name
138157
139158 resp = fn (* args , ** kwargs )
140- add_attributes (span , attrs , origin , resp , * args , ** kwargs )
159+ add_attributes (
160+ span , attrs , name , origin , * args , response = resp , ** kwargs
161+ )
141162 return resp
142163 else :
143164 return fn (* args , ** kwargs )
@@ -153,9 +174,9 @@ def async_trace(
153174 origin : Optional [str ] = None ,
154175 is_parent : Optional [bool ] = False ,
155176):
156- def decorator (fn : Callable [..., Awaitable [Any ]]):
177+ def decorator (fn : Callable [..., Awaitable [R ]]):
157178 @wraps (fn )
158- async def async_wrapper (* args , ** kwargs ):
179+ async def async_wrapper (* args , ** kwargs ) -> R :
159180 hub_telemetry = HubTelemetry ()
160181 if hub_telemetry ._enabled and hub_telemetry ._tracer is not None :
161182 context = (
@@ -170,7 +191,7 @@ async def async_wrapper(*args, **kwargs):
170191
171192 nonlocal origin
172193 origin = origin if origin is not None else name
173- add_attributes (span , {"async" : True }, origin , * args , ** kwargs )
194+ add_attributes (span , {"async" : True }, name , origin , * args , ** kwargs )
174195 return await fn (* args , ** kwargs )
175196 else :
176197 return await fn (* args , ** kwargs )
@@ -193,9 +214,9 @@ def trace_stream(
193214 is_parent : Optional [bool ] = False ,
194215 ** attrs ,
195216):
196- def decorator (fn : Callable [..., Any ]):
217+ def decorator (fn : Callable [..., Iterator [ R ] ]):
197218 @wraps (fn )
198- def wrapper (* args , ** kwargs ):
219+ def wrapper (* args , ** kwargs ) -> Iterator [ R ] :
199220 hub_telemetry = HubTelemetry ()
200221 if hub_telemetry ._enabled and hub_telemetry ._tracer is not None :
201222 context = (
@@ -212,7 +233,7 @@ def wrapper(*args, **kwargs):
212233
213234 nonlocal origin
214235 origin = origin if origin is not None else name
215- add_attributes (span , attrs , name , origin , None , * args , ** kwargs )
236+ add_attributes (span , attrs , name , origin , * args , ** kwargs )
216237 return _run_gen (fn , * args , ** kwargs )
217238 else :
218239 return fn (* args , ** kwargs )
@@ -235,9 +256,9 @@ def async_trace_stream(
235256 is_parent : Optional [bool ] = False ,
236257 ** attrs ,
237258):
238- def decorator (fn : Callable [..., Awaitable [ Any ]]):
259+ def decorator (fn : Callable [..., AsyncIterator [ R ]]):
239260 @wraps (fn )
240- async def wrapper (* args , ** kwargs ):
261+ async def wrapper (* args , ** kwargs ) -> AsyncIterator [ R ] :
241262 hub_telemetry = HubTelemetry ()
242263 if hub_telemetry ._enabled and hub_telemetry ._tracer is not None :
243264 context = (
@@ -254,7 +275,7 @@ async def wrapper(*args, **kwargs):
254275
255276 nonlocal origin
256277 origin = origin if origin is not None else name
257- add_attributes (span , attrs , name , origin , None , * args , ** kwargs )
278+ add_attributes (span , attrs , name , origin , * args , ** kwargs )
258279 return _run_async_gen (fn , * args , ** kwargs )
259280 else :
260281 return fn (* args , ** kwargs )
0 commit comments