Skip to content

Commit 27eb0b4

Browse files
committed
guard execution tracing with isolated context
1 parent 8e58413 commit 27eb0b4

File tree

5 files changed

+104
-42
lines changed

5 files changed

+104
-42
lines changed

guardrails/guard.py

Lines changed: 4 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
set_tracer,
6565
set_tracer_context,
6666
)
67+
from guardrails.telemetry.hub_tracing import trace
6768
from guardrails.types.on_fail import OnFailAction
6869
from guardrails.types.pydantic import ModelOrListOfModels
6970
from guardrails.utils.naming_utils import random_id
@@ -738,42 +739,6 @@ def __exec(
738739
if full_schema_reask is None:
739740
full_schema_reask = self._base_model is not None
740741

741-
if self._allow_metrics_collection and self._hub_telemetry:
742-
# Create a new span for this guard call
743-
llm_api_str = ""
744-
if llm_api:
745-
llm_api_module_name = (
746-
llm_api.__module__ if hasattr(llm_api, "__module__") else ""
747-
)
748-
llm_api_name = (
749-
llm_api.__name__
750-
if hasattr(llm_api, "__name__")
751-
else type(llm_api).__name__
752-
)
753-
llm_api_str = f"{llm_api_module_name}.{llm_api_name}"
754-
self._hub_telemetry.create_new_span(
755-
span_name="/guard_call",
756-
attributes=[
757-
("guard_id", self.id),
758-
("user_id", self._user_id),
759-
("llm_api", llm_api_str if llm_api_str else "None"),
760-
(
761-
"custom_reask_prompt",
762-
self._exec_opts.reask_prompt is not None,
763-
),
764-
(
765-
"custom_reask_instructions",
766-
self._exec_opts.reask_instructions is not None,
767-
),
768-
(
769-
"custom_reask_messages",
770-
self._exec_opts.reask_messages is not None,
771-
),
772-
],
773-
is_parent=True, # It will have children
774-
has_parent=False, # Has no parents
775-
)
776-
777742
set_call_kwargs(kwargs)
778743
set_tracer(self._tracer)
779744
set_tracer_context(self._tracer_context)
@@ -923,7 +888,7 @@ def _exec(
923888
call = runner(call_log=call_log, prompt_params=prompt_params)
924889
return ValidationOutcome[OT].from_guard_history(call)
925890

926-
# @trace(name="Guard.__call__")
891+
@trace(name="/guard_call", origin="Guard.__call__")
927892
def __call__(
928893
self,
929894
llm_api: Optional[Callable] = None,
@@ -982,6 +947,7 @@ def __call__(
982947
**kwargs,
983948
)
984949

950+
@trace(name="/guard_call", origin="Guard.parse")
985951
def parse(
986952
self,
987953
llm_output: str,
@@ -1158,6 +1124,7 @@ def use_many(
11581124
self._save()
11591125
return self
11601126

1127+
@trace(name="/guard_call", origin="Guard.validate")
11611128
def validate(self, llm_output: str, *args, **kwargs) -> ValidationOutcome[OT]:
11621129
return self.parse(llm_output=llm_output, *args, **kwargs)
11631130

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from functools import wraps
2+
from typing import Any, Callable, Dict, Optional
3+
4+
from opentelemetry.trace import Span
5+
6+
from guardrails.types.primitives import PrimitiveTypes
7+
from guardrails.utils.safe_get import safe_get
8+
from guardrails.utils.hub_telemetry_utils import HubTelemetry
9+
10+
11+
def get_guard_attributes(attrs: Dict[str, Any], guard_self: Any) -> Dict[str, Any]:
12+
attrs["guard_id"] = guard_self.id
13+
attrs["user_id"] = guard_self._user_id
14+
attrs["custom_reask_prompt"] = guard_self._exec_opts.reask_prompt is not None
15+
attrs["custom_reask_instructions"] = (
16+
guard_self._exec_opts.reask_instructions is not None
17+
)
18+
attrs["custom_reask_messages"] = guard_self._exec_opts.reask_messages is not None
19+
attrs["output_type"] = (
20+
"unstructured"
21+
if PrimitiveTypes.is_primitive(guard_self.output_schema.type.actual_instance)
22+
else "structured"
23+
)
24+
return attrs
25+
26+
27+
def get_guard_call_attributes(attrs: Dict[str, Any], *args, **kwargs) -> Dict[str, Any]:
28+
guard_self = safe_get(args, 0)
29+
if guard_self is not None:
30+
attrs = get_guard_attributes(attrs, guard_self)
31+
32+
llm_api_str = "" # noqa
33+
llm_api = safe_get(args, 1, kwargs.get("llm_api"))
34+
if llm_api:
35+
llm_api_module_name = (
36+
llm_api.__module__ if hasattr(llm_api, "__module__") else ""
37+
)
38+
llm_api_name = (
39+
llm_api.__name__ if hasattr(llm_api, "__name__") else type(llm_api).__name__
40+
)
41+
llm_api_str = f"{llm_api_module_name}.{llm_api_name}"
42+
attrs["llm_api"] = llm_api_str if llm_api_str else "None"
43+
44+
return attrs
45+
46+
47+
def add_attributes(name: str, span: Span, origin: str, *args, **kwargs):
48+
attrs = {"origin": origin}
49+
if origin == "Guard.__call__":
50+
attrs = get_guard_call_attributes(attrs, *args, **kwargs)
51+
52+
for key, value in attrs.items():
53+
span.set_attribute(key, value)
54+
55+
56+
def trace(*, name: str, origin: str, is_parent: Optional[bool] = False):
57+
def decorator(fn: Callable[..., Any]):
58+
@wraps(fn)
59+
def wrapper(*args, **kwargs):
60+
hub_telemetry = HubTelemetry()
61+
if hub_telemetry._enabled and hub_telemetry._tracer is not None:
62+
context = (
63+
hub_telemetry.extract_current_context() if not is_parent else None
64+
)
65+
with hub_telemetry._tracer.start_as_current_span(
66+
name, context=context
67+
) as span: # noqa
68+
if is_parent:
69+
# Inject the current context
70+
hub_telemetry.inject_current_context()
71+
72+
add_attributes(name, span, origin, *args, **kwargs)
73+
return fn(*args, **kwargs)
74+
else:
75+
return fn(*args, **kwargs)
76+
77+
return wrapper
78+
79+
return decorator

guardrails/types/primitives.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,15 @@
33

44

55
class PrimitiveTypes(str, Enum):
6-
BOOLEAN = SimpleTypes.BOOLEAN
7-
INTEGER = SimpleTypes.INTEGER
8-
NUMBER = SimpleTypes.NUMBER
9-
STRING = SimpleTypes.STRING
6+
BOOLEAN = SimpleTypes.BOOLEAN.value
7+
INTEGER = SimpleTypes.INTEGER.value
8+
NUMBER = SimpleTypes.NUMBER.value
9+
STRING = SimpleTypes.STRING.value
10+
11+
@staticmethod
12+
def is_primitive(value: str) -> bool:
13+
try:
14+
return value in [member.value for member in PrimitiveTypes]
15+
except Exception as e:
16+
print(e)
17+
return False

guardrails/utils/hub_telemetry_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ class HubTelemetry:
2323
_processor = None
2424
_tracer = None
2525
_prop = None
26-
_carrier = {}
2726
_enabled = False
2827

2928
def __new__(
@@ -56,6 +55,7 @@ def initialize_tracer(
5655
"""Initializes a tracer for Guardrails Hub."""
5756

5857
self._enabled = enabled
58+
self._carrier = {}
5959
self._service_name = service_name
6060
# self._endpoint = "http://localhost:4318/v1/traces"
6161
self._endpoint = (

tests/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,18 @@ def mock_runner_hub_telemetry():
4848
yield MockHubTelemetry
4949

5050

51+
@pytest.fixture(autouse=True)
52+
def mock_hub_tracing():
53+
with patch("guardrails.telemetry.hub_tracing.HubTelemetry") as MockHubTelemetry:
54+
MockHubTelemetry.return_value = MagicMock()
55+
yield MockHubTelemetry
56+
57+
5158
def pytest_collection_modifyitems(items):
5259
for item in items:
5360
if "no_hub_telemetry_mock" in item.keywords:
5461
item.fixturenames.remove("mock_guard_hub_telemetry")
5562
item.fixturenames.remove("mock_validator_base_hub_telemetry")
5663
item.fixturenames.remove("mock_validator_service_hub_telemetry")
5764
item.fixturenames.remove("mock_runner_hub_telemetry")
65+
item.fixturenames.remove("mock_hub_tracing")

0 commit comments

Comments
 (0)