Skip to content

Commit 93a6a36

Browse files
authored
Merge pull request #1164 from guardrails-ai/temp-fix-async-mlflow
Hotfix for MLFlow validator spans during async execution
2 parents 358cdbf + ee3fdc1 commit 93a6a36

File tree

2 files changed

+128
-0
lines changed

2 files changed

+128
-0
lines changed

guardrails/integrations/databricks/ml_flow_instrumentor.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,12 @@ def instrument(self):
100100
export.validate
101101
)
102102
setattr(export, "validate", wrapped_validator_validate)
103+
104+
wrapped_validator_async_validate = (
105+
self._instrument_validator_async_validate(export.async_validate)
106+
)
107+
setattr(export, "async_validate", wrapped_validator_async_validate)
108+
103109
setattr(guardrails.hub, validator_name, export) # type: ignore
104110

105111
def _instrument_guard(
@@ -387,6 +393,14 @@ def trace_validator_wrapper(*args, **kwargs):
387393
init_kwargs = validator_self._kwargs
388394

389395
validator_span_name = f"{validator_name}.validate"
396+
397+
# Skip this instrumentation in the case of async
398+
# when the parent span cannot be fetched from the current context
399+
# because Validator.validate is running in a ThreadPoolExecutor
400+
parent_span = mlflow.get_current_active_span()
401+
if not parent_span:
402+
return validator_validate(*args, **kwargs)
403+
390404
with mlflow.start_span(
391405
name=validator_span_name,
392406
span_type="validator",
@@ -425,3 +439,64 @@ def trace_validator_wrapper(*args, **kwargs):
425439
raise e
426440

427441
return trace_validator_wrapper
442+
443+
def _instrument_validator_async_validate(
444+
self,
445+
validator_async_validate: Callable[..., Coroutine[Any, Any, ValidationResult]],
446+
):
447+
@wraps(validator_async_validate)
448+
async def trace_async_validator_wrapper(*args, **kwargs):
449+
validator_name = "validator"
450+
obj_id = id(validator_async_validate)
451+
on_fail_descriptor = "unknown"
452+
init_kwargs = {}
453+
validation_session_id = "unknown"
454+
455+
validator_self = args[0]
456+
if validator_self is not None and isinstance(validator_self, Validator):
457+
validator_name = validator_self.rail_alias
458+
obj_id = id(validator_self)
459+
on_fail_descriptor = validator_self.on_fail_descriptor
460+
init_kwargs = validator_self._kwargs
461+
462+
validator_span_name = f"{validator_name}.validate"
463+
464+
with mlflow.start_span(
465+
name=validator_span_name,
466+
span_type="validator",
467+
attributes={
468+
"guardrails.version": GUARDRAILS_VERSION,
469+
"type": "guardrails/guard/step/validator",
470+
"async": True,
471+
},
472+
) as validator_span:
473+
try:
474+
resp = await validator_async_validate(*args, **kwargs)
475+
add_validator_attributes(
476+
*args,
477+
validator_span=validator_span, # type: ignore
478+
validator_name=validator_name,
479+
obj_id=obj_id,
480+
on_fail_descriptor=on_fail_descriptor,
481+
result=resp,
482+
init_kwargs=init_kwargs,
483+
validation_session_id=validation_session_id,
484+
**kwargs,
485+
)
486+
return resp
487+
except Exception as e:
488+
validator_span.set_status(status=SpanStatusCode.ERROR)
489+
add_validator_attributes(
490+
*args,
491+
validator_span=validator_span, # type: ignore
492+
validator_name=validator_name,
493+
obj_id=obj_id,
494+
on_fail_descriptor=on_fail_descriptor,
495+
result=None,
496+
init_kwargs=init_kwargs,
497+
validation_session_id=validation_session_id,
498+
**kwargs,
499+
)
500+
raise e
501+
502+
return trace_async_validator_wrapper

tests/unit_tests/integrations/databricks/test_ml_flow_instrumentor.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,10 @@ async def test__instrument_async_runner_call(self, mocker):
588588

589589
def test__instrument_validator_validate(self, mocker):
590590
mock_span = MockSpan()
591+
mock_start_span = mocker.patch(
592+
"guardrails.integrations.databricks.ml_flow_instrumentor.mlflow.get_current_active_span",
593+
return_value=mock_span,
594+
)
591595
mock_start_span = mocker.patch(
592596
"guardrails.integrations.databricks.ml_flow_instrumentor.mlflow.start_span",
593597
return_value=mock_span,
@@ -630,3 +634,52 @@ def test__instrument_validator_validate(self, mocker):
630634
init_kwargs={},
631635
validation_session_id="unknown",
632636
)
637+
638+
@pytest.mark.asyncio
639+
async def test__instrument_validator_async_validate(self, mocker):
640+
mock_span = MockSpan()
641+
mock_start_span = mocker.patch(
642+
"guardrails.integrations.databricks.ml_flow_instrumentor.mlflow.start_span",
643+
return_value=mock_span,
644+
)
645+
646+
mock_add_validator_attributes = mocker.patch(
647+
"guardrails.integrations.databricks.ml_flow_instrumentor.add_validator_attributes"
648+
)
649+
650+
from guardrails.integrations.databricks import MlFlowInstrumentor
651+
from tests.unit_tests.mocks.mock_hub import MockValidator
652+
653+
m = MlFlowInstrumentor("mock experiment")
654+
655+
wrapped_async_validate = m._instrument_validator_async_validate(
656+
MockValidator.async_validate
657+
)
658+
659+
mock_validator = MockValidator()
660+
661+
resp = await wrapped_async_validate(mock_validator, True, {})
662+
663+
mock_start_span.assert_called_once_with(
664+
name="mock-validator.validate",
665+
span_type="validator",
666+
attributes={
667+
"guardrails.version": GUARDRAILS_VERSION,
668+
"type": "guardrails/guard/step/validator",
669+
"async": True,
670+
},
671+
)
672+
673+
# Internally called, not the wrapped call above
674+
mock_add_validator_attributes.assert_called_once_with(
675+
mock_validator,
676+
True,
677+
{},
678+
validator_span=mock_span, # type: ignore
679+
validator_name="mock-validator",
680+
obj_id=id(mock_validator),
681+
on_fail_descriptor="exception",
682+
result=resp,
683+
init_kwargs={},
684+
validation_session_id="unknown",
685+
)

0 commit comments

Comments
 (0)