@@ -100,6 +100,12 @@ def instrument(self):
100100 export .validate
101101 )
102102 setattr (export , "validate" , wrapped_validator_validate )
103+
104+ wrapped_validator_async_validate = (
105+ self ._instrument_validator_async_validate (export .async_validate )
106+ )
107+ setattr (export , "async_validate" , wrapped_validator_async_validate )
108+
103109 setattr (guardrails .hub , validator_name , export ) # type: ignore
104110
105111 def _instrument_guard (
@@ -387,6 +393,14 @@ def trace_validator_wrapper(*args, **kwargs):
387393 init_kwargs = validator_self ._kwargs
388394
389395 validator_span_name = f"{ validator_name } .validate"
396+
397+ # Skip this instrumentation in the case of async
398+ # when the parent span cannot be fetched from the current context
399+ # because Validator.validate is running in a ThreadPoolExecutor
400+ parent_span = mlflow .get_current_active_span ()
401+ if not parent_span :
402+ return validator_validate (* args , ** kwargs )
403+
390404 with mlflow .start_span (
391405 name = validator_span_name ,
392406 span_type = "validator" ,
@@ -425,3 +439,64 @@ def trace_validator_wrapper(*args, **kwargs):
425439 raise e
426440
427441 return trace_validator_wrapper
442+
443+ def _instrument_validator_async_validate (
444+ self ,
445+ validator_async_validate : Callable [..., Coroutine [Any , Any , ValidationResult ]],
446+ ):
447+ @wraps (validator_async_validate )
448+ async def trace_async_validator_wrapper (* args , ** kwargs ):
449+ validator_name = "validator"
450+ obj_id = id (validator_async_validate )
451+ on_fail_descriptor = "unknown"
452+ init_kwargs = {}
453+ validation_session_id = "unknown"
454+
455+ validator_self = args [0 ]
456+ if validator_self is not None and isinstance (validator_self , Validator ):
457+ validator_name = validator_self .rail_alias
458+ obj_id = id (validator_self )
459+ on_fail_descriptor = validator_self .on_fail_descriptor
460+ init_kwargs = validator_self ._kwargs
461+
462+ validator_span_name = f"{ validator_name } .validate"
463+
464+ with mlflow .start_span (
465+ name = validator_span_name ,
466+ span_type = "validator" ,
467+ attributes = {
468+ "guardrails.version" : GUARDRAILS_VERSION ,
469+ "type" : "guardrails/guard/step/validator" ,
470+ "async" : True ,
471+ },
472+ ) as validator_span :
473+ try :
474+ resp = await validator_async_validate (* args , ** kwargs )
475+ add_validator_attributes (
476+ * args ,
477+ validator_span = validator_span , # type: ignore
478+ validator_name = validator_name ,
479+ obj_id = obj_id ,
480+ on_fail_descriptor = on_fail_descriptor ,
481+ result = resp ,
482+ init_kwargs = init_kwargs ,
483+ validation_session_id = validation_session_id ,
484+ ** kwargs ,
485+ )
486+ return resp
487+ except Exception as e :
488+ validator_span .set_status (status = SpanStatusCode .ERROR )
489+ add_validator_attributes (
490+ * args ,
491+ validator_span = validator_span , # type: ignore
492+ validator_name = validator_name ,
493+ obj_id = obj_id ,
494+ on_fail_descriptor = on_fail_descriptor ,
495+ result = None ,
496+ init_kwargs = init_kwargs ,
497+ validation_session_id = validation_session_id ,
498+ ** kwargs ,
499+ )
500+ raise e
501+
502+ return trace_async_validator_wrapper
0 commit comments